Refactor/asr tts (#1939)

* refactor: update ASR and TTS implementations

* fix lint

* Integrating asr/tts models w/ new security config

* update documents

* add arbitrary whisper transcriptor support

* update documents

* fix lint

* add mimo tts
This commit is contained in:
Hua Audio
2026-04-01 06:21:21 +02:00
committed by GitHub
parent ff90a65814
commit 0f395ce110
48 changed files with 3527 additions and 358 deletions
+53 -17
View File
@@ -18,6 +18,8 @@ import (
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/audio/asr"
"github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/commands"
@@ -31,7 +33,6 @@ import (
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type AgentLoop struct {
@@ -51,7 +52,7 @@ type AgentLoop struct {
fallback *providers.FallbackChain
channelManager *channels.Manager
mediaStore media.MediaStore
transcriber voice.Transcriber
transcriber asr.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
hookRuntime hookRuntime
@@ -159,6 +160,13 @@ func registerSharedTools(
provider providers.LLMProvider,
) {
allowReadPaths := buildAllowReadPatterns(cfg)
var ttsProvider tts.TTSProvider
if cfg.Tools.IsToolEnabled("send_tts") {
ttsProvider = tts.DetectTTS(cfg)
if ttsProvider == nil {
logger.WarnCF("voice-tts", "send_tts enabled but no TTS provider configured", nil)
}
}
for _, agentID := range registry.ListAgentIDs() {
agent, ok := registry.GetAgent(agentID)
@@ -269,6 +277,10 @@ func registerSharedTools(
agent.Tools.Register(sendFileTool)
}
if ttsProvider != nil {
agent.Tools.Register(tools.NewSendTTSTool(ttsProvider, nil))
}
// Skill discovery and installation tools
skills_enabled := cfg.Tools.IsToolEnabled("skills")
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
@@ -1059,10 +1071,15 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
agent.Tools.SetMediaStore(s)
}
}
registry.ForEachTool("send_tts", func(t tools.Tool) {
if st, ok := t.(*tools.SendTTSTool); ok {
st.SetMediaStore(s)
}
})
}
// SetTranscriber injects a voice transcriber for agent-level audio transcription.
func (al *AgentLoop) SetTranscriber(t voice.Transcriber) {
func (al *AgentLoop) SetTranscriber(t asr.Transcriber) {
al.transcriber = t
}
@@ -1083,19 +1100,23 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou
// Transcribe each audio media ref in order.
var transcriptions []string
var keptMedia []string
for _, ref := range msg.Media {
path, meta, err := al.mediaStore.ResolveWithMeta(ref)
if err != nil {
logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err})
keptMedia = append(keptMedia, ref)
continue
}
if !utils.IsAudioFile(meta.Filename, meta.ContentType) {
keptMedia = append(keptMedia, ref)
continue
}
result, err := al.transcriber.Transcribe(ctx, path)
if err != nil {
logger.WarnCF("voice", "Transcription failed", map[string]any{"ref": ref, "error": err})
transcriptions = append(transcriptions, "")
keptMedia = append(keptMedia, ref)
continue
}
transcriptions = append(transcriptions, result.Text)
@@ -1115,15 +1136,21 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou
}
text := transcriptions[idx]
idx++
if text == "" {
return match
}
return "[voice: " + text + "]"
})
// Append any remaining transcriptions not matched by an annotation.
for ; idx < len(transcriptions); idx++ {
newContent += "\n[voice: " + transcriptions[idx] + "]"
if transcriptions[idx] != "" {
newContent += "\n[voice: " + transcriptions[idx] + "]"
}
}
msg.Content = newContent
msg.Media = keptMedia
return msg, true
}
@@ -2464,6 +2491,28 @@ turnLoop:
if toolResult == nil {
toolResult = tools.ErrorResult("hook returned nil tool result")
}
// Send ForUser if not silent and has content.
// For ResponseHandled tools, send regardless of SendResponse setting,
// since they've already handled the response (e.g., send_tts, send_file).
shouldSendForUser := !toolResult.Silent && toolResult.ForUser != "" &&
(ts.opts.SendResponse || toolResult.ResponseHandled)
if shouldSendForUser {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Content: toolResult.ForUser,
Metadata: map[string]string{
"is_tool_call": "true",
},
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
"tool": toolName,
"content_len": len(toolResult.ForUser),
})
}
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
@@ -2509,19 +2558,6 @@ turnLoop:
allResponsesHandled = false
}
if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
"tool": toolName,
"content_len": len(toolResult.ForUser),
})
}
contentForLLM := toolResult.ContentForLLM()
// Filter sensitive data (API keys, tokens, secrets) before sending to LLM
+166
View File
@@ -0,0 +1,166 @@
# ASR (Automatic Speech Recognition)
This package handles speech-to-text for PicoClaw voice input.
If you are new to ASR setup, the simplest mental model is:
1. Add one or more ASR-capable entries to `model_list`.
2. Point `voice.model_name` at the one you want to use.
3. Put the API key in `.security.yml`.
## Quick Recommendation
For most new users, start with one of these:
| Provider | Example model | Why start here |
| --- | --- | --- |
| [Groq](https://console.groq.com/keys) | `groq/whisper-large-v3-turbo` | Fast Whisper-style transcription and a straightforward OpenAI-compatible API. Groq currently advertises a free tier plan for 2000 reqs/day. |
| [ElevenLabs](https://elevenlabs.io/pricing) | `elevenlabs/scribe_v1` | Easy setup and strong speech-to-text quality. ElevenLabs currently advertises a free plan that includes speech-to-text usage. |
Pricing and free-plan limits can change, so check the linked pricing pages before depending on them in production.
## How ASR Configuration Works
PicoClaw does not keep ASR API keys inside the `voice` section.
Instead:
- `voice.model_name` chooses a named entry from `model_list`.
- The matching `model_list` entry describes the actual provider and model.
- `.security.yml` stores the API key for that named model entry.
This is the recommended pattern because it is explicit, reusable, and consistent with the rest of PicoClaw's model configuration.
## Recommended Setup
### Option A: Groq Whisper
`config.json`
```json
{
"voice": {
"model_name": "groq-asr",
"echo_transcription": true
},
"model_list": [
{
"model_name": "groq-asr",
"model": "groq/whisper-large-v3-turbo"
}
]
}
```
`.security.yml`
```yaml
model_list:
groq-asr:
api_keys:
- "gsk_your_groq_key"
```
Notes:
- You can omit `api_base` and PicoClaw will use Groq's default API base automatically.
- If you set `api_base` manually for Groq Whisper, both of these forms work:
- `https://api.groq.com/openai/v1`
- `https://api.groq.com/openai/v1/audio/transcriptions`
- Any OpenAI-compatible Whisper model name containing `whisper` can use the Whisper transcription path, not only `whisper-large-v3-turbo`.
### Option B: ElevenLabs
`config.json`
```json
{
"voice": {
"model_name": "elevenlabs-asr",
"echo_transcription": true
},
"model_list": [
{
"model_name": "elevenlabs-asr",
"model": "elevenlabs/scribe_v1"
}
]
}
```
`.security.yml`
```yaml
model_list:
elevenlabs-asr:
api_keys:
- "sk-elevenlabs-your-key"
```
### Option C: OpenAI Whisper
`config.json`
```json
{
"voice": {
"model_name": "openai-asr"
},
"model_list": [
{
"model_name": "openai-asr",
"model": "openai/whisper-1"
}
]
}
```
`.security.yml`
```yaml
model_list:
openai-asr:
api_keys:
- "sk-openai-your-key"
```
## Other ASR-Capable Model Types
PicoClaw currently supports three main ASR routes:
| Route | Example models | Behavior |
| --- | --- | --- |
| ElevenLabs ASR | `elevenlabs/scribe_v1` | Uses the ElevenLabs transcription API. |
| Whisper endpoint models | `openai/whisper-1`, `groq/whisper-large-v3` | Uses an OpenAI-compatible `/audio/transcriptions` endpoint. |
| Audio-capable chat models **(Under construction)** | `openai/gpt-4o-audio-preview`, `gemini/gemini-2.5-flash` | Sends audio to a multimodal chat model and asks it to transcribe. |
If you are unsure which one to pick, choose Groq Whisper or ElevenLabs first.
## How PicoClaw Chooses a Transcriber
`DetectTranscriber` resolves ASR in this order:
1. **Preferred path**: resolve `voice.model_name` against `model_list`.
2. If that resolved model is:
- `elevenlabs/...`, PicoClaw uses the ElevenLabs transcriber.
- an OpenAI-compatible Whisper model, PicoClaw uses the Whisper transcriber.
- an audio-capable chat model, PicoClaw uses `AudioModelTranscriber`.
3. **Fallback path**: if `voice.model_name` is not set, PicoClaw performs a compatibility scan through `model_list` for legacy auto-detected ASR entries.
Fallback scanning exists for backward compatibility. New configurations should set `voice.model_name` explicitly.
## Common Mistakes
- Defining an ASR model in `model_list` but forgetting to set `voice.model_name`.
- Putting the API key in `voice` instead of `.security.yml`.
- Using a non-ASR model and expecting Whisper-style transcription behavior.
- Setting a custom `api_base` that points to the wrong provider endpoint.
## Minimal Checklist
Before testing voice input, make sure:
- `voice.model_name` matches a `model_list[].model_name`.
- The matching `.security.yml` entry contains a valid API key.
- The selected model is actually ASR-capable.
- Voice input is enabled for the channel you are using.
+166
View File
@@ -0,0 +1,166 @@
# ASR(自动语音识别)
这个目录负责 PicoClaw 的语音转文字能力。
如果你是第一次配置 ASR,可以参考如下步骤:
1.`model_list` 里添加一个或多个支持 ASR 的模型条目。
2.`voice.model_name` 指向你想使用的那个条目。
3.`.security.yml` 里配置对应的 API Key。
## 快速推荐
对于大多数新用户,建议先从下面两种开始:
| 提供商 | 示例模型 | 推荐理由 |
| --- | --- | --- |
| [Groq](https://console.groq.com/keys) | `groq/whisper-large-v3-turbo` | Whisper 风格转录速度快,并且提供 OpenAI 兼容接口,配置比较直接。Groq 目前官方提供2000请求每日的免费套餐。 |
| [ElevenLabs](https://elevenlabs.io/pricing) | `elevenlabs/scribe_v1` | 上手简单,语音转文字质量也不错。ElevenLabs 目前官方免费套餐包含 STT 用量。 |
价格和免费额度可能会变化,正式使用前请以官网定价页为准。
## ASR 配置是如何工作的
PicoClaw 不会把 ASR 的 API Key 放在 `voice` 配置里。
推荐的方式是:
- `voice.model_name` 用来选择 `model_list` 里的某个命名模型。
- `model_list` 条目描述真实的提供商和模型。
- `.security.yml` 负责保存该模型条目的 API Key。
这种方式更明确、更安全,也和 PicoClaw 其他模型配置方式保持一致。
## 推荐配置方式
### 方案 AGroq Whisper
`config.json`
```json
{
"voice": {
"model_name": "groq-asr",
"echo_transcription": true
},
"model_list": [
{
"model_name": "groq-asr",
"model": "groq/whisper-large-v3-turbo"
}
]
}
```
`.security.yml`
```yaml
model_list:
groq-asr:
api_keys:
- "gsk_your_groq_key"
```
说明:
- 你可以不写 `api_base`PicoClaw 会自动使用 Groq 默认接口地址。
- 如果你手动设置 Groq Whisper 的 `api_base`,下面两种写法都可以:
- `https://api.groq.com/openai/v1`
- `https://api.groq.com/openai/v1/audio/transcriptions`
- 只要是 OpenAI 兼容、并且模型名里包含 `whisper` 的模型,都可以走 Whisper 转录路径,不仅限于 `whisper-large-v3-turbo`
### 方案 BElevenLabs
`config.json`
```json
{
"voice": {
"model_name": "elevenlabs-asr",
"echo_transcription": true
},
"model_list": [
{
"model_name": "elevenlabs-asr",
"model": "elevenlabs/scribe_v1"
}
]
}
```
`.security.yml`
```yaml
model_list:
elevenlabs-asr:
api_keys:
- "sk-elevenlabs-your-key"
```
### 方案 COpenAI Whisper
`config.json`
```json
{
"voice": {
"model_name": "openai-asr"
},
"model_list": [
{
"model_name": "openai-asr",
"model": "openai/whisper-1"
}
]
}
```
`.security.yml`
```yaml
model_list:
openai-asr:
api_keys:
- "sk-openai-your-key"
```
## 其他支持 ASR 的模型类型
PicoClaw 目前主要支持三种 ASR 路径:
| 路径 | 示例模型 | 行为说明 |
| --- | --- | --- |
| ElevenLabs ASR | `elevenlabs/scribe_v1` | 使用 ElevenLabs 的语音转录接口。 |
| Whisper 接口模型 | `openai/whisper-1``groq/whisper-large-v3` | 使用 OpenAI 兼容的 `/audio/transcriptions` 接口。 |
| 支持音频的聊天模型 **(重构中)** | `openai/gpt-4o-audio-preview``gemini/gemini-2.5-flash` | 把音频发给多模态聊天模型,并要求它返回转录结果。 |
如果你不确定该选哪种,建议优先使用 Groq Whisper 或 ElevenLabs。
## PicoClaw 如何选择转录器
`DetectTranscriber` 会按下面顺序选择 ASR
1. **首选路径**:根据 `voice.model_name``model_list` 中找到对应模型。
2. 如果找到的模型属于以下类型:
- `elevenlabs/...`,则使用 ElevenLabs transcriber。
- OpenAI 兼容的 Whisper 模型,则使用 Whisper transcriber。
- 支持音频输入的聊天模型,则使用 `AudioModelTranscriber`
3. **回退路径**:如果没有设置 `voice.model_name`,PicoClaw 会为了兼容旧配置,扫描 `model_list` 中可自动识别的 ASR 条目。
回退扫描只是为了兼容旧行为。新配置建议始终显式设置 `voice.model_name`
## 常见错误
-`model_list` 里定义了 ASR 模型,但忘了设置 `voice.model_name`
- 把 API Key 写进了 `voice`,而不是 `.security.yml`
- 选择了不支持 ASR 的模型,却期望得到 Whisper 风格的转录结果。
- 自定义了错误的 `api_base`,导致请求打到错误的接口地址。
## 最小检查清单
在测试语音输入前,请确认:
- `voice.model_name` 能正确匹配某个 `model_list[].model_name`
- `.security.yml` 中对应条目已经配置了有效 API Key。
- 你选择的模型确实支持 ASR。
- 你当前使用的频道已经启用了语音输入能力。
+252
View File
@@ -0,0 +1,252 @@
package asr
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/pion/rtp"
"github.com/pion/webrtc/v3/pkg/media/oggwriter"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
)
type speechAccumulator struct {
writer *oggwriter.OggWriter
file string
lastAudioAt time.Time
mu sync.Mutex
closed bool
chatID string
speakerID string
sessionID string
channel string
}
func (a *speechAccumulator) Push(chunk bus.AudioChunk) {
a.mu.Lock()
defer a.mu.Unlock()
if a.closed {
return
}
a.lastAudioAt = time.Now()
pkt := &rtp.Packet{
Header: rtp.Header{
SequenceNumber: uint16(chunk.Sequence),
Timestamp: chunk.Timestamp,
SSRC: 1, // Stable arbitrary dummy
},
Payload: chunk.Data,
}
if err := a.writer.WriteRTP(pkt); err != nil {
logger.ErrorCF("voice-agent", "Failed to write RTP", map[string]any{"error": err})
}
}
func (a *speechAccumulator) Close() {
a.mu.Lock()
defer a.mu.Unlock()
if !a.closed {
a.writer.Close()
a.closed = true
}
}
type Agent struct {
bus *bus.MessageBus
transcriber Transcriber
mu sync.Mutex
sessions map[string]*speechAccumulator // keyed by sessionID_speakerID
}
func NewAgent(mb *bus.MessageBus, t Transcriber) *Agent {
return &Agent{
bus: mb,
transcriber: t,
sessions: make(map[string]*speechAccumulator),
}
}
func (a *Agent) Start(ctx context.Context) {
logger.InfoCF("voice-agent", "Started Voice Agent orchestrator", nil)
go a.listenChunks(ctx)
go a.vadTick(ctx)
// Cleanup sessions on shutdown
go func() {
<-ctx.Done()
a.mu.Lock()
for key, acc := range a.sessions {
acc.Close()
os.Remove(acc.file)
delete(a.sessions, key)
}
a.mu.Unlock()
logger.InfoCF("voice-agent", "Cleaned up voice sessions on shutdown", nil)
}()
}
func (a *Agent) listenChunks(ctx context.Context) {
chunks := a.bus.AudioChunksChan()
for {
select {
case <-ctx.Done():
return
case chunk, ok := <-chunks:
if !ok {
return
}
a.handleChunk(chunk)
}
}
}
func (a *Agent) handleChunk(chunk bus.AudioChunk) {
// Only accept Opus-encoded audio
if chunk.Format != "opus" {
logger.DebugCF("voice-agent", "Ignoring unsupported audio format", map[string]any{"format": chunk.Format})
return
}
key := fmt.Sprintf("%s_%s", chunk.SessionID, chunk.SpeakerID)
a.mu.Lock()
acc, exists := a.sessions[key]
if !exists {
filename := filepath.Join(os.TempDir(), fmt.Sprintf("voice_%s_%d.ogg", key, time.Now().UnixNano()))
writer, err := oggwriter.New(filename, uint32(chunk.SampleRate), uint16(chunk.Channels))
if err != nil {
a.mu.Unlock()
logger.ErrorCF("voice-agent", "Failed to create OggWriter", map[string]any{"error": err})
return
}
acc = &speechAccumulator{
writer: writer,
file: filename,
lastAudioAt: time.Now(),
chatID: chunk.ChatID,
speakerID: chunk.SpeakerID,
sessionID: chunk.SessionID,
channel: chunk.Channel,
}
a.sessions[key] = acc
logger.DebugCF("voice-agent", "Started accumulating voice", map[string]any{"key": key, "file": filename})
}
a.mu.Unlock()
acc.Push(chunk)
}
func (a *Agent) vadTick(ctx context.Context) {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
a.checkSilence(ctx)
}
}
}
func (a *Agent) checkSilence(ctx context.Context) {
a.mu.Lock()
now := time.Now()
var finished []*speechAccumulator
for key, acc := range a.sessions {
acc.mu.Lock()
last := acc.lastAudioAt
acc.mu.Unlock()
if now.Sub(last) > 1500*time.Millisecond {
acc.Close()
delete(a.sessions, key)
finished = append(finished, acc)
}
}
a.mu.Unlock()
for _, acc := range finished {
go a.processUtterance(ctx, acc)
}
}
func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
defer os.Remove(acc.file)
logger.InfoCF("voice-agent", "User finished speaking, transcribing...", map[string]any{"file": acc.file})
if a.transcriber == nil {
logger.ErrorCF("voice-agent", "No STT configured!", nil)
return
}
res, err := a.transcriber.Transcribe(ctx, acc.file)
if err != nil {
logger.ErrorCF("voice-agent", "Transcription failed", map[string]any{"error": err})
return
}
if res.Text == "" {
logger.DebugCF("voice-agent", "Ignored empty transcription", map[string]any{"file": acc.file})
return
}
logger.InfoCF("voice-agent", "Transcription result", map[string]any{"text": res.Text, "duration": res.Duration})
channelType := acc.channel
if channelType == "" {
channelType = "discord" // fallback for legacy chunks
}
text := strings.ToLower(strings.TrimSpace(res.Text))
if strings.Contains(text, "leave the voice channel") || strings.Contains(text, "leave voice") ||
strings.Contains(text, "disconnect voice") || strings.Contains(text, "leave the channel") ||
strings.Contains(text, "leave channel") {
logger.InfoCF("voice-agent", "Voice command triggered: leave", nil)
if err := a.bus.PublishVoiceControl(ctx, bus.VoiceControl{
SessionID: acc.sessionID,
Type: "command",
Action: "leave",
}); err != nil {
logger.ErrorCF("voice-agent", "Failed to publish leave control", map[string]any{"error": err})
}
if err := a.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: channelType,
ChatID: acc.chatID,
Content: "Goodbye! Leaving the voice channel.",
}); err != nil {
logger.ErrorCF("voice-agent", "Failed to publish goodbye message", map[string]any{"error": err})
}
return
}
oralPrompt := "\n\n[SYSTEM]: The user just spoke this to you over voice chat. Please reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally."
if err := a.bus.PublishInbound(ctx, bus.InboundMessage{
Channel: channelType,
SenderID: acc.speakerID,
ChatID: acc.chatID,
Content: res.Text + oralPrompt,
Peer: bus.Peer{Kind: "channel", ID: acc.chatID},
Metadata: map[string]string{
"is_voice": "true",
},
}); err != nil {
logger.ErrorCF("voice-agent", "Failed to publish inbound message", map[string]any{"error": err})
}
}
+196
View File
@@ -0,0 +1,196 @@
package asr
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/pion/webrtc/v3/pkg/media/oggwriter"
"github.com/sipeed/picoclaw/pkg/bus"
)
type fakeTranscriber struct {
text string
err error
lastPath string
}
func (f *fakeTranscriber) Name() string { return "fake" }
func (f *fakeTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
f.lastPath = audioFilePath
if f.err != nil {
return nil, f.err
}
return &TranscriptionResponse{Text: f.text}, nil
}
func waitForFileRemoval(t *testing.T, path string, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if _, err := os.Stat(path); os.IsNotExist(err) {
return
}
time.Sleep(10 * time.Millisecond)
}
if _, err := os.Stat(path); err == nil {
t.Fatalf("expected file to be removed: %s", path)
}
}
func TestAgentHandleChunkCreatesSession(t *testing.T) {
t.Parallel()
mb := bus.NewMessageBus()
defer mb.Close()
agent := NewAgent(mb, &fakeTranscriber{})
chunk := bus.AudioChunk{
SessionID: "sess",
SpeakerID: "speaker",
ChatID: "chat",
Channel: "discord",
Sequence: 1,
Timestamp: 1,
SampleRate: 48000,
Channels: 2,
Format: "opus",
Data: []byte{0xF8, 0xFF, 0xFE},
}
agent.handleChunk(chunk)
key := "sess_speaker"
agent.mu.Lock()
acc, ok := agent.sessions[key]
agent.mu.Unlock()
if !ok {
t.Fatal("expected session to be created")
}
acc.Close()
_ = os.Remove(acc.file)
}
func TestAgentHandleChunkIgnoresUnsupportedFormat(t *testing.T) {
t.Parallel()
mb := bus.NewMessageBus()
defer mb.Close()
agent := NewAgent(mb, &fakeTranscriber{})
chunk := bus.AudioChunk{Format: "pcm"}
agent.handleChunk(chunk)
agent.mu.Lock()
count := len(agent.sessions)
agent.mu.Unlock()
if count != 0 {
t.Fatalf("expected no sessions, got %d", count)
}
}
func TestAgentProcessUtteranceLeaveCommand(t *testing.T) {
t.Parallel()
mb := bus.NewMessageBus()
defer mb.Close()
tr := &fakeTranscriber{text: "please leave the voice channel now"}
agent := NewAgent(mb, tr)
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "voice.ogg")
if err := os.WriteFile(filePath, []byte("data"), 0o600); err != nil {
t.Fatalf("write temp file: %v", err)
}
acc := &speechAccumulator{
file: filePath,
chatID: "chat",
speakerID: "speaker",
sessionID: "sess",
channel: "discord",
}
agent.processUtterance(context.Background(), acc)
select {
case ctrl := <-mb.VoiceControlsChan():
if ctrl.Action != "leave" || ctrl.Type != "command" || ctrl.SessionID != "sess" {
t.Fatalf("unexpected voice control: %#v", ctrl)
}
case <-time.After(250 * time.Millisecond):
t.Fatal("expected voice control publish")
}
select {
case out := <-mb.OutboundChan():
if !strings.Contains(out.Content, "Leaving the voice channel") {
t.Fatalf("unexpected outbound content: %q", out.Content)
}
case <-time.After(250 * time.Millisecond):
t.Fatal("expected outbound publish")
}
if _, err := os.Stat(filePath); !os.IsNotExist(err) {
t.Fatalf("expected temp file to be removed")
}
}
func TestAgentCheckSilencePublishesInboundAndCleansUp(t *testing.T) {
t.Parallel()
mb := bus.NewMessageBus()
defer mb.Close()
tr := &fakeTranscriber{text: "hello there"}
agent := NewAgent(mb, tr)
filePath := filepath.Join(t.TempDir(), "voice.ogg")
writer, err := oggwriter.New(filePath, 48000, 2)
if err != nil {
t.Fatalf("create ogg writer: %v", err)
}
acc := &speechAccumulator{
writer: writer,
file: filePath,
lastAudioAt: time.Now().Add(-2 * time.Second),
chatID: "chat",
speakerID: "speaker",
sessionID: "sess",
channel: "slack",
}
agent.mu.Lock()
agent.sessions["sess_speaker"] = acc
agent.mu.Unlock()
agent.checkSilence(context.Background())
select {
case msg := <-mb.InboundChan():
if msg.Channel != "slack" {
t.Fatalf("unexpected inbound channel: %q", msg.Channel)
}
if !strings.Contains(msg.Content, "hello there") {
t.Fatalf("unexpected inbound content: %q", msg.Content)
}
if msg.Metadata["is_voice"] != "true" {
t.Fatalf("expected is_voice metadata, got %#v", msg.Metadata)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("expected inbound publish")
}
waitForFileRemoval(t, filePath, 500*time.Millisecond)
}
+131
View File
@@ -0,0 +1,131 @@
package asr
import (
"context"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
type Transcriber interface {
Name() string
Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
}
type TranscriptionResponse struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
}
func supportsAudioTranscription(model string) bool {
protocol, _ := providers.ExtractProtocol(model)
switch protocol {
case "openai", "azure", "azure-openai",
"litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
"coding-plan", "alibaba-coding", "qwen-coding":
// These protocols all go through the OpenAI-compatible or Azure provider path in
// providers.CreateProviderFromConfig, so they are the only ones that can supply
// the audio media payload shape expected by NewAudioModelTranscriber.
// TODO: Further restrict this by modelID, since not every model under these
// protocols supports audio transcription.
return true
default:
return false
}
}
func supportsWhisperTranscription(model string) bool {
protocol, _ := providers.ExtractProtocol(model)
switch protocol {
case "openai", "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
"coding-plan", "alibaba-coding", "qwen-coding", "mimo":
return true
default:
return false
}
}
func whisperModelID(modelCfg *config.ModelConfig) string {
if modelCfg == nil || modelCfg.APIKey() == "" {
return ""
}
if !supportsWhisperTranscription(modelCfg.Model) {
return ""
}
_, modelID := providers.ExtractProtocol(strings.TrimSpace(modelCfg.Model))
if strings.Contains(strings.ToLower(modelID), "whisper") {
return modelID
}
return ""
}
func transcriberFromModelConfig(modelCfg *config.ModelConfig) Transcriber {
if modelCfg == nil {
return nil
}
protocol, _ := providers.ExtractProtocol(modelCfg.Model)
if protocol == "elevenlabs" && modelCfg.APIKey() != "" {
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase)
}
if modelID := whisperModelID(modelCfg); modelID != "" {
return NewWhisperTranscriber(modelCfg)
}
if supportsAudioTranscription(modelCfg.Model) {
return NewAudioModelTranscriber(modelCfg)
}
return nil
}
func fallbackTranscriberFromModelConfig(modelCfg *config.ModelConfig) Transcriber {
if modelCfg == nil {
return nil
}
protocol, _ := providers.ExtractProtocol(modelCfg.Model)
if protocol == "elevenlabs" && modelCfg.APIKey() != "" {
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase)
}
if modelID := whisperModelID(modelCfg); modelID != "" {
return NewWhisperTranscriber(modelCfg)
}
return nil
}
// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or
// nil if no supported transcription provider is configured.
func DetectTranscriber(cfg *config.Config) Transcriber {
if cfg == nil {
return nil
}
if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" {
modelCfg, err := cfg.GetModelConfig(modelName)
if err == nil {
if tr := transcriberFromModelConfig(modelCfg); tr != nil {
return tr
}
}
}
// Fall back to compatibility scanning for legacy auto-detected ASR providers.
for _, mc := range cfg.ModelList {
if tr := fallbackTranscriberFromModelConfig(mc); tr != nil {
return tr
}
}
return nil
}
@@ -1,4 +1,4 @@
package voice
package asr
import (
"testing"
@@ -33,26 +33,68 @@ func TestDetectTranscriber(t *testing.T) {
wantName: "audio-model",
},
{
name: "groq via model list",
name: "voice model name alias selects elevenlabs transcriber",
cfg: &config.Config{
Voice: config.VoiceConfig{ModelName: "my-asr-model"},
ModelList: []*config.ModelConfig{
{
ModelName: "my-asr-model",
Model: "elevenlabs/scribe_v1",
APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test"),
},
},
},
wantName: "elevenlabs",
},
{
name: "voice model name alias selects whisper transcriber for groq",
cfg: &config.Config{
Voice: config.VoiceConfig{ModelName: "my-asr-model"},
ModelList: []*config.ModelConfig{
{
ModelName: "my-asr-model",
Model: "groq/whisper-large-v3",
APIKeys: config.SimpleSecureStrings("sk-groq-model"),
},
},
},
wantName: "whisper",
},
{
name: "openai whisper alias selects whisper transcriber",
cfg: &config.Config{
Voice: config.VoiceConfig{ModelName: "my-asr-model"},
ModelList: []*config.ModelConfig{
{
ModelName: "my-asr-model",
Model: "openai/whisper-1",
APIKeys: config.SimpleSecureStrings("sk-openai-model"),
},
},
},
wantName: "whisper",
},
{
name: "whisper via model list fallback",
cfg: &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "openai", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("sk-openai")},
{
ModelName: "groq",
Model: "groq/llama-3.3-70b",
Model: "groq/whisper-large-v3-turbo",
APIKeys: config.SimpleSecureStrings("sk-groq-model"),
},
},
},
wantName: "groq",
wantName: "whisper",
},
{
name: "voice model name selects non-gemini audio model transcriber",
name: "voice model name alias selects non-gemini audio model transcriber",
cfg: &config.Config{
Voice: config.VoiceConfig{ModelName: "voice-openai-audio"},
Voice: config.VoiceConfig{ModelName: "my-asr-model"},
ModelList: []*config.ModelConfig{
{
ModelName: "voice-openai-audio",
ModelName: "my-asr-model",
Model: "openai/gpt-4o-audio-preview",
APIKeys: config.SimpleSecureStrings("sk-openai"),
},
@@ -92,7 +134,7 @@ func TestDetectTranscriber(t *testing.T) {
name: "groq model list entry without key is skipped",
cfg: &config.Config{
ModelList: []*config.ModelConfig{
{Model: "groq/llama-3.3-70b"},
{Model: "groq/whisper-large-v3"},
},
},
wantNil: true,
@@ -103,12 +145,12 @@ func TestDetectTranscriber(t *testing.T) {
ModelList: []*config.ModelConfig{
{
ModelName: "groq",
Model: "groq/llama-3.3-70b",
Model: "groq/whisper-large-v3",
APIKeys: config.SimpleSecureStrings("sk-groq-model"),
},
},
},
wantName: "groq",
wantName: "whisper",
},
{
name: "missing voice model name config returns nil",
@@ -127,15 +169,17 @@ func TestDetectTranscriber(t *testing.T) {
{
name: "elevenlabs voice config key",
cfg: &config.Config{
Voice: config.VoiceConfig{ElevenLabsAPIKey: "sk_elevenlabs_test"},
ModelList: []*config.ModelConfig{
{Model: "elevenlabs/scribe_v1", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")},
},
},
wantName: "elevenlabs",
},
{
name: "elevenlabs takes priority over groq model list",
cfg: &config.Config{
Voice: config.VoiceConfig{ElevenLabsAPIKey: "sk_elevenlabs_test"},
ModelList: []*config.ModelConfig{
{Model: "elevenlabs/scribe_v1", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")},
{
ModelName: "groq",
Model: "groq/llama-3.3-70b",
@@ -149,10 +193,10 @@ func TestDetectTranscriber(t *testing.T) {
name: "voice model name takes priority over elevenlabs",
cfg: &config.Config{
Voice: config.VoiceConfig{
ModelName: "voice-gemini",
ElevenLabsAPIKey: "sk_elevenlabs_test",
ModelName: "voice-gemini",
},
ModelList: []*config.ModelConfig{
{Model: "elevenlabs", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")},
{
ModelName: "voice-gemini",
Model: "gemini/gemini-2.5-flash",
@@ -1,4 +1,4 @@
package voice
package asr
import (
"context"
@@ -1,4 +1,4 @@
package voice
package asr
import (
"context"
@@ -1,4 +1,4 @@
package voice
package asr
import (
"bytes"
@@ -23,12 +23,16 @@ type ElevenLabsTranscriber struct {
httpClient *http.Client
}
func NewElevenLabsTranscriber(apiKey string) *ElevenLabsTranscriber {
func NewElevenLabsTranscriber(apiKey, apiBase string) *ElevenLabsTranscriber {
logger.DebugCF("voice", "Creating ElevenLabs transcriber", map[string]any{"has_api_key": apiKey != ""})
if apiBase == "" {
apiBase = "https://api.elevenlabs.io"
}
return &ElevenLabsTranscriber{
apiKey: apiKey,
apiBase: "https://api.elevenlabs.io",
apiBase: apiBase,
httpClient: &http.Client{
Timeout: 120 * time.Second,
},
@@ -1,4 +1,4 @@
package voice
package asr
import (
"context"
@@ -14,7 +14,7 @@ import (
var _ Transcriber = (*ElevenLabsTranscriber)(nil)
func TestElevenLabsTranscriberName(t *testing.T) {
tr := NewElevenLabsTranscriber("sk_test")
tr := NewElevenLabsTranscriber("sk_test", "")
if got := tr.Name(); got != "elevenlabs" {
t.Errorf("Name() = %q, want %q", got, "elevenlabs")
}
@@ -43,7 +43,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
}))
defer srv.Close()
tr := NewElevenLabsTranscriber("sk_test")
tr := NewElevenLabsTranscriber("sk_test", "")
tr.apiBase = srv.URL
resp, err := tr.Transcribe(context.Background(), audioPath)
@@ -64,7 +64,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
}))
defer srv.Close()
tr := NewElevenLabsTranscriber("sk_bad")
tr := NewElevenLabsTranscriber("sk_bad", "")
tr.apiBase = srv.URL
_, err := tr.Transcribe(context.Background(), audioPath)
@@ -74,7 +74,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
})
t.Run("missing file", func(t *testing.T) {
tr := NewElevenLabsTranscriber("sk_test")
tr := NewElevenLabsTranscriber("sk_test", "")
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
if err == nil {
t.Fatal("expected error for missing file, got nil")
+245
View File
@@ -0,0 +1,245 @@
package asr
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/utils"
)
type WhisperTranscriber struct {
apiKey string
apiBase string
modelID string
providerName string
httpClient *http.Client
}
func NewWhisperTranscriber(modelCfg *config.ModelConfig) *WhisperTranscriber {
if modelCfg == nil {
return nil
}
protocol, modelID := providers.ExtractProtocol(modelCfg.Model)
if modelID == "" {
modelID = strings.TrimSpace(modelCfg.Model)
}
tr := newWhisperTranscriber(
modelCfg.APIKey(),
providers.ResolveAPIBase(modelCfg),
modelID,
protocol,
)
if tr == nil {
return nil
}
logger.DebugCF("voice", "Creating whisper transcriber", map[string]any{
"api_base": tr.apiBase,
"has_key": tr.apiKey != "",
"model": tr.modelID,
"provider": tr.providerName,
})
return tr
}
func NewGroqTranscriber(apiKey, modelID string) *WhisperTranscriber {
return newWhisperTranscriber(apiKey, "https://api.groq.com/openai/v1", modelID, "groq")
}
func newWhisperTranscriber(apiKey, apiBase, modelID, providerName string) *WhisperTranscriber {
if modelID == "" {
return nil
}
if providerName == "" {
providerName = "whisper"
}
return &WhisperTranscriber{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
modelID: modelID,
providerName: providerName,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
}
}
func (t *WhisperTranscriber) transcriptionURL() string {
base := strings.TrimRight(t.apiBase, "/")
if strings.HasSuffix(base, "/audio/transcriptions") {
return base
}
return base + "/audio/transcriptions"
}
func (t *WhisperTranscriber) TranscribeData(
ctx context.Context,
data []byte,
filename string,
) (*TranscriptionResponse, error) {
logger.InfoCF("voice", "Starting whisper transcription from memory", map[string]any{
"bytes": len(data),
"filename": filename,
"model": t.modelID,
"provider": t.providerName,
})
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
logger.ErrorCF("voice", "Failed to create whisper form file", map[string]any{"error": err})
return nil, fmt.Errorf("failed to create form file: %w", err)
}
if _, copyErr := io.Copy(part, bytes.NewReader(data)); copyErr != nil {
logger.ErrorCF("voice", "Failed to copy whisper file content", map[string]any{"error": copyErr})
return nil, fmt.Errorf("failed to copy file content: %w", copyErr)
}
if err = writer.WriteField("model", t.modelID); err != nil {
logger.ErrorCF("voice", "Failed to write whisper model field", map[string]any{"error": err})
return nil, fmt.Errorf("failed to write model field: %w", err)
}
if err = writer.WriteField("response_format", "json"); err != nil {
logger.ErrorCF("voice", "Failed to write whisper response_format field", map[string]any{"error": err})
return nil, fmt.Errorf("failed to write response_format field: %w", err)
}
if err = writer.Close(); err != nil {
logger.ErrorCF("voice", "Failed to close whisper multipart writer", map[string]any{"error": err})
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
}
return t.doRequest(ctx, &requestBody, writer.FormDataContentType(), int64(len(data)))
}
func (t *WhisperTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
logger.InfoCF("voice", "Starting whisper transcription", map[string]any{
"audio_file": audioFilePath,
"model": t.modelID,
"provider": t.providerName,
})
audioFile, err := os.Open(audioFilePath)
if err != nil {
return nil, fmt.Errorf("failed to open audio file %s: %w", audioFilePath, err)
}
defer audioFile.Close()
fileInfo, err := audioFile.Stat()
if err != nil {
return nil, fmt.Errorf("failed to stat audio file %s: %w", audioFilePath, err)
}
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
if err != nil {
return nil, fmt.Errorf("failed to create form file: %w", err)
}
if _, copyErr := io.Copy(part, audioFile); copyErr != nil {
return nil, fmt.Errorf("failed to copy audio data: %w", copyErr)
}
if err = writer.WriteField("model", t.modelID); err != nil {
return nil, fmt.Errorf("failed to write model field: %w", err)
}
if err = writer.WriteField("response_format", "json"); err != nil {
return nil, fmt.Errorf("failed to write response_format field: %w", err)
}
if err = writer.Close(); err != nil {
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
}
return t.doRequest(ctx, &requestBody, writer.FormDataContentType(), fileInfo.Size())
}
func (t *WhisperTranscriber) doRequest(
ctx context.Context,
requestBody *bytes.Buffer,
contentType string,
fileSize int64,
) (*TranscriptionResponse, error) {
url := t.transcriptionURL()
req, err := http.NewRequestWithContext(ctx, "POST", url, requestBody)
if err != nil {
logger.ErrorCF("voice", "Failed to create whisper request", map[string]any{"error": err})
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", contentType)
if t.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+t.apiKey)
}
logger.DebugCF("voice", "Sending whisper transcription request", map[string]any{
"file_size_bytes": fileSize,
"model": t.modelID,
"provider": t.providerName,
"request_size_bytes": requestBody.Len(),
"url": url,
})
resp, err := t.httpClient.Do(req)
if err != nil {
logger.ErrorCF("voice", "Failed to send whisper request", map[string]any{"error": err})
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.ErrorCF("voice", "Failed to read whisper response", map[string]any{"error": err})
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
logger.ErrorCF("voice", "Whisper API error", map[string]any{
"provider": t.providerName,
"response": string(body),
"status_code": resp.StatusCode,
})
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result TranscriptionResponse
if err := json.Unmarshal(body, &result); err != nil {
logger.ErrorCF("voice", "Failed to unmarshal whisper response", map[string]any{"error": err})
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
logger.InfoCF("voice", "Whisper transcription completed successfully", map[string]any{
"duration_seconds": result.Duration,
"language": result.Language,
"provider": t.providerName,
"text_length": len(result.Text),
"transcription_preview": utils.Truncate(result.Text, 50),
})
return &result, nil
}
func (t *WhisperTranscriber) Name() string {
return "whisper"
}
+102
View File
@@ -0,0 +1,102 @@
package asr
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestWhisperTranscriberTranscribeDataUsesConfiguredModel(t *testing.T) {
var gotModel string
var gotPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
if got := r.Header.Get("Authorization"); got != "Bearer sk-openai-test" {
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-openai-test")
}
reader, err := r.MultipartReader()
if err != nil {
t.Fatalf("MultipartReader() error: %v", err)
}
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("NextPart() error: %v", err)
}
data, err := io.ReadAll(part)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if part.FormName() == "model" {
gotModel = string(data)
}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(TranscriptionResponse{Text: "hello from whisper"}); err != nil {
t.Fatalf("Encode() error: %v", err)
}
}))
defer server.Close()
tr := NewWhisperTranscriber(&config.ModelConfig{
Model: "openai/whisper-1",
APIBase: server.URL,
APIKeys: config.SimpleSecureStrings("sk-openai-test"),
})
tr.httpClient = server.Client()
resp, err := tr.TranscribeData(context.Background(), []byte("audio"), "clip.ogg")
if err != nil {
t.Fatalf("TranscribeData() error: %v", err)
}
if resp.Text != "hello from whisper" {
t.Errorf("Text = %q, want %q", resp.Text, "hello from whisper")
}
if gotModel != "whisper-1" {
t.Errorf("model field = %q, want %q", gotModel, "whisper-1")
}
if gotPath != "/audio/transcriptions" {
t.Errorf("path = %q, want %q", gotPath, "/audio/transcriptions")
}
}
func TestWhisperTranscriberUsesEndpointAPIBaseWithoutDoubleAppend(t *testing.T) {
var gotPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(TranscriptionResponse{Text: "ok"}); err != nil {
t.Fatalf("Encode() error: %v", err)
}
}))
defer server.Close()
tr := NewWhisperTranscriber(&config.ModelConfig{
Model: "groq/whisper-large-v3",
APIBase: server.URL + "/audio/transcriptions",
APIKeys: config.SimpleSecureStrings("sk-groq-test"),
})
tr.httpClient = server.Client()
if _, err := tr.TranscribeData(context.Background(), []byte("audio"), "clip.ogg"); err != nil {
t.Fatalf("TranscribeData() error: %v", err)
}
if gotPath != "/audio/transcriptions" {
t.Errorf("path = %q, want %q", gotPath, "/audio/transcriptions")
}
}
+57
View File
@@ -0,0 +1,57 @@
package audio
import (
"bytes"
"fmt"
"io"
)
// DecodeOggOpus reads an Ogg format stream and extracts individual Opus payloads.
// It calls onFrame for every complete Opus frame found in the stream.
func DecodeOggOpus(r io.Reader, onFrame func([]byte) error) error {
var packet bytes.Buffer
header := make([]byte, 27)
segment := make([]byte, 255)
for {
if _, err := io.ReadFull(r, header); err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
return nil
}
return fmt.Errorf("failed to read ogg header: %w", err)
}
if string(header[:4]) != "OggS" {
return fmt.Errorf("invalid ogg magic string")
}
pageSegments := int(header[26])
segmentTable := make([]byte, pageSegments)
if _, err := io.ReadFull(r, segmentTable); err != nil {
return fmt.Errorf("failed to read segment table: %w", err)
}
for _, lacing := range segmentTable {
if _, err := io.ReadFull(r, segment[:lacing]); err != nil {
return fmt.Errorf("failed to read segment data: %w", err)
}
packet.Write(segment[:lacing])
// If lacing is less than 255, the packet is complete
if lacing < 255 {
if packet.Len() > 0 {
packetBytes := packet.Bytes()
// Ignore Ogg Opus headers
if !bytes.HasPrefix(packetBytes, []byte("OpusHead")) &&
!bytes.HasPrefix(packetBytes, []byte("OpusTags")) {
if err := onFrame(packetBytes); err != nil {
return err
}
}
// Start new packet
packet.Reset()
}
}
}
}
}
+146
View File
@@ -0,0 +1,146 @@
package audio
import (
"bytes"
"reflect"
"strings"
"testing"
)
// buildOggPage helper creates an Ogg page for testing.
// lacingVals specifies the segment table, and data is the payload.
func buildOggPage(lacingVals []byte, data []byte) []byte {
var buf bytes.Buffer
// 27-byte Ogg header
header := make([]byte, 27)
copy(header[:4], "OggS")
header[5] = 0 // type flag
// For testing, we only care about OggS magic and page_segments (byte 26)
header[26] = byte(len(lacingVals))
buf.Write(header)
buf.Write(lacingVals)
buf.Write(data)
return buf.Bytes()
}
func TestDecodeOggOpus_ValidParsing(t *testing.T) {
var b bytes.Buffer
// Packet 1: Single segment, length 50
pkt1 := bytes.Repeat([]byte{1}, 50)
// Packet 2: Multi-segment (255 + 10 = 265 bytes)
pkt2Part1 := bytes.Repeat([]byte{2}, 255)
pkt2Part2 := bytes.Repeat([]byte{2}, 10)
// Packet 3: Continued across pages. Page 1 gets 255, Page 2 gets 20. Total 275 bytes.
pkt3Part1 := bytes.Repeat([]byte{3}, 255)
pkt3Part2 := bytes.Repeat([]byte{3}, 20)
// Page 1: OpusHead (skip), OpusTags (skip), pkt1, pkt2, pkt3Part1
page1Lacing := []byte{8, 8, 50, 255, 10, 255}
page1Data := bytes.Join([][]byte{
[]byte("OpusHead"),
[]byte("OpusTags"),
pkt1,
pkt2Part1, pkt2Part2,
pkt3Part1,
}, nil)
// Page 2: pkt3Part2, pkt4 (length 10)
pkt4 := bytes.Repeat([]byte{4}, 10)
page2Lacing := []byte{20, 10}
page2Data := bytes.Join([][]byte{
pkt3Part2,
pkt4,
}, nil)
b.Write(buildOggPage(page1Lacing, page1Data))
b.Write(buildOggPage(page2Lacing, page2Data))
var frames [][]byte
err := DecodeOggOpus(&b, func(frame []byte) error {
// making a copy to store as DecodeOggOpus might reuse backing array
cpy := make([]byte, len(frame))
copy(cpy, frame)
frames = append(frames, cpy)
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expectedFrames := [][]byte{
pkt1,
append(pkt2Part1, pkt2Part2...),
append(pkt3Part1, pkt3Part2...),
pkt4,
}
if len(frames) != len(expectedFrames) {
t.Fatalf("expected %d frames, got %d", len(expectedFrames), len(frames))
}
for i, expected := range expectedFrames {
if !reflect.DeepEqual(frames[i], expected) {
t.Errorf("frame %d mismatch:\nexp: %v\ngot: %v", i, expected, frames[i])
}
}
}
func TestDecodeOggOpus_Errors(t *testing.T) {
tests := []struct {
name string
data []byte
errContains string
}{
{
name: "invalid magic string",
data: []byte(
"OggX\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
),
errContains: "invalid ogg magic string",
},
{
name: "short header",
data: []byte("Ogg"),
errContains: "failed to read ogg header",
},
{
name: "eof in segment table",
data: func() []byte {
h := make([]byte, 27)
copy(h, "OggS")
h[26] = 5 // expects 5 bytes of segment table, but none provided
return h
}(),
errContains: "failed to read segment table",
},
{
name: "eof in segment data",
data: func() []byte {
h := make([]byte, 27, 28)
copy(h, "OggS")
h[26] = 1
return append(h, 100) // expects 100 bytes of data, but none provided
}(),
errContains: "failed to read segment data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := DecodeOggOpus(bytes.NewReader(tt.data), func(b []byte) error { return nil })
if tt.name == "short header" {
if err != nil {
t.Errorf("expected no error (io.EOF/ErrUnexpectedEOF swallowed), got %v", err)
}
return
}
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.errContains)
}
if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("expected error to contain %q, got: %q", tt.errContains, err.Error())
}
})
}
}
+96
View File
@@ -0,0 +1,96 @@
package audio
import (
"strings"
"unicode"
)
// SplitSentences splits text into sentence-sized chunks suitable for TTS synthesis.
// It splits on sentence-ending punctuation (.!?\n, as well as CJK 。, , ) while avoiding false splits
// on decimal numbers. Very short fragments are merged with
// the next sentence to prevent choppy playback.
func SplitSentences(text string) []string {
if text == "" {
return nil
}
var sentences []string
var current strings.Builder
runes := []rune(text)
for i := 0; i < len(runes); i++ {
r := runes[i]
if r == '\n' {
s := strings.TrimSpace(current.String())
if s != "" {
sentences = append(sentences, s)
}
current.Reset()
continue
}
current.WriteRune(r)
if r == '.' || r == '!' || r == '?' || r == '。' || r == '' || r == '' {
// Avoid splitting on decimal numbers like "3.14"
if r == '.' && i > 0 && unicode.IsDigit(runes[i-1]) &&
i+1 < len(runes) && unicode.IsDigit(runes[i+1]) {
continue
}
// Consume contiguous punctuation clusters (e.g., "..." or "?!").
for i+1 < len(runes) && (runes[i+1] == '.' || runes[i+1] == '!' || runes[i+1] == '?' || runes[i+1] == '。' || runes[i+1] == '' || runes[i+1] == '') {
i++
current.WriteRune(runes[i])
}
s := strings.TrimSpace(current.String())
if s != "" {
sentences = append(sentences, s)
}
current.Reset()
}
}
// Flush remaining text
if s := strings.TrimSpace(current.String()); s != "" {
sentences = append(sentences, s)
}
// Merge very short fragments with the next sentence
return mergeShorties(sentences, 15)
}
// mergeShorties merges sentences shorter than minLen characters with the following sentence.
func mergeShorties(sentences []string, minLen int) []string {
if len(sentences) <= 1 {
return sentences
}
var merged []string
var buf string
for _, s := range sentences {
if buf != "" {
buf += " " + s
if len([]rune(buf)) >= minLen {
merged = append(merged, buf)
buf = ""
}
} else if len([]rune(s)) < minLen {
buf = s
} else {
merged = append(merged, s)
}
}
if buf != "" {
if len(merged) > 0 {
merged[len(merged)-1] += " " + buf
} else {
merged = append(merged, buf)
}
}
return merged
}
+69
View File
@@ -0,0 +1,69 @@
package audio
import (
"reflect"
"testing"
)
func TestSplitSentences(t *testing.T) {
tests := []struct {
name string
in string
want []string
}{
{
name: "empty input",
in: "",
want: nil,
},
{
name: "single sentence",
in: "Hello world.",
want: []string{"Hello world."},
},
{
name: "decimal numbers do not split",
in: "The value is 3.14 today. Keep watching closely.",
want: []string{"The value is 3.14 today.", "Keep watching closely."},
},
{
name: "newline boundary",
in: "This is line number one\nThis is line number two",
want: []string{"This is line number one", "This is line number two"},
},
{
name: "newline with surrounding spaces",
in: " This is the first line \n This is the second line ",
want: []string{"This is the first line", "This is the second line"},
},
{
name: "trailing punctuation consumed",
in: "Please wait a moment... What on earth?! That is perfectly fine.",
want: []string{"Please wait a moment...", "What on earth?!", "That is perfectly fine."},
},
{
name: "short leading fragment merges with next",
in: "Hi. This is a longer sentence.",
want: []string{"Hi. This is a longer sentence."},
},
{
name: "consecutive short fragments keep merging",
in: "A. B. C. This is the real sentence.",
want: []string{"A. B. C. This is the real sentence."},
},
{
name: "short trailing fragment merges back",
in: "This sentence is long enough. End.",
want: []string{"This sentence is long enough. End."},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := SplitSentences(tc.in)
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("SplitSentences(%q) = %#v, want %#v", tc.in, got, tc.want)
}
})
}
}
+137
View File
@@ -0,0 +1,137 @@
# TTS (Text-to-Speech)
This package handles speech synthesis for PicoClaw.
If you are new to TTS setup, the simplest workflow is:
1. Add a TTS-capable entry to `model_list`.
2. Point `voice.tts_model_name` at that entry.
3. Put the API key in `.security.yml`.
## Quick Recommendation
For most users, these are the best starting points:
| Provider | Why start here |
| --- | --- |
| [OpenAI](https://platform.openai.com/docs/guides/text-to-speech) | Best-supported path in PicoClaw today. The current TTS implementation is built around the OpenAI-compatible `/audio/speech` API shape, and OpenAI is the safest default. |
| [Xiaomi MiMo](https://platform.xiaomimimo.com) | A good second option if you want an OpenAI-compatible provider endpoint and are already using MiMo models in the rest of your stack. |
## How TTS Configuration Works
PicoClaw does not keep TTS API keys inside `voice`.
Instead:
- `voice.tts_model_name` selects a named entry from `model_list`.
- That `model_list` entry provides the provider, model ID, API base, and proxy settings.
- `.security.yml` stores the API key for the same named model entry.
This is the recommended and supported configuration pattern.
## Recommended Setup
### Option A: OpenAI
`config.json`
```json
{
"voice": {
"tts_model_name": "openai-tts"
},
"model_list": [
{
"model_name": "openai-tts",
"model": "openai/tts-1"
}
]
}
```
`.security.yml`
```yaml
model_list:
openai-tts:
api_keys:
- "sk-openai-your-key"
```
### Option B: Xiaomi MiMo
`config.json`
```json
{
"voice": {
"tts_model_name": "mimo-tts"
},
"model_list": [
{
"model_name": "mimo-tts",
"model": "mimo/mimo-v2-tts"
}
]
}
```
`.security.yml`
```yaml
model_list:
mimo-tts:
api_keys:
- "your-mimo-key"
```
If you use a custom MiMo endpoint, you can also set `api_base` explicitly. Otherwise PicoClaw will use the provider default.
## What PicoClaw Sends Today
The current TTS runtime uses an OpenAI-compatible speech request with these defaults:
- Endpoint: `/audio/speech`
- Response format: `opus`
- Voice: `alloy`
- Model: taken from the selected `model_list` entry
That means:
- `openai/tts-1` works naturally.
- Other OpenAI-compatible providers can work if they accept the same request format.
- PicoClaw currently does not expose a user-facing config field for changing the TTS voice from `alloy`.
## How PicoClaw Chooses a TTS Provider
`DetectTTS` resolves TTS in this order:
1. **Preferred path**: resolve `voice.tts_model_name` against `model_list`.
2. If a matching model entry exists and has an API key, PicoClaw creates an OpenAI-compatible TTS provider using that model's settings.
3. **Fallback path**: if `voice.tts_model_name` is not set or cannot be resolved, PicoClaw scans `model_list` for the first entry whose model string contains `tts` and has an API key.
Fallback scanning exists for compatibility. New configs should set `voice.tts_model_name` explicitly.
## Notes About API Base Handling
PicoClaw normalizes the configured base URL for TTS:
- For OpenAI, a base like `https://api.openai.com` or `https://api.openai.com/v1` becomes `https://api.openai.com/v1/audio/speech`.
- For other OpenAI-compatible providers, PicoClaw preserves the configured base path and ensures it ends with `/audio/speech`.
- If `api_base` is omitted, PicoClaw uses the provider default base when the model prefix is known.
## Common Mistakes
- Setting `voice.tts_model_name` to a name that does not exist in `model_list`.
- Adding a TTS model but forgetting to put its API key in `.security.yml`.
- Assuming PicoClaw will automatically use provider-specific custom voices.
- Using a provider endpoint that is not compatible with the OpenAI `/audio/speech` request format.
## Minimal Checklist
Before testing `send_tts`, make sure:
- `voice.tts_model_name` matches a `model_list[].model_name`.
- The matching `.security.yml` entry contains a valid API key.
- The chosen provider supports an OpenAI-compatible speech synthesis endpoint.
- Your selected model is actually a TTS-capable model.
+137
View File
@@ -0,0 +1,137 @@
# TTS(文本转语音)
这个目录负责 PicoClaw 的语音合成能力。
如果你是第一次配置 TTS,可以参照下面这个流程:
1. 在 `model_list` 里添加一个支持 TTS 的模型。
2. 用 `voice.tts_model_name` 指向这个模型。
3. 在 `.security.yml` 里配置对应的 API Key。
## 快速推荐
对于大多数用户,建议优先从下面两种开始:
| 提供商 | 推荐理由 |
| --- | --- |
| [OpenAI](https://platform.openai.com/docs/guides/text-to-speech) | 这是 PicoClaw 当前最稳定、最直接的 TTS 路径。当前实现就是围绕 OpenAI 兼容的 `/audio/speech` 接口格式构建的,所以 OpenAI 是最稳妥的默认选择。 |
| [Xiaomi MiMo](https://platform.xiaomimimo.com) | 由于响应速度和语音音色对于中国用户更友好,MiMo 是一个不错的第二选择。 |
## TTS 配置是如何工作的
PicoClaw 不会把 TTS 的 API Key 放在 `voice` 配置里。
推荐方式是:
- `voice.tts_model_name` 用来选择 `model_list` 里的某个命名模型。
- 对应的 `model_list` 条目提供真实的 provider、model ID、`api_base` 和代理配置。
- `.security.yml` 负责保存该模型条目的 API Key。
这是当前推荐且受支持的配置方式。
## 推荐配置方式
### 方案 AOpenAI
`config.json`
```json
{
"voice": {
"tts_model_name": "openai-tts"
},
"model_list": [
{
"model_name": "openai-tts",
"model": "openai/tts-1"
}
]
}
```
`.security.yml`
```yaml
model_list:
openai-tts:
api_keys:
- "sk-openai-your-key"
```
### 方案 BXiaomi MiMo
`config.json`
```json
{
"voice": {
"tts_model_name": "mimo-tts"
},
"model_list": [
{
"model_name": "mimo-tts",
"model": "mimo/mimo-v2-tts"
}
]
}
```
`.security.yml`
```yaml
model_list:
mimo-tts:
api_keys:
- "your-mimo-key"
```
如果你使用自定义的 MiMo 接口地址,也可以显式设置 `api_base`。如果不设置,PicoClaw 会自动使用该 provider 的默认地址。
## PicoClaw 当前实际发送的 TTS 请求
当前 TTS 运行时使用的是 OpenAI 兼容的语音合成请求,并带有以下默认值:
- Endpoint`/audio/speech`
- 返回格式:`opus`
- Voice`alloy`
- Model:来自你所选中的 `model_list` 条目
这意味着:
- `openai/tts-1` 可以自然工作。
- 其他 OpenAI 兼容 provider 也可能可用,前提是它们接受相同的请求格式。
- PicoClaw 目前还没有对用户暴露一个配置项来修改 TTS voice,当前固定为 `alloy`
## PicoClaw 如何选择 TTS Provider
`DetectTTS` 会按下面顺序选择 TTS
1. **首选路径**:根据 `voice.tts_model_name``model_list` 中找到对应模型。
2. 如果找到了匹配条目,并且它有 API Key,PicoClaw 就会使用这个模型条目的配置创建一个 OpenAI 兼容的 TTS provider。
3. **回退路径**:如果没有设置 `voice.tts_model_name`,或者该名字无法解析,PicoClaw 会扫描 `model_list`,选中第一个模型字符串里包含 `tts` 且带有 API Key 的条目。
回退扫描只是为了兼容旧行为。新配置建议始终显式设置 `voice.tts_model_name`
## 关于 API Base 的处理方式
PicoClaw 会对 TTS 的 `api_base` 做规范化处理:
- 对 OpenAI 来说,像 `https://api.openai.com``https://api.openai.com/v1` 这样的地址,会自动变成 `https://api.openai.com/v1/audio/speech`
- 对其他 OpenAI 兼容 providerPicoClaw 会尽量保留你提供的基础路径,只确保它最终以 `/audio/speech` 结尾。
- 如果没有设置 `api_base`,并且模型前缀是已知 providerPicoClaw 会自动使用该 provider 的默认地址。
## 常见错误
- `voice.tts_model_name` 指向了一个不存在的 `model_list` 名称。
- 在 `model_list` 里定义了 TTS 模型,但忘了在 `.security.yml` 中配置对应 API Key。
- 误以为 PicoClaw 会自动支持 provider 自定义 voice 参数。
- 使用了不兼容 OpenAI `/audio/speech` 请求格式的接口地址。
## 最小检查清单
在测试 `send_tts` 之前,请确认:
- `voice.tts_model_name` 能正确匹配某个 `model_list[].model_name`
- `.security.yml` 中对应条目已经配置了有效 API Key。
- 你所选的 provider 支持 OpenAI 兼容的语音合成接口。
- 你选择的模型本身确实支持 TTS。
+162
View File
@@ -0,0 +1,162 @@
package tts
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
type MimoTTSProvider struct {
apiKey string
apiBase string
voice string
format string
model string
httpClient *http.Client
}
func NewMimoTTSProvider(apiKey string, apiBase string, model string, proxyURL string) *MimoTTSProvider {
if apiBase == "" {
apiBase = "https://api.xiaomimimo.com/v1/chat/completions"
} else {
if u, err := url.Parse(apiBase); err == nil && u.Scheme != "" && u.Host != "" {
path := u.Path
if u.Host == "api.xiaomimimo.com" {
if path == "" || path == "/" || path == "/v1" || path == "/v1/" {
path = "/v1/chat/completions"
} else {
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
if !strings.HasPrefix(path, "/v1/") {
path = "/v1" + strings.TrimSuffix(path, "/")
}
if !strings.HasSuffix(path, "/chat/completions") {
path = strings.TrimSuffix(path, "/") + "/chat/completions"
}
}
} else {
if !strings.HasSuffix(path, "/chat/completions") {
path = strings.TrimSuffix(path, "/") + "/chat/completions"
}
}
u.Path = path
apiBase = u.String()
} else {
if apiBase == "https://api.xiaomimimo.com/v1" {
apiBase = "https://api.xiaomimimo.com/v1/chat/completions"
} else if !strings.HasSuffix(apiBase, "/chat/completions") {
apiBase = strings.TrimSuffix(apiBase, "/") + "/chat/completions"
}
}
}
model = strings.TrimSpace(model)
if model == "" {
model = "mimo-v2-tts"
}
client := &http.Client{Timeout: 60 * time.Second}
if proxyURL != "" {
if pURL, err := url.Parse(proxyURL); err == nil {
client.Transport = &http.Transport{Proxy: http.ProxyURL(pURL)}
} else {
logger.WarnF(
"NewMimoTTSProvider: invalid proxy URL; proceeding without proxy",
map[string]any{"proxyURL": proxyURL, "error": err},
)
}
}
return &MimoTTSProvider{
apiKey: apiKey,
apiBase: apiBase,
voice: "default_zh", // mimo_default now seems to be an alias for default_en, which is not working for Chinese TTS. default_zh seems to work fine with both English and Chinese, and is likely the intended default for TTS.
format: "mp3",
model: model,
httpClient: client,
}
}
func (t *MimoTTSProvider) Name() string {
return "mimo-tts"
}
func (t *MimoTTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) {
logger.DebugCF("voice-tts", "Starting TTS synthesis", map[string]any{"text_len": len(text), "provider": t.Name()})
reqBody := map[string]any{
"model": t.model,
"messages": []map[string]string{
{"role": "assistant", "content": text},
},
"audio": map[string]string{
"format": t.format,
"voice": t.voice,
},
"stream": false,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", t.apiBase, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Api-Key", t.apiKey)
resp, err := t.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var payload struct {
Choices []struct {
Message struct {
Audio struct {
Data string `json:"data"`
} `json:"audio"`
} `json:"message"`
} `json:"choices"`
}
err = json.Unmarshal(body, &payload)
if err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(payload.Choices) == 0 || payload.Choices[0].Message.Audio.Data == "" {
return nil, fmt.Errorf("invalid TTS response: missing audio data")
}
audioBytes, err := base64.StdEncoding.DecodeString(payload.Choices[0].Message.Audio.Data)
if err != nil {
return nil, fmt.Errorf("failed to decode audio data: %w", err)
}
return io.NopCloser(bytes.NewReader(audioBytes)), nil
}
+126
View File
@@ -0,0 +1,126 @@
package tts
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers/common"
)
type OpenAITTSProvider struct {
apiKey string
apiBase string
voice string
model string
httpClient *http.Client
}
func NewOpenAITTSProvider(apiKey string, apiBase string, proxyURL string, model string) *OpenAITTSProvider {
// Normalize apiBase to avoid malformed endpoints like
// "https://api.openai.com/audio/speech" when "/v1" is required.
if apiBase == "" {
apiBase = "https://api.openai.com/v1/audio/speech"
} else {
if u, err := url.Parse(apiBase); err == nil && u.Scheme != "" && u.Host != "" {
path := u.Path
if u.Host == "api.openai.com" {
// For the official OpenAI host, ensure exactly one /v1 prefix and
// that the path ends with /audio/speech.
if path == "" || path == "/" || path == "/v1" {
path = "/v1/audio/speech"
} else {
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
if !strings.HasPrefix(path, "/v1/") {
path = "/v1" + strings.TrimSuffix(path, "/")
}
if !strings.HasSuffix(path, "/audio/speech") {
path = strings.TrimSuffix(path, "/") + "/audio/speech"
}
}
} else {
// For non-OpenAI hosts (e.g., proxies), preserve the existing base
// path and only ensure it ends with /audio/speech.
if !strings.HasSuffix(path, "/audio/speech") {
path = strings.TrimSuffix(path, "/") + "/audio/speech"
}
}
u.Path = path
apiBase = u.String()
} else {
// Fallback to the previous string-based behavior if parsing fails.
if apiBase == "https://api.openai.com/v1" {
apiBase = "https://api.openai.com/v1/audio/speech"
} else if !strings.HasSuffix(apiBase, "/audio/speech") {
// Just in case they provide openrouter base or standard base
apiBase = strings.TrimSuffix(apiBase, "/") + "/audio/speech"
}
}
}
client := common.NewHTTPClient(proxyURL)
client.Timeout = 60 * time.Second
model = strings.TrimSpace(model)
if model == "" {
model = "tts-1"
}
return &OpenAITTSProvider{
apiKey: apiKey,
apiBase: apiBase,
voice: "alloy",
model: model,
httpClient: client,
}
}
func (t *OpenAITTSProvider) Name() string {
return "openai-tts"
}
func (t *OpenAITTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) {
logger.DebugCF("voice-tts", "Starting TTS synthesis", map[string]any{"text_len": len(text)})
reqBody := map[string]any{
"model": t.model,
"input": text,
"voice": t.voice,
"response_format": "opus",
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", t.apiBase, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+t.apiKey)
resp, err := t.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
return resp.Body, nil
}
+151
View File
@@ -0,0 +1,151 @@
package tts
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
)
type TTSProvider interface {
Name() string
Synthesize(ctx context.Context, text string) (io.ReadCloser, error)
}
func providerFromModelConfig(mc *config.ModelConfig) TTSProvider {
if mc == nil || mc.APIKey() == "" {
return nil
}
protocol, modelID := providers.ExtractProtocol(mc.Model)
if modelID == "" {
modelID = strings.TrimSpace(mc.Model)
}
switch protocol {
case "mimo":
return NewMimoTTSProvider(mc.APIKey(), providers.ResolveAPIBase(mc), modelID, mc.Proxy)
default:
return NewOpenAITTSProvider(mc.APIKey(), providers.ResolveAPIBase(mc), mc.Proxy, modelID)
}
}
func DetectTTS(cfg *config.Config) TTSProvider {
if cfg == nil {
return nil
}
if modelName := strings.TrimSpace(cfg.Voice.TTSModelName); modelName != "" {
if mc, err := cfg.GetModelConfig(modelName); err == nil {
if provider := providerFromModelConfig(mc); provider != nil {
return provider
}
}
}
for _, mc := range cfg.ModelList {
if strings.Contains(strings.ToLower(mc.Model), "tts") && mc.APIKey() != "" {
if provider := providerFromModelConfig(mc); provider != nil {
return provider
}
}
}
return nil
}
// SynthesizeAndStore synthesizes text to speech and registers it in the media store, returning the media reference.
func SynthesizeAndStore(
ctx context.Context,
provider TTSProvider,
store media.MediaStore,
text string,
filename string,
channel string,
chatID string,
) (string, error) {
if provider == nil {
return "", fmt.Errorf("tts provider is not configured")
}
if store == nil {
return "", fmt.Errorf("media store not configured")
}
if channel == "" || chatID == "" {
return "", fmt.Errorf("no target channel/chat available")
}
if strings.TrimSpace(text) == "" {
return "", fmt.Errorf("text is required")
}
stream, err := provider.Synthesize(ctx, text)
if err != nil {
return "", fmt.Errorf("tts synthesize failed: %w", err)
}
defer stream.Close()
err = os.MkdirAll(media.TempDir(), 0o700)
if err != nil {
return "", fmt.Errorf("failed to create media temp dir: %w", err)
}
fileExt := ".ogg"
contentType := "audio/ogg"
if provider.Name() == "mimo-tts" {
fileExt = ".mp3"
contentType = "audio/mpeg"
}
file, err := os.CreateTemp(media.TempDir(), "tts-*"+fileExt)
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
removeTemp := true
defer func() {
if removeTemp {
_ = os.Remove(file.Name())
}
}()
_, err = io.Copy(file, stream)
if err != nil {
file.Close()
return "", fmt.Errorf("failed to write tts audio: %w", err)
}
err = file.Close()
if err != nil {
return "", fmt.Errorf("failed to close tts audio file: %w", err)
}
filename = strings.TrimSpace(filename)
if filename == "" {
filename = fmt.Sprintf("tts-%d%s", time.Now().Unix(), fileExt)
}
ext := strings.ToLower(filepath.Ext(filename))
if ext == "" {
filename += fileExt
} else if ext != fileExt {
filename = strings.TrimSuffix(filename, filepath.Ext(filename)) + fileExt
}
scope := fmt.Sprintf("tool:send_tts:%s:%s:%d", channel, chatID, time.Now().UnixNano())
ref, err := store.Store(file.Name(), media.MediaMeta{
Filename: filename,
ContentType: contentType,
Source: "tool:send_tts",
}, scope)
if err != nil {
return "", fmt.Errorf("failed to register audio: %w", err)
}
removeTemp = false
return ref, nil
}
+247
View File
@@ -0,0 +1,247 @@
package tts
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestNewOpenAITTSProvider_APIBaseNormalization(t *testing.T) {
t.Parallel()
cases := []struct {
name string
input string
expect string
}{
{
name: "empty base",
input: "",
expect: "https://api.openai.com/v1/audio/speech",
},
{
name: "official host no path",
input: "https://api.openai.com",
expect: "https://api.openai.com/v1/audio/speech",
},
{
name: "official host v1",
input: "https://api.openai.com/v1",
expect: "https://api.openai.com/v1/audio/speech",
},
{
name: "official host v1 slash",
input: "https://api.openai.com/v1/",
expect: "https://api.openai.com/v1/audio/speech",
},
{
name: "non-openai host preserves base path",
input: "https://proxy.example.com/base",
expect: "https://proxy.example.com/base/audio/speech",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
provider := NewOpenAITTSProvider("key", tc.input, "", "")
if provider.apiBase != tc.expect {
t.Fatalf("apiBase mismatch: got %q, want %q", provider.apiBase, tc.expect)
}
})
}
}
func TestOpenAITTSProvider_SynthesizeSuccess(t *testing.T) {
t.Parallel()
var gotPath string
var gotAuth string
var gotContentType string
var gotBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuth = r.Header.Get("Authorization")
gotContentType = r.Header.Get("Content-Type")
bodyBytes, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
_ = json.Unmarshal(bodyBytes, &gotBody)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("audio-bytes"))
}))
defer server.Close()
provider := NewOpenAITTSProvider("k123", server.URL, "", "")
stream, err := provider.Synthesize(context.Background(), "hello")
if err != nil {
t.Fatalf("Synthesize failed: %v", err)
}
defer stream.Close()
data, err := io.ReadAll(stream)
if err != nil {
t.Fatalf("read stream failed: %v", err)
}
if gotPath != "/audio/speech" {
t.Fatalf("request path mismatch: got %q", gotPath)
}
if gotAuth != "Bearer k123" {
t.Fatalf("authorization mismatch: got %q", gotAuth)
}
if gotContentType != "application/json" {
t.Fatalf("content-type mismatch: got %q", gotContentType)
}
if gotBody["model"] != "tts-1" || gotBody["voice"] != "alloy" || gotBody["response_format"] != "opus" ||
gotBody["input"] != "hello" {
bodyJSON, _ := json.Marshal(gotBody)
t.Fatalf("request body mismatch: %s", string(bodyJSON))
}
if string(data) != "audio-bytes" {
t.Fatalf("response body mismatch: got %q", string(data))
}
}
func TestOpenAITTSProvider_SynthesizeNon200(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("nope"))
}))
defer server.Close()
provider := NewOpenAITTSProvider("k123", server.URL, "", "")
_, err := provider.Synthesize(context.Background(), "hello")
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "API error (status 500): nope") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewOpenAITTSProvider_UsesConfiguredModel(t *testing.T) {
t.Parallel()
provider := NewOpenAITTSProvider("key", "https://api.xiaomimimo.com/v1", "", "mimo-v2-tts")
if provider.model != "mimo-v2-tts" {
t.Fatalf("model mismatch: got %q, want %q", provider.model, "mimo-v2-tts")
}
if provider.apiBase != "https://api.xiaomimimo.com/v1/audio/speech" {
t.Fatalf("apiBase mismatch: got %q", provider.apiBase)
}
}
func TestDetectTTS_UsesMimoProviderForMimoModels(t *testing.T) {
t.Parallel()
provider := DetectTTS(&config.Config{
Voice: config.VoiceConfig{TTSModelName: "mimo-tts"},
ModelList: []*config.ModelConfig{
{
ModelName: "mimo-tts",
Model: "mimo/mimo-v2-tts",
APIKeys: config.SimpleSecureStrings("sk-mimo"),
},
},
})
ttsProvider, ok := provider.(*MimoTTSProvider)
if !ok {
t.Fatalf("DetectTTS() type = %T, want *MimoTTSProvider", provider)
}
if ttsProvider.model != "mimo-v2-tts" {
t.Fatalf("model mismatch: got %q, want %q", ttsProvider.model, "mimo-v2-tts")
}
if ttsProvider.apiBase != "https://api.xiaomimimo.com/v1/chat/completions" {
t.Fatalf("apiBase mismatch: got %q", ttsProvider.apiBase)
}
}
type stubTTSProvider struct {
name string
}
func (s stubTTSProvider) Name() string {
return s.name
}
func (s stubTTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader("audio")), nil
}
func TestSynthesizeAndStore_UsesOggMetadataByDefault(t *testing.T) {
t.Parallel()
store := media.NewFileMediaStore()
ref, err := SynthesizeAndStore(
context.Background(),
stubTTSProvider{name: "openai-tts"},
store,
"hello",
"",
"discord",
"chat123",
)
if err != nil {
t.Fatalf("SynthesizeAndStore failed: %v", err)
}
path, meta, err := store.ResolveWithMeta(ref)
if err != nil {
t.Fatalf("ResolveWithMeta failed: %v", err)
}
if meta.ContentType != "audio/ogg" {
t.Fatalf("ContentType = %q, want %q", meta.ContentType, "audio/ogg")
}
if filepath.Ext(path) != ".ogg" {
t.Fatalf("stored file extension = %q, want %q", filepath.Ext(path), ".ogg")
}
if filepath.Ext(meta.Filename) != ".ogg" {
t.Fatalf("filename extension = %q, want %q", filepath.Ext(meta.Filename), ".ogg")
}
}
func TestSynthesizeAndStore_UsesMp3MetadataForMimo(t *testing.T) {
t.Parallel()
store := media.NewFileMediaStore()
ref, err := SynthesizeAndStore(
context.Background(),
stubTTSProvider{name: "mimo-tts"},
store,
"hello",
"",
"discord",
"chat123",
)
if err != nil {
t.Fatalf("SynthesizeAndStore failed: %v", err)
}
path, meta, err := store.ResolveWithMeta(ref)
if err != nil {
t.Fatalf("ResolveWithMeta failed: %v", err)
}
if meta.ContentType != "audio/mpeg" {
t.Fatalf("ContentType = %q, want %q", meta.ContentType, "audio/mpeg")
}
if filepath.Ext(path) != ".mp3" {
t.Fatalf("stored file extension = %q, want %q", filepath.Ext(path), ".mp3")
}
if filepath.Ext(meta.Filename) != ".mp3" {
t.Fatalf("filename extension = %q, want %q", filepath.Ext(meta.Filename), ".mp3")
}
}
+28
View File
@@ -34,6 +34,8 @@ type MessageBus struct {
inbound chan InboundMessage
outbound chan OutboundMessage
outboundMedia chan OutboundMediaMessage
audioChunks chan AudioChunk
voiceControls chan VoiceControl
closeOnce sync.Once
done chan struct{}
@@ -47,6 +49,8 @@ func NewMessageBus() *MessageBus {
inbound: make(chan InboundMessage, defaultBusBufferSize),
outbound: make(chan OutboundMessage, defaultBusBufferSize),
outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize),
audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer
voiceControls: make(chan VoiceControl, defaultBusBufferSize),
done: make(chan struct{}),
}
}
@@ -103,6 +107,22 @@ func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
return mb.outboundMedia
}
func (mb *MessageBus) PublishAudioChunk(ctx context.Context, chunk AudioChunk) error {
return publish(ctx, mb, mb.audioChunks, chunk)
}
func (mb *MessageBus) AudioChunksChan() <-chan AudioChunk {
return mb.audioChunks
}
func (mb *MessageBus) PublishVoiceControl(ctx context.Context, ctrl VoiceControl) error {
return publish(ctx, mb, mb.voiceControls, ctrl)
}
func (mb *MessageBus) VoiceControlsChan() <-chan VoiceControl {
return mb.voiceControls
}
// SetStreamDelegate registers a StreamDelegate (typically the channel Manager).
func (mb *MessageBus) SetStreamDelegate(d StreamDelegate) {
mb.streamDelegate.Store(d)
@@ -132,6 +152,8 @@ func (mb *MessageBus) Close() {
close(mb.inbound)
close(mb.outbound)
close(mb.outboundMedia)
close(mb.audioChunks)
close(mb.voiceControls)
// clean up any remaining messages in channels
drained := 0
@@ -144,6 +166,12 @@ func (mb *MessageBus) Close() {
for range mb.outboundMedia {
drained++
}
for range mb.audioChunks {
drained++
}
for range mb.voiceControls {
drained++
}
if drained > 0 {
logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
+27 -4
View File
@@ -30,10 +30,11 @@ type InboundMessage struct {
}
type OutboundMessage struct {
Channel string `json:"channel"`
ChatID string `json:"chat_id"`
Content string `json:"content"`
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
Channel string `json:"channel"`
ChatID string `json:"chat_id"`
Content string `json:"content"`
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// MediaPart describes a single media attachment to send.
@@ -51,3 +52,25 @@ type OutboundMediaMessage struct {
ChatID string `json:"chat_id"`
Parts []MediaPart `json:"parts"`
}
// AudioChunk represents a chunk of streaming voice data.
type AudioChunk struct {
SessionID string `json:"session_id"`
SpeakerID string `json:"speaker_id"` // User ID or SSRC
ChatID string `json:"chat_id"` // Where to respond
Channel string `json:"channel"` // Source channel type (e.g. "discord")
Sequence uint64 `json:"sequence"`
Timestamp uint32 `json:"timestamp"`
SampleRate int `json:"sample_rate"`
Channels int `json:"channels"`
Format string `json:"format"` // "opus", "pcm", etc
Data []byte `json:"data"`
}
// VoiceControl represents state or commands for voice sessions.
type VoiceControl struct {
SessionID string `json:"session_id"`
ChatID string `json:"chat_id"`
Type string `json:"type"` // "state", "command"
Action string `json:"action"` // "idle", "listening", "start", "stop", "leave"
}
+170
View File
@@ -3,6 +3,7 @@ package discord
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"os"
@@ -14,6 +15,8 @@ import (
"github.com/bwmarrin/discordgo"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/audio"
"github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
@@ -42,6 +45,15 @@ type DiscordChannel struct {
typingMu sync.Mutex
typingStop map[string]chan struct{} // chatID → stop signal
botUserID string // stored for mention checking
bus *bus.MessageBus
tts tts.TTSProvider
voiceMu sync.RWMutex
voiceSSRC map[string]map[uint32]string // guildID -> ssrc -> userID
// TTS interruption: cancel active playback when user speaks
ttsMu sync.Mutex
cancelTTS context.CancelFunc
ttsPlayID uint64
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -73,6 +85,8 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
config: cfg,
ctx: context.Background(),
typingStop: make(map[string]chan struct{}),
bus: bus,
voiceSSRC: make(map[string]map[uint32]string),
}, nil
}
@@ -90,6 +104,8 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
c.session.AddHandler(c.handleMessage)
go c.listenVoiceControl(c.ctx)
if err := c.session.Open(); err != nil {
return fmt.Errorf("failed to open discord session: %w", err)
}
@@ -142,6 +158,25 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]s
return nil, nil
}
if c.tts != nil {
if ch, err := c.session.State.Channel(channelID); err == nil && ch.GuildID != "" {
if vc, ok := c.session.VoiceConnections[ch.GuildID]; ok && vc != nil {
// Cancel any previous TTS playback
c.ttsMu.Lock()
if c.cancelTTS != nil {
c.cancelTTS()
}
ttsCtx, ttsCancel := context.WithCancel(c.ctx)
c.ttsPlayID++
playID := c.ttsPlayID
c.cancelTTS = ttsCancel
c.ttsMu.Unlock()
go c.playTTS(ttsCtx, vc, msg.Content, playID)
}
}
}
msgID, err := c.sendChunk(ctx, channelID, msg.Content, msg.ReplyToMessageID)
if err != nil {
return nil, err
@@ -359,6 +394,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
if c.handleVoiceCommand(s, m) {
return
}
content := m.Content
// In guild (group) channels, apply unified group trigger filtering
@@ -630,3 +669,134 @@ func (c *DiscordChannel) stripBotMention(text string) string {
text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "")
return strings.TrimSpace(text)
}
func (c *DiscordChannel) listenVoiceControl(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case ctrl, ok := <-c.bus.VoiceControlsChan():
if !ok {
return
}
if ctrl.Type == "command" && ctrl.Action == "leave" {
if strings.HasPrefix(ctrl.SessionID, "discord_vc_") {
guildID := strings.TrimPrefix(ctrl.SessionID, "discord_vc_")
vc, exists := c.session.VoiceConnections[guildID]
if exists && vc != nil {
vc.Disconnect(ctx)
}
}
}
}
}
}
func (c *DiscordChannel) playTTS(ctx context.Context, vc *discordgo.VoiceConnection, text string, playID uint64) {
// Capture the cancel func associated with this playback (if any).
// Clear cancelTTS when playback finishes (normal or interrupted),
// but only if it still refers to this playback's cancel func.
defer func() {
c.ttsMu.Lock()
if c.ttsPlayID == playID {
c.cancelTTS = nil
}
c.ttsMu.Unlock()
}()
sentences := audio.SplitSentences(text)
if len(sentences) == 0 {
return
}
logger.InfoCF("discord", "Starting streamed TTS", map[string]any{"sentences": len(sentences)})
// Pipeline: prefetch next sentence's audio while playing current
type ttResult struct {
stream io.ReadCloser
err error
}
var prefetch chan ttResult
// Ensure any in-flight prefetch is drained on exit to prevent stream leaks,
// but avoid blocking indefinitely if the prefetch goroutine is stuck or never sends.
defer func() {
if prefetch != nil {
select {
case result := <-prefetch:
if result.stream != nil {
result.stream.Close()
}
case <-time.After(100 * time.Millisecond):
// Timed out waiting for a prefetched result; avoid blocking on exit.
}
}
}()
for i, sentence := range sentences {
// Check for cancellation (interruption)
select {
case <-ctx.Done():
logger.InfoCF("discord", "TTS interrupted", map[string]any{"at_sentence": i})
return
default:
}
// Start prefetching the NEXT sentence while we process the current one
var nextPrefetch chan ttResult
if i+1 < len(sentences) {
nextPrefetch = make(chan ttResult, 1)
nextSentence := sentences[i+1]
go func() {
s, e := c.tts.Synthesize(ctx, nextSentence)
nextPrefetch <- ttResult{s, e}
}()
}
// Get the current sentence's audio
var stream io.ReadCloser
var err error
if prefetch != nil {
// Use prefetched result from previous iteration, but be responsive to cancellation.
var result ttResult
select {
case result = <-prefetch:
stream, err = result.stream, result.err
case <-ctx.Done():
// Context canceled while waiting for prefetched audio; abort playback.
logger.InfoCF(
"discord",
"TTS interrupted while waiting for prefetched audio",
map[string]any{"at_sentence": i},
)
return
}
} else {
// First sentence: synthesize directly
stream, err = c.tts.Synthesize(ctx, sentence)
}
if err != nil {
if stream != nil {
stream.Close()
}
logger.ErrorCF("discord", "TTS synthesize failed", map[string]any{"error": err.Error(), "sentence": i})
prefetch = nextPrefetch
continue
}
if err := streamOggOpusToDiscord(ctx, vc, stream); err != nil {
logger.ErrorCF("discord", "TTS playback failed", map[string]any{"error": err.Error(), "sentence": i})
}
stream.Close()
prefetch = nextPrefetch
}
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *DiscordChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+6 -1
View File
@@ -1,6 +1,7 @@
package discord
import (
"github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
@@ -8,6 +9,10 @@ import (
func init() {
channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewDiscordChannel(cfg.Channels.Discord, b)
ch, err := NewDiscordChannel(cfg.Channels.Discord, b)
if err == nil {
ch.tts = tts.DetectTTS(cfg)
}
return ch, err
})
}
+313
View File
@@ -0,0 +1,313 @@
package discord
import (
"context"
"fmt"
"io"
"time"
"github.com/bwmarrin/discordgo"
"github.com/sipeed/picoclaw/pkg/audio"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
)
func (c *DiscordChannel) setVoiceUserID(guildID string, ssrc uint32, userID string) {
if userID == "" {
return
}
c.voiceMu.Lock()
defer c.voiceMu.Unlock()
ssrcMap, ok := c.voiceSSRC[guildID]
if !ok {
ssrcMap = make(map[uint32]string)
c.voiceSSRC[guildID] = ssrcMap
}
ssrcMap[ssrc] = userID
}
func (c *DiscordChannel) voiceUserID(guildID string, ssrc uint32) string {
c.voiceMu.RLock()
defer c.voiceMu.RUnlock()
ssrcMap, ok := c.voiceSSRC[guildID]
if !ok {
return ""
}
return ssrcMap[ssrc]
}
func (c *DiscordChannel) handleVoiceCommand(s *discordgo.Session, m *discordgo.MessageCreate) bool {
if m.Content == "!vc join" {
vs, err := s.State.VoiceState(m.GuildID, m.Author.ID)
if err != nil || vs == nil {
if _, sendErr := s.ChannelMessageSend(
m.ChannelID,
"You need to be in a voice channel first!",
); sendErr != nil {
logger.InfoCF("discord", "Failed to send voice channel requirement message", map[string]any{
"channel": m.ChannelID,
"error": sendErr,
})
}
return true
}
logger.InfoCF("discord", "Joining voice channel", map[string]any{"channel": vs.ChannelID})
vc, err := s.ChannelVoiceJoin(c.ctx, m.GuildID, vs.ChannelID, false, false)
if err != nil {
if _, sendErr := s.ChannelMessageSend(
m.ChannelID,
fmt.Sprintf("Failed to join voice channel: %v", err),
); sendErr != nil {
logger.InfoCF("discord", "Failed to send voice join error message", map[string]any{
"channel": m.ChannelID,
"error": sendErr,
})
}
return true
}
go c.receiveVoice(vc, m.GuildID, m.ChannelID)
if _, sendErr := s.ChannelMessageSend(
m.ChannelID,
"Joined Voice Channel! Listening for audio...",
); sendErr != nil {
logger.InfoCF("discord", "Failed to send voice join success message", map[string]any{
"channel": m.ChannelID,
"error": sendErr,
})
}
return true
} else if m.Content == "!vc leave" {
vc, exists := s.VoiceConnections[m.GuildID]
if exists && vc != nil {
if err := vc.Disconnect(c.ctx); err != nil {
logger.InfoCF("discord", "Failed to disconnect from voice channel", map[string]any{
"guild": m.GuildID,
"error": err,
})
}
if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Left Voice Channel."); sendErr != nil {
logger.InfoCF("discord", "Failed to send voice leave success message", map[string]any{
"channel": m.ChannelID,
"error": sendErr,
})
}
} else {
if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Not in a voice channel."); sendErr != nil {
logger.InfoCF("discord", "Failed to send voice not-in-channel message", map[string]any{
"channel": m.ChannelID,
"error": sendErr,
})
}
}
return true
}
return false
}
func VoiceReceiveActive(vc *discordgo.VoiceConnection) bool {
return vc != nil && vc.OpusRecv != nil
}
func streamOggOpusToDiscord(ctx context.Context, vc *discordgo.VoiceConnection, r io.Reader) (retErr error) {
// Recover from panic if vc.OpusSend is closed mid-send (e.g. on disconnect)
defer func() {
if rec := recover(); rec != nil {
retErr = fmt.Errorf("voice connection closed during playback")
}
}()
// Wait for the speaking transition to register
vc.Speaking(true)
defer vc.Speaking(false)
return audio.DecodeOggOpus(r, func(frame []byte) error {
select {
case <-ctx.Done():
return ctx.Err()
case vc.OpusSend <- frame:
return nil
}
})
}
func (c *DiscordChannel) receiveVoice(vc *discordgo.VoiceConnection, guildID string, chatID string) {
logger.InfoCF("discord", "Started listening for voice", map[string]any{"guild": guildID})
vc.AddHandler(func(_ *discordgo.VoiceConnection, vs *discordgo.VoiceSpeakingUpdate) {
if vs == nil {
return
}
c.setVoiceUserID(guildID, uint32(vs.SSRC), vs.UserID)
})
defer func() {
c.voiceMu.Lock()
delete(c.voiceSSRC, guildID)
c.voiceMu.Unlock()
}()
go func(ctx context.Context, vc *discordgo.VoiceConnection) {
// Recover from potential panics if OpusSend is closed mid-send.
defer func() {
if rec := recover(); rec != nil {
logger.WarnCF("discord", "Recovered from panic while sending wake-up frames", map[string]any{
"error": rec,
"guild": guildID,
})
}
}()
// If the voice connection or OpusSend are not available, nothing to do.
if vc == nil || vc.OpusSend == nil {
return
}
time.Sleep(250 * time.Millisecond) // Wait a bit for connection to settle
// Abort if the context has already been canceled.
select {
case <-ctx.Done():
return
default:
}
vc.Speaking(true)
defer vc.Speaking(false)
silenceFrame := []byte{0xF8, 0xFF, 0xFE}
for i := 0; i < 5; i++ {
select {
case <-ctx.Done():
return
case vc.OpusSend <- silenceFrame:
}
time.Sleep(20 * time.Millisecond)
}
logger.DebugCF("discord", "Sent wake-up silence frames", map[string]any{"guild": guildID})
}(c.ctx, vc)
sessionID := fmt.Sprintf("discord_vc_%s", guildID)
c.bus.PublishVoiceControl(c.ctx, bus.VoiceControl{
SessionID: sessionID,
Type: "state",
Action: "listening",
})
var sequence uint64 = 0
var interruptCount int
var lastInterruptAt time.Time
for {
select {
case <-c.ctx.Done():
return
case p, ok := <-vc.OpusRecv:
if !ok {
logger.InfoCF("discord", "Voice channel closed", map[string]any{"guild": guildID})
// Cancel any TTS that may still be playing
c.ttsMu.Lock()
if c.cancelTTS != nil {
c.cancelTTS()
c.cancelTTS = nil
}
c.ttsMu.Unlock()
return
}
if p == nil {
logger.DebugCF("discord", "Received nil Opus packet", nil)
continue
}
if len(p.Opus) == 0 {
logger.DebugCF("discord", "Received empty Opus packet", map[string]any{
"seq": p.Sequence,
"ssrc": p.SSRC,
})
continue
}
logger.DebugCF("discord", "Received Opus packet", map[string]any{
"seq": p.Sequence,
"len": len(p.Opus),
"ssrc": p.SSRC,
})
// Interruption detection: if user sends voice while TTS is playing,
// cancel TTS after a short debounce (3 packets in 200ms)
now := time.Now()
if now.Sub(lastInterruptAt) > 500*time.Millisecond {
interruptCount = 0
}
interruptCount++
lastInterruptAt = now
if interruptCount >= 3 {
c.ttsMu.Lock()
if c.cancelTTS != nil {
c.cancelTTS()
c.cancelTTS = nil
logger.InfoCF("discord", "TTS interrupted by user voice", nil)
}
c.ttsMu.Unlock()
interruptCount = 0
}
userID := c.voiceUserID(guildID, p.SSRC)
if userID == "" {
logger.DebugCF("discord", "Dropping voice packet without user mapping", map[string]any{
"ssrc": p.SSRC,
"guild": guildID,
})
continue
}
sender := bus.SenderInfo{
Platform: "discord",
PlatformID: userID,
CanonicalID: identity.BuildCanonicalID("discord", userID),
}
if !c.IsAllowedSender(sender) {
logger.DebugCF("discord", "Voice packet rejected by allowlist", map[string]any{
"user_id": userID,
"guild": guildID,
})
continue
}
sequence++
chunk := bus.AudioChunk{
SessionID: sessionID,
SpeakerID: userID,
ChatID: chatID,
Channel: "discord",
Sequence: sequence,
Timestamp: p.Timestamp,
SampleRate: 48000,
Channels: 2,
Format: "opus",
Data: p.Opus,
}
ctx, cancel := context.WithTimeout(c.ctx, 100*time.Millisecond)
err := c.bus.PublishAudioChunk(ctx, chunk)
cancel()
if err != nil {
logger.ErrorCF("discord", "Failed to publish audio chunk", map[string]any{
"guild": guildID,
"sessionID": sessionID,
"sequence": sequence,
"error": err.Error(),
})
}
}
}
}
+7
View File
@@ -6,6 +6,8 @@ import (
"strings"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
"github.com/sipeed/picoclaw/pkg/channels"
)
// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions.
@@ -145,3 +147,8 @@ func extractImageKeysRecursive(v any, feishuKeys, externalURLs *[]string) {
}
}
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *FeishuChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+5
View File
@@ -684,3 +684,8 @@ func (c *LINEChannel) downloadContent(messageID, filename string) string {
},
})
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *LINEChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+5
View File
@@ -1300,3 +1300,8 @@ func stripUserMentionWithRegexp(text string, userID id.UserID, mentionR *regexp.
cleaned = strings.TrimLeft(cleaned, ",:; ")
return strings.TrimSpace(cleaned)
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *MatrixChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+5
View File
@@ -1104,3 +1104,8 @@ func truncate(s string, n int) string {
}
return string(runes[:n]) + "..."
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *OneBotChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+5
View File
@@ -1002,3 +1002,8 @@ func sanitizeURLs(text string) string {
return scheme + domain + path
})
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *QQChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+5
View File
@@ -1133,3 +1133,8 @@ func cryptoRandInt() int {
_, _ = rand.Read(b[:])
return int(binary.BigEndian.Uint32(b[:])) | 1 // ensure non-zero
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *TelegramChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+58
View File
@@ -0,0 +1,58 @@
package channels
// VoiceCapabilities describes whether ASR (speech-to-text) and TTS (text-to-speech)
// are available for a channel under the current configuration.
type VoiceCapabilities struct {
ASR bool
TTS bool
}
// VoiceCapabilityProvider is an optional interface for channels that want to
// explicitly declare their ASR/TTS support.
type VoiceCapabilityProvider interface {
VoiceCapabilities() VoiceCapabilities
}
// Deprecated: Channels should implement VoiceCapabilityProvider instead.
// To be removed once all existing capable channels conform to the interface.
var asrCapableChannels = map[string]bool{
"discord": true,
"telegram": true,
"matrix": true,
"qq": true,
"weixin": true,
"line": true,
"feishu": true,
"onebot": true,
}
// DetectVoiceCapabilities returns ASR/TTS availability for a channel, gated by
// whether providers are configured.
func DetectVoiceCapabilities(channelName string, ch Channel, asrAvailable bool, ttsAvailable bool) VoiceCapabilities {
if ch == nil {
return VoiceCapabilities{}
}
if vcp, ok := ch.(VoiceCapabilityProvider); ok {
caps := vcp.VoiceCapabilities()
if !asrAvailable {
caps.ASR = false
}
if !ttsAvailable {
caps.TTS = false
}
return caps
}
caps := VoiceCapabilities{}
if asrAvailable {
caps.ASR = asrCapableChannels[channelName]
}
if ttsAvailable {
if _, ok := ch.(MediaSender); ok {
caps.TTS = true
}
}
return caps
}
+5
View File
@@ -402,3 +402,8 @@ func (c *WeixinChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st
return nil, nil
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *WeixinChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+6 -3
View File
@@ -558,9 +558,9 @@ type DevicesConfig struct {
}
type VoiceConfig struct {
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"`
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
ElevenLabsAPIKey string `json:"elevenlabs_api_key,omitempty" env:"PICOCLAW_VOICE_ELEVENLABS_API_KEY"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"`
TTSModelName string `json:"tts_model_name,omitempty" env:"PICOCLAW_VOICE_TTS_MODEL_NAME"`
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
}
// ModelConfig represents a model-centric provider configuration.
@@ -829,6 +829,7 @@ type ToolsConfig struct {
Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
SendTTS ToolConfig `json:"send_tts" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_TTS_"`
Spawn ToolConfig `json:"spawn" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
SpawnStatus ToolConfig `json:"spawn_status" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_STATUS_"`
SPI ToolConfig `json:"spi" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPI_"`
@@ -1281,6 +1282,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
return t.WebFetch.Enabled
case "send_file":
return t.SendFile.Enabled
case "send_tts":
return t.SendTTS.Enabled
case "write_file":
return t.WriteFile.Enabled
case "mcp":
+3
View File
@@ -434,6 +434,9 @@ func DefaultConfig() *Config {
SendFile: ToolConfig{
Enabled: true,
},
SendTTS: ToolConfig{
Enabled: false,
},
MCP: MCPConfig{
ToolConfig: ToolConfig{
Enabled: false,
+51 -3
View File
@@ -6,6 +6,7 @@ import (
"os"
"os/signal"
"path/filepath"
"sort"
"strings"
"sync"
"sync/atomic"
@@ -13,6 +14,8 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/audio/asr"
"github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
_ "github.com/sipeed/picoclaw/pkg/channels/dingtalk"
@@ -41,7 +44,6 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/voice"
)
const (
@@ -61,6 +63,7 @@ type services struct {
ChannelManager *channels.Manager
DeviceService *devices.Service
HealthServer *health.Server
VoiceAgentCancel context.CancelFunc
manualReloadChan chan struct{}
reloading atomic.Bool
authToken string
@@ -70,6 +73,27 @@ type startupBlockedProvider struct {
reason string
}
func logChannelVoiceCapabilities(cm *channels.Manager, asrAvailable bool, ttsAvailable bool) {
if cm == nil {
return
}
names := cm.GetEnabledChannels()
sort.Strings(names)
for _, name := range names {
ch, ok := cm.GetChannel(name)
if !ok {
continue
}
caps := channels.DetectVoiceCapabilities(name, ch, asrAvailable, ttsAvailable)
logger.InfoCF("voice", "Channel voice capabilities", map[string]any{
"channel": name,
"asr": caps.ASR,
"tts": caps.TTS,
})
}
}
func (p *startupBlockedProvider) Chat(
_ context.Context,
_ []providers.Message,
@@ -337,11 +361,14 @@ func setupAndStartServices(
agentLoop.SetChannelManager(runningServices.ChannelManager)
agentLoop.SetMediaStore(runningServices.MediaStore)
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
transcriber := asr.DetectTranscriber(cfg)
if transcriber != nil {
agentLoop.SetTranscriber(transcriber)
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
}
ttsAvailable := tts.DetectTTS(cfg) != nil
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
@@ -358,6 +385,16 @@ func setupAndStartServices(
return nil, fmt.Errorf("error starting channels: %w", err)
}
logChannelVoiceCapabilities(runningServices.ChannelManager, transcriber != nil, ttsAvailable)
if transcriber != nil {
// Start Voice Agent Orchestrator after channels are ready.
vaCtx, vaCancel := context.WithCancel(context.Background())
runningServices.VoiceAgentCancel = vaCancel
voiceAgent := asr.NewAgent(msgBus, transcriber)
voiceAgent.Start(vaCtx)
}
fmt.Printf(
"✓ Health endpoints available at http://%s:%d/health, /ready and /reload (POST)\n",
cfg.Gateway.Host,
@@ -387,6 +424,9 @@ func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Dura
if !isReload && runningServices.ChannelManager != nil {
runningServices.ChannelManager.StopAll(shutdownCtx)
}
if runningServices.VoiceAgentCancel != nil {
runningServices.VoiceAgentCancel()
}
if runningServices.DeviceService != nil {
runningServices.DeviceService.Stop()
}
@@ -563,14 +603,22 @@ func restartServices(
fmt.Println(" ✓ Device event service restarted")
}
transcriber := voice.DetectTranscriber(cfg)
transcriber := asr.DetectTranscriber(cfg)
al.SetTranscriber(transcriber)
if transcriber != nil {
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
// Start Voice Agent Orchestrator on reload
vaCtx, vaCancel := context.WithCancel(context.Background())
runningServices.VoiceAgentCancel = vaCancel
voiceAgent := asr.NewAgent(msgBus, transcriber)
voiceAgent.Start(vaCtx)
} else {
logger.InfoCF("voice", "Transcription disabled", nil)
}
ttsAvailable := tts.DetectTTS(cfg) != nil
logChannelVoiceCapabilities(runningServices.ChannelManager, transcriber != nil, ttsAvailable)
// NOTE: PID file is written once at startup and not updated on reload.
// Changing the gateway listen address requires a full restart.
+13
View File
@@ -98,6 +98,19 @@ func ExtractProtocol(model string) (protocol, modelID string) {
return protocol, modelID
}
// ResolveAPIBase returns the configured API base, or the protocol default when
// the model uses an HTTP-based provider family with a known default endpoint.
func ResolveAPIBase(cfg *config.ModelConfig) string {
if cfg == nil {
return ""
}
if apiBase := strings.TrimSpace(cfg.APIBase); apiBase != "" {
return strings.TrimRight(apiBase, "/")
}
protocol, _ := ExtractProtocol(cfg.Model)
return strings.TrimRight(getDefaultAPIBase(protocol), "/")
}
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
+82
View File
@@ -0,0 +1,82 @@
package tools
import (
"context"
"strings"
"github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/media"
)
type SendTTSTool struct {
provider tts.TTSProvider
mediaStore media.MediaStore
}
func NewSendTTSTool(provider tts.TTSProvider, store media.MediaStore) *SendTTSTool {
return &SendTTSTool{
provider: provider,
mediaStore: store,
}
}
func (t *SendTTSTool) Name() string { return "send_tts" }
func (t *SendTTSTool) Description() string {
return "Synthesize speech from text and send it as an audio file to the user."
}
func (t *SendTTSTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"text": map[string]any{
"type": "string",
"description": "The text to synthesize into speech. NOTE: Reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally.",
},
"filename": map[string]any{
"type": "string",
"description": "Optional filename for the audio file (e.g., response.ogg).",
},
},
"required": []string{"text"},
}
}
func (t *SendTTSTool) SetMediaStore(store media.MediaStore) {
t.mediaStore = store
}
func (t *SendTTSTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
text, _ := args["text"].(string)
text = strings.TrimSpace(text)
if text == "" {
return ErrorResult("text is required")
}
channel := ToolChannel(ctx)
chatID := ToolChatID(ctx)
filename, _ := args["filename"].(string)
ref, err := tts.SynthesizeAndStore(
ctx,
t.provider,
t.mediaStore,
text,
filename,
channel,
chatID,
)
if err != nil {
return ErrorResult(err.Error()).WithError(err)
}
// Return with ForUser set to original text, Media containing the audio ref,
// and mark as ResponseHandled so the audio is sent immediately without LLM intervention.
return &ToolResult{
ForLLM: "TTS audio sent",
ForUser: text,
Media: []string{ref},
ResponseHandled: true,
}
}
-151
View File
@@ -1,151 +0,0 @@
package voice
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type GroqTranscriber struct {
apiKey string
apiBase string
httpClient *http.Client
}
func NewGroqTranscriber(apiKey string) *GroqTranscriber {
logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""})
apiBase := "https://api.groq.com/openai/v1"
return &GroqTranscriber{
apiKey: apiKey,
apiBase: apiBase,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
}
}
func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath})
audioFile, err := os.Open(audioFilePath)
if err != nil {
logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err})
return nil, fmt.Errorf("failed to open audio file: %w", err)
}
defer audioFile.Close()
fileInfo, err := audioFile.Stat()
if err != nil {
logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err})
return nil, fmt.Errorf("failed to get file info: %w", err)
}
logger.DebugCF("voice", "Audio file details", map[string]any{
"size_bytes": fileInfo.Size(),
"file_name": filepath.Base(audioFilePath),
})
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
if err != nil {
logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err})
return nil, fmt.Errorf("failed to create form file: %w", err)
}
copied, err := io.Copy(part, audioFile)
if err != nil {
logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err})
return nil, fmt.Errorf("failed to copy file content: %w", err)
}
logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied})
if err = writer.WriteField("model", "whisper-large-v3"); err != nil {
logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err})
return nil, fmt.Errorf("failed to write model field: %w", err)
}
if err = writer.WriteField("response_format", "json"); err != nil {
logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err})
return nil, fmt.Errorf("failed to write response_format field: %w", err)
}
if err = writer.Close(); err != nil {
logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err})
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
}
url := t.apiBase + "/audio/transcriptions"
req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody)
if err != nil {
logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err})
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+t.apiKey)
logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{
"url": url,
"request_size_bytes": requestBody.Len(),
"file_size_bytes": fileInfo.Size(),
})
resp, err := t.httpClient.Do(req)
if err != nil {
logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err})
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err})
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
logger.ErrorCF("voice", "API error", map[string]any{
"status_code": resp.StatusCode,
"response": string(body),
})
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
logger.DebugCF("voice", "Received response from Groq API", map[string]any{
"status_code": resp.StatusCode,
"response_size_bytes": len(body),
})
var result TranscriptionResponse
if err := json.Unmarshal(body, &result); err != nil {
logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err})
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
logger.InfoCF("voice", "Transcription completed successfully", map[string]any{
"text_length": len(result.Text),
"language": result.Language,
"duration_seconds": result.Duration,
"transcription_preview": utils.Truncate(result.Text, 50),
})
return &result, nil
}
func (t *GroqTranscriber) Name() string {
return "groq"
}
-84
View File
@@ -1,84 +0,0 @@
package voice
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
var _ Transcriber = (*GroqTranscriber)(nil)
func TestGroqTranscriberName(t *testing.T) {
tr := NewGroqTranscriber("sk-test")
if got := tr.Name(); got != "groq" {
t.Errorf("Name() = %q, want %q", got, "groq")
}
}
func TestGroqTranscribe(t *testing.T) {
// Write a minimal fake audio file so the transcriber can open and send it.
tmpDir := t.TempDir()
audioPath := filepath.Join(tmpDir, "clip.ogg")
if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil {
t.Fatalf("failed to write fake audio file: %v", err)
}
t.Run("success", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/audio/transcriptions" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
if r.Header.Get("Authorization") != "Bearer sk-test" {
t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization"))
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TranscriptionResponse{
Text: "hello world",
Language: "en",
Duration: 1.5,
})
}))
defer srv.Close()
tr := NewGroqTranscriber("sk-test")
tr.apiBase = srv.URL
resp, err := tr.Transcribe(context.Background(), audioPath)
if err != nil {
t.Fatalf("Transcribe() error: %v", err)
}
if resp.Text != "hello world" {
t.Errorf("Text = %q, want %q", resp.Text, "hello world")
}
if resp.Language != "en" {
t.Errorf("Language = %q, want %q", resp.Language, "en")
}
})
t.Run("api error", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized)
}))
defer srv.Close()
tr := NewGroqTranscriber("sk-bad")
tr.apiBase = srv.URL
_, err := tr.Transcribe(context.Background(), audioPath)
if err == nil {
t.Fatal("expected error for non-200 response, got nil")
}
})
t.Run("missing file", func(t *testing.T) {
tr := NewGroqTranscriber("sk-test")
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
})
}
-68
View File
@@ -1,68 +0,0 @@
package voice
import (
"context"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
type Transcriber interface {
Name() string
Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
}
type TranscriptionResponse struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
}
func supportsAudioTranscription(model string) bool {
protocol, _ := providers.ExtractProtocol(model)
switch protocol {
case "openai", "azure", "azure-openai",
"litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
"coding-plan", "alibaba-coding", "qwen-coding":
// These protocols all go through the OpenAI-compatible or Azure provider path in
// providers.CreateProviderFromConfig, so they are the only ones that can supply
// the audio media payload shape expected by NewAudioModelTranscriber.
// TODO: Further restrict this by modelID, since not every model under these
// protocols supports audio transcription.
return true
default:
return false
}
}
// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or
// nil if no supported transcription provider is configured.
func DetectTranscriber(cfg *config.Config) Transcriber {
if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" {
modelCfg, err := cfg.GetModelConfig(modelName)
if err != nil {
return nil
}
if supportsAudioTranscription(modelCfg.Model) {
return NewAudioModelTranscriber(modelCfg)
}
}
// ElevenLabs voice config (supports Scribe STT).
if key := strings.TrimSpace(cfg.Voice.ElevenLabsAPIKey); key != "" {
return NewElevenLabsTranscriber(key)
}
// Fall back to any model-list entry that uses the groq/ protocol.
for _, mc := range cfg.ModelList {
if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey() != "" {
return NewGroqTranscriber(mc.APIKey())
}
}
return nil
}