mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Refactor/asr tts (#1939)
* refactor: update ASR and TTS implementations * fix lint * Integrating asr/tts models w/ new security config * update documents * add arbitrary whisper transcriptor support * update documents * fix lint * add mimo tts
This commit is contained in:
+53
-17
@@ -18,6 +18,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio/asr"
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
@@ -31,7 +33,6 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
type AgentLoop struct {
|
||||
@@ -51,7 +52,7 @@ type AgentLoop struct {
|
||||
fallback *providers.FallbackChain
|
||||
channelManager *channels.Manager
|
||||
mediaStore media.MediaStore
|
||||
transcriber voice.Transcriber
|
||||
transcriber asr.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
hookRuntime hookRuntime
|
||||
@@ -159,6 +160,13 @@ func registerSharedTools(
|
||||
provider providers.LLMProvider,
|
||||
) {
|
||||
allowReadPaths := buildAllowReadPatterns(cfg)
|
||||
var ttsProvider tts.TTSProvider
|
||||
if cfg.Tools.IsToolEnabled("send_tts") {
|
||||
ttsProvider = tts.DetectTTS(cfg)
|
||||
if ttsProvider == nil {
|
||||
logger.WarnCF("voice-tts", "send_tts enabled but no TTS provider configured", nil)
|
||||
}
|
||||
}
|
||||
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
agent, ok := registry.GetAgent(agentID)
|
||||
@@ -269,6 +277,10 @@ func registerSharedTools(
|
||||
agent.Tools.Register(sendFileTool)
|
||||
}
|
||||
|
||||
if ttsProvider != nil {
|
||||
agent.Tools.Register(tools.NewSendTTSTool(ttsProvider, nil))
|
||||
}
|
||||
|
||||
// Skill discovery and installation tools
|
||||
skills_enabled := cfg.Tools.IsToolEnabled("skills")
|
||||
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
|
||||
@@ -1059,10 +1071,15 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
|
||||
agent.Tools.SetMediaStore(s)
|
||||
}
|
||||
}
|
||||
registry.ForEachTool("send_tts", func(t tools.Tool) {
|
||||
if st, ok := t.(*tools.SendTTSTool); ok {
|
||||
st.SetMediaStore(s)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SetTranscriber injects a voice transcriber for agent-level audio transcription.
|
||||
func (al *AgentLoop) SetTranscriber(t voice.Transcriber) {
|
||||
func (al *AgentLoop) SetTranscriber(t asr.Transcriber) {
|
||||
al.transcriber = t
|
||||
}
|
||||
|
||||
@@ -1083,19 +1100,23 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou
|
||||
|
||||
// Transcribe each audio media ref in order.
|
||||
var transcriptions []string
|
||||
var keptMedia []string
|
||||
for _, ref := range msg.Media {
|
||||
path, meta, err := al.mediaStore.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err})
|
||||
keptMedia = append(keptMedia, ref)
|
||||
continue
|
||||
}
|
||||
if !utils.IsAudioFile(meta.Filename, meta.ContentType) {
|
||||
keptMedia = append(keptMedia, ref)
|
||||
continue
|
||||
}
|
||||
result, err := al.transcriber.Transcribe(ctx, path)
|
||||
if err != nil {
|
||||
logger.WarnCF("voice", "Transcription failed", map[string]any{"ref": ref, "error": err})
|
||||
transcriptions = append(transcriptions, "")
|
||||
keptMedia = append(keptMedia, ref)
|
||||
continue
|
||||
}
|
||||
transcriptions = append(transcriptions, result.Text)
|
||||
@@ -1115,15 +1136,21 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou
|
||||
}
|
||||
text := transcriptions[idx]
|
||||
idx++
|
||||
if text == "" {
|
||||
return match
|
||||
}
|
||||
return "[voice: " + text + "]"
|
||||
})
|
||||
|
||||
// Append any remaining transcriptions not matched by an annotation.
|
||||
for ; idx < len(transcriptions); idx++ {
|
||||
newContent += "\n[voice: " + transcriptions[idx] + "]"
|
||||
if transcriptions[idx] != "" {
|
||||
newContent += "\n[voice: " + transcriptions[idx] + "]"
|
||||
}
|
||||
}
|
||||
|
||||
msg.Content = newContent
|
||||
msg.Media = keptMedia
|
||||
return msg, true
|
||||
}
|
||||
|
||||
@@ -2464,6 +2491,28 @@ turnLoop:
|
||||
if toolResult == nil {
|
||||
toolResult = tools.ErrorResult("hook returned nil tool result")
|
||||
}
|
||||
|
||||
// Send ForUser if not silent and has content.
|
||||
// For ResponseHandled tools, send regardless of SendResponse setting,
|
||||
// since they've already handled the response (e.g., send_tts, send_file).
|
||||
shouldSendForUser := !toolResult.Silent && toolResult.ForUser != "" &&
|
||||
(ts.opts.SendResponse || toolResult.ResponseHandled)
|
||||
if shouldSendForUser {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Content: toolResult.ForUser,
|
||||
Metadata: map[string]string{
|
||||
"is_tool_call": "true",
|
||||
},
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]any{
|
||||
"tool": toolName,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
|
||||
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
@@ -2509,19 +2558,6 @@ turnLoop:
|
||||
allResponsesHandled = false
|
||||
}
|
||||
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Content: toolResult.ForUser,
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]any{
|
||||
"tool": toolName,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
|
||||
contentForLLM := toolResult.ContentForLLM()
|
||||
|
||||
// Filter sensitive data (API keys, tokens, secrets) before sending to LLM
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
# ASR (Automatic Speech Recognition)
|
||||
|
||||
This package handles speech-to-text for PicoClaw voice input.
|
||||
|
||||
If you are new to ASR setup, the simplest mental model is:
|
||||
|
||||
1. Add one or more ASR-capable entries to `model_list`.
|
||||
2. Point `voice.model_name` at the one you want to use.
|
||||
3. Put the API key in `.security.yml`.
|
||||
|
||||
## Quick Recommendation
|
||||
|
||||
For most new users, start with one of these:
|
||||
|
||||
| Provider | Example model | Why start here |
|
||||
| --- | --- | --- |
|
||||
| [Groq](https://console.groq.com/keys) | `groq/whisper-large-v3-turbo` | Fast Whisper-style transcription and a straightforward OpenAI-compatible API. Groq currently advertises a free tier plan for 2000 reqs/day. |
|
||||
| [ElevenLabs](https://elevenlabs.io/pricing) | `elevenlabs/scribe_v1` | Easy setup and strong speech-to-text quality. ElevenLabs currently advertises a free plan that includes speech-to-text usage. |
|
||||
|
||||
Pricing and free-plan limits can change, so check the linked pricing pages before depending on them in production.
|
||||
|
||||
## How ASR Configuration Works
|
||||
|
||||
PicoClaw does not keep ASR API keys inside the `voice` section.
|
||||
|
||||
Instead:
|
||||
|
||||
- `voice.model_name` chooses a named entry from `model_list`.
|
||||
- The matching `model_list` entry describes the actual provider and model.
|
||||
- `.security.yml` stores the API key for that named model entry.
|
||||
|
||||
This is the recommended pattern because it is explicit, reusable, and consistent with the rest of PicoClaw's model configuration.
|
||||
|
||||
## Recommended Setup
|
||||
|
||||
### Option A: Groq Whisper
|
||||
|
||||
`config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"voice": {
|
||||
"model_name": "groq-asr",
|
||||
"echo_transcription": true
|
||||
},
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "groq-asr",
|
||||
"model": "groq/whisper-large-v3-turbo"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`.security.yml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
groq-asr:
|
||||
api_keys:
|
||||
- "gsk_your_groq_key"
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- You can omit `api_base` and PicoClaw will use Groq's default API base automatically.
|
||||
- If you set `api_base` manually for Groq Whisper, both of these forms work:
|
||||
- `https://api.groq.com/openai/v1`
|
||||
- `https://api.groq.com/openai/v1/audio/transcriptions`
|
||||
- Any OpenAI-compatible Whisper model name containing `whisper` can use the Whisper transcription path, not only `whisper-large-v3-turbo`.
|
||||
|
||||
### Option B: ElevenLabs
|
||||
|
||||
`config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"voice": {
|
||||
"model_name": "elevenlabs-asr",
|
||||
"echo_transcription": true
|
||||
},
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "elevenlabs-asr",
|
||||
"model": "elevenlabs/scribe_v1"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`.security.yml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
elevenlabs-asr:
|
||||
api_keys:
|
||||
- "sk-elevenlabs-your-key"
|
||||
```
|
||||
|
||||
### Option C: OpenAI Whisper
|
||||
|
||||
`config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"voice": {
|
||||
"model_name": "openai-asr"
|
||||
},
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "openai-asr",
|
||||
"model": "openai/whisper-1"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`.security.yml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
openai-asr:
|
||||
api_keys:
|
||||
- "sk-openai-your-key"
|
||||
```
|
||||
|
||||
## Other ASR-Capable Model Types
|
||||
|
||||
PicoClaw currently supports three main ASR routes:
|
||||
|
||||
| Route | Example models | Behavior |
|
||||
| --- | --- | --- |
|
||||
| ElevenLabs ASR | `elevenlabs/scribe_v1` | Uses the ElevenLabs transcription API. |
|
||||
| Whisper endpoint models | `openai/whisper-1`, `groq/whisper-large-v3` | Uses an OpenAI-compatible `/audio/transcriptions` endpoint. |
|
||||
| Audio-capable chat models **(Under construction)** | `openai/gpt-4o-audio-preview`, `gemini/gemini-2.5-flash` | Sends audio to a multimodal chat model and asks it to transcribe. |
|
||||
|
||||
If you are unsure which one to pick, choose Groq Whisper or ElevenLabs first.
|
||||
|
||||
## How PicoClaw Chooses a Transcriber
|
||||
|
||||
`DetectTranscriber` resolves ASR in this order:
|
||||
|
||||
1. **Preferred path**: resolve `voice.model_name` against `model_list`.
|
||||
2. If that resolved model is:
|
||||
- `elevenlabs/...`, PicoClaw uses the ElevenLabs transcriber.
|
||||
- an OpenAI-compatible Whisper model, PicoClaw uses the Whisper transcriber.
|
||||
- an audio-capable chat model, PicoClaw uses `AudioModelTranscriber`.
|
||||
3. **Fallback path**: if `voice.model_name` is not set, PicoClaw performs a compatibility scan through `model_list` for legacy auto-detected ASR entries.
|
||||
|
||||
Fallback scanning exists for backward compatibility. New configurations should set `voice.model_name` explicitly.
|
||||
|
||||
## Common Mistakes
|
||||
|
||||
- Defining an ASR model in `model_list` but forgetting to set `voice.model_name`.
|
||||
- Putting the API key in `voice` instead of `.security.yml`.
|
||||
- Using a non-ASR model and expecting Whisper-style transcription behavior.
|
||||
- Setting a custom `api_base` that points to the wrong provider endpoint.
|
||||
|
||||
## Minimal Checklist
|
||||
|
||||
Before testing voice input, make sure:
|
||||
|
||||
- `voice.model_name` matches a `model_list[].model_name`.
|
||||
- The matching `.security.yml` entry contains a valid API key.
|
||||
- The selected model is actually ASR-capable.
|
||||
- Voice input is enabled for the channel you are using.
|
||||
@@ -0,0 +1,166 @@
|
||||
# ASR(自动语音识别)
|
||||
|
||||
这个目录负责 PicoClaw 的语音转文字能力。
|
||||
|
||||
如果你是第一次配置 ASR,可以参考如下步骤:
|
||||
|
||||
1. 在 `model_list` 里添加一个或多个支持 ASR 的模型条目。
|
||||
2. 用 `voice.model_name` 指向你想使用的那个条目。
|
||||
3. 在 `.security.yml` 里配置对应的 API Key。
|
||||
|
||||
## 快速推荐
|
||||
|
||||
对于大多数新用户,建议先从下面两种开始:
|
||||
|
||||
| 提供商 | 示例模型 | 推荐理由 |
|
||||
| --- | --- | --- |
|
||||
| [Groq](https://console.groq.com/keys) | `groq/whisper-large-v3-turbo` | Whisper 风格转录速度快,并且提供 OpenAI 兼容接口,配置比较直接。Groq 目前官方提供2000请求每日的免费套餐。 |
|
||||
| [ElevenLabs](https://elevenlabs.io/pricing) | `elevenlabs/scribe_v1` | 上手简单,语音转文字质量也不错。ElevenLabs 目前官方免费套餐包含 STT 用量。 |
|
||||
|
||||
价格和免费额度可能会变化,正式使用前请以官网定价页为准。
|
||||
|
||||
## ASR 配置是如何工作的
|
||||
|
||||
PicoClaw 不会把 ASR 的 API Key 放在 `voice` 配置里。
|
||||
|
||||
推荐的方式是:
|
||||
|
||||
- `voice.model_name` 用来选择 `model_list` 里的某个命名模型。
|
||||
- `model_list` 条目描述真实的提供商和模型。
|
||||
- `.security.yml` 负责保存该模型条目的 API Key。
|
||||
|
||||
这种方式更明确、更安全,也和 PicoClaw 其他模型配置方式保持一致。
|
||||
|
||||
## 推荐配置方式
|
||||
|
||||
### 方案 A:Groq Whisper
|
||||
|
||||
`config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"voice": {
|
||||
"model_name": "groq-asr",
|
||||
"echo_transcription": true
|
||||
},
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "groq-asr",
|
||||
"model": "groq/whisper-large-v3-turbo"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`.security.yml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
groq-asr:
|
||||
api_keys:
|
||||
- "gsk_your_groq_key"
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- 你可以不写 `api_base`,PicoClaw 会自动使用 Groq 默认接口地址。
|
||||
- 如果你手动设置 Groq Whisper 的 `api_base`,下面两种写法都可以:
|
||||
- `https://api.groq.com/openai/v1`
|
||||
- `https://api.groq.com/openai/v1/audio/transcriptions`
|
||||
- 只要是 OpenAI 兼容、并且模型名里包含 `whisper` 的模型,都可以走 Whisper 转录路径,不仅限于 `whisper-large-v3-turbo`。
|
||||
|
||||
### 方案 B:ElevenLabs
|
||||
|
||||
`config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"voice": {
|
||||
"model_name": "elevenlabs-asr",
|
||||
"echo_transcription": true
|
||||
},
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "elevenlabs-asr",
|
||||
"model": "elevenlabs/scribe_v1"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`.security.yml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
elevenlabs-asr:
|
||||
api_keys:
|
||||
- "sk-elevenlabs-your-key"
|
||||
```
|
||||
|
||||
### 方案 C:OpenAI Whisper
|
||||
|
||||
`config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"voice": {
|
||||
"model_name": "openai-asr"
|
||||
},
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "openai-asr",
|
||||
"model": "openai/whisper-1"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`.security.yml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
openai-asr:
|
||||
api_keys:
|
||||
- "sk-openai-your-key"
|
||||
```
|
||||
|
||||
## 其他支持 ASR 的模型类型
|
||||
|
||||
PicoClaw 目前主要支持三种 ASR 路径:
|
||||
|
||||
| 路径 | 示例模型 | 行为说明 |
|
||||
| --- | --- | --- |
|
||||
| ElevenLabs ASR | `elevenlabs/scribe_v1` | 使用 ElevenLabs 的语音转录接口。 |
|
||||
| Whisper 接口模型 | `openai/whisper-1`、`groq/whisper-large-v3` | 使用 OpenAI 兼容的 `/audio/transcriptions` 接口。 |
|
||||
| 支持音频的聊天模型 **(重构中)** | `openai/gpt-4o-audio-preview`、`gemini/gemini-2.5-flash` | 把音频发给多模态聊天模型,并要求它返回转录结果。 |
|
||||
|
||||
如果你不确定该选哪种,建议优先使用 Groq Whisper 或 ElevenLabs。
|
||||
|
||||
## PicoClaw 如何选择转录器
|
||||
|
||||
`DetectTranscriber` 会按下面顺序选择 ASR:
|
||||
|
||||
1. **首选路径**:根据 `voice.model_name` 在 `model_list` 中找到对应模型。
|
||||
2. 如果找到的模型属于以下类型:
|
||||
- `elevenlabs/...`,则使用 ElevenLabs transcriber。
|
||||
- OpenAI 兼容的 Whisper 模型,则使用 Whisper transcriber。
|
||||
- 支持音频输入的聊天模型,则使用 `AudioModelTranscriber`。
|
||||
3. **回退路径**:如果没有设置 `voice.model_name`,PicoClaw 会为了兼容旧配置,扫描 `model_list` 中可自动识别的 ASR 条目。
|
||||
|
||||
回退扫描只是为了兼容旧行为。新配置建议始终显式设置 `voice.model_name`。
|
||||
|
||||
## 常见错误
|
||||
|
||||
- 在 `model_list` 里定义了 ASR 模型,但忘了设置 `voice.model_name`。
|
||||
- 把 API Key 写进了 `voice`,而不是 `.security.yml`。
|
||||
- 选择了不支持 ASR 的模型,却期望得到 Whisper 风格的转录结果。
|
||||
- 自定义了错误的 `api_base`,导致请求打到错误的接口地址。
|
||||
|
||||
## 最小检查清单
|
||||
|
||||
在测试语音输入前,请确认:
|
||||
|
||||
- `voice.model_name` 能正确匹配某个 `model_list[].model_name`。
|
||||
- `.security.yml` 中对应条目已经配置了有效 API Key。
|
||||
- 你选择的模型确实支持 ASR。
|
||||
- 你当前使用的频道已经启用了语音输入能力。
|
||||
@@ -0,0 +1,252 @@
|
||||
package asr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
"github.com/pion/webrtc/v3/pkg/media/oggwriter"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
type speechAccumulator struct {
|
||||
writer *oggwriter.OggWriter
|
||||
file string
|
||||
lastAudioAt time.Time
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
chatID string
|
||||
speakerID string
|
||||
sessionID string
|
||||
channel string
|
||||
}
|
||||
|
||||
func (a *speechAccumulator) Push(chunk bus.AudioChunk) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if a.closed {
|
||||
return
|
||||
}
|
||||
|
||||
a.lastAudioAt = time.Now()
|
||||
|
||||
pkt := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: uint16(chunk.Sequence),
|
||||
Timestamp: chunk.Timestamp,
|
||||
SSRC: 1, // Stable arbitrary dummy
|
||||
},
|
||||
Payload: chunk.Data,
|
||||
}
|
||||
|
||||
if err := a.writer.WriteRTP(pkt); err != nil {
|
||||
logger.ErrorCF("voice-agent", "Failed to write RTP", map[string]any{"error": err})
|
||||
}
|
||||
}
|
||||
|
||||
func (a *speechAccumulator) Close() {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if !a.closed {
|
||||
a.writer.Close()
|
||||
a.closed = true
|
||||
}
|
||||
}
|
||||
|
||||
type Agent struct {
|
||||
bus *bus.MessageBus
|
||||
transcriber Transcriber
|
||||
|
||||
mu sync.Mutex
|
||||
sessions map[string]*speechAccumulator // keyed by sessionID_speakerID
|
||||
}
|
||||
|
||||
func NewAgent(mb *bus.MessageBus, t Transcriber) *Agent {
|
||||
return &Agent{
|
||||
bus: mb,
|
||||
transcriber: t,
|
||||
sessions: make(map[string]*speechAccumulator),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) Start(ctx context.Context) {
|
||||
logger.InfoCF("voice-agent", "Started Voice Agent orchestrator", nil)
|
||||
go a.listenChunks(ctx)
|
||||
go a.vadTick(ctx)
|
||||
|
||||
// Cleanup sessions on shutdown
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
a.mu.Lock()
|
||||
for key, acc := range a.sessions {
|
||||
acc.Close()
|
||||
os.Remove(acc.file)
|
||||
delete(a.sessions, key)
|
||||
}
|
||||
a.mu.Unlock()
|
||||
logger.InfoCF("voice-agent", "Cleaned up voice sessions on shutdown", nil)
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Agent) listenChunks(ctx context.Context) {
|
||||
chunks := a.bus.AudioChunksChan()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case chunk, ok := <-chunks:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
a.handleChunk(chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) handleChunk(chunk bus.AudioChunk) {
|
||||
// Only accept Opus-encoded audio
|
||||
if chunk.Format != "opus" {
|
||||
logger.DebugCF("voice-agent", "Ignoring unsupported audio format", map[string]any{"format": chunk.Format})
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s_%s", chunk.SessionID, chunk.SpeakerID)
|
||||
|
||||
a.mu.Lock()
|
||||
acc, exists := a.sessions[key]
|
||||
if !exists {
|
||||
filename := filepath.Join(os.TempDir(), fmt.Sprintf("voice_%s_%d.ogg", key, time.Now().UnixNano()))
|
||||
writer, err := oggwriter.New(filename, uint32(chunk.SampleRate), uint16(chunk.Channels))
|
||||
if err != nil {
|
||||
a.mu.Unlock()
|
||||
logger.ErrorCF("voice-agent", "Failed to create OggWriter", map[string]any{"error": err})
|
||||
return
|
||||
}
|
||||
|
||||
acc = &speechAccumulator{
|
||||
writer: writer,
|
||||
file: filename,
|
||||
lastAudioAt: time.Now(),
|
||||
chatID: chunk.ChatID,
|
||||
speakerID: chunk.SpeakerID,
|
||||
sessionID: chunk.SessionID,
|
||||
channel: chunk.Channel,
|
||||
}
|
||||
a.sessions[key] = acc
|
||||
logger.DebugCF("voice-agent", "Started accumulating voice", map[string]any{"key": key, "file": filename})
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
acc.Push(chunk)
|
||||
}
|
||||
|
||||
func (a *Agent) vadTick(ctx context.Context) {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
a.checkSilence(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) checkSilence(ctx context.Context) {
|
||||
a.mu.Lock()
|
||||
now := time.Now()
|
||||
var finished []*speechAccumulator
|
||||
|
||||
for key, acc := range a.sessions {
|
||||
acc.mu.Lock()
|
||||
last := acc.lastAudioAt
|
||||
acc.mu.Unlock()
|
||||
|
||||
if now.Sub(last) > 1500*time.Millisecond {
|
||||
acc.Close()
|
||||
delete(a.sessions, key)
|
||||
finished = append(finished, acc)
|
||||
}
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
for _, acc := range finished {
|
||||
go a.processUtterance(ctx, acc)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
|
||||
defer os.Remove(acc.file)
|
||||
|
||||
logger.InfoCF("voice-agent", "User finished speaking, transcribing...", map[string]any{"file": acc.file})
|
||||
|
||||
if a.transcriber == nil {
|
||||
logger.ErrorCF("voice-agent", "No STT configured!", nil)
|
||||
return
|
||||
}
|
||||
|
||||
res, err := a.transcriber.Transcribe(ctx, acc.file)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice-agent", "Transcription failed", map[string]any{"error": err})
|
||||
return
|
||||
}
|
||||
|
||||
if res.Text == "" {
|
||||
logger.DebugCF("voice-agent", "Ignored empty transcription", map[string]any{"file": acc.file})
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoCF("voice-agent", "Transcription result", map[string]any{"text": res.Text, "duration": res.Duration})
|
||||
|
||||
channelType := acc.channel
|
||||
if channelType == "" {
|
||||
channelType = "discord" // fallback for legacy chunks
|
||||
}
|
||||
|
||||
text := strings.ToLower(strings.TrimSpace(res.Text))
|
||||
if strings.Contains(text, "leave the voice channel") || strings.Contains(text, "leave voice") ||
|
||||
strings.Contains(text, "disconnect voice") || strings.Contains(text, "leave the channel") ||
|
||||
strings.Contains(text, "leave channel") {
|
||||
logger.InfoCF("voice-agent", "Voice command triggered: leave", nil)
|
||||
if err := a.bus.PublishVoiceControl(ctx, bus.VoiceControl{
|
||||
SessionID: acc.sessionID,
|
||||
Type: "command",
|
||||
Action: "leave",
|
||||
}); err != nil {
|
||||
logger.ErrorCF("voice-agent", "Failed to publish leave control", map[string]any{"error": err})
|
||||
}
|
||||
if err := a.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: channelType,
|
||||
ChatID: acc.chatID,
|
||||
Content: "Goodbye! Leaving the voice channel.",
|
||||
}); err != nil {
|
||||
logger.ErrorCF("voice-agent", "Failed to publish goodbye message", map[string]any{"error": err})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
oralPrompt := "\n\n[SYSTEM]: The user just spoke this to you over voice chat. Please reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally."
|
||||
|
||||
if err := a.bus.PublishInbound(ctx, bus.InboundMessage{
|
||||
Channel: channelType,
|
||||
SenderID: acc.speakerID,
|
||||
ChatID: acc.chatID,
|
||||
Content: res.Text + oralPrompt,
|
||||
Peer: bus.Peer{Kind: "channel", ID: acc.chatID},
|
||||
Metadata: map[string]string{
|
||||
"is_voice": "true",
|
||||
},
|
||||
}); err != nil {
|
||||
logger.ErrorCF("voice-agent", "Failed to publish inbound message", map[string]any{"error": err})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
package asr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/webrtc/v3/pkg/media/oggwriter"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
)
|
||||
|
||||
type fakeTranscriber struct {
|
||||
text string
|
||||
err error
|
||||
lastPath string
|
||||
}
|
||||
|
||||
func (f *fakeTranscriber) Name() string { return "fake" }
|
||||
|
||||
func (f *fakeTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
|
||||
f.lastPath = audioFilePath
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
return &TranscriptionResponse{Text: f.text}, nil
|
||||
}
|
||||
|
||||
func waitForFileRemoval(t *testing.T, path string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
t.Fatalf("expected file to be removed: %s", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentHandleChunkCreatesSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mb := bus.NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
agent := NewAgent(mb, &fakeTranscriber{})
|
||||
|
||||
chunk := bus.AudioChunk{
|
||||
SessionID: "sess",
|
||||
SpeakerID: "speaker",
|
||||
ChatID: "chat",
|
||||
Channel: "discord",
|
||||
Sequence: 1,
|
||||
Timestamp: 1,
|
||||
SampleRate: 48000,
|
||||
Channels: 2,
|
||||
Format: "opus",
|
||||
Data: []byte{0xF8, 0xFF, 0xFE},
|
||||
}
|
||||
|
||||
agent.handleChunk(chunk)
|
||||
|
||||
key := "sess_speaker"
|
||||
agent.mu.Lock()
|
||||
acc, ok := agent.sessions[key]
|
||||
agent.mu.Unlock()
|
||||
if !ok {
|
||||
t.Fatal("expected session to be created")
|
||||
}
|
||||
|
||||
acc.Close()
|
||||
_ = os.Remove(acc.file)
|
||||
}
|
||||
|
||||
func TestAgentHandleChunkIgnoresUnsupportedFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mb := bus.NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
agent := NewAgent(mb, &fakeTranscriber{})
|
||||
|
||||
chunk := bus.AudioChunk{Format: "pcm"}
|
||||
agent.handleChunk(chunk)
|
||||
|
||||
agent.mu.Lock()
|
||||
count := len(agent.sessions)
|
||||
agent.mu.Unlock()
|
||||
if count != 0 {
|
||||
t.Fatalf("expected no sessions, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentProcessUtteranceLeaveCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mb := bus.NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
tr := &fakeTranscriber{text: "please leave the voice channel now"}
|
||||
agent := NewAgent(mb, tr)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
filePath := filepath.Join(tmpDir, "voice.ogg")
|
||||
if err := os.WriteFile(filePath, []byte("data"), 0o600); err != nil {
|
||||
t.Fatalf("write temp file: %v", err)
|
||||
}
|
||||
|
||||
acc := &speechAccumulator{
|
||||
file: filePath,
|
||||
chatID: "chat",
|
||||
speakerID: "speaker",
|
||||
sessionID: "sess",
|
||||
channel: "discord",
|
||||
}
|
||||
|
||||
agent.processUtterance(context.Background(), acc)
|
||||
|
||||
select {
|
||||
case ctrl := <-mb.VoiceControlsChan():
|
||||
if ctrl.Action != "leave" || ctrl.Type != "command" || ctrl.SessionID != "sess" {
|
||||
t.Fatalf("unexpected voice control: %#v", ctrl)
|
||||
}
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
t.Fatal("expected voice control publish")
|
||||
}
|
||||
|
||||
select {
|
||||
case out := <-mb.OutboundChan():
|
||||
if !strings.Contains(out.Content, "Leaving the voice channel") {
|
||||
t.Fatalf("unexpected outbound content: %q", out.Content)
|
||||
}
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
t.Fatal("expected outbound publish")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filePath); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected temp file to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentCheckSilencePublishesInboundAndCleansUp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mb := bus.NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
tr := &fakeTranscriber{text: "hello there"}
|
||||
agent := NewAgent(mb, tr)
|
||||
|
||||
filePath := filepath.Join(t.TempDir(), "voice.ogg")
|
||||
writer, err := oggwriter.New(filePath, 48000, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("create ogg writer: %v", err)
|
||||
}
|
||||
|
||||
acc := &speechAccumulator{
|
||||
writer: writer,
|
||||
file: filePath,
|
||||
lastAudioAt: time.Now().Add(-2 * time.Second),
|
||||
chatID: "chat",
|
||||
speakerID: "speaker",
|
||||
sessionID: "sess",
|
||||
channel: "slack",
|
||||
}
|
||||
|
||||
agent.mu.Lock()
|
||||
agent.sessions["sess_speaker"] = acc
|
||||
agent.mu.Unlock()
|
||||
|
||||
agent.checkSilence(context.Background())
|
||||
|
||||
select {
|
||||
case msg := <-mb.InboundChan():
|
||||
if msg.Channel != "slack" {
|
||||
t.Fatalf("unexpected inbound channel: %q", msg.Channel)
|
||||
}
|
||||
if !strings.Contains(msg.Content, "hello there") {
|
||||
t.Fatalf("unexpected inbound content: %q", msg.Content)
|
||||
}
|
||||
if msg.Metadata["is_voice"] != "true" {
|
||||
t.Fatalf("expected is_voice metadata, got %#v", msg.Metadata)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("expected inbound publish")
|
||||
}
|
||||
|
||||
waitForFileRemoval(t, filePath, 500*time.Millisecond)
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package asr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type Transcriber interface {
|
||||
Name() string
|
||||
Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
|
||||
}
|
||||
|
||||
type TranscriptionResponse struct {
|
||||
Text string `json:"text"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Duration float64 `json:"duration,omitempty"`
|
||||
}
|
||||
|
||||
func supportsAudioTranscription(model string) bool {
|
||||
protocol, _ := providers.ExtractProtocol(model)
|
||||
|
||||
switch protocol {
|
||||
case "openai", "azure", "azure-openai",
|
||||
"litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
||||
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
|
||||
"coding-plan", "alibaba-coding", "qwen-coding":
|
||||
// These protocols all go through the OpenAI-compatible or Azure provider path in
|
||||
// providers.CreateProviderFromConfig, so they are the only ones that can supply
|
||||
// the audio media payload shape expected by NewAudioModelTranscriber.
|
||||
|
||||
// TODO: Further restrict this by modelID, since not every model under these
|
||||
// protocols supports audio transcription.
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func supportsWhisperTranscription(model string) bool {
|
||||
protocol, _ := providers.ExtractProtocol(model)
|
||||
|
||||
switch protocol {
|
||||
case "openai", "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
||||
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
|
||||
"coding-plan", "alibaba-coding", "qwen-coding", "mimo":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func whisperModelID(modelCfg *config.ModelConfig) string {
|
||||
if modelCfg == nil || modelCfg.APIKey() == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if !supportsWhisperTranscription(modelCfg.Model) {
|
||||
return ""
|
||||
}
|
||||
|
||||
_, modelID := providers.ExtractProtocol(strings.TrimSpace(modelCfg.Model))
|
||||
if strings.Contains(strings.ToLower(modelID), "whisper") {
|
||||
return modelID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func transcriberFromModelConfig(modelCfg *config.ModelConfig) Transcriber {
|
||||
if modelCfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
protocol, _ := providers.ExtractProtocol(modelCfg.Model)
|
||||
if protocol == "elevenlabs" && modelCfg.APIKey() != "" {
|
||||
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase)
|
||||
}
|
||||
if modelID := whisperModelID(modelCfg); modelID != "" {
|
||||
return NewWhisperTranscriber(modelCfg)
|
||||
}
|
||||
if supportsAudioTranscription(modelCfg.Model) {
|
||||
return NewAudioModelTranscriber(modelCfg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fallbackTranscriberFromModelConfig(modelCfg *config.ModelConfig) Transcriber {
|
||||
if modelCfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
protocol, _ := providers.ExtractProtocol(modelCfg.Model)
|
||||
if protocol == "elevenlabs" && modelCfg.APIKey() != "" {
|
||||
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase)
|
||||
}
|
||||
if modelID := whisperModelID(modelCfg); modelID != "" {
|
||||
return NewWhisperTranscriber(modelCfg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or
|
||||
// nil if no supported transcription provider is configured.
|
||||
func DetectTranscriber(cfg *config.Config) Transcriber {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" {
|
||||
modelCfg, err := cfg.GetModelConfig(modelName)
|
||||
if err == nil {
|
||||
if tr := transcriberFromModelConfig(modelCfg); tr != nil {
|
||||
return tr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to compatibility scanning for legacy auto-detected ASR providers.
|
||||
for _, mc := range cfg.ModelList {
|
||||
if tr := fallbackTranscriberFromModelConfig(mc); tr != nil {
|
||||
return tr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package voice
|
||||
package asr
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -33,26 +33,68 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
wantName: "audio-model",
|
||||
},
|
||||
{
|
||||
name: "groq via model list",
|
||||
name: "voice model name alias selects elevenlabs transcriber",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ModelName: "my-asr-model"},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "my-asr-model",
|
||||
Model: "elevenlabs/scribe_v1",
|
||||
APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantName: "elevenlabs",
|
||||
},
|
||||
{
|
||||
name: "voice model name alias selects whisper transcriber for groq",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ModelName: "my-asr-model"},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "my-asr-model",
|
||||
Model: "groq/whisper-large-v3",
|
||||
APIKeys: config.SimpleSecureStrings("sk-groq-model"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantName: "whisper",
|
||||
},
|
||||
{
|
||||
name: "openai whisper alias selects whisper transcriber",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ModelName: "my-asr-model"},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "my-asr-model",
|
||||
Model: "openai/whisper-1",
|
||||
APIKeys: config.SimpleSecureStrings("sk-openai-model"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantName: "whisper",
|
||||
},
|
||||
{
|
||||
name: "whisper via model list fallback",
|
||||
cfg: &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "openai", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("sk-openai")},
|
||||
{
|
||||
ModelName: "groq",
|
||||
Model: "groq/llama-3.3-70b",
|
||||
Model: "groq/whisper-large-v3-turbo",
|
||||
APIKeys: config.SimpleSecureStrings("sk-groq-model"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantName: "groq",
|
||||
wantName: "whisper",
|
||||
},
|
||||
{
|
||||
name: "voice model name selects non-gemini audio model transcriber",
|
||||
name: "voice model name alias selects non-gemini audio model transcriber",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ModelName: "voice-openai-audio"},
|
||||
Voice: config.VoiceConfig{ModelName: "my-asr-model"},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "voice-openai-audio",
|
||||
ModelName: "my-asr-model",
|
||||
Model: "openai/gpt-4o-audio-preview",
|
||||
APIKeys: config.SimpleSecureStrings("sk-openai"),
|
||||
},
|
||||
@@ -92,7 +134,7 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
name: "groq model list entry without key is skipped",
|
||||
cfg: &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{Model: "groq/llama-3.3-70b"},
|
||||
{Model: "groq/whisper-large-v3"},
|
||||
},
|
||||
},
|
||||
wantNil: true,
|
||||
@@ -103,12 +145,12 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "groq",
|
||||
Model: "groq/llama-3.3-70b",
|
||||
Model: "groq/whisper-large-v3",
|
||||
APIKeys: config.SimpleSecureStrings("sk-groq-model"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantName: "groq",
|
||||
wantName: "whisper",
|
||||
},
|
||||
{
|
||||
name: "missing voice model name config returns nil",
|
||||
@@ -127,15 +169,17 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
{
|
||||
name: "elevenlabs voice config key",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ElevenLabsAPIKey: "sk_elevenlabs_test"},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{Model: "elevenlabs/scribe_v1", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")},
|
||||
},
|
||||
},
|
||||
wantName: "elevenlabs",
|
||||
},
|
||||
{
|
||||
name: "elevenlabs takes priority over groq model list",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ElevenLabsAPIKey: "sk_elevenlabs_test"},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{Model: "elevenlabs/scribe_v1", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")},
|
||||
{
|
||||
ModelName: "groq",
|
||||
Model: "groq/llama-3.3-70b",
|
||||
@@ -149,10 +193,10 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
name: "voice model name takes priority over elevenlabs",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{
|
||||
ModelName: "voice-gemini",
|
||||
ElevenLabsAPIKey: "sk_elevenlabs_test",
|
||||
ModelName: "voice-gemini",
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{Model: "elevenlabs", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")},
|
||||
{
|
||||
ModelName: "voice-gemini",
|
||||
Model: "gemini/gemini-2.5-flash",
|
||||
@@ -1,4 +1,4 @@
|
||||
package voice
|
||||
package asr
|
||||
|
||||
import (
|
||||
"context"
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
package voice
|
||||
package asr
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package voice
|
||||
package asr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -23,12 +23,16 @@ type ElevenLabsTranscriber struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewElevenLabsTranscriber(apiKey string) *ElevenLabsTranscriber {
|
||||
func NewElevenLabsTranscriber(apiKey, apiBase string) *ElevenLabsTranscriber {
|
||||
logger.DebugCF("voice", "Creating ElevenLabs transcriber", map[string]any{"has_api_key": apiKey != ""})
|
||||
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.elevenlabs.io"
|
||||
}
|
||||
|
||||
return &ElevenLabsTranscriber{
|
||||
apiKey: apiKey,
|
||||
apiBase: "https://api.elevenlabs.io",
|
||||
apiBase: apiBase,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
},
|
||||
+5
-5
@@ -1,4 +1,4 @@
|
||||
package voice
|
||||
package asr
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
var _ Transcriber = (*ElevenLabsTranscriber)(nil)
|
||||
|
||||
func TestElevenLabsTranscriberName(t *testing.T) {
|
||||
tr := NewElevenLabsTranscriber("sk_test")
|
||||
tr := NewElevenLabsTranscriber("sk_test", "")
|
||||
if got := tr.Name(); got != "elevenlabs" {
|
||||
t.Errorf("Name() = %q, want %q", got, "elevenlabs")
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewElevenLabsTranscriber("sk_test")
|
||||
tr := NewElevenLabsTranscriber("sk_test", "")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
resp, err := tr.Transcribe(context.Background(), audioPath)
|
||||
@@ -64,7 +64,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewElevenLabsTranscriber("sk_bad")
|
||||
tr := NewElevenLabsTranscriber("sk_bad", "")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
_, err := tr.Transcribe(context.Background(), audioPath)
|
||||
@@ -74,7 +74,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("missing file", func(t *testing.T) {
|
||||
tr := NewElevenLabsTranscriber("sk_test")
|
||||
tr := NewElevenLabsTranscriber("sk_test", "")
|
||||
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
@@ -0,0 +1,245 @@
|
||||
package asr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type WhisperTranscriber struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
modelID string
|
||||
providerName string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewWhisperTranscriber(modelCfg *config.ModelConfig) *WhisperTranscriber {
|
||||
if modelCfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
protocol, modelID := providers.ExtractProtocol(modelCfg.Model)
|
||||
if modelID == "" {
|
||||
modelID = strings.TrimSpace(modelCfg.Model)
|
||||
}
|
||||
|
||||
tr := newWhisperTranscriber(
|
||||
modelCfg.APIKey(),
|
||||
providers.ResolveAPIBase(modelCfg),
|
||||
modelID,
|
||||
protocol,
|
||||
)
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "Creating whisper transcriber", map[string]any{
|
||||
"api_base": tr.apiBase,
|
||||
"has_key": tr.apiKey != "",
|
||||
"model": tr.modelID,
|
||||
"provider": tr.providerName,
|
||||
})
|
||||
return tr
|
||||
}
|
||||
|
||||
func NewGroqTranscriber(apiKey, modelID string) *WhisperTranscriber {
|
||||
return newWhisperTranscriber(apiKey, "https://api.groq.com/openai/v1", modelID, "groq")
|
||||
}
|
||||
|
||||
func newWhisperTranscriber(apiKey, apiBase, modelID, providerName string) *WhisperTranscriber {
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
if providerName == "" {
|
||||
providerName = "whisper"
|
||||
}
|
||||
return &WhisperTranscriber{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
modelID: modelID,
|
||||
providerName: providerName,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WhisperTranscriber) transcriptionURL() string {
|
||||
base := strings.TrimRight(t.apiBase, "/")
|
||||
if strings.HasSuffix(base, "/audio/transcriptions") {
|
||||
return base
|
||||
}
|
||||
return base + "/audio/transcriptions"
|
||||
}
|
||||
|
||||
func (t *WhisperTranscriber) TranscribeData(
|
||||
ctx context.Context,
|
||||
data []byte,
|
||||
filename string,
|
||||
) (*TranscriptionResponse, error) {
|
||||
logger.InfoCF("voice", "Starting whisper transcription from memory", map[string]any{
|
||||
"bytes": len(data),
|
||||
"filename": filename,
|
||||
"model": t.modelID,
|
||||
"provider": t.providerName,
|
||||
})
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to create whisper form file", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to create form file: %w", err)
|
||||
}
|
||||
|
||||
if _, copyErr := io.Copy(part, bytes.NewReader(data)); copyErr != nil {
|
||||
logger.ErrorCF("voice", "Failed to copy whisper file content", map[string]any{"error": copyErr})
|
||||
return nil, fmt.Errorf("failed to copy file content: %w", copyErr)
|
||||
}
|
||||
|
||||
if err = writer.WriteField("model", t.modelID); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to write whisper model field", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to write model field: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.WriteField("response_format", "json"); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to write whisper response_format field", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to write response_format field: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.Close(); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to close whisper multipart writer", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
|
||||
}
|
||||
|
||||
return t.doRequest(ctx, &requestBody, writer.FormDataContentType(), int64(len(data)))
|
||||
}
|
||||
|
||||
func (t *WhisperTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
|
||||
logger.InfoCF("voice", "Starting whisper transcription", map[string]any{
|
||||
"audio_file": audioFilePath,
|
||||
"model": t.modelID,
|
||||
"provider": t.providerName,
|
||||
})
|
||||
|
||||
audioFile, err := os.Open(audioFilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open audio file %s: %w", audioFilePath, err)
|
||||
}
|
||||
defer audioFile.Close()
|
||||
|
||||
fileInfo, err := audioFile.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to stat audio file %s: %w", audioFilePath, err)
|
||||
}
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create form file: %w", err)
|
||||
}
|
||||
|
||||
if _, copyErr := io.Copy(part, audioFile); copyErr != nil {
|
||||
return nil, fmt.Errorf("failed to copy audio data: %w", copyErr)
|
||||
}
|
||||
|
||||
if err = writer.WriteField("model", t.modelID); err != nil {
|
||||
return nil, fmt.Errorf("failed to write model field: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.WriteField("response_format", "json"); err != nil {
|
||||
return nil, fmt.Errorf("failed to write response_format field: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
|
||||
}
|
||||
|
||||
return t.doRequest(ctx, &requestBody, writer.FormDataContentType(), fileInfo.Size())
|
||||
}
|
||||
|
||||
func (t *WhisperTranscriber) doRequest(
|
||||
ctx context.Context,
|
||||
requestBody *bytes.Buffer,
|
||||
contentType string,
|
||||
fileSize int64,
|
||||
) (*TranscriptionResponse, error) {
|
||||
url := t.transcriptionURL()
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, requestBody)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to create whisper request", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
if t.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+t.apiKey)
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "Sending whisper transcription request", map[string]any{
|
||||
"file_size_bytes": fileSize,
|
||||
"model": t.modelID,
|
||||
"provider": t.providerName,
|
||||
"request_size_bytes": requestBody.Len(),
|
||||
"url": url,
|
||||
})
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to send whisper request", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to read whisper response", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.ErrorCF("voice", "Whisper API error", map[string]any{
|
||||
"provider": t.providerName,
|
||||
"response": string(body),
|
||||
"status_code": resp.StatusCode,
|
||||
})
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result TranscriptionResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to unmarshal whisper response", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
logger.InfoCF("voice", "Whisper transcription completed successfully", map[string]any{
|
||||
"duration_seconds": result.Duration,
|
||||
"language": result.Language,
|
||||
"provider": t.providerName,
|
||||
"text_length": len(result.Text),
|
||||
"transcription_preview": utils.Truncate(result.Text, 50),
|
||||
})
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (t *WhisperTranscriber) Name() string {
|
||||
return "whisper"
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package asr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestWhisperTranscriberTranscribeDataUsesConfiguredModel(t *testing.T) {
|
||||
var gotModel string
|
||||
var gotPath string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer sk-openai-test" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-openai-test")
|
||||
}
|
||||
|
||||
reader, err := r.MultipartReader()
|
||||
if err != nil {
|
||||
t.Fatalf("MultipartReader() error: %v", err)
|
||||
}
|
||||
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NextPart() error: %v", err)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(part)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll() error: %v", err)
|
||||
}
|
||||
|
||||
if part.FormName() == "model" {
|
||||
gotModel = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(TranscriptionResponse{Text: "hello from whisper"}); err != nil {
|
||||
t.Fatalf("Encode() error: %v", err)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tr := NewWhisperTranscriber(&config.ModelConfig{
|
||||
Model: "openai/whisper-1",
|
||||
APIBase: server.URL,
|
||||
APIKeys: config.SimpleSecureStrings("sk-openai-test"),
|
||||
})
|
||||
tr.httpClient = server.Client()
|
||||
|
||||
resp, err := tr.TranscribeData(context.Background(), []byte("audio"), "clip.ogg")
|
||||
if err != nil {
|
||||
t.Fatalf("TranscribeData() error: %v", err)
|
||||
}
|
||||
if resp.Text != "hello from whisper" {
|
||||
t.Errorf("Text = %q, want %q", resp.Text, "hello from whisper")
|
||||
}
|
||||
if gotModel != "whisper-1" {
|
||||
t.Errorf("model field = %q, want %q", gotModel, "whisper-1")
|
||||
}
|
||||
if gotPath != "/audio/transcriptions" {
|
||||
t.Errorf("path = %q, want %q", gotPath, "/audio/transcriptions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhisperTranscriberUsesEndpointAPIBaseWithoutDoubleAppend(t *testing.T) {
|
||||
var gotPath string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(TranscriptionResponse{Text: "ok"}); err != nil {
|
||||
t.Fatalf("Encode() error: %v", err)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tr := NewWhisperTranscriber(&config.ModelConfig{
|
||||
Model: "groq/whisper-large-v3",
|
||||
APIBase: server.URL + "/audio/transcriptions",
|
||||
APIKeys: config.SimpleSecureStrings("sk-groq-test"),
|
||||
})
|
||||
tr.httpClient = server.Client()
|
||||
|
||||
if _, err := tr.TranscribeData(context.Background(), []byte("audio"), "clip.ogg"); err != nil {
|
||||
t.Fatalf("TranscribeData() error: %v", err)
|
||||
}
|
||||
if gotPath != "/audio/transcriptions" {
|
||||
t.Errorf("path = %q, want %q", gotPath, "/audio/transcriptions")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package audio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// DecodeOggOpus reads an Ogg format stream and extracts individual Opus payloads.
|
||||
// It calls onFrame for every complete Opus frame found in the stream.
|
||||
func DecodeOggOpus(r io.Reader, onFrame func([]byte) error) error {
|
||||
var packet bytes.Buffer
|
||||
header := make([]byte, 27)
|
||||
segment := make([]byte, 255)
|
||||
|
||||
for {
|
||||
if _, err := io.ReadFull(r, header); err != nil {
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to read ogg header: %w", err)
|
||||
}
|
||||
if string(header[:4]) != "OggS" {
|
||||
return fmt.Errorf("invalid ogg magic string")
|
||||
}
|
||||
|
||||
pageSegments := int(header[26])
|
||||
segmentTable := make([]byte, pageSegments)
|
||||
if _, err := io.ReadFull(r, segmentTable); err != nil {
|
||||
return fmt.Errorf("failed to read segment table: %w", err)
|
||||
}
|
||||
|
||||
for _, lacing := range segmentTable {
|
||||
if _, err := io.ReadFull(r, segment[:lacing]); err != nil {
|
||||
return fmt.Errorf("failed to read segment data: %w", err)
|
||||
}
|
||||
|
||||
packet.Write(segment[:lacing])
|
||||
|
||||
// If lacing is less than 255, the packet is complete
|
||||
if lacing < 255 {
|
||||
if packet.Len() > 0 {
|
||||
packetBytes := packet.Bytes()
|
||||
// Ignore Ogg Opus headers
|
||||
if !bytes.HasPrefix(packetBytes, []byte("OpusHead")) &&
|
||||
!bytes.HasPrefix(packetBytes, []byte("OpusTags")) {
|
||||
if err := onFrame(packetBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Start new packet
|
||||
packet.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package audio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// buildOggPage helper creates an Ogg page for testing.
|
||||
// lacingVals specifies the segment table, and data is the payload.
|
||||
func buildOggPage(lacingVals []byte, data []byte) []byte {
|
||||
var buf bytes.Buffer
|
||||
// 27-byte Ogg header
|
||||
header := make([]byte, 27)
|
||||
copy(header[:4], "OggS")
|
||||
header[5] = 0 // type flag
|
||||
// For testing, we only care about OggS magic and page_segments (byte 26)
|
||||
header[26] = byte(len(lacingVals))
|
||||
buf.Write(header)
|
||||
buf.Write(lacingVals)
|
||||
buf.Write(data)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestDecodeOggOpus_ValidParsing(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
|
||||
// Packet 1: Single segment, length 50
|
||||
pkt1 := bytes.Repeat([]byte{1}, 50)
|
||||
// Packet 2: Multi-segment (255 + 10 = 265 bytes)
|
||||
pkt2Part1 := bytes.Repeat([]byte{2}, 255)
|
||||
pkt2Part2 := bytes.Repeat([]byte{2}, 10)
|
||||
// Packet 3: Continued across pages. Page 1 gets 255, Page 2 gets 20. Total 275 bytes.
|
||||
pkt3Part1 := bytes.Repeat([]byte{3}, 255)
|
||||
pkt3Part2 := bytes.Repeat([]byte{3}, 20)
|
||||
|
||||
// Page 1: OpusHead (skip), OpusTags (skip), pkt1, pkt2, pkt3Part1
|
||||
page1Lacing := []byte{8, 8, 50, 255, 10, 255}
|
||||
page1Data := bytes.Join([][]byte{
|
||||
[]byte("OpusHead"),
|
||||
[]byte("OpusTags"),
|
||||
pkt1,
|
||||
pkt2Part1, pkt2Part2,
|
||||
pkt3Part1,
|
||||
}, nil)
|
||||
|
||||
// Page 2: pkt3Part2, pkt4 (length 10)
|
||||
pkt4 := bytes.Repeat([]byte{4}, 10)
|
||||
page2Lacing := []byte{20, 10}
|
||||
page2Data := bytes.Join([][]byte{
|
||||
pkt3Part2,
|
||||
pkt4,
|
||||
}, nil)
|
||||
|
||||
b.Write(buildOggPage(page1Lacing, page1Data))
|
||||
b.Write(buildOggPage(page2Lacing, page2Data))
|
||||
|
||||
var frames [][]byte
|
||||
err := DecodeOggOpus(&b, func(frame []byte) error {
|
||||
// making a copy to store as DecodeOggOpus might reuse backing array
|
||||
cpy := make([]byte, len(frame))
|
||||
copy(cpy, frame)
|
||||
frames = append(frames, cpy)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedFrames := [][]byte{
|
||||
pkt1,
|
||||
append(pkt2Part1, pkt2Part2...),
|
||||
append(pkt3Part1, pkt3Part2...),
|
||||
pkt4,
|
||||
}
|
||||
|
||||
if len(frames) != len(expectedFrames) {
|
||||
t.Fatalf("expected %d frames, got %d", len(expectedFrames), len(frames))
|
||||
}
|
||||
|
||||
for i, expected := range expectedFrames {
|
||||
if !reflect.DeepEqual(frames[i], expected) {
|
||||
t.Errorf("frame %d mismatch:\nexp: %v\ngot: %v", i, expected, frames[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeOggOpus_Errors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "invalid magic string",
|
||||
data: []byte(
|
||||
"OggX\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
),
|
||||
errContains: "invalid ogg magic string",
|
||||
},
|
||||
{
|
||||
name: "short header",
|
||||
data: []byte("Ogg"),
|
||||
errContains: "failed to read ogg header",
|
||||
},
|
||||
{
|
||||
name: "eof in segment table",
|
||||
data: func() []byte {
|
||||
h := make([]byte, 27)
|
||||
copy(h, "OggS")
|
||||
h[26] = 5 // expects 5 bytes of segment table, but none provided
|
||||
return h
|
||||
}(),
|
||||
errContains: "failed to read segment table",
|
||||
},
|
||||
{
|
||||
name: "eof in segment data",
|
||||
data: func() []byte {
|
||||
h := make([]byte, 27, 28)
|
||||
copy(h, "OggS")
|
||||
h[26] = 1
|
||||
return append(h, 100) // expects 100 bytes of data, but none provided
|
||||
}(),
|
||||
errContains: "failed to read segment data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := DecodeOggOpus(bytes.NewReader(tt.data), func(b []byte) error { return nil })
|
||||
if tt.name == "short header" {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error (io.EOF/ErrUnexpectedEOF swallowed), got %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.errContains)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("expected error to contain %q, got: %q", tt.errContains, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package audio
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// SplitSentences splits text into sentence-sized chunks suitable for TTS synthesis.
|
||||
// It splits on sentence-ending punctuation (.!?\n, as well as CJK 。, !, ?) while avoiding false splits
|
||||
// on decimal numbers. Very short fragments are merged with
|
||||
// the next sentence to prevent choppy playback.
|
||||
func SplitSentences(text string) []string {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var sentences []string
|
||||
var current strings.Builder
|
||||
runes := []rune(text)
|
||||
|
||||
for i := 0; i < len(runes); i++ {
|
||||
r := runes[i]
|
||||
if r == '\n' {
|
||||
s := strings.TrimSpace(current.String())
|
||||
if s != "" {
|
||||
sentences = append(sentences, s)
|
||||
}
|
||||
current.Reset()
|
||||
continue
|
||||
}
|
||||
|
||||
current.WriteRune(r)
|
||||
|
||||
if r == '.' || r == '!' || r == '?' || r == '。' || r == '!' || r == '?' {
|
||||
// Avoid splitting on decimal numbers like "3.14"
|
||||
if r == '.' && i > 0 && unicode.IsDigit(runes[i-1]) &&
|
||||
i+1 < len(runes) && unicode.IsDigit(runes[i+1]) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Consume contiguous punctuation clusters (e.g., "..." or "?!").
|
||||
for i+1 < len(runes) && (runes[i+1] == '.' || runes[i+1] == '!' || runes[i+1] == '?' || runes[i+1] == '。' || runes[i+1] == '!' || runes[i+1] == '?') {
|
||||
i++
|
||||
current.WriteRune(runes[i])
|
||||
}
|
||||
|
||||
s := strings.TrimSpace(current.String())
|
||||
if s != "" {
|
||||
sentences = append(sentences, s)
|
||||
}
|
||||
current.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining text
|
||||
if s := strings.TrimSpace(current.String()); s != "" {
|
||||
sentences = append(sentences, s)
|
||||
}
|
||||
|
||||
// Merge very short fragments with the next sentence
|
||||
return mergeShorties(sentences, 15)
|
||||
}
|
||||
|
||||
// mergeShorties merges sentences shorter than minLen characters with the following sentence.
|
||||
func mergeShorties(sentences []string, minLen int) []string {
|
||||
if len(sentences) <= 1 {
|
||||
return sentences
|
||||
}
|
||||
|
||||
var merged []string
|
||||
var buf string
|
||||
|
||||
for _, s := range sentences {
|
||||
if buf != "" {
|
||||
buf += " " + s
|
||||
if len([]rune(buf)) >= minLen {
|
||||
merged = append(merged, buf)
|
||||
buf = ""
|
||||
}
|
||||
} else if len([]rune(s)) < minLen {
|
||||
buf = s
|
||||
} else {
|
||||
merged = append(merged, s)
|
||||
}
|
||||
}
|
||||
|
||||
if buf != "" {
|
||||
if len(merged) > 0 {
|
||||
merged[len(merged)-1] += " " + buf
|
||||
} else {
|
||||
merged = append(merged, buf)
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package audio
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitSentences(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
in: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single sentence",
|
||||
in: "Hello world.",
|
||||
want: []string{"Hello world."},
|
||||
},
|
||||
{
|
||||
name: "decimal numbers do not split",
|
||||
in: "The value is 3.14 today. Keep watching closely.",
|
||||
want: []string{"The value is 3.14 today.", "Keep watching closely."},
|
||||
},
|
||||
{
|
||||
name: "newline boundary",
|
||||
in: "This is line number one\nThis is line number two",
|
||||
want: []string{"This is line number one", "This is line number two"},
|
||||
},
|
||||
{
|
||||
name: "newline with surrounding spaces",
|
||||
in: " This is the first line \n This is the second line ",
|
||||
want: []string{"This is the first line", "This is the second line"},
|
||||
},
|
||||
{
|
||||
name: "trailing punctuation consumed",
|
||||
in: "Please wait a moment... What on earth?! That is perfectly fine.",
|
||||
want: []string{"Please wait a moment...", "What on earth?!", "That is perfectly fine."},
|
||||
},
|
||||
{
|
||||
name: "short leading fragment merges with next",
|
||||
in: "Hi. This is a longer sentence.",
|
||||
want: []string{"Hi. This is a longer sentence."},
|
||||
},
|
||||
{
|
||||
name: "consecutive short fragments keep merging",
|
||||
in: "A. B. C. This is the real sentence.",
|
||||
want: []string{"A. B. C. This is the real sentence."},
|
||||
},
|
||||
{
|
||||
name: "short trailing fragment merges back",
|
||||
in: "This sentence is long enough. End.",
|
||||
want: []string{"This sentence is long enough. End."},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := SplitSentences(tc.in)
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("SplitSentences(%q) = %#v, want %#v", tc.in, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
# TTS (Text-to-Speech)
|
||||
|
||||
This package handles speech synthesis for PicoClaw.
|
||||
|
||||
If you are new to TTS setup, the simplest workflow is:
|
||||
|
||||
1. Add a TTS-capable entry to `model_list`.
|
||||
2. Point `voice.tts_model_name` at that entry.
|
||||
3. Put the API key in `.security.yml`.
|
||||
|
||||
## Quick Recommendation
|
||||
|
||||
For most users, these are the best starting points:
|
||||
|
||||
| Provider | Why start here |
|
||||
| --- | --- |
|
||||
| [OpenAI](https://platform.openai.com/docs/guides/text-to-speech) | Best-supported path in PicoClaw today. The current TTS implementation is built around the OpenAI-compatible `/audio/speech` API shape, and OpenAI is the safest default. |
|
||||
| [Xiaomi MiMo](https://platform.xiaomimimo.com) | A good second option if you want an OpenAI-compatible provider endpoint and are already using MiMo models in the rest of your stack. |
|
||||
|
||||
## How TTS Configuration Works
|
||||
|
||||
PicoClaw does not keep TTS API keys inside `voice`.
|
||||
|
||||
Instead:
|
||||
|
||||
- `voice.tts_model_name` selects a named entry from `model_list`.
|
||||
- That `model_list` entry provides the provider, model ID, API base, and proxy settings.
|
||||
- `.security.yml` stores the API key for the same named model entry.
|
||||
|
||||
This is the recommended and supported configuration pattern.
|
||||
|
||||
## Recommended Setup
|
||||
|
||||
### Option A: OpenAI
|
||||
|
||||
`config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"voice": {
|
||||
"tts_model_name": "openai-tts"
|
||||
},
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "openai-tts",
|
||||
"model": "openai/tts-1"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`.security.yml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
openai-tts:
|
||||
api_keys:
|
||||
- "sk-openai-your-key"
|
||||
```
|
||||
|
||||
### Option B: Xiaomi MiMo
|
||||
|
||||
`config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"voice": {
|
||||
"tts_model_name": "mimo-tts"
|
||||
},
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "mimo-tts",
|
||||
"model": "mimo/mimo-v2-tts"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`.security.yml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
mimo-tts:
|
||||
api_keys:
|
||||
- "your-mimo-key"
|
||||
```
|
||||
|
||||
If you use a custom MiMo endpoint, you can also set `api_base` explicitly. Otherwise PicoClaw will use the provider default.
|
||||
|
||||
## What PicoClaw Sends Today
|
||||
|
||||
The current TTS runtime uses an OpenAI-compatible speech request with these defaults:
|
||||
|
||||
- Endpoint: `/audio/speech`
|
||||
- Response format: `opus`
|
||||
- Voice: `alloy`
|
||||
- Model: taken from the selected `model_list` entry
|
||||
|
||||
That means:
|
||||
|
||||
- `openai/tts-1` works naturally.
|
||||
- Other OpenAI-compatible providers can work if they accept the same request format.
|
||||
- PicoClaw currently does not expose a user-facing config field for changing the TTS voice from `alloy`.
|
||||
|
||||
## How PicoClaw Chooses a TTS Provider
|
||||
|
||||
`DetectTTS` resolves TTS in this order:
|
||||
|
||||
1. **Preferred path**: resolve `voice.tts_model_name` against `model_list`.
|
||||
2. If a matching model entry exists and has an API key, PicoClaw creates an OpenAI-compatible TTS provider using that model's settings.
|
||||
3. **Fallback path**: if `voice.tts_model_name` is not set or cannot be resolved, PicoClaw scans `model_list` for the first entry whose model string contains `tts` and has an API key.
|
||||
|
||||
Fallback scanning exists for compatibility. New configs should set `voice.tts_model_name` explicitly.
|
||||
|
||||
## Notes About API Base Handling
|
||||
|
||||
PicoClaw normalizes the configured base URL for TTS:
|
||||
|
||||
- For OpenAI, a base like `https://api.openai.com` or `https://api.openai.com/v1` becomes `https://api.openai.com/v1/audio/speech`.
|
||||
- For other OpenAI-compatible providers, PicoClaw preserves the configured base path and ensures it ends with `/audio/speech`.
|
||||
- If `api_base` is omitted, PicoClaw uses the provider default base when the model prefix is known.
|
||||
|
||||
## Common Mistakes
|
||||
|
||||
- Setting `voice.tts_model_name` to a name that does not exist in `model_list`.
|
||||
- Adding a TTS model but forgetting to put its API key in `.security.yml`.
|
||||
- Assuming PicoClaw will automatically use provider-specific custom voices.
|
||||
- Using a provider endpoint that is not compatible with the OpenAI `/audio/speech` request format.
|
||||
|
||||
## Minimal Checklist
|
||||
|
||||
Before testing `send_tts`, make sure:
|
||||
|
||||
- `voice.tts_model_name` matches a `model_list[].model_name`.
|
||||
- The matching `.security.yml` entry contains a valid API key.
|
||||
- The chosen provider supports an OpenAI-compatible speech synthesis endpoint.
|
||||
- Your selected model is actually a TTS-capable model.
|
||||
@@ -0,0 +1,137 @@
|
||||
# TTS(文本转语音)
|
||||
|
||||
这个目录负责 PicoClaw 的语音合成能力。
|
||||
|
||||
如果你是第一次配置 TTS,可以参照下面这个流程:
|
||||
|
||||
1. 在 `model_list` 里添加一个支持 TTS 的模型。
|
||||
2. 用 `voice.tts_model_name` 指向这个模型。
|
||||
3. 在 `.security.yml` 里配置对应的 API Key。
|
||||
|
||||
## 快速推荐
|
||||
|
||||
对于大多数用户,建议优先从下面两种开始:
|
||||
|
||||
| 提供商 | 推荐理由 |
|
||||
| --- | --- |
|
||||
| [OpenAI](https://platform.openai.com/docs/guides/text-to-speech) | 这是 PicoClaw 当前最稳定、最直接的 TTS 路径。当前实现就是围绕 OpenAI 兼容的 `/audio/speech` 接口格式构建的,所以 OpenAI 是最稳妥的默认选择。 |
|
||||
| [Xiaomi MiMo](https://platform.xiaomimimo.com) | 由于响应速度和语音音色对于中国用户更友好,MiMo 是一个不错的第二选择。 |
|
||||
|
||||
## TTS 配置是如何工作的
|
||||
|
||||
PicoClaw 不会把 TTS 的 API Key 放在 `voice` 配置里。
|
||||
|
||||
推荐方式是:
|
||||
|
||||
- `voice.tts_model_name` 用来选择 `model_list` 里的某个命名模型。
|
||||
- 对应的 `model_list` 条目提供真实的 provider、model ID、`api_base` 和代理配置。
|
||||
- `.security.yml` 负责保存该模型条目的 API Key。
|
||||
|
||||
这是当前推荐且受支持的配置方式。
|
||||
|
||||
## 推荐配置方式
|
||||
|
||||
### 方案 A:OpenAI
|
||||
|
||||
`config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"voice": {
|
||||
"tts_model_name": "openai-tts"
|
||||
},
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "openai-tts",
|
||||
"model": "openai/tts-1"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`.security.yml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
openai-tts:
|
||||
api_keys:
|
||||
- "sk-openai-your-key"
|
||||
```
|
||||
|
||||
### 方案 B:Xiaomi MiMo
|
||||
|
||||
`config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"voice": {
|
||||
"tts_model_name": "mimo-tts"
|
||||
},
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "mimo-tts",
|
||||
"model": "mimo/mimo-v2-tts"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`.security.yml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
mimo-tts:
|
||||
api_keys:
|
||||
- "your-mimo-key"
|
||||
```
|
||||
|
||||
如果你使用自定义的 MiMo 接口地址,也可以显式设置 `api_base`。如果不设置,PicoClaw 会自动使用该 provider 的默认地址。
|
||||
|
||||
## PicoClaw 当前实际发送的 TTS 请求
|
||||
|
||||
当前 TTS 运行时使用的是 OpenAI 兼容的语音合成请求,并带有以下默认值:
|
||||
|
||||
- Endpoint:`/audio/speech`
|
||||
- 返回格式:`opus`
|
||||
- Voice:`alloy`
|
||||
- Model:来自你所选中的 `model_list` 条目
|
||||
|
||||
这意味着:
|
||||
|
||||
- `openai/tts-1` 可以自然工作。
|
||||
- 其他 OpenAI 兼容 provider 也可能可用,前提是它们接受相同的请求格式。
|
||||
- PicoClaw 目前还没有对用户暴露一个配置项来修改 TTS voice,当前固定为 `alloy`。
|
||||
|
||||
## PicoClaw 如何选择 TTS Provider
|
||||
|
||||
`DetectTTS` 会按下面顺序选择 TTS:
|
||||
|
||||
1. **首选路径**:根据 `voice.tts_model_name` 在 `model_list` 中找到对应模型。
|
||||
2. 如果找到了匹配条目,并且它有 API Key,PicoClaw 就会使用这个模型条目的配置创建一个 OpenAI 兼容的 TTS provider。
|
||||
3. **回退路径**:如果没有设置 `voice.tts_model_name`,或者该名字无法解析,PicoClaw 会扫描 `model_list`,选中第一个模型字符串里包含 `tts` 且带有 API Key 的条目。
|
||||
|
||||
回退扫描只是为了兼容旧行为。新配置建议始终显式设置 `voice.tts_model_name`。
|
||||
|
||||
## 关于 API Base 的处理方式
|
||||
|
||||
PicoClaw 会对 TTS 的 `api_base` 做规范化处理:
|
||||
|
||||
- 对 OpenAI 来说,像 `https://api.openai.com` 或 `https://api.openai.com/v1` 这样的地址,会自动变成 `https://api.openai.com/v1/audio/speech`。
|
||||
- 对其他 OpenAI 兼容 provider,PicoClaw 会尽量保留你提供的基础路径,只确保它最终以 `/audio/speech` 结尾。
|
||||
- 如果没有设置 `api_base`,并且模型前缀是已知 provider,PicoClaw 会自动使用该 provider 的默认地址。
|
||||
|
||||
## 常见错误
|
||||
|
||||
- `voice.tts_model_name` 指向了一个不存在的 `model_list` 名称。
|
||||
- 在 `model_list` 里定义了 TTS 模型,但忘了在 `.security.yml` 中配置对应 API Key。
|
||||
- 误以为 PicoClaw 会自动支持 provider 自定义 voice 参数。
|
||||
- 使用了不兼容 OpenAI `/audio/speech` 请求格式的接口地址。
|
||||
|
||||
## 最小检查清单
|
||||
|
||||
在测试 `send_tts` 之前,请确认:
|
||||
|
||||
- `voice.tts_model_name` 能正确匹配某个 `model_list[].model_name`。
|
||||
- `.security.yml` 中对应条目已经配置了有效 API Key。
|
||||
- 你所选的 provider 支持 OpenAI 兼容的语音合成接口。
|
||||
- 你选择的模型本身确实支持 TTS。
|
||||
@@ -0,0 +1,162 @@
|
||||
package tts
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
type MimoTTSProvider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
voice string
|
||||
format string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewMimoTTSProvider(apiKey string, apiBase string, model string, proxyURL string) *MimoTTSProvider {
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.xiaomimimo.com/v1/chat/completions"
|
||||
} else {
|
||||
if u, err := url.Parse(apiBase); err == nil && u.Scheme != "" && u.Host != "" {
|
||||
path := u.Path
|
||||
if u.Host == "api.xiaomimimo.com" {
|
||||
if path == "" || path == "/" || path == "/v1" || path == "/v1/" {
|
||||
path = "/v1/chat/completions"
|
||||
} else {
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
if !strings.HasPrefix(path, "/v1/") {
|
||||
path = "/v1" + strings.TrimSuffix(path, "/")
|
||||
}
|
||||
if !strings.HasSuffix(path, "/chat/completions") {
|
||||
path = strings.TrimSuffix(path, "/") + "/chat/completions"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if !strings.HasSuffix(path, "/chat/completions") {
|
||||
path = strings.TrimSuffix(path, "/") + "/chat/completions"
|
||||
}
|
||||
}
|
||||
u.Path = path
|
||||
apiBase = u.String()
|
||||
} else {
|
||||
if apiBase == "https://api.xiaomimimo.com/v1" {
|
||||
apiBase = "https://api.xiaomimimo.com/v1/chat/completions"
|
||||
} else if !strings.HasSuffix(apiBase, "/chat/completions") {
|
||||
apiBase = strings.TrimSuffix(apiBase, "/") + "/chat/completions"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
model = "mimo-v2-tts"
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
if proxyURL != "" {
|
||||
if pURL, err := url.Parse(proxyURL); err == nil {
|
||||
client.Transport = &http.Transport{Proxy: http.ProxyURL(pURL)}
|
||||
} else {
|
||||
logger.WarnF(
|
||||
"NewMimoTTSProvider: invalid proxy URL; proceeding without proxy",
|
||||
map[string]any{"proxyURL": proxyURL, "error": err},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return &MimoTTSProvider{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
voice: "default_zh", // mimo_default now seems to be an alias for default_en, which is not working for Chinese TTS. default_zh seems to work fine with both English and Chinese, and is likely the intended default for TTS.
|
||||
format: "mp3",
|
||||
model: model,
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MimoTTSProvider) Name() string {
|
||||
return "mimo-tts"
|
||||
}
|
||||
|
||||
func (t *MimoTTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) {
|
||||
logger.DebugCF("voice-tts", "Starting TTS synthesis", map[string]any{"text_len": len(text), "provider": t.Name()})
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": t.model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "assistant", "content": text},
|
||||
},
|
||||
"audio": map[string]string{
|
||||
"format": t.format,
|
||||
"voice": t.voice,
|
||||
},
|
||||
"stream": false,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.apiBase, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Api-Key", t.apiKey)
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Audio struct {
|
||||
Data string `json:"data"`
|
||||
} `json:"audio"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, &payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(payload.Choices) == 0 || payload.Choices[0].Message.Audio.Data == "" {
|
||||
return nil, fmt.Errorf("invalid TTS response: missing audio data")
|
||||
}
|
||||
|
||||
audioBytes, err := base64.StdEncoding.DecodeString(payload.Choices[0].Message.Audio.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode audio data: %w", err)
|
||||
}
|
||||
|
||||
return io.NopCloser(bytes.NewReader(audioBytes)), nil
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package tts
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
type OpenAITTSProvider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
voice string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewOpenAITTSProvider(apiKey string, apiBase string, proxyURL string, model string) *OpenAITTSProvider {
|
||||
// Normalize apiBase to avoid malformed endpoints like
|
||||
// "https://api.openai.com/audio/speech" when "/v1" is required.
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1/audio/speech"
|
||||
} else {
|
||||
if u, err := url.Parse(apiBase); err == nil && u.Scheme != "" && u.Host != "" {
|
||||
path := u.Path
|
||||
if u.Host == "api.openai.com" {
|
||||
// For the official OpenAI host, ensure exactly one /v1 prefix and
|
||||
// that the path ends with /audio/speech.
|
||||
if path == "" || path == "/" || path == "/v1" {
|
||||
path = "/v1/audio/speech"
|
||||
} else {
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
if !strings.HasPrefix(path, "/v1/") {
|
||||
path = "/v1" + strings.TrimSuffix(path, "/")
|
||||
}
|
||||
if !strings.HasSuffix(path, "/audio/speech") {
|
||||
path = strings.TrimSuffix(path, "/") + "/audio/speech"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For non-OpenAI hosts (e.g., proxies), preserve the existing base
|
||||
// path and only ensure it ends with /audio/speech.
|
||||
if !strings.HasSuffix(path, "/audio/speech") {
|
||||
path = strings.TrimSuffix(path, "/") + "/audio/speech"
|
||||
}
|
||||
}
|
||||
u.Path = path
|
||||
apiBase = u.String()
|
||||
} else {
|
||||
// Fallback to the previous string-based behavior if parsing fails.
|
||||
if apiBase == "https://api.openai.com/v1" {
|
||||
apiBase = "https://api.openai.com/v1/audio/speech"
|
||||
} else if !strings.HasSuffix(apiBase, "/audio/speech") {
|
||||
// Just in case they provide openrouter base or standard base
|
||||
apiBase = strings.TrimSuffix(apiBase, "/") + "/audio/speech"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := common.NewHTTPClient(proxyURL)
|
||||
client.Timeout = 60 * time.Second
|
||||
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
model = "tts-1"
|
||||
}
|
||||
|
||||
return &OpenAITTSProvider{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
voice: "alloy",
|
||||
model: model,
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *OpenAITTSProvider) Name() string {
|
||||
return "openai-tts"
|
||||
}
|
||||
|
||||
func (t *OpenAITTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) {
|
||||
logger.DebugCF("voice-tts", "Starting TTS synthesis", map[string]any{"text_len": len(text)})
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": t.model,
|
||||
"input": text,
|
||||
"voice": t.voice,
|
||||
"response_format": "opus",
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.apiBase, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+t.apiKey)
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
package tts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type TTSProvider interface {
|
||||
Name() string
|
||||
Synthesize(ctx context.Context, text string) (io.ReadCloser, error)
|
||||
}
|
||||
|
||||
func providerFromModelConfig(mc *config.ModelConfig) TTSProvider {
|
||||
if mc == nil || mc.APIKey() == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
protocol, modelID := providers.ExtractProtocol(mc.Model)
|
||||
if modelID == "" {
|
||||
modelID = strings.TrimSpace(mc.Model)
|
||||
}
|
||||
|
||||
switch protocol {
|
||||
case "mimo":
|
||||
return NewMimoTTSProvider(mc.APIKey(), providers.ResolveAPIBase(mc), modelID, mc.Proxy)
|
||||
default:
|
||||
return NewOpenAITTSProvider(mc.APIKey(), providers.ResolveAPIBase(mc), mc.Proxy, modelID)
|
||||
}
|
||||
}
|
||||
|
||||
func DetectTTS(cfg *config.Config) TTSProvider {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if modelName := strings.TrimSpace(cfg.Voice.TTSModelName); modelName != "" {
|
||||
if mc, err := cfg.GetModelConfig(modelName); err == nil {
|
||||
if provider := providerFromModelConfig(mc); provider != nil {
|
||||
return provider
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, mc := range cfg.ModelList {
|
||||
if strings.Contains(strings.ToLower(mc.Model), "tts") && mc.APIKey() != "" {
|
||||
if provider := providerFromModelConfig(mc); provider != nil {
|
||||
return provider
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SynthesizeAndStore synthesizes text to speech and registers it in the media store, returning the media reference.
|
||||
func SynthesizeAndStore(
|
||||
ctx context.Context,
|
||||
provider TTSProvider,
|
||||
store media.MediaStore,
|
||||
text string,
|
||||
filename string,
|
||||
channel string,
|
||||
chatID string,
|
||||
) (string, error) {
|
||||
if provider == nil {
|
||||
return "", fmt.Errorf("tts provider is not configured")
|
||||
}
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("media store not configured")
|
||||
}
|
||||
if channel == "" || chatID == "" {
|
||||
return "", fmt.Errorf("no target channel/chat available")
|
||||
}
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return "", fmt.Errorf("text is required")
|
||||
}
|
||||
|
||||
stream, err := provider.Synthesize(ctx, text)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("tts synthesize failed: %w", err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
err = os.MkdirAll(media.TempDir(), 0o700)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create media temp dir: %w", err)
|
||||
}
|
||||
|
||||
fileExt := ".ogg"
|
||||
contentType := "audio/ogg"
|
||||
if provider.Name() == "mimo-tts" {
|
||||
fileExt = ".mp3"
|
||||
contentType = "audio/mpeg"
|
||||
}
|
||||
|
||||
file, err := os.CreateTemp(media.TempDir(), "tts-*"+fileExt)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
|
||||
removeTemp := true
|
||||
defer func() {
|
||||
if removeTemp {
|
||||
_ = os.Remove(file.Name())
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = io.Copy(file, stream)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return "", fmt.Errorf("failed to write tts audio: %w", err)
|
||||
}
|
||||
|
||||
err = file.Close()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to close tts audio file: %w", err)
|
||||
}
|
||||
|
||||
filename = strings.TrimSpace(filename)
|
||||
if filename == "" {
|
||||
filename = fmt.Sprintf("tts-%d%s", time.Now().Unix(), fileExt)
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
if ext == "" {
|
||||
filename += fileExt
|
||||
} else if ext != fileExt {
|
||||
filename = strings.TrimSuffix(filename, filepath.Ext(filename)) + fileExt
|
||||
}
|
||||
|
||||
scope := fmt.Sprintf("tool:send_tts:%s:%s:%d", channel, chatID, time.Now().UnixNano())
|
||||
ref, err := store.Store(file.Name(), media.MediaMeta{
|
||||
Filename: filename,
|
||||
ContentType: contentType,
|
||||
Source: "tool:send_tts",
|
||||
}, scope)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to register audio: %w", err)
|
||||
}
|
||||
removeTemp = false
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package tts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestNewOpenAITTSProvider_APIBaseNormalization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "empty base",
|
||||
input: "",
|
||||
expect: "https://api.openai.com/v1/audio/speech",
|
||||
},
|
||||
{
|
||||
name: "official host no path",
|
||||
input: "https://api.openai.com",
|
||||
expect: "https://api.openai.com/v1/audio/speech",
|
||||
},
|
||||
{
|
||||
name: "official host v1",
|
||||
input: "https://api.openai.com/v1",
|
||||
expect: "https://api.openai.com/v1/audio/speech",
|
||||
},
|
||||
{
|
||||
name: "official host v1 slash",
|
||||
input: "https://api.openai.com/v1/",
|
||||
expect: "https://api.openai.com/v1/audio/speech",
|
||||
},
|
||||
{
|
||||
name: "non-openai host preserves base path",
|
||||
input: "https://proxy.example.com/base",
|
||||
expect: "https://proxy.example.com/base/audio/speech",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
provider := NewOpenAITTSProvider("key", tc.input, "", "")
|
||||
if provider.apiBase != tc.expect {
|
||||
t.Fatalf("apiBase mismatch: got %q, want %q", provider.apiBase, tc.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITTSProvider_SynthesizeSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var gotPath string
|
||||
var gotAuth string
|
||||
var gotContentType string
|
||||
var gotBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotContentType = r.Header.Get("Content-Type")
|
||||
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
_ = json.Unmarshal(bodyBytes, &gotBody)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("audio-bytes"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewOpenAITTSProvider("k123", server.URL, "", "")
|
||||
stream, err := provider.Synthesize(context.Background(), "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("Synthesize failed: %v", err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
data, err := io.ReadAll(stream)
|
||||
if err != nil {
|
||||
t.Fatalf("read stream failed: %v", err)
|
||||
}
|
||||
|
||||
if gotPath != "/audio/speech" {
|
||||
t.Fatalf("request path mismatch: got %q", gotPath)
|
||||
}
|
||||
if gotAuth != "Bearer k123" {
|
||||
t.Fatalf("authorization mismatch: got %q", gotAuth)
|
||||
}
|
||||
if gotContentType != "application/json" {
|
||||
t.Fatalf("content-type mismatch: got %q", gotContentType)
|
||||
}
|
||||
if gotBody["model"] != "tts-1" || gotBody["voice"] != "alloy" || gotBody["response_format"] != "opus" ||
|
||||
gotBody["input"] != "hello" {
|
||||
bodyJSON, _ := json.Marshal(gotBody)
|
||||
t.Fatalf("request body mismatch: %s", string(bodyJSON))
|
||||
}
|
||||
if string(data) != "audio-bytes" {
|
||||
t.Fatalf("response body mismatch: got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITTSProvider_SynthesizeNon200(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("nope"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewOpenAITTSProvider("k123", server.URL, "", "")
|
||||
_, err := provider.Synthesize(context.Background(), "hello")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "API error (status 500): nope") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOpenAITTSProvider_UsesConfiguredModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := NewOpenAITTSProvider("key", "https://api.xiaomimimo.com/v1", "", "mimo-v2-tts")
|
||||
if provider.model != "mimo-v2-tts" {
|
||||
t.Fatalf("model mismatch: got %q, want %q", provider.model, "mimo-v2-tts")
|
||||
}
|
||||
if provider.apiBase != "https://api.xiaomimimo.com/v1/audio/speech" {
|
||||
t.Fatalf("apiBase mismatch: got %q", provider.apiBase)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectTTS_UsesMimoProviderForMimoModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := DetectTTS(&config.Config{
|
||||
Voice: config.VoiceConfig{TTSModelName: "mimo-tts"},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "mimo-tts",
|
||||
Model: "mimo/mimo-v2-tts",
|
||||
APIKeys: config.SimpleSecureStrings("sk-mimo"),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
ttsProvider, ok := provider.(*MimoTTSProvider)
|
||||
if !ok {
|
||||
t.Fatalf("DetectTTS() type = %T, want *MimoTTSProvider", provider)
|
||||
}
|
||||
if ttsProvider.model != "mimo-v2-tts" {
|
||||
t.Fatalf("model mismatch: got %q, want %q", ttsProvider.model, "mimo-v2-tts")
|
||||
}
|
||||
if ttsProvider.apiBase != "https://api.xiaomimimo.com/v1/chat/completions" {
|
||||
t.Fatalf("apiBase mismatch: got %q", ttsProvider.apiBase)
|
||||
}
|
||||
}
|
||||
|
||||
type stubTTSProvider struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (s stubTTSProvider) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
func (s stubTTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) {
|
||||
return io.NopCloser(strings.NewReader("audio")), nil
|
||||
}
|
||||
|
||||
func TestSynthesizeAndStore_UsesOggMetadataByDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ref, err := SynthesizeAndStore(
|
||||
context.Background(),
|
||||
stubTTSProvider{name: "openai-tts"},
|
||||
store,
|
||||
"hello",
|
||||
"",
|
||||
"discord",
|
||||
"chat123",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SynthesizeAndStore failed: %v", err)
|
||||
}
|
||||
|
||||
path, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveWithMeta failed: %v", err)
|
||||
}
|
||||
if meta.ContentType != "audio/ogg" {
|
||||
t.Fatalf("ContentType = %q, want %q", meta.ContentType, "audio/ogg")
|
||||
}
|
||||
if filepath.Ext(path) != ".ogg" {
|
||||
t.Fatalf("stored file extension = %q, want %q", filepath.Ext(path), ".ogg")
|
||||
}
|
||||
if filepath.Ext(meta.Filename) != ".ogg" {
|
||||
t.Fatalf("filename extension = %q, want %q", filepath.Ext(meta.Filename), ".ogg")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeAndStore_UsesMp3MetadataForMimo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ref, err := SynthesizeAndStore(
|
||||
context.Background(),
|
||||
stubTTSProvider{name: "mimo-tts"},
|
||||
store,
|
||||
"hello",
|
||||
"",
|
||||
"discord",
|
||||
"chat123",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SynthesizeAndStore failed: %v", err)
|
||||
}
|
||||
|
||||
path, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveWithMeta failed: %v", err)
|
||||
}
|
||||
if meta.ContentType != "audio/mpeg" {
|
||||
t.Fatalf("ContentType = %q, want %q", meta.ContentType, "audio/mpeg")
|
||||
}
|
||||
if filepath.Ext(path) != ".mp3" {
|
||||
t.Fatalf("stored file extension = %q, want %q", filepath.Ext(path), ".mp3")
|
||||
}
|
||||
if filepath.Ext(meta.Filename) != ".mp3" {
|
||||
t.Fatalf("filename extension = %q, want %q", filepath.Ext(meta.Filename), ".mp3")
|
||||
}
|
||||
}
|
||||
@@ -34,6 +34,8 @@ type MessageBus struct {
|
||||
inbound chan InboundMessage
|
||||
outbound chan OutboundMessage
|
||||
outboundMedia chan OutboundMediaMessage
|
||||
audioChunks chan AudioChunk
|
||||
voiceControls chan VoiceControl
|
||||
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
@@ -47,6 +49,8 @@ func NewMessageBus() *MessageBus {
|
||||
inbound: make(chan InboundMessage, defaultBusBufferSize),
|
||||
outbound: make(chan OutboundMessage, defaultBusBufferSize),
|
||||
outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize),
|
||||
audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer
|
||||
voiceControls: make(chan VoiceControl, defaultBusBufferSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
@@ -103,6 +107,22 @@ func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
|
||||
return mb.outboundMedia
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishAudioChunk(ctx context.Context, chunk AudioChunk) error {
|
||||
return publish(ctx, mb, mb.audioChunks, chunk)
|
||||
}
|
||||
|
||||
func (mb *MessageBus) AudioChunksChan() <-chan AudioChunk {
|
||||
return mb.audioChunks
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishVoiceControl(ctx context.Context, ctrl VoiceControl) error {
|
||||
return publish(ctx, mb, mb.voiceControls, ctrl)
|
||||
}
|
||||
|
||||
func (mb *MessageBus) VoiceControlsChan() <-chan VoiceControl {
|
||||
return mb.voiceControls
|
||||
}
|
||||
|
||||
// SetStreamDelegate registers a StreamDelegate (typically the channel Manager).
|
||||
func (mb *MessageBus) SetStreamDelegate(d StreamDelegate) {
|
||||
mb.streamDelegate.Store(d)
|
||||
@@ -132,6 +152,8 @@ func (mb *MessageBus) Close() {
|
||||
close(mb.inbound)
|
||||
close(mb.outbound)
|
||||
close(mb.outboundMedia)
|
||||
close(mb.audioChunks)
|
||||
close(mb.voiceControls)
|
||||
|
||||
// clean up any remaining messages in channels
|
||||
drained := 0
|
||||
@@ -144,6 +166,12 @@ func (mb *MessageBus) Close() {
|
||||
for range mb.outboundMedia {
|
||||
drained++
|
||||
}
|
||||
for range mb.audioChunks {
|
||||
drained++
|
||||
}
|
||||
for range mb.voiceControls {
|
||||
drained++
|
||||
}
|
||||
|
||||
if drained > 0 {
|
||||
logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
|
||||
|
||||
+27
-4
@@ -30,10 +30,11 @@ type InboundMessage struct {
|
||||
}
|
||||
|
||||
type OutboundMessage struct {
|
||||
Channel string `json:"channel"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Content string `json:"content"`
|
||||
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
|
||||
Channel string `json:"channel"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Content string `json:"content"`
|
||||
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MediaPart describes a single media attachment to send.
|
||||
@@ -51,3 +52,25 @@ type OutboundMediaMessage struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
Parts []MediaPart `json:"parts"`
|
||||
}
|
||||
|
||||
// AudioChunk represents a chunk of streaming voice data.
|
||||
type AudioChunk struct {
|
||||
SessionID string `json:"session_id"`
|
||||
SpeakerID string `json:"speaker_id"` // User ID or SSRC
|
||||
ChatID string `json:"chat_id"` // Where to respond
|
||||
Channel string `json:"channel"` // Source channel type (e.g. "discord")
|
||||
Sequence uint64 `json:"sequence"`
|
||||
Timestamp uint32 `json:"timestamp"`
|
||||
SampleRate int `json:"sample_rate"`
|
||||
Channels int `json:"channels"`
|
||||
Format string `json:"format"` // "opus", "pcm", etc
|
||||
Data []byte `json:"data"`
|
||||
}
|
||||
|
||||
// VoiceControl represents state or commands for voice sessions.
|
||||
type VoiceControl struct {
|
||||
SessionID string `json:"session_id"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Type string `json:"type"` // "state", "command"
|
||||
Action string `json:"action"` // "idle", "listening", "start", "stop", "leave"
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package discord
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -14,6 +15,8 @@ import (
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio"
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -42,6 +45,15 @@ type DiscordChannel struct {
|
||||
typingMu sync.Mutex
|
||||
typingStop map[string]chan struct{} // chatID → stop signal
|
||||
botUserID string // stored for mention checking
|
||||
bus *bus.MessageBus
|
||||
tts tts.TTSProvider
|
||||
voiceMu sync.RWMutex
|
||||
voiceSSRC map[string]map[uint32]string // guildID -> ssrc -> userID
|
||||
|
||||
// TTS interruption: cancel active playback when user speaks
|
||||
ttsMu sync.Mutex
|
||||
cancelTTS context.CancelFunc
|
||||
ttsPlayID uint64
|
||||
}
|
||||
|
||||
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
|
||||
@@ -73,6 +85,8 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
|
||||
config: cfg,
|
||||
ctx: context.Background(),
|
||||
typingStop: make(map[string]chan struct{}),
|
||||
bus: bus,
|
||||
voiceSSRC: make(map[string]map[uint32]string),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -90,6 +104,8 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
|
||||
|
||||
c.session.AddHandler(c.handleMessage)
|
||||
|
||||
go c.listenVoiceControl(c.ctx)
|
||||
|
||||
if err := c.session.Open(); err != nil {
|
||||
return fmt.Errorf("failed to open discord session: %w", err)
|
||||
}
|
||||
@@ -142,6 +158,25 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]s
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if c.tts != nil {
|
||||
if ch, err := c.session.State.Channel(channelID); err == nil && ch.GuildID != "" {
|
||||
if vc, ok := c.session.VoiceConnections[ch.GuildID]; ok && vc != nil {
|
||||
// Cancel any previous TTS playback
|
||||
c.ttsMu.Lock()
|
||||
if c.cancelTTS != nil {
|
||||
c.cancelTTS()
|
||||
}
|
||||
ttsCtx, ttsCancel := context.WithCancel(c.ctx)
|
||||
c.ttsPlayID++
|
||||
playID := c.ttsPlayID
|
||||
c.cancelTTS = ttsCancel
|
||||
c.ttsMu.Unlock()
|
||||
|
||||
go c.playTTS(ttsCtx, vc, msg.Content, playID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgID, err := c.sendChunk(ctx, channelID, msg.Content, msg.ReplyToMessageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -359,6 +394,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
return
|
||||
}
|
||||
|
||||
if c.handleVoiceCommand(s, m) {
|
||||
return
|
||||
}
|
||||
|
||||
content := m.Content
|
||||
|
||||
// In guild (group) channels, apply unified group trigger filtering
|
||||
@@ -630,3 +669,134 @@ func (c *DiscordChannel) stripBotMention(text string) string {
|
||||
text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) listenVoiceControl(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case ctrl, ok := <-c.bus.VoiceControlsChan():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if ctrl.Type == "command" && ctrl.Action == "leave" {
|
||||
if strings.HasPrefix(ctrl.SessionID, "discord_vc_") {
|
||||
guildID := strings.TrimPrefix(ctrl.SessionID, "discord_vc_")
|
||||
vc, exists := c.session.VoiceConnections[guildID]
|
||||
if exists && vc != nil {
|
||||
vc.Disconnect(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) playTTS(ctx context.Context, vc *discordgo.VoiceConnection, text string, playID uint64) {
|
||||
// Capture the cancel func associated with this playback (if any).
|
||||
// Clear cancelTTS when playback finishes (normal or interrupted),
|
||||
// but only if it still refers to this playback's cancel func.
|
||||
defer func() {
|
||||
c.ttsMu.Lock()
|
||||
if c.ttsPlayID == playID {
|
||||
c.cancelTTS = nil
|
||||
}
|
||||
c.ttsMu.Unlock()
|
||||
}()
|
||||
|
||||
sentences := audio.SplitSentences(text)
|
||||
if len(sentences) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoCF("discord", "Starting streamed TTS", map[string]any{"sentences": len(sentences)})
|
||||
|
||||
// Pipeline: prefetch next sentence's audio while playing current
|
||||
type ttResult struct {
|
||||
stream io.ReadCloser
|
||||
err error
|
||||
}
|
||||
|
||||
var prefetch chan ttResult
|
||||
|
||||
// Ensure any in-flight prefetch is drained on exit to prevent stream leaks,
|
||||
// but avoid blocking indefinitely if the prefetch goroutine is stuck or never sends.
|
||||
defer func() {
|
||||
if prefetch != nil {
|
||||
select {
|
||||
case result := <-prefetch:
|
||||
if result.stream != nil {
|
||||
result.stream.Close()
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Timed out waiting for a prefetched result; avoid blocking on exit.
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for i, sentence := range sentences {
|
||||
// Check for cancellation (interruption)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.InfoCF("discord", "TTS interrupted", map[string]any{"at_sentence": i})
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Start prefetching the NEXT sentence while we process the current one
|
||||
var nextPrefetch chan ttResult
|
||||
if i+1 < len(sentences) {
|
||||
nextPrefetch = make(chan ttResult, 1)
|
||||
nextSentence := sentences[i+1]
|
||||
go func() {
|
||||
s, e := c.tts.Synthesize(ctx, nextSentence)
|
||||
nextPrefetch <- ttResult{s, e}
|
||||
}()
|
||||
}
|
||||
|
||||
// Get the current sentence's audio
|
||||
var stream io.ReadCloser
|
||||
var err error
|
||||
|
||||
if prefetch != nil {
|
||||
// Use prefetched result from previous iteration, but be responsive to cancellation.
|
||||
var result ttResult
|
||||
select {
|
||||
case result = <-prefetch:
|
||||
stream, err = result.stream, result.err
|
||||
case <-ctx.Done():
|
||||
// Context canceled while waiting for prefetched audio; abort playback.
|
||||
logger.InfoCF(
|
||||
"discord",
|
||||
"TTS interrupted while waiting for prefetched audio",
|
||||
map[string]any{"at_sentence": i},
|
||||
)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// First sentence: synthesize directly
|
||||
stream, err = c.tts.Synthesize(ctx, sentence)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if stream != nil {
|
||||
stream.Close()
|
||||
}
|
||||
logger.ErrorCF("discord", "TTS synthesize failed", map[string]any{"error": err.Error(), "sentence": i})
|
||||
prefetch = nextPrefetch
|
||||
continue
|
||||
}
|
||||
|
||||
if err := streamOggOpusToDiscord(ctx, vc, stream); err != nil {
|
||||
logger.ErrorCF("discord", "TTS playback failed", map[string]any{"error": err.Error(), "sentence": i})
|
||||
}
|
||||
stream.Close()
|
||||
|
||||
prefetch = nextPrefetch
|
||||
}
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *DiscordChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -8,6 +9,10 @@ import (
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewDiscordChannel(cfg.Channels.Discord, b)
|
||||
ch, err := NewDiscordChannel(cfg.Channels.Discord, b)
|
||||
if err == nil {
|
||||
ch.tts = tts.DetectTTS(cfg)
|
||||
}
|
||||
return ch, err
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func (c *DiscordChannel) setVoiceUserID(guildID string, ssrc uint32, userID string) {
|
||||
if userID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.voiceMu.Lock()
|
||||
defer c.voiceMu.Unlock()
|
||||
|
||||
ssrcMap, ok := c.voiceSSRC[guildID]
|
||||
if !ok {
|
||||
ssrcMap = make(map[uint32]string)
|
||||
c.voiceSSRC[guildID] = ssrcMap
|
||||
}
|
||||
ssrcMap[ssrc] = userID
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) voiceUserID(guildID string, ssrc uint32) string {
|
||||
c.voiceMu.RLock()
|
||||
defer c.voiceMu.RUnlock()
|
||||
|
||||
ssrcMap, ok := c.voiceSSRC[guildID]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return ssrcMap[ssrc]
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) handleVoiceCommand(s *discordgo.Session, m *discordgo.MessageCreate) bool {
|
||||
if m.Content == "!vc join" {
|
||||
vs, err := s.State.VoiceState(m.GuildID, m.Author.ID)
|
||||
if err != nil || vs == nil {
|
||||
if _, sendErr := s.ChannelMessageSend(
|
||||
m.ChannelID,
|
||||
"You need to be in a voice channel first!",
|
||||
); sendErr != nil {
|
||||
logger.InfoCF("discord", "Failed to send voice channel requirement message", map[string]any{
|
||||
"channel": m.ChannelID,
|
||||
"error": sendErr,
|
||||
})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
logger.InfoCF("discord", "Joining voice channel", map[string]any{"channel": vs.ChannelID})
|
||||
vc, err := s.ChannelVoiceJoin(c.ctx, m.GuildID, vs.ChannelID, false, false)
|
||||
if err != nil {
|
||||
if _, sendErr := s.ChannelMessageSend(
|
||||
m.ChannelID,
|
||||
fmt.Sprintf("Failed to join voice channel: %v", err),
|
||||
); sendErr != nil {
|
||||
logger.InfoCF("discord", "Failed to send voice join error message", map[string]any{
|
||||
"channel": m.ChannelID,
|
||||
"error": sendErr,
|
||||
})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
go c.receiveVoice(vc, m.GuildID, m.ChannelID)
|
||||
if _, sendErr := s.ChannelMessageSend(
|
||||
m.ChannelID,
|
||||
"Joined Voice Channel! Listening for audio...",
|
||||
); sendErr != nil {
|
||||
logger.InfoCF("discord", "Failed to send voice join success message", map[string]any{
|
||||
"channel": m.ChannelID,
|
||||
"error": sendErr,
|
||||
})
|
||||
}
|
||||
return true
|
||||
} else if m.Content == "!vc leave" {
|
||||
vc, exists := s.VoiceConnections[m.GuildID]
|
||||
if exists && vc != nil {
|
||||
if err := vc.Disconnect(c.ctx); err != nil {
|
||||
logger.InfoCF("discord", "Failed to disconnect from voice channel", map[string]any{
|
||||
"guild": m.GuildID,
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Left Voice Channel."); sendErr != nil {
|
||||
logger.InfoCF("discord", "Failed to send voice leave success message", map[string]any{
|
||||
"channel": m.ChannelID,
|
||||
"error": sendErr,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Not in a voice channel."); sendErr != nil {
|
||||
logger.InfoCF("discord", "Failed to send voice not-in-channel message", map[string]any{
|
||||
"channel": m.ChannelID,
|
||||
"error": sendErr,
|
||||
})
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func VoiceReceiveActive(vc *discordgo.VoiceConnection) bool {
|
||||
return vc != nil && vc.OpusRecv != nil
|
||||
}
|
||||
|
||||
func streamOggOpusToDiscord(ctx context.Context, vc *discordgo.VoiceConnection, r io.Reader) (retErr error) {
|
||||
// Recover from panic if vc.OpusSend is closed mid-send (e.g. on disconnect)
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
retErr = fmt.Errorf("voice connection closed during playback")
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for the speaking transition to register
|
||||
vc.Speaking(true)
|
||||
defer vc.Speaking(false)
|
||||
|
||||
return audio.DecodeOggOpus(r, func(frame []byte) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case vc.OpusSend <- frame:
|
||||
return nil
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) receiveVoice(vc *discordgo.VoiceConnection, guildID string, chatID string) {
|
||||
logger.InfoCF("discord", "Started listening for voice", map[string]any{"guild": guildID})
|
||||
|
||||
vc.AddHandler(func(_ *discordgo.VoiceConnection, vs *discordgo.VoiceSpeakingUpdate) {
|
||||
if vs == nil {
|
||||
return
|
||||
}
|
||||
c.setVoiceUserID(guildID, uint32(vs.SSRC), vs.UserID)
|
||||
})
|
||||
|
||||
defer func() {
|
||||
c.voiceMu.Lock()
|
||||
delete(c.voiceSSRC, guildID)
|
||||
c.voiceMu.Unlock()
|
||||
}()
|
||||
|
||||
go func(ctx context.Context, vc *discordgo.VoiceConnection) {
|
||||
// Recover from potential panics if OpusSend is closed mid-send.
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
logger.WarnCF("discord", "Recovered from panic while sending wake-up frames", map[string]any{
|
||||
"error": rec,
|
||||
"guild": guildID,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// If the voice connection or OpusSend are not available, nothing to do.
|
||||
if vc == nil || vc.OpusSend == nil {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(250 * time.Millisecond) // Wait a bit for connection to settle
|
||||
|
||||
// Abort if the context has already been canceled.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
vc.Speaking(true)
|
||||
defer vc.Speaking(false)
|
||||
|
||||
silenceFrame := []byte{0xF8, 0xFF, 0xFE}
|
||||
for i := 0; i < 5; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case vc.OpusSend <- silenceFrame:
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
logger.DebugCF("discord", "Sent wake-up silence frames", map[string]any{"guild": guildID})
|
||||
}(c.ctx, vc)
|
||||
sessionID := fmt.Sprintf("discord_vc_%s", guildID)
|
||||
|
||||
c.bus.PublishVoiceControl(c.ctx, bus.VoiceControl{
|
||||
SessionID: sessionID,
|
||||
Type: "state",
|
||||
Action: "listening",
|
||||
})
|
||||
|
||||
var sequence uint64 = 0
|
||||
var interruptCount int
|
||||
var lastInterruptAt time.Time
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case p, ok := <-vc.OpusRecv:
|
||||
if !ok {
|
||||
logger.InfoCF("discord", "Voice channel closed", map[string]any{"guild": guildID})
|
||||
// Cancel any TTS that may still be playing
|
||||
c.ttsMu.Lock()
|
||||
if c.cancelTTS != nil {
|
||||
c.cancelTTS()
|
||||
c.cancelTTS = nil
|
||||
}
|
||||
c.ttsMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if p == nil {
|
||||
logger.DebugCF("discord", "Received nil Opus packet", nil)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(p.Opus) == 0 {
|
||||
logger.DebugCF("discord", "Received empty Opus packet", map[string]any{
|
||||
"seq": p.Sequence,
|
||||
"ssrc": p.SSRC,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.DebugCF("discord", "Received Opus packet", map[string]any{
|
||||
"seq": p.Sequence,
|
||||
"len": len(p.Opus),
|
||||
"ssrc": p.SSRC,
|
||||
})
|
||||
// Interruption detection: if user sends voice while TTS is playing,
|
||||
// cancel TTS after a short debounce (3 packets in 200ms)
|
||||
now := time.Now()
|
||||
if now.Sub(lastInterruptAt) > 500*time.Millisecond {
|
||||
interruptCount = 0
|
||||
}
|
||||
interruptCount++
|
||||
lastInterruptAt = now
|
||||
|
||||
if interruptCount >= 3 {
|
||||
c.ttsMu.Lock()
|
||||
if c.cancelTTS != nil {
|
||||
c.cancelTTS()
|
||||
c.cancelTTS = nil
|
||||
logger.InfoCF("discord", "TTS interrupted by user voice", nil)
|
||||
}
|
||||
c.ttsMu.Unlock()
|
||||
interruptCount = 0
|
||||
}
|
||||
|
||||
userID := c.voiceUserID(guildID, p.SSRC)
|
||||
if userID == "" {
|
||||
logger.DebugCF("discord", "Dropping voice packet without user mapping", map[string]any{
|
||||
"ssrc": p.SSRC,
|
||||
"guild": guildID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "discord",
|
||||
PlatformID: userID,
|
||||
CanonicalID: identity.BuildCanonicalID("discord", userID),
|
||||
}
|
||||
if !c.IsAllowedSender(sender) {
|
||||
logger.DebugCF("discord", "Voice packet rejected by allowlist", map[string]any{
|
||||
"user_id": userID,
|
||||
"guild": guildID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
sequence++
|
||||
|
||||
chunk := bus.AudioChunk{
|
||||
SessionID: sessionID,
|
||||
SpeakerID: userID,
|
||||
ChatID: chatID,
|
||||
Channel: "discord",
|
||||
Sequence: sequence,
|
||||
Timestamp: p.Timestamp,
|
||||
SampleRate: 48000,
|
||||
Channels: 2,
|
||||
Format: "opus",
|
||||
Data: p.Opus,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 100*time.Millisecond)
|
||||
err := c.bus.PublishAudioChunk(ctx, chunk)
|
||||
cancel()
|
||||
if err != nil {
|
||||
logger.ErrorCF("discord", "Failed to publish audio chunk", map[string]any{
|
||||
"guild": guildID,
|
||||
"sessionID": sessionID,
|
||||
"sequence": sequence,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
)
|
||||
|
||||
// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions.
|
||||
@@ -145,3 +147,8 @@ func extractImageKeysRecursive(v any, feishuKeys, externalURLs *[]string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *FeishuChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -684,3 +684,8 @@ func (c *LINEChannel) downloadContent(messageID, filename string) string {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *LINEChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -1300,3 +1300,8 @@ func stripUserMentionWithRegexp(text string, userID id.UserID, mentionR *regexp.
|
||||
cleaned = strings.TrimLeft(cleaned, ",:; ")
|
||||
return strings.TrimSpace(cleaned)
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *MatrixChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -1104,3 +1104,8 @@ func truncate(s string, n int) string {
|
||||
}
|
||||
return string(runes[:n]) + "..."
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *OneBotChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -1002,3 +1002,8 @@ func sanitizeURLs(text string) string {
|
||||
return scheme + domain + path
|
||||
})
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *QQChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -1133,3 +1133,8 @@ func cryptoRandInt() int {
|
||||
_, _ = rand.Read(b[:])
|
||||
return int(binary.BigEndian.Uint32(b[:])) | 1 // ensure non-zero
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *TelegramChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package channels
|
||||
|
||||
// VoiceCapabilities describes whether ASR (speech-to-text) and TTS (text-to-speech)
|
||||
// are available for a channel under the current configuration.
|
||||
type VoiceCapabilities struct {
|
||||
ASR bool
|
||||
TTS bool
|
||||
}
|
||||
|
||||
// VoiceCapabilityProvider is an optional interface for channels that want to
|
||||
// explicitly declare their ASR/TTS support.
|
||||
type VoiceCapabilityProvider interface {
|
||||
VoiceCapabilities() VoiceCapabilities
|
||||
}
|
||||
|
||||
// Deprecated: Channels should implement VoiceCapabilityProvider instead.
|
||||
// To be removed once all existing capable channels conform to the interface.
|
||||
var asrCapableChannels = map[string]bool{
|
||||
"discord": true,
|
||||
"telegram": true,
|
||||
"matrix": true,
|
||||
"qq": true,
|
||||
"weixin": true,
|
||||
"line": true,
|
||||
"feishu": true,
|
||||
"onebot": true,
|
||||
}
|
||||
|
||||
// DetectVoiceCapabilities returns ASR/TTS availability for a channel, gated by
|
||||
// whether providers are configured.
|
||||
func DetectVoiceCapabilities(channelName string, ch Channel, asrAvailable bool, ttsAvailable bool) VoiceCapabilities {
|
||||
if ch == nil {
|
||||
return VoiceCapabilities{}
|
||||
}
|
||||
|
||||
if vcp, ok := ch.(VoiceCapabilityProvider); ok {
|
||||
caps := vcp.VoiceCapabilities()
|
||||
if !asrAvailable {
|
||||
caps.ASR = false
|
||||
}
|
||||
if !ttsAvailable {
|
||||
caps.TTS = false
|
||||
}
|
||||
return caps
|
||||
}
|
||||
|
||||
caps := VoiceCapabilities{}
|
||||
if asrAvailable {
|
||||
caps.ASR = asrCapableChannels[channelName]
|
||||
}
|
||||
if ttsAvailable {
|
||||
if _, ok := ch.(MediaSender); ok {
|
||||
caps.TTS = true
|
||||
}
|
||||
}
|
||||
|
||||
return caps
|
||||
}
|
||||
@@ -402,3 +402,8 @@ func (c *WeixinChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *WeixinChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -558,9 +558,9 @@ type DevicesConfig struct {
|
||||
}
|
||||
|
||||
type VoiceConfig struct {
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"`
|
||||
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
|
||||
ElevenLabsAPIKey string `json:"elevenlabs_api_key,omitempty" env:"PICOCLAW_VOICE_ELEVENLABS_API_KEY"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"`
|
||||
TTSModelName string `json:"tts_model_name,omitempty" env:"PICOCLAW_VOICE_TTS_MODEL_NAME"`
|
||||
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
|
||||
}
|
||||
|
||||
// ModelConfig represents a model-centric provider configuration.
|
||||
@@ -829,6 +829,7 @@ type ToolsConfig struct {
|
||||
Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
|
||||
ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
|
||||
SendTTS ToolConfig `json:"send_tts" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_TTS_"`
|
||||
Spawn ToolConfig `json:"spawn" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
SpawnStatus ToolConfig `json:"spawn_status" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_STATUS_"`
|
||||
SPI ToolConfig `json:"spi" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPI_"`
|
||||
@@ -1281,6 +1282,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
return t.WebFetch.Enabled
|
||||
case "send_file":
|
||||
return t.SendFile.Enabled
|
||||
case "send_tts":
|
||||
return t.SendTTS.Enabled
|
||||
case "write_file":
|
||||
return t.WriteFile.Enabled
|
||||
case "mcp":
|
||||
|
||||
@@ -434,6 +434,9 @@ func DefaultConfig() *Config {
|
||||
SendFile: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
SendTTS: ToolConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: false,
|
||||
|
||||
+51
-3
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -13,6 +14,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/audio/asr"
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/dingtalk"
|
||||
@@ -41,7 +44,6 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -61,6 +63,7 @@ type services struct {
|
||||
ChannelManager *channels.Manager
|
||||
DeviceService *devices.Service
|
||||
HealthServer *health.Server
|
||||
VoiceAgentCancel context.CancelFunc
|
||||
manualReloadChan chan struct{}
|
||||
reloading atomic.Bool
|
||||
authToken string
|
||||
@@ -70,6 +73,27 @@ type startupBlockedProvider struct {
|
||||
reason string
|
||||
}
|
||||
|
||||
func logChannelVoiceCapabilities(cm *channels.Manager, asrAvailable bool, ttsAvailable bool) {
|
||||
if cm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
names := cm.GetEnabledChannels()
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
ch, ok := cm.GetChannel(name)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
caps := channels.DetectVoiceCapabilities(name, ch, asrAvailable, ttsAvailable)
|
||||
logger.InfoCF("voice", "Channel voice capabilities", map[string]any{
|
||||
"channel": name,
|
||||
"asr": caps.ASR,
|
||||
"tts": caps.TTS,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (p *startupBlockedProvider) Chat(
|
||||
_ context.Context,
|
||||
_ []providers.Message,
|
||||
@@ -337,11 +361,14 @@ func setupAndStartServices(
|
||||
agentLoop.SetChannelManager(runningServices.ChannelManager)
|
||||
agentLoop.SetMediaStore(runningServices.MediaStore)
|
||||
|
||||
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
|
||||
transcriber := asr.DetectTranscriber(cfg)
|
||||
if transcriber != nil {
|
||||
agentLoop.SetTranscriber(transcriber)
|
||||
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
}
|
||||
|
||||
ttsAvailable := tts.DetectTTS(cfg) != nil
|
||||
|
||||
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
|
||||
@@ -358,6 +385,16 @@ func setupAndStartServices(
|
||||
return nil, fmt.Errorf("error starting channels: %w", err)
|
||||
}
|
||||
|
||||
logChannelVoiceCapabilities(runningServices.ChannelManager, transcriber != nil, ttsAvailable)
|
||||
|
||||
if transcriber != nil {
|
||||
// Start Voice Agent Orchestrator after channels are ready.
|
||||
vaCtx, vaCancel := context.WithCancel(context.Background())
|
||||
runningServices.VoiceAgentCancel = vaCancel
|
||||
voiceAgent := asr.NewAgent(msgBus, transcriber)
|
||||
voiceAgent.Start(vaCtx)
|
||||
}
|
||||
|
||||
fmt.Printf(
|
||||
"✓ Health endpoints available at http://%s:%d/health, /ready and /reload (POST)\n",
|
||||
cfg.Gateway.Host,
|
||||
@@ -387,6 +424,9 @@ func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Dura
|
||||
if !isReload && runningServices.ChannelManager != nil {
|
||||
runningServices.ChannelManager.StopAll(shutdownCtx)
|
||||
}
|
||||
if runningServices.VoiceAgentCancel != nil {
|
||||
runningServices.VoiceAgentCancel()
|
||||
}
|
||||
if runningServices.DeviceService != nil {
|
||||
runningServices.DeviceService.Stop()
|
||||
}
|
||||
@@ -563,14 +603,22 @@ func restartServices(
|
||||
fmt.Println(" ✓ Device event service restarted")
|
||||
}
|
||||
|
||||
transcriber := voice.DetectTranscriber(cfg)
|
||||
transcriber := asr.DetectTranscriber(cfg)
|
||||
al.SetTranscriber(transcriber)
|
||||
if transcriber != nil {
|
||||
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
|
||||
// Start Voice Agent Orchestrator on reload
|
||||
vaCtx, vaCancel := context.WithCancel(context.Background())
|
||||
runningServices.VoiceAgentCancel = vaCancel
|
||||
voiceAgent := asr.NewAgent(msgBus, transcriber)
|
||||
voiceAgent.Start(vaCtx)
|
||||
} else {
|
||||
logger.InfoCF("voice", "Transcription disabled", nil)
|
||||
}
|
||||
|
||||
ttsAvailable := tts.DetectTTS(cfg) != nil
|
||||
logChannelVoiceCapabilities(runningServices.ChannelManager, transcriber != nil, ttsAvailable)
|
||||
// NOTE: PID file is written once at startup and not updated on reload.
|
||||
// Changing the gateway listen address requires a full restart.
|
||||
|
||||
|
||||
@@ -98,6 +98,19 @@ func ExtractProtocol(model string) (protocol, modelID string) {
|
||||
return protocol, modelID
|
||||
}
|
||||
|
||||
// ResolveAPIBase returns the configured API base, or the protocol default when
|
||||
// the model uses an HTTP-based provider family with a known default endpoint.
|
||||
func ResolveAPIBase(cfg *config.ModelConfig) string {
|
||||
if cfg == nil {
|
||||
return ""
|
||||
}
|
||||
if apiBase := strings.TrimSpace(cfg.APIBase); apiBase != "" {
|
||||
return strings.TrimRight(apiBase, "/")
|
||||
}
|
||||
protocol, _ := ExtractProtocol(cfg.Model)
|
||||
return strings.TrimRight(getDefaultAPIBase(protocol), "/")
|
||||
}
|
||||
|
||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
type SendTTSTool struct {
|
||||
provider tts.TTSProvider
|
||||
mediaStore media.MediaStore
|
||||
}
|
||||
|
||||
func NewSendTTSTool(provider tts.TTSProvider, store media.MediaStore) *SendTTSTool {
|
||||
return &SendTTSTool{
|
||||
provider: provider,
|
||||
mediaStore: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) Name() string { return "send_tts" }
|
||||
|
||||
func (t *SendTTSTool) Description() string {
|
||||
return "Synthesize speech from text and send it as an audio file to the user."
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to synthesize into speech. NOTE: Reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally.",
|
||||
},
|
||||
"filename": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional filename for the audio file (e.g., response.ogg).",
|
||||
},
|
||||
},
|
||||
"required": []string{"text"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) SetMediaStore(store media.MediaStore) {
|
||||
t.mediaStore = store
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
text, _ := args["text"].(string)
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return ErrorResult("text is required")
|
||||
}
|
||||
|
||||
channel := ToolChannel(ctx)
|
||||
chatID := ToolChatID(ctx)
|
||||
filename, _ := args["filename"].(string)
|
||||
|
||||
ref, err := tts.SynthesizeAndStore(
|
||||
ctx,
|
||||
t.provider,
|
||||
t.mediaStore,
|
||||
text,
|
||||
filename,
|
||||
channel,
|
||||
chatID,
|
||||
)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error()).WithError(err)
|
||||
}
|
||||
|
||||
// Return with ForUser set to original text, Media containing the audio ref,
|
||||
// and mark as ResponseHandled so the audio is sent immediately without LLM intervention.
|
||||
return &ToolResult{
|
||||
ForLLM: "TTS audio sent",
|
||||
ForUser: text,
|
||||
Media: []string{ref},
|
||||
ResponseHandled: true,
|
||||
}
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
package voice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type GroqTranscriber struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewGroqTranscriber(apiKey string) *GroqTranscriber {
|
||||
logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""})
|
||||
|
||||
apiBase := "https://api.groq.com/openai/v1"
|
||||
return &GroqTranscriber{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
|
||||
logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath})
|
||||
|
||||
audioFile, err := os.Open(audioFilePath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err})
|
||||
return nil, fmt.Errorf("failed to open audio file: %w", err)
|
||||
}
|
||||
defer audioFile.Close()
|
||||
|
||||
fileInfo, err := audioFile.Stat()
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err})
|
||||
return nil, fmt.Errorf("failed to get file info: %w", err)
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "Audio file details", map[string]any{
|
||||
"size_bytes": fileInfo.Size(),
|
||||
"file_name": filepath.Base(audioFilePath),
|
||||
})
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to create form file: %w", err)
|
||||
}
|
||||
|
||||
copied, err := io.Copy(part, audioFile)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to copy file content: %w", err)
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied})
|
||||
|
||||
if err = writer.WriteField("model", "whisper-large-v3"); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to write model field: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.WriteField("response_format", "json"); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to write response_format field: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.Close(); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
|
||||
}
|
||||
|
||||
url := t.apiBase + "/audio/transcriptions"
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer "+t.apiKey)
|
||||
|
||||
logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{
|
||||
"url": url,
|
||||
"request_size_bytes": requestBody.Len(),
|
||||
"file_size_bytes": fileInfo.Size(),
|
||||
})
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.ErrorCF("voice", "API error", map[string]any{
|
||||
"status_code": resp.StatusCode,
|
||||
"response": string(body),
|
||||
})
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "Received response from Groq API", map[string]any{
|
||||
"status_code": resp.StatusCode,
|
||||
"response_size_bytes": len(body),
|
||||
})
|
||||
|
||||
var result TranscriptionResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
logger.InfoCF("voice", "Transcription completed successfully", map[string]any{
|
||||
"text_length": len(result.Text),
|
||||
"language": result.Language,
|
||||
"duration_seconds": result.Duration,
|
||||
"transcription_preview": utils.Truncate(result.Text, 50),
|
||||
})
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (t *GroqTranscriber) Name() string {
|
||||
return "groq"
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var _ Transcriber = (*GroqTranscriber)(nil)
|
||||
|
||||
func TestGroqTranscriberName(t *testing.T) {
|
||||
tr := NewGroqTranscriber("sk-test")
|
||||
if got := tr.Name(); got != "groq" {
|
||||
t.Errorf("Name() = %q, want %q", got, "groq")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroqTranscribe(t *testing.T) {
|
||||
// Write a minimal fake audio file so the transcriber can open and send it.
|
||||
tmpDir := t.TempDir()
|
||||
audioPath := filepath.Join(tmpDir, "clip.ogg")
|
||||
if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil {
|
||||
t.Fatalf("failed to write fake audio file: %v", err)
|
||||
}
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/audio/transcriptions" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer sk-test" {
|
||||
t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization"))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TranscriptionResponse{
|
||||
Text: "hello world",
|
||||
Language: "en",
|
||||
Duration: 1.5,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewGroqTranscriber("sk-test")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
resp, err := tr.Transcribe(context.Background(), audioPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Transcribe() error: %v", err)
|
||||
}
|
||||
if resp.Text != "hello world" {
|
||||
t.Errorf("Text = %q, want %q", resp.Text, "hello world")
|
||||
}
|
||||
if resp.Language != "en" {
|
||||
t.Errorf("Language = %q, want %q", resp.Language, "en")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("api error", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewGroqTranscriber("sk-bad")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
_, err := tr.Transcribe(context.Background(), audioPath)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-200 response, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing file", func(t *testing.T) {
|
||||
tr := NewGroqTranscriber("sk-test")
|
||||
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type Transcriber interface {
|
||||
Name() string
|
||||
Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
|
||||
}
|
||||
|
||||
type TranscriptionResponse struct {
|
||||
Text string `json:"text"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Duration float64 `json:"duration,omitempty"`
|
||||
}
|
||||
|
||||
func supportsAudioTranscription(model string) bool {
|
||||
protocol, _ := providers.ExtractProtocol(model)
|
||||
|
||||
switch protocol {
|
||||
case "openai", "azure", "azure-openai",
|
||||
"litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
||||
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
|
||||
"coding-plan", "alibaba-coding", "qwen-coding":
|
||||
// These protocols all go through the OpenAI-compatible or Azure provider path in
|
||||
// providers.CreateProviderFromConfig, so they are the only ones that can supply
|
||||
// the audio media payload shape expected by NewAudioModelTranscriber.
|
||||
|
||||
// TODO: Further restrict this by modelID, since not every model under these
|
||||
// protocols supports audio transcription.
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or
|
||||
// nil if no supported transcription provider is configured.
|
||||
func DetectTranscriber(cfg *config.Config) Transcriber {
|
||||
if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" {
|
||||
modelCfg, err := cfg.GetModelConfig(modelName)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if supportsAudioTranscription(modelCfg.Model) {
|
||||
return NewAudioModelTranscriber(modelCfg)
|
||||
}
|
||||
}
|
||||
|
||||
// ElevenLabs voice config (supports Scribe STT).
|
||||
if key := strings.TrimSpace(cfg.Voice.ElevenLabsAPIKey); key != "" {
|
||||
return NewElevenLabsTranscriber(key)
|
||||
}
|
||||
// Fall back to any model-list entry that uses the groq/ protocol.
|
||||
for _, mc := range cfg.ModelList {
|
||||
if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey() != "" {
|
||||
return NewGroqTranscriber(mc.APIKey())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user