From 0f395ce11057d941344c43f32ef95295fe8c178d Mon Sep 17 00:00:00 2001 From: Hua Audio Date: Wed, 1 Apr 2026 06:21:21 +0200 Subject: [PATCH] Refactor/asr tts (#1939) * refactor: update ASR and TTS implementations * fix lint * Integrating asr/tts models w/ new security config * update documents * add arbitrary whisper transcriptor support * update documents * fix lint * add mimo tts --- .golangci.yaml | 3 + config/config.example.json | 3 + go.mod | 6 + go.sum | 13 +- pkg/agent/loop.go | 70 +++- pkg/audio/asr/README.md | 166 ++++++++++ pkg/audio/asr/README_zh.md | 166 ++++++++++ pkg/audio/asr/agent.go | 252 ++++++++++++++ pkg/audio/asr/agent_test.go | 196 +++++++++++ pkg/audio/asr/asr.go | 131 ++++++++ .../asr/asr_test.go} | 72 +++- .../asr}/audio_model_transcriber.go | 2 +- .../asr}/audio_model_transcriber_test.go | 2 +- .../asr}/elevenlabs_transcriber.go | 10 +- .../asr}/elevenlabs_transcriber_test.go | 10 +- pkg/audio/asr/whisper_transcriber.go | 245 ++++++++++++++ pkg/audio/asr/whisper_transcriber_test.go | 102 ++++++ pkg/audio/ogg.go | 57 ++++ pkg/audio/ogg_test.go | 146 ++++++++ pkg/audio/sentence.go | 96 ++++++ pkg/audio/sentence_test.go | 69 ++++ pkg/audio/tts/README.md | 137 ++++++++ pkg/audio/tts/README_zh.md | 137 ++++++++ pkg/audio/tts/mimo_tts.go | 162 +++++++++ pkg/audio/tts/openai_tts.go | 126 +++++++ pkg/audio/tts/tts.go | 151 +++++++++ pkg/audio/tts/tts_test.go | 247 ++++++++++++++ pkg/bus/bus.go | 28 ++ pkg/bus/types.go | 31 +- pkg/channels/discord/discord.go | 170 ++++++++++ pkg/channels/discord/init.go | 7 +- pkg/channels/discord/voice.go | 313 ++++++++++++++++++ pkg/channels/feishu/common.go | 7 + pkg/channels/line/line.go | 5 + pkg/channels/matrix/matrix.go | 5 + pkg/channels/onebot/onebot.go | 5 + pkg/channels/qq/qq.go | 5 + pkg/channels/telegram/telegram.go | 5 + pkg/channels/voice_capabilities.go | 58 ++++ pkg/channels/weixin/weixin.go | 5 + pkg/config/config.go | 9 +- pkg/config/defaults.go | 3 + pkg/gateway/gateway.go | 54 ++- pkg/providers/factory_provider.go | 13 + pkg/tools/tts_send.go | 82 +++++ pkg/voice/groq_transcriber.go | 151 --------- pkg/voice/groq_transcriber_test.go | 84 ----- pkg/voice/transcriber.go | 68 ---- 48 files changed, 3527 insertions(+), 358 deletions(-) create mode 100644 pkg/audio/asr/README.md create mode 100644 pkg/audio/asr/README_zh.md create mode 100644 pkg/audio/asr/agent.go create mode 100644 pkg/audio/asr/agent_test.go create mode 100644 pkg/audio/asr/asr.go rename pkg/{voice/transcriber_test.go => audio/asr/asr_test.go} (67%) rename pkg/{voice => audio/asr}/audio_model_transcriber.go (99%) rename pkg/{voice => audio/asr}/audio_model_transcriber_test.go (99%) rename pkg/{voice => audio/asr}/elevenlabs_transcriber.go (96%) rename pkg/{voice => audio/asr}/elevenlabs_transcriber_test.go (91%) create mode 100644 pkg/audio/asr/whisper_transcriber.go create mode 100644 pkg/audio/asr/whisper_transcriber_test.go create mode 100644 pkg/audio/ogg.go create mode 100644 pkg/audio/ogg_test.go create mode 100644 pkg/audio/sentence.go create mode 100644 pkg/audio/sentence_test.go create mode 100644 pkg/audio/tts/README.md create mode 100644 pkg/audio/tts/README_zh.md create mode 100644 pkg/audio/tts/mimo_tts.go create mode 100644 pkg/audio/tts/openai_tts.go create mode 100644 pkg/audio/tts/tts.go create mode 100644 pkg/audio/tts/tts_test.go create mode 100644 pkg/channels/discord/voice.go create mode 100644 pkg/channels/voice_capabilities.go create mode 100644 pkg/tools/tts_send.go delete mode 100644 pkg/voice/groq_transcriber.go delete mode 100644 pkg/voice/groq_transcriber_test.go delete mode 100644 pkg/voice/transcriber.go diff --git a/.golangci.yaml b/.golangci.yaml index ea3107ec8..b2b772406 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -61,6 +61,9 @@ linters: - usestdlibvars - usetesting settings: + gomoddirectives: + replace-allow-list: + - github.com/bwmarrin/discordgo errcheck: check-type-assertions: true check-blank: true diff --git a/config/config.example.json b/config/config.example.json index 814c82503..95fe24e0b 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -418,6 +418,9 @@ "read_file": { "enabled": true }, + "send_tts": { + "enabled": false + }, "spawn": { "enabled": true }, diff --git a/go.mod b/go.mod index 7d242d498..5f311306e 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,8 @@ require ( github.com/mymmrac/telego v1.7.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 + github.com/pion/rtp v1.8.7 + github.com/pion/webrtc/v3 v3.3.6 github.com/rivo/tview v0.42.0 github.com/rs/zerolog v1.34.0 github.com/slack-go/slack v0.17.3 @@ -61,6 +63,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect github.com/aws/smithy-go v1.24.2 // indirect github.com/beeper/argo-go v1.1.2 // indirect + github.com/cloudflare/circl v1.6.3 // indirect github.com/coder/websocket v1.8.14 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -76,6 +79,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.34 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/petermattis/goid v0.0.0-20260226131333-17d1149c6ac6 // indirect + github.com/pion/randutil v0.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect @@ -123,3 +127,5 @@ require ( golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 ) + +replace github.com/bwmarrin/discordgo => github.com/yeongaori/discordgo-fork v0.0.0-20260319072544-e8e546f5d532 diff --git a/go.sum b/go.sum index 76d1b46c7..ca5dd0423 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,6 @@ github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs= github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4= -github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= -github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= @@ -65,6 +63,8 @@ github.com/caarlos0/env/v11 v11.4.0 h1:Kcb6t5kIIr4XkoQC9AF2j+8E1Jsrl3Wz/hhm1LtoG github.com/caarlos0/env/v11 v11.4.0/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= +github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= @@ -204,6 +204,12 @@ github.com/openai/openai-go/v3 v3.22.0 h1:6MEoNoV8sbjOVmXdvhmuX3BjVbVdcExbVyGixi github.com/openai/openai-go/v3 v3.22.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/petermattis/goid v0.0.0-20260226131333-17d1149c6ac6 h1:rh2lKw/P/EqHa724vYH2+VVQ1YnW4u6EOXl0PMAovZE= github.com/petermattis/goid v0.0.0-20260226131333-17d1149c6ac6/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= +github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/rtp v1.8.7 h1:qslKkG8qxvQ7hqaxkmL7Pl0XcUm+/Er7nMnu6Vq+ZxM= +github.com/pion/rtp v1.8.7/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= +github.com/pion/webrtc/v3 v3.3.6 h1:7XAh4RPtlY1Vul6/GmZrv7z+NnxKA6If0KStXBI2ZLE= +github.com/pion/webrtc/v3 v3.3.6/go.mod h1:zyN7th4mZpV27eXybfR/cnUf3J2DRy8zw/mdjD9JTNM= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -273,6 +279,8 @@ github.com/vektah/gqlparser/v2 v2.5.27 h1:RHPD3JOplpk5mP5JGX8RKZkt2/Vwj/PZv0HxTd github.com/vektah/gqlparser/v2 v2.5.27/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yeongaori/discordgo-fork v0.0.0-20260319072544-e8e546f5d532 h1:gxFHYeUDGziRb0zXYEqBFohC+NJbIW9L0tddaXMWr2o= +github.com/yeongaori/discordgo-fork v0.0.0-20260319072544-e8e546f5d532/go.mod h1:A0FcMFJKJ9fRjgSuZ2o+pIQ6mPS81SVuiLN2vYTa7Ao= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -300,7 +308,6 @@ golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index d7461e76f..b376ed0af 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -18,6 +18,8 @@ import ( "sync/atomic" "time" + "github.com/sipeed/picoclaw/pkg/audio/asr" + "github.com/sipeed/picoclaw/pkg/audio/tts" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/commands" @@ -31,7 +33,6 @@ import ( "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) type AgentLoop struct { @@ -51,7 +52,7 @@ type AgentLoop struct { fallback *providers.FallbackChain channelManager *channels.Manager mediaStore media.MediaStore - transcriber voice.Transcriber + transcriber asr.Transcriber cmdRegistry *commands.Registry mcp mcpRuntime hookRuntime hookRuntime @@ -159,6 +160,13 @@ func registerSharedTools( provider providers.LLMProvider, ) { allowReadPaths := buildAllowReadPatterns(cfg) + var ttsProvider tts.TTSProvider + if cfg.Tools.IsToolEnabled("send_tts") { + ttsProvider = tts.DetectTTS(cfg) + if ttsProvider == nil { + logger.WarnCF("voice-tts", "send_tts enabled but no TTS provider configured", nil) + } + } for _, agentID := range registry.ListAgentIDs() { agent, ok := registry.GetAgent(agentID) @@ -269,6 +277,10 @@ func registerSharedTools( agent.Tools.Register(sendFileTool) } + if ttsProvider != nil { + agent.Tools.Register(tools.NewSendTTSTool(ttsProvider, nil)) + } + // Skill discovery and installation tools skills_enabled := cfg.Tools.IsToolEnabled("skills") find_skills_enable := cfg.Tools.IsToolEnabled("find_skills") @@ -1059,10 +1071,15 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) { agent.Tools.SetMediaStore(s) } } + registry.ForEachTool("send_tts", func(t tools.Tool) { + if st, ok := t.(*tools.SendTTSTool); ok { + st.SetMediaStore(s) + } + }) } // SetTranscriber injects a voice transcriber for agent-level audio transcription. -func (al *AgentLoop) SetTranscriber(t voice.Transcriber) { +func (al *AgentLoop) SetTranscriber(t asr.Transcriber) { al.transcriber = t } @@ -1083,19 +1100,23 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou // Transcribe each audio media ref in order. var transcriptions []string + var keptMedia []string for _, ref := range msg.Media { path, meta, err := al.mediaStore.ResolveWithMeta(ref) if err != nil { logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err}) + keptMedia = append(keptMedia, ref) continue } if !utils.IsAudioFile(meta.Filename, meta.ContentType) { + keptMedia = append(keptMedia, ref) continue } result, err := al.transcriber.Transcribe(ctx, path) if err != nil { logger.WarnCF("voice", "Transcription failed", map[string]any{"ref": ref, "error": err}) transcriptions = append(transcriptions, "") + keptMedia = append(keptMedia, ref) continue } transcriptions = append(transcriptions, result.Text) @@ -1115,15 +1136,21 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou } text := transcriptions[idx] idx++ + if text == "" { + return match + } return "[voice: " + text + "]" }) // Append any remaining transcriptions not matched by an annotation. for ; idx < len(transcriptions); idx++ { - newContent += "\n[voice: " + transcriptions[idx] + "]" + if transcriptions[idx] != "" { + newContent += "\n[voice: " + transcriptions[idx] + "]" + } } msg.Content = newContent + msg.Media = keptMedia return msg, true } @@ -2464,6 +2491,28 @@ turnLoop: if toolResult == nil { toolResult = tools.ErrorResult("hook returned nil tool result") } + + // Send ForUser if not silent and has content. + // For ResponseHandled tools, send regardless of SendResponse setting, + // since they've already handled the response (e.g., send_tts, send_file). + shouldSendForUser := !toolResult.Silent && toolResult.ForUser != "" && + (ts.opts.SendResponse || toolResult.ResponseHandled) + if shouldSendForUser { + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + Channel: ts.channel, + ChatID: ts.chatID, + Content: toolResult.ForUser, + Metadata: map[string]string{ + "is_tool_call": "true", + }, + }) + logger.DebugCF("agent", "Sent tool result to user", + map[string]any{ + "tool": toolName, + "content_len": len(toolResult.ForUser), + }) + } + if len(toolResult.Media) > 0 && toolResult.ResponseHandled { parts := make([]bus.MediaPart, 0, len(toolResult.Media)) for _, ref := range toolResult.Media { @@ -2509,19 +2558,6 @@ turnLoop: allResponsesHandled = false } - if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Content: toolResult.ForUser, - }) - logger.DebugCF("agent", "Sent tool result to user", - map[string]any{ - "tool": toolName, - "content_len": len(toolResult.ForUser), - }) - } - contentForLLM := toolResult.ContentForLLM() // Filter sensitive data (API keys, tokens, secrets) before sending to LLM diff --git a/pkg/audio/asr/README.md b/pkg/audio/asr/README.md new file mode 100644 index 000000000..0477276dd --- /dev/null +++ b/pkg/audio/asr/README.md @@ -0,0 +1,166 @@ +# ASR (Automatic Speech Recognition) + +This package handles speech-to-text for PicoClaw voice input. + +If you are new to ASR setup, the simplest mental model is: + +1. Add one or more ASR-capable entries to `model_list`. +2. Point `voice.model_name` at the one you want to use. +3. Put the API key in `.security.yml`. + +## Quick Recommendation + +For most new users, start with one of these: + +| Provider | Example model | Why start here | +| --- | --- | --- | +| [Groq](https://console.groq.com/keys) | `groq/whisper-large-v3-turbo` | Fast Whisper-style transcription and a straightforward OpenAI-compatible API. Groq currently advertises a free tier plan for 2000 reqs/day. | +| [ElevenLabs](https://elevenlabs.io/pricing) | `elevenlabs/scribe_v1` | Easy setup and strong speech-to-text quality. ElevenLabs currently advertises a free plan that includes speech-to-text usage. | + +Pricing and free-plan limits can change, so check the linked pricing pages before depending on them in production. + +## How ASR Configuration Works + +PicoClaw does not keep ASR API keys inside the `voice` section. + +Instead: + +- `voice.model_name` chooses a named entry from `model_list`. +- The matching `model_list` entry describes the actual provider and model. +- `.security.yml` stores the API key for that named model entry. + +This is the recommended pattern because it is explicit, reusable, and consistent with the rest of PicoClaw's model configuration. + +## Recommended Setup + +### Option A: Groq Whisper + +`config.json` + +```json +{ + "voice": { + "model_name": "groq-asr", + "echo_transcription": true + }, + "model_list": [ + { + "model_name": "groq-asr", + "model": "groq/whisper-large-v3-turbo" + } + ] +} +``` + +`.security.yml` + +```yaml +model_list: + groq-asr: + api_keys: + - "gsk_your_groq_key" +``` + +Notes: + +- You can omit `api_base` and PicoClaw will use Groq's default API base automatically. +- If you set `api_base` manually for Groq Whisper, both of these forms work: + - `https://api.groq.com/openai/v1` + - `https://api.groq.com/openai/v1/audio/transcriptions` +- Any OpenAI-compatible Whisper model name containing `whisper` can use the Whisper transcription path, not only `whisper-large-v3-turbo`. + +### Option B: ElevenLabs + +`config.json` + +```json +{ + "voice": { + "model_name": "elevenlabs-asr", + "echo_transcription": true + }, + "model_list": [ + { + "model_name": "elevenlabs-asr", + "model": "elevenlabs/scribe_v1" + } + ] +} +``` + +`.security.yml` + +```yaml +model_list: + elevenlabs-asr: + api_keys: + - "sk-elevenlabs-your-key" +``` + +### Option C: OpenAI Whisper + +`config.json` + +```json +{ + "voice": { + "model_name": "openai-asr" + }, + "model_list": [ + { + "model_name": "openai-asr", + "model": "openai/whisper-1" + } + ] +} +``` + +`.security.yml` + +```yaml +model_list: + openai-asr: + api_keys: + - "sk-openai-your-key" +``` + +## Other ASR-Capable Model Types + +PicoClaw currently supports three main ASR routes: + +| Route | Example models | Behavior | +| --- | --- | --- | +| ElevenLabs ASR | `elevenlabs/scribe_v1` | Uses the ElevenLabs transcription API. | +| Whisper endpoint models | `openai/whisper-1`, `groq/whisper-large-v3` | Uses an OpenAI-compatible `/audio/transcriptions` endpoint. | +| Audio-capable chat models **(Under construction)** | `openai/gpt-4o-audio-preview`, `gemini/gemini-2.5-flash` | Sends audio to a multimodal chat model and asks it to transcribe. | + +If you are unsure which one to pick, choose Groq Whisper or ElevenLabs first. + +## How PicoClaw Chooses a Transcriber + +`DetectTranscriber` resolves ASR in this order: + +1. **Preferred path**: resolve `voice.model_name` against `model_list`. +2. If that resolved model is: + - `elevenlabs/...`, PicoClaw uses the ElevenLabs transcriber. + - an OpenAI-compatible Whisper model, PicoClaw uses the Whisper transcriber. + - an audio-capable chat model, PicoClaw uses `AudioModelTranscriber`. +3. **Fallback path**: if `voice.model_name` is not set, PicoClaw performs a compatibility scan through `model_list` for legacy auto-detected ASR entries. + +Fallback scanning exists for backward compatibility. New configurations should set `voice.model_name` explicitly. + +## Common Mistakes + +- Defining an ASR model in `model_list` but forgetting to set `voice.model_name`. +- Putting the API key in `voice` instead of `.security.yml`. +- Using a non-ASR model and expecting Whisper-style transcription behavior. +- Setting a custom `api_base` that points to the wrong provider endpoint. + +## Minimal Checklist + +Before testing voice input, make sure: + +- `voice.model_name` matches a `model_list[].model_name`. +- The matching `.security.yml` entry contains a valid API key. +- The selected model is actually ASR-capable. +- Voice input is enabled for the channel you are using. diff --git a/pkg/audio/asr/README_zh.md b/pkg/audio/asr/README_zh.md new file mode 100644 index 000000000..104116080 --- /dev/null +++ b/pkg/audio/asr/README_zh.md @@ -0,0 +1,166 @@ +# ASR(自动语音识别) + +这个目录负责 PicoClaw 的语音转文字能力。 + +如果你是第一次配置 ASR,可以参考如下步骤: + +1. 在 `model_list` 里添加一个或多个支持 ASR 的模型条目。 +2. 用 `voice.model_name` 指向你想使用的那个条目。 +3. 在 `.security.yml` 里配置对应的 API Key。 + +## 快速推荐 + +对于大多数新用户,建议先从下面两种开始: + +| 提供商 | 示例模型 | 推荐理由 | +| --- | --- | --- | +| [Groq](https://console.groq.com/keys) | `groq/whisper-large-v3-turbo` | Whisper 风格转录速度快,并且提供 OpenAI 兼容接口,配置比较直接。Groq 目前官方提供2000请求每日的免费套餐。 | +| [ElevenLabs](https://elevenlabs.io/pricing) | `elevenlabs/scribe_v1` | 上手简单,语音转文字质量也不错。ElevenLabs 目前官方免费套餐包含 STT 用量。 | + +价格和免费额度可能会变化,正式使用前请以官网定价页为准。 + +## ASR 配置是如何工作的 + +PicoClaw 不会把 ASR 的 API Key 放在 `voice` 配置里。 + +推荐的方式是: + +- `voice.model_name` 用来选择 `model_list` 里的某个命名模型。 +- `model_list` 条目描述真实的提供商和模型。 +- `.security.yml` 负责保存该模型条目的 API Key。 + +这种方式更明确、更安全,也和 PicoClaw 其他模型配置方式保持一致。 + +## 推荐配置方式 + +### 方案 A:Groq Whisper + +`config.json` + +```json +{ + "voice": { + "model_name": "groq-asr", + "echo_transcription": true + }, + "model_list": [ + { + "model_name": "groq-asr", + "model": "groq/whisper-large-v3-turbo" + } + ] +} +``` + +`.security.yml` + +```yaml +model_list: + groq-asr: + api_keys: + - "gsk_your_groq_key" +``` + +说明: + +- 你可以不写 `api_base`,PicoClaw 会自动使用 Groq 默认接口地址。 +- 如果你手动设置 Groq Whisper 的 `api_base`,下面两种写法都可以: + - `https://api.groq.com/openai/v1` + - `https://api.groq.com/openai/v1/audio/transcriptions` +- 只要是 OpenAI 兼容、并且模型名里包含 `whisper` 的模型,都可以走 Whisper 转录路径,不仅限于 `whisper-large-v3-turbo`。 + +### 方案 B:ElevenLabs + +`config.json` + +```json +{ + "voice": { + "model_name": "elevenlabs-asr", + "echo_transcription": true + }, + "model_list": [ + { + "model_name": "elevenlabs-asr", + "model": "elevenlabs/scribe_v1" + } + ] +} +``` + +`.security.yml` + +```yaml +model_list: + elevenlabs-asr: + api_keys: + - "sk-elevenlabs-your-key" +``` + +### 方案 C:OpenAI Whisper + +`config.json` + +```json +{ + "voice": { + "model_name": "openai-asr" + }, + "model_list": [ + { + "model_name": "openai-asr", + "model": "openai/whisper-1" + } + ] +} +``` + +`.security.yml` + +```yaml +model_list: + openai-asr: + api_keys: + - "sk-openai-your-key" +``` + +## 其他支持 ASR 的模型类型 + +PicoClaw 目前主要支持三种 ASR 路径: + +| 路径 | 示例模型 | 行为说明 | +| --- | --- | --- | +| ElevenLabs ASR | `elevenlabs/scribe_v1` | 使用 ElevenLabs 的语音转录接口。 | +| Whisper 接口模型 | `openai/whisper-1`、`groq/whisper-large-v3` | 使用 OpenAI 兼容的 `/audio/transcriptions` 接口。 | +| 支持音频的聊天模型 **(重构中)** | `openai/gpt-4o-audio-preview`、`gemini/gemini-2.5-flash` | 把音频发给多模态聊天模型,并要求它返回转录结果。 | + +如果你不确定该选哪种,建议优先使用 Groq Whisper 或 ElevenLabs。 + +## PicoClaw 如何选择转录器 + +`DetectTranscriber` 会按下面顺序选择 ASR: + +1. **首选路径**:根据 `voice.model_name` 在 `model_list` 中找到对应模型。 +2. 如果找到的模型属于以下类型: + - `elevenlabs/...`,则使用 ElevenLabs transcriber。 + - OpenAI 兼容的 Whisper 模型,则使用 Whisper transcriber。 + - 支持音频输入的聊天模型,则使用 `AudioModelTranscriber`。 +3. **回退路径**:如果没有设置 `voice.model_name`,PicoClaw 会为了兼容旧配置,扫描 `model_list` 中可自动识别的 ASR 条目。 + +回退扫描只是为了兼容旧行为。新配置建议始终显式设置 `voice.model_name`。 + +## 常见错误 + +- 在 `model_list` 里定义了 ASR 模型,但忘了设置 `voice.model_name`。 +- 把 API Key 写进了 `voice`,而不是 `.security.yml`。 +- 选择了不支持 ASR 的模型,却期望得到 Whisper 风格的转录结果。 +- 自定义了错误的 `api_base`,导致请求打到错误的接口地址。 + +## 最小检查清单 + +在测试语音输入前,请确认: + +- `voice.model_name` 能正确匹配某个 `model_list[].model_name`。 +- `.security.yml` 中对应条目已经配置了有效 API Key。 +- 你选择的模型确实支持 ASR。 +- 你当前使用的频道已经启用了语音输入能力。 diff --git a/pkg/audio/asr/agent.go b/pkg/audio/asr/agent.go new file mode 100644 index 000000000..32ce0c92a --- /dev/null +++ b/pkg/audio/asr/agent.go @@ -0,0 +1,252 @@ +package asr + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/pion/rtp" + "github.com/pion/webrtc/v3/pkg/media/oggwriter" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/logger" +) + +type speechAccumulator struct { + writer *oggwriter.OggWriter + file string + lastAudioAt time.Time + mu sync.Mutex + closed bool + chatID string + speakerID string + sessionID string + channel string +} + +func (a *speechAccumulator) Push(chunk bus.AudioChunk) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.closed { + return + } + + a.lastAudioAt = time.Now() + + pkt := &rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: uint16(chunk.Sequence), + Timestamp: chunk.Timestamp, + SSRC: 1, // Stable arbitrary dummy + }, + Payload: chunk.Data, + } + + if err := a.writer.WriteRTP(pkt); err != nil { + logger.ErrorCF("voice-agent", "Failed to write RTP", map[string]any{"error": err}) + } +} + +func (a *speechAccumulator) Close() { + a.mu.Lock() + defer a.mu.Unlock() + if !a.closed { + a.writer.Close() + a.closed = true + } +} + +type Agent struct { + bus *bus.MessageBus + transcriber Transcriber + + mu sync.Mutex + sessions map[string]*speechAccumulator // keyed by sessionID_speakerID +} + +func NewAgent(mb *bus.MessageBus, t Transcriber) *Agent { + return &Agent{ + bus: mb, + transcriber: t, + sessions: make(map[string]*speechAccumulator), + } +} + +func (a *Agent) Start(ctx context.Context) { + logger.InfoCF("voice-agent", "Started Voice Agent orchestrator", nil) + go a.listenChunks(ctx) + go a.vadTick(ctx) + + // Cleanup sessions on shutdown + go func() { + <-ctx.Done() + a.mu.Lock() + for key, acc := range a.sessions { + acc.Close() + os.Remove(acc.file) + delete(a.sessions, key) + } + a.mu.Unlock() + logger.InfoCF("voice-agent", "Cleaned up voice sessions on shutdown", nil) + }() +} + +func (a *Agent) listenChunks(ctx context.Context) { + chunks := a.bus.AudioChunksChan() + for { + select { + case <-ctx.Done(): + return + case chunk, ok := <-chunks: + if !ok { + return + } + a.handleChunk(chunk) + } + } +} + +func (a *Agent) handleChunk(chunk bus.AudioChunk) { + // Only accept Opus-encoded audio + if chunk.Format != "opus" { + logger.DebugCF("voice-agent", "Ignoring unsupported audio format", map[string]any{"format": chunk.Format}) + return + } + + key := fmt.Sprintf("%s_%s", chunk.SessionID, chunk.SpeakerID) + + a.mu.Lock() + acc, exists := a.sessions[key] + if !exists { + filename := filepath.Join(os.TempDir(), fmt.Sprintf("voice_%s_%d.ogg", key, time.Now().UnixNano())) + writer, err := oggwriter.New(filename, uint32(chunk.SampleRate), uint16(chunk.Channels)) + if err != nil { + a.mu.Unlock() + logger.ErrorCF("voice-agent", "Failed to create OggWriter", map[string]any{"error": err}) + return + } + + acc = &speechAccumulator{ + writer: writer, + file: filename, + lastAudioAt: time.Now(), + chatID: chunk.ChatID, + speakerID: chunk.SpeakerID, + sessionID: chunk.SessionID, + channel: chunk.Channel, + } + a.sessions[key] = acc + logger.DebugCF("voice-agent", "Started accumulating voice", map[string]any{"key": key, "file": filename}) + } + a.mu.Unlock() + + acc.Push(chunk) +} + +func (a *Agent) vadTick(ctx context.Context) { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + a.checkSilence(ctx) + } + } +} + +func (a *Agent) checkSilence(ctx context.Context) { + a.mu.Lock() + now := time.Now() + var finished []*speechAccumulator + + for key, acc := range a.sessions { + acc.mu.Lock() + last := acc.lastAudioAt + acc.mu.Unlock() + + if now.Sub(last) > 1500*time.Millisecond { + acc.Close() + delete(a.sessions, key) + finished = append(finished, acc) + } + } + a.mu.Unlock() + + for _, acc := range finished { + go a.processUtterance(ctx, acc) + } +} + +func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) { + defer os.Remove(acc.file) + + logger.InfoCF("voice-agent", "User finished speaking, transcribing...", map[string]any{"file": acc.file}) + + if a.transcriber == nil { + logger.ErrorCF("voice-agent", "No STT configured!", nil) + return + } + + res, err := a.transcriber.Transcribe(ctx, acc.file) + if err != nil { + logger.ErrorCF("voice-agent", "Transcription failed", map[string]any{"error": err}) + return + } + + if res.Text == "" { + logger.DebugCF("voice-agent", "Ignored empty transcription", map[string]any{"file": acc.file}) + return + } + + logger.InfoCF("voice-agent", "Transcription result", map[string]any{"text": res.Text, "duration": res.Duration}) + + channelType := acc.channel + if channelType == "" { + channelType = "discord" // fallback for legacy chunks + } + + text := strings.ToLower(strings.TrimSpace(res.Text)) + if strings.Contains(text, "leave the voice channel") || strings.Contains(text, "leave voice") || + strings.Contains(text, "disconnect voice") || strings.Contains(text, "leave the channel") || + strings.Contains(text, "leave channel") { + logger.InfoCF("voice-agent", "Voice command triggered: leave", nil) + if err := a.bus.PublishVoiceControl(ctx, bus.VoiceControl{ + SessionID: acc.sessionID, + Type: "command", + Action: "leave", + }); err != nil { + logger.ErrorCF("voice-agent", "Failed to publish leave control", map[string]any{"error": err}) + } + if err := a.bus.PublishOutbound(ctx, bus.OutboundMessage{ + Channel: channelType, + ChatID: acc.chatID, + Content: "Goodbye! Leaving the voice channel.", + }); err != nil { + logger.ErrorCF("voice-agent", "Failed to publish goodbye message", map[string]any{"error": err}) + } + return + } + + oralPrompt := "\n\n[SYSTEM]: The user just spoke this to you over voice chat. Please reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally." + + if err := a.bus.PublishInbound(ctx, bus.InboundMessage{ + Channel: channelType, + SenderID: acc.speakerID, + ChatID: acc.chatID, + Content: res.Text + oralPrompt, + Peer: bus.Peer{Kind: "channel", ID: acc.chatID}, + Metadata: map[string]string{ + "is_voice": "true", + }, + }); err != nil { + logger.ErrorCF("voice-agent", "Failed to publish inbound message", map[string]any{"error": err}) + } +} diff --git a/pkg/audio/asr/agent_test.go b/pkg/audio/asr/agent_test.go new file mode 100644 index 000000000..cc1b008a4 --- /dev/null +++ b/pkg/audio/asr/agent_test.go @@ -0,0 +1,196 @@ +package asr + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/pion/webrtc/v3/pkg/media/oggwriter" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +type fakeTranscriber struct { + text string + err error + lastPath string +} + +func (f *fakeTranscriber) Name() string { return "fake" } + +func (f *fakeTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) { + f.lastPath = audioFilePath + if f.err != nil { + return nil, f.err + } + return &TranscriptionResponse{Text: f.text}, nil +} + +func waitForFileRemoval(t *testing.T, path string, timeout time.Duration) { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if _, err := os.Stat(path); os.IsNotExist(err) { + return + } + time.Sleep(10 * time.Millisecond) + } + if _, err := os.Stat(path); err == nil { + t.Fatalf("expected file to be removed: %s", path) + } +} + +func TestAgentHandleChunkCreatesSession(t *testing.T) { + t.Parallel() + + mb := bus.NewMessageBus() + defer mb.Close() + + agent := NewAgent(mb, &fakeTranscriber{}) + + chunk := bus.AudioChunk{ + SessionID: "sess", + SpeakerID: "speaker", + ChatID: "chat", + Channel: "discord", + Sequence: 1, + Timestamp: 1, + SampleRate: 48000, + Channels: 2, + Format: "opus", + Data: []byte{0xF8, 0xFF, 0xFE}, + } + + agent.handleChunk(chunk) + + key := "sess_speaker" + agent.mu.Lock() + acc, ok := agent.sessions[key] + agent.mu.Unlock() + if !ok { + t.Fatal("expected session to be created") + } + + acc.Close() + _ = os.Remove(acc.file) +} + +func TestAgentHandleChunkIgnoresUnsupportedFormat(t *testing.T) { + t.Parallel() + + mb := bus.NewMessageBus() + defer mb.Close() + + agent := NewAgent(mb, &fakeTranscriber{}) + + chunk := bus.AudioChunk{Format: "pcm"} + agent.handleChunk(chunk) + + agent.mu.Lock() + count := len(agent.sessions) + agent.mu.Unlock() + if count != 0 { + t.Fatalf("expected no sessions, got %d", count) + } +} + +func TestAgentProcessUtteranceLeaveCommand(t *testing.T) { + t.Parallel() + + mb := bus.NewMessageBus() + defer mb.Close() + + tr := &fakeTranscriber{text: "please leave the voice channel now"} + agent := NewAgent(mb, tr) + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "voice.ogg") + if err := os.WriteFile(filePath, []byte("data"), 0o600); err != nil { + t.Fatalf("write temp file: %v", err) + } + + acc := &speechAccumulator{ + file: filePath, + chatID: "chat", + speakerID: "speaker", + sessionID: "sess", + channel: "discord", + } + + agent.processUtterance(context.Background(), acc) + + select { + case ctrl := <-mb.VoiceControlsChan(): + if ctrl.Action != "leave" || ctrl.Type != "command" || ctrl.SessionID != "sess" { + t.Fatalf("unexpected voice control: %#v", ctrl) + } + case <-time.After(250 * time.Millisecond): + t.Fatal("expected voice control publish") + } + + select { + case out := <-mb.OutboundChan(): + if !strings.Contains(out.Content, "Leaving the voice channel") { + t.Fatalf("unexpected outbound content: %q", out.Content) + } + case <-time.After(250 * time.Millisecond): + t.Fatal("expected outbound publish") + } + + if _, err := os.Stat(filePath); !os.IsNotExist(err) { + t.Fatalf("expected temp file to be removed") + } +} + +func TestAgentCheckSilencePublishesInboundAndCleansUp(t *testing.T) { + t.Parallel() + + mb := bus.NewMessageBus() + defer mb.Close() + + tr := &fakeTranscriber{text: "hello there"} + agent := NewAgent(mb, tr) + + filePath := filepath.Join(t.TempDir(), "voice.ogg") + writer, err := oggwriter.New(filePath, 48000, 2) + if err != nil { + t.Fatalf("create ogg writer: %v", err) + } + + acc := &speechAccumulator{ + writer: writer, + file: filePath, + lastAudioAt: time.Now().Add(-2 * time.Second), + chatID: "chat", + speakerID: "speaker", + sessionID: "sess", + channel: "slack", + } + + agent.mu.Lock() + agent.sessions["sess_speaker"] = acc + agent.mu.Unlock() + + agent.checkSilence(context.Background()) + + select { + case msg := <-mb.InboundChan(): + if msg.Channel != "slack" { + t.Fatalf("unexpected inbound channel: %q", msg.Channel) + } + if !strings.Contains(msg.Content, "hello there") { + t.Fatalf("unexpected inbound content: %q", msg.Content) + } + if msg.Metadata["is_voice"] != "true" { + t.Fatalf("expected is_voice metadata, got %#v", msg.Metadata) + } + case <-time.After(500 * time.Millisecond): + t.Fatal("expected inbound publish") + } + + waitForFileRemoval(t, filePath, 500*time.Millisecond) +} diff --git a/pkg/audio/asr/asr.go b/pkg/audio/asr/asr.go new file mode 100644 index 000000000..d15dc3f09 --- /dev/null +++ b/pkg/audio/asr/asr.go @@ -0,0 +1,131 @@ +package asr + +import ( + "context" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +type Transcriber interface { + Name() string + Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) +} + +type TranscriptionResponse struct { + Text string `json:"text"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` +} + +func supportsAudioTranscription(model string) bool { + protocol, _ := providers.ExtractProtocol(model) + + switch protocol { + case "openai", "azure", "azure-openai", + "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", + "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", + "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl", + "qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita", + "coding-plan", "alibaba-coding", "qwen-coding": + // These protocols all go through the OpenAI-compatible or Azure provider path in + // providers.CreateProviderFromConfig, so they are the only ones that can supply + // the audio media payload shape expected by NewAudioModelTranscriber. + + // TODO: Further restrict this by modelID, since not every model under these + // protocols supports audio transcription. + return true + default: + return false + } +} + +func supportsWhisperTranscription(model string) bool { + protocol, _ := providers.ExtractProtocol(model) + + switch protocol { + case "openai", "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", + "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", + "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl", + "qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita", + "coding-plan", "alibaba-coding", "qwen-coding", "mimo": + return true + default: + return false + } +} + +func whisperModelID(modelCfg *config.ModelConfig) string { + if modelCfg == nil || modelCfg.APIKey() == "" { + return "" + } + + if !supportsWhisperTranscription(modelCfg.Model) { + return "" + } + + _, modelID := providers.ExtractProtocol(strings.TrimSpace(modelCfg.Model)) + if strings.Contains(strings.ToLower(modelID), "whisper") { + return modelID + } + return "" +} + +func transcriberFromModelConfig(modelCfg *config.ModelConfig) Transcriber { + if modelCfg == nil { + return nil + } + + protocol, _ := providers.ExtractProtocol(modelCfg.Model) + if protocol == "elevenlabs" && modelCfg.APIKey() != "" { + return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase) + } + if modelID := whisperModelID(modelCfg); modelID != "" { + return NewWhisperTranscriber(modelCfg) + } + if supportsAudioTranscription(modelCfg.Model) { + return NewAudioModelTranscriber(modelCfg) + } + return nil +} + +func fallbackTranscriberFromModelConfig(modelCfg *config.ModelConfig) Transcriber { + if modelCfg == nil { + return nil + } + + protocol, _ := providers.ExtractProtocol(modelCfg.Model) + if protocol == "elevenlabs" && modelCfg.APIKey() != "" { + return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase) + } + if modelID := whisperModelID(modelCfg); modelID != "" { + return NewWhisperTranscriber(modelCfg) + } + return nil +} + +// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or +// nil if no supported transcription provider is configured. +func DetectTranscriber(cfg *config.Config) Transcriber { + if cfg == nil { + return nil + } + + if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" { + modelCfg, err := cfg.GetModelConfig(modelName) + if err == nil { + if tr := transcriberFromModelConfig(modelCfg); tr != nil { + return tr + } + } + } + + // Fall back to compatibility scanning for legacy auto-detected ASR providers. + for _, mc := range cfg.ModelList { + if tr := fallbackTranscriberFromModelConfig(mc); tr != nil { + return tr + } + } + return nil +} diff --git a/pkg/voice/transcriber_test.go b/pkg/audio/asr/asr_test.go similarity index 67% rename from pkg/voice/transcriber_test.go rename to pkg/audio/asr/asr_test.go index 3e71ff13a..0970d69f4 100644 --- a/pkg/voice/transcriber_test.go +++ b/pkg/audio/asr/asr_test.go @@ -1,4 +1,4 @@ -package voice +package asr import ( "testing" @@ -33,26 +33,68 @@ func TestDetectTranscriber(t *testing.T) { wantName: "audio-model", }, { - name: "groq via model list", + name: "voice model name alias selects elevenlabs transcriber", + cfg: &config.Config{ + Voice: config.VoiceConfig{ModelName: "my-asr-model"}, + ModelList: []*config.ModelConfig{ + { + ModelName: "my-asr-model", + Model: "elevenlabs/scribe_v1", + APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test"), + }, + }, + }, + wantName: "elevenlabs", + }, + { + name: "voice model name alias selects whisper transcriber for groq", + cfg: &config.Config{ + Voice: config.VoiceConfig{ModelName: "my-asr-model"}, + ModelList: []*config.ModelConfig{ + { + ModelName: "my-asr-model", + Model: "groq/whisper-large-v3", + APIKeys: config.SimpleSecureStrings("sk-groq-model"), + }, + }, + }, + wantName: "whisper", + }, + { + name: "openai whisper alias selects whisper transcriber", + cfg: &config.Config{ + Voice: config.VoiceConfig{ModelName: "my-asr-model"}, + ModelList: []*config.ModelConfig{ + { + ModelName: "my-asr-model", + Model: "openai/whisper-1", + APIKeys: config.SimpleSecureStrings("sk-openai-model"), + }, + }, + }, + wantName: "whisper", + }, + { + name: "whisper via model list fallback", cfg: &config.Config{ ModelList: []*config.ModelConfig{ {ModelName: "openai", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("sk-openai")}, { ModelName: "groq", - Model: "groq/llama-3.3-70b", + Model: "groq/whisper-large-v3-turbo", APIKeys: config.SimpleSecureStrings("sk-groq-model"), }, }, }, - wantName: "groq", + wantName: "whisper", }, { - name: "voice model name selects non-gemini audio model transcriber", + name: "voice model name alias selects non-gemini audio model transcriber", cfg: &config.Config{ - Voice: config.VoiceConfig{ModelName: "voice-openai-audio"}, + Voice: config.VoiceConfig{ModelName: "my-asr-model"}, ModelList: []*config.ModelConfig{ { - ModelName: "voice-openai-audio", + ModelName: "my-asr-model", Model: "openai/gpt-4o-audio-preview", APIKeys: config.SimpleSecureStrings("sk-openai"), }, @@ -92,7 +134,7 @@ func TestDetectTranscriber(t *testing.T) { name: "groq model list entry without key is skipped", cfg: &config.Config{ ModelList: []*config.ModelConfig{ - {Model: "groq/llama-3.3-70b"}, + {Model: "groq/whisper-large-v3"}, }, }, wantNil: true, @@ -103,12 +145,12 @@ func TestDetectTranscriber(t *testing.T) { ModelList: []*config.ModelConfig{ { ModelName: "groq", - Model: "groq/llama-3.3-70b", + Model: "groq/whisper-large-v3", APIKeys: config.SimpleSecureStrings("sk-groq-model"), }, }, }, - wantName: "groq", + wantName: "whisper", }, { name: "missing voice model name config returns nil", @@ -127,15 +169,17 @@ func TestDetectTranscriber(t *testing.T) { { name: "elevenlabs voice config key", cfg: &config.Config{ - Voice: config.VoiceConfig{ElevenLabsAPIKey: "sk_elevenlabs_test"}, + ModelList: []*config.ModelConfig{ + {Model: "elevenlabs/scribe_v1", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")}, + }, }, wantName: "elevenlabs", }, { name: "elevenlabs takes priority over groq model list", cfg: &config.Config{ - Voice: config.VoiceConfig{ElevenLabsAPIKey: "sk_elevenlabs_test"}, ModelList: []*config.ModelConfig{ + {Model: "elevenlabs/scribe_v1", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")}, { ModelName: "groq", Model: "groq/llama-3.3-70b", @@ -149,10 +193,10 @@ func TestDetectTranscriber(t *testing.T) { name: "voice model name takes priority over elevenlabs", cfg: &config.Config{ Voice: config.VoiceConfig{ - ModelName: "voice-gemini", - ElevenLabsAPIKey: "sk_elevenlabs_test", + ModelName: "voice-gemini", }, ModelList: []*config.ModelConfig{ + {Model: "elevenlabs", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")}, { ModelName: "voice-gemini", Model: "gemini/gemini-2.5-flash", diff --git a/pkg/voice/audio_model_transcriber.go b/pkg/audio/asr/audio_model_transcriber.go similarity index 99% rename from pkg/voice/audio_model_transcriber.go rename to pkg/audio/asr/audio_model_transcriber.go index f3ca81961..e8ded15dd 100644 --- a/pkg/voice/audio_model_transcriber.go +++ b/pkg/audio/asr/audio_model_transcriber.go @@ -1,4 +1,4 @@ -package voice +package asr import ( "context" diff --git a/pkg/voice/audio_model_transcriber_test.go b/pkg/audio/asr/audio_model_transcriber_test.go similarity index 99% rename from pkg/voice/audio_model_transcriber_test.go rename to pkg/audio/asr/audio_model_transcriber_test.go index c33e3bf97..5aaa82061 100644 --- a/pkg/voice/audio_model_transcriber_test.go +++ b/pkg/audio/asr/audio_model_transcriber_test.go @@ -1,4 +1,4 @@ -package voice +package asr import ( "context" diff --git a/pkg/voice/elevenlabs_transcriber.go b/pkg/audio/asr/elevenlabs_transcriber.go similarity index 96% rename from pkg/voice/elevenlabs_transcriber.go rename to pkg/audio/asr/elevenlabs_transcriber.go index 93db10f8d..452b9512d 100644 --- a/pkg/voice/elevenlabs_transcriber.go +++ b/pkg/audio/asr/elevenlabs_transcriber.go @@ -1,4 +1,4 @@ -package voice +package asr import ( "bytes" @@ -23,12 +23,16 @@ type ElevenLabsTranscriber struct { httpClient *http.Client } -func NewElevenLabsTranscriber(apiKey string) *ElevenLabsTranscriber { +func NewElevenLabsTranscriber(apiKey, apiBase string) *ElevenLabsTranscriber { logger.DebugCF("voice", "Creating ElevenLabs transcriber", map[string]any{"has_api_key": apiKey != ""}) + if apiBase == "" { + apiBase = "https://api.elevenlabs.io" + } + return &ElevenLabsTranscriber{ apiKey: apiKey, - apiBase: "https://api.elevenlabs.io", + apiBase: apiBase, httpClient: &http.Client{ Timeout: 120 * time.Second, }, diff --git a/pkg/voice/elevenlabs_transcriber_test.go b/pkg/audio/asr/elevenlabs_transcriber_test.go similarity index 91% rename from pkg/voice/elevenlabs_transcriber_test.go rename to pkg/audio/asr/elevenlabs_transcriber_test.go index 78be8958a..fa80110be 100644 --- a/pkg/voice/elevenlabs_transcriber_test.go +++ b/pkg/audio/asr/elevenlabs_transcriber_test.go @@ -1,4 +1,4 @@ -package voice +package asr import ( "context" @@ -14,7 +14,7 @@ import ( var _ Transcriber = (*ElevenLabsTranscriber)(nil) func TestElevenLabsTranscriberName(t *testing.T) { - tr := NewElevenLabsTranscriber("sk_test") + tr := NewElevenLabsTranscriber("sk_test", "") if got := tr.Name(); got != "elevenlabs" { t.Errorf("Name() = %q, want %q", got, "elevenlabs") } @@ -43,7 +43,7 @@ func TestElevenLabsTranscribe(t *testing.T) { })) defer srv.Close() - tr := NewElevenLabsTranscriber("sk_test") + tr := NewElevenLabsTranscriber("sk_test", "") tr.apiBase = srv.URL resp, err := tr.Transcribe(context.Background(), audioPath) @@ -64,7 +64,7 @@ func TestElevenLabsTranscribe(t *testing.T) { })) defer srv.Close() - tr := NewElevenLabsTranscriber("sk_bad") + tr := NewElevenLabsTranscriber("sk_bad", "") tr.apiBase = srv.URL _, err := tr.Transcribe(context.Background(), audioPath) @@ -74,7 +74,7 @@ func TestElevenLabsTranscribe(t *testing.T) { }) t.Run("missing file", func(t *testing.T) { - tr := NewElevenLabsTranscriber("sk_test") + tr := NewElevenLabsTranscriber("sk_test", "") _, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg")) if err == nil { t.Fatal("expected error for missing file, got nil") diff --git a/pkg/audio/asr/whisper_transcriber.go b/pkg/audio/asr/whisper_transcriber.go new file mode 100644 index 000000000..406710a8a --- /dev/null +++ b/pkg/audio/asr/whisper_transcriber.go @@ -0,0 +1,245 @@ +package asr + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/utils" +) + +type WhisperTranscriber struct { + apiKey string + apiBase string + modelID string + providerName string + httpClient *http.Client +} + +func NewWhisperTranscriber(modelCfg *config.ModelConfig) *WhisperTranscriber { + if modelCfg == nil { + return nil + } + + protocol, modelID := providers.ExtractProtocol(modelCfg.Model) + if modelID == "" { + modelID = strings.TrimSpace(modelCfg.Model) + } + + tr := newWhisperTranscriber( + modelCfg.APIKey(), + providers.ResolveAPIBase(modelCfg), + modelID, + protocol, + ) + if tr == nil { + return nil + } + + logger.DebugCF("voice", "Creating whisper transcriber", map[string]any{ + "api_base": tr.apiBase, + "has_key": tr.apiKey != "", + "model": tr.modelID, + "provider": tr.providerName, + }) + return tr +} + +func NewGroqTranscriber(apiKey, modelID string) *WhisperTranscriber { + return newWhisperTranscriber(apiKey, "https://api.groq.com/openai/v1", modelID, "groq") +} + +func newWhisperTranscriber(apiKey, apiBase, modelID, providerName string) *WhisperTranscriber { + if modelID == "" { + return nil + } + if providerName == "" { + providerName = "whisper" + } + return &WhisperTranscriber{ + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + modelID: modelID, + providerName: providerName, + httpClient: &http.Client{ + Timeout: 60 * time.Second, + }, + } +} + +func (t *WhisperTranscriber) transcriptionURL() string { + base := strings.TrimRight(t.apiBase, "/") + if strings.HasSuffix(base, "/audio/transcriptions") { + return base + } + return base + "/audio/transcriptions" +} + +func (t *WhisperTranscriber) TranscribeData( + ctx context.Context, + data []byte, + filename string, +) (*TranscriptionResponse, error) { + logger.InfoCF("voice", "Starting whisper transcription from memory", map[string]any{ + "bytes": len(data), + "filename": filename, + "model": t.modelID, + "provider": t.providerName, + }) + + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + part, err := writer.CreateFormFile("file", filename) + if err != nil { + logger.ErrorCF("voice", "Failed to create whisper form file", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to create form file: %w", err) + } + + if _, copyErr := io.Copy(part, bytes.NewReader(data)); copyErr != nil { + logger.ErrorCF("voice", "Failed to copy whisper file content", map[string]any{"error": copyErr}) + return nil, fmt.Errorf("failed to copy file content: %w", copyErr) + } + + if err = writer.WriteField("model", t.modelID); err != nil { + logger.ErrorCF("voice", "Failed to write whisper model field", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to write model field: %w", err) + } + + if err = writer.WriteField("response_format", "json"); err != nil { + logger.ErrorCF("voice", "Failed to write whisper response_format field", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to write response_format field: %w", err) + } + + if err = writer.Close(); err != nil { + logger.ErrorCF("voice", "Failed to close whisper multipart writer", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + return t.doRequest(ctx, &requestBody, writer.FormDataContentType(), int64(len(data))) +} + +func (t *WhisperTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) { + logger.InfoCF("voice", "Starting whisper transcription", map[string]any{ + "audio_file": audioFilePath, + "model": t.modelID, + "provider": t.providerName, + }) + + audioFile, err := os.Open(audioFilePath) + if err != nil { + return nil, fmt.Errorf("failed to open audio file %s: %w", audioFilePath, err) + } + defer audioFile.Close() + + fileInfo, err := audioFile.Stat() + if err != nil { + return nil, fmt.Errorf("failed to stat audio file %s: %w", audioFilePath, err) + } + + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath)) + if err != nil { + return nil, fmt.Errorf("failed to create form file: %w", err) + } + + if _, copyErr := io.Copy(part, audioFile); copyErr != nil { + return nil, fmt.Errorf("failed to copy audio data: %w", copyErr) + } + + if err = writer.WriteField("model", t.modelID); err != nil { + return nil, fmt.Errorf("failed to write model field: %w", err) + } + + if err = writer.WriteField("response_format", "json"); err != nil { + return nil, fmt.Errorf("failed to write response_format field: %w", err) + } + + if err = writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + return t.doRequest(ctx, &requestBody, writer.FormDataContentType(), fileInfo.Size()) +} + +func (t *WhisperTranscriber) doRequest( + ctx context.Context, + requestBody *bytes.Buffer, + contentType string, + fileSize int64, +) (*TranscriptionResponse, error) { + url := t.transcriptionURL() + req, err := http.NewRequestWithContext(ctx, "POST", url, requestBody) + if err != nil { + logger.ErrorCF("voice", "Failed to create whisper request", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", contentType) + if t.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+t.apiKey) + } + + logger.DebugCF("voice", "Sending whisper transcription request", map[string]any{ + "file_size_bytes": fileSize, + "model": t.modelID, + "provider": t.providerName, + "request_size_bytes": requestBody.Len(), + "url": url, + }) + + resp, err := t.httpClient.Do(req) + if err != nil { + logger.ErrorCF("voice", "Failed to send whisper request", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.ErrorCF("voice", "Failed to read whisper response", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + logger.ErrorCF("voice", "Whisper API error", map[string]any{ + "provider": t.providerName, + "response": string(body), + "status_code": resp.StatusCode, + }) + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + var result TranscriptionResponse + if err := json.Unmarshal(body, &result); err != nil { + logger.ErrorCF("voice", "Failed to unmarshal whisper response", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + logger.InfoCF("voice", "Whisper transcription completed successfully", map[string]any{ + "duration_seconds": result.Duration, + "language": result.Language, + "provider": t.providerName, + "text_length": len(result.Text), + "transcription_preview": utils.Truncate(result.Text, 50), + }) + + return &result, nil +} + +func (t *WhisperTranscriber) Name() string { + return "whisper" +} diff --git a/pkg/audio/asr/whisper_transcriber_test.go b/pkg/audio/asr/whisper_transcriber_test.go new file mode 100644 index 000000000..a2a5178d1 --- /dev/null +++ b/pkg/audio/asr/whisper_transcriber_test.go @@ -0,0 +1,102 @@ +package asr + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestWhisperTranscriberTranscribeDataUsesConfiguredModel(t *testing.T) { + var gotModel string + var gotPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + if got := r.Header.Get("Authorization"); got != "Bearer sk-openai-test" { + t.Errorf("Authorization = %q, want %q", got, "Bearer sk-openai-test") + } + + reader, err := r.MultipartReader() + if err != nil { + t.Fatalf("MultipartReader() error: %v", err) + } + + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("NextPart() error: %v", err) + } + + data, err := io.ReadAll(part) + if err != nil { + t.Fatalf("ReadAll() error: %v", err) + } + + if part.FormName() == "model" { + gotModel = string(data) + } + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(TranscriptionResponse{Text: "hello from whisper"}); err != nil { + t.Fatalf("Encode() error: %v", err) + } + })) + defer server.Close() + + tr := NewWhisperTranscriber(&config.ModelConfig{ + Model: "openai/whisper-1", + APIBase: server.URL, + APIKeys: config.SimpleSecureStrings("sk-openai-test"), + }) + tr.httpClient = server.Client() + + resp, err := tr.TranscribeData(context.Background(), []byte("audio"), "clip.ogg") + if err != nil { + t.Fatalf("TranscribeData() error: %v", err) + } + if resp.Text != "hello from whisper" { + t.Errorf("Text = %q, want %q", resp.Text, "hello from whisper") + } + if gotModel != "whisper-1" { + t.Errorf("model field = %q, want %q", gotModel, "whisper-1") + } + if gotPath != "/audio/transcriptions" { + t.Errorf("path = %q, want %q", gotPath, "/audio/transcriptions") + } +} + +func TestWhisperTranscriberUsesEndpointAPIBaseWithoutDoubleAppend(t *testing.T) { + var gotPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(TranscriptionResponse{Text: "ok"}); err != nil { + t.Fatalf("Encode() error: %v", err) + } + })) + defer server.Close() + + tr := NewWhisperTranscriber(&config.ModelConfig{ + Model: "groq/whisper-large-v3", + APIBase: server.URL + "/audio/transcriptions", + APIKeys: config.SimpleSecureStrings("sk-groq-test"), + }) + tr.httpClient = server.Client() + + if _, err := tr.TranscribeData(context.Background(), []byte("audio"), "clip.ogg"); err != nil { + t.Fatalf("TranscribeData() error: %v", err) + } + if gotPath != "/audio/transcriptions" { + t.Errorf("path = %q, want %q", gotPath, "/audio/transcriptions") + } +} diff --git a/pkg/audio/ogg.go b/pkg/audio/ogg.go new file mode 100644 index 000000000..f0055a574 --- /dev/null +++ b/pkg/audio/ogg.go @@ -0,0 +1,57 @@ +package audio + +import ( + "bytes" + "fmt" + "io" +) + +// DecodeOggOpus reads an Ogg format stream and extracts individual Opus payloads. +// It calls onFrame for every complete Opus frame found in the stream. +func DecodeOggOpus(r io.Reader, onFrame func([]byte) error) error { + var packet bytes.Buffer + header := make([]byte, 27) + segment := make([]byte, 255) + + for { + if _, err := io.ReadFull(r, header); err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + return nil + } + return fmt.Errorf("failed to read ogg header: %w", err) + } + if string(header[:4]) != "OggS" { + return fmt.Errorf("invalid ogg magic string") + } + + pageSegments := int(header[26]) + segmentTable := make([]byte, pageSegments) + if _, err := io.ReadFull(r, segmentTable); err != nil { + return fmt.Errorf("failed to read segment table: %w", err) + } + + for _, lacing := range segmentTable { + if _, err := io.ReadFull(r, segment[:lacing]); err != nil { + return fmt.Errorf("failed to read segment data: %w", err) + } + + packet.Write(segment[:lacing]) + + // If lacing is less than 255, the packet is complete + if lacing < 255 { + if packet.Len() > 0 { + packetBytes := packet.Bytes() + // Ignore Ogg Opus headers + if !bytes.HasPrefix(packetBytes, []byte("OpusHead")) && + !bytes.HasPrefix(packetBytes, []byte("OpusTags")) { + if err := onFrame(packetBytes); err != nil { + return err + } + } + // Start new packet + packet.Reset() + } + } + } + } +} diff --git a/pkg/audio/ogg_test.go b/pkg/audio/ogg_test.go new file mode 100644 index 000000000..8d5e5ac2a --- /dev/null +++ b/pkg/audio/ogg_test.go @@ -0,0 +1,146 @@ +package audio + +import ( + "bytes" + "reflect" + "strings" + "testing" +) + +// buildOggPage helper creates an Ogg page for testing. +// lacingVals specifies the segment table, and data is the payload. +func buildOggPage(lacingVals []byte, data []byte) []byte { + var buf bytes.Buffer + // 27-byte Ogg header + header := make([]byte, 27) + copy(header[:4], "OggS") + header[5] = 0 // type flag + // For testing, we only care about OggS magic and page_segments (byte 26) + header[26] = byte(len(lacingVals)) + buf.Write(header) + buf.Write(lacingVals) + buf.Write(data) + return buf.Bytes() +} + +func TestDecodeOggOpus_ValidParsing(t *testing.T) { + var b bytes.Buffer + + // Packet 1: Single segment, length 50 + pkt1 := bytes.Repeat([]byte{1}, 50) + // Packet 2: Multi-segment (255 + 10 = 265 bytes) + pkt2Part1 := bytes.Repeat([]byte{2}, 255) + pkt2Part2 := bytes.Repeat([]byte{2}, 10) + // Packet 3: Continued across pages. Page 1 gets 255, Page 2 gets 20. Total 275 bytes. + pkt3Part1 := bytes.Repeat([]byte{3}, 255) + pkt3Part2 := bytes.Repeat([]byte{3}, 20) + + // Page 1: OpusHead (skip), OpusTags (skip), pkt1, pkt2, pkt3Part1 + page1Lacing := []byte{8, 8, 50, 255, 10, 255} + page1Data := bytes.Join([][]byte{ + []byte("OpusHead"), + []byte("OpusTags"), + pkt1, + pkt2Part1, pkt2Part2, + pkt3Part1, + }, nil) + + // Page 2: pkt3Part2, pkt4 (length 10) + pkt4 := bytes.Repeat([]byte{4}, 10) + page2Lacing := []byte{20, 10} + page2Data := bytes.Join([][]byte{ + pkt3Part2, + pkt4, + }, nil) + + b.Write(buildOggPage(page1Lacing, page1Data)) + b.Write(buildOggPage(page2Lacing, page2Data)) + + var frames [][]byte + err := DecodeOggOpus(&b, func(frame []byte) error { + // making a copy to store as DecodeOggOpus might reuse backing array + cpy := make([]byte, len(frame)) + copy(cpy, frame) + frames = append(frames, cpy) + return nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedFrames := [][]byte{ + pkt1, + append(pkt2Part1, pkt2Part2...), + append(pkt3Part1, pkt3Part2...), + pkt4, + } + + if len(frames) != len(expectedFrames) { + t.Fatalf("expected %d frames, got %d", len(expectedFrames), len(frames)) + } + + for i, expected := range expectedFrames { + if !reflect.DeepEqual(frames[i], expected) { + t.Errorf("frame %d mismatch:\nexp: %v\ngot: %v", i, expected, frames[i]) + } + } +} + +func TestDecodeOggOpus_Errors(t *testing.T) { + tests := []struct { + name string + data []byte + errContains string + }{ + { + name: "invalid magic string", + data: []byte( + "OggX\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + ), + errContains: "invalid ogg magic string", + }, + { + name: "short header", + data: []byte("Ogg"), + errContains: "failed to read ogg header", + }, + { + name: "eof in segment table", + data: func() []byte { + h := make([]byte, 27) + copy(h, "OggS") + h[26] = 5 // expects 5 bytes of segment table, but none provided + return h + }(), + errContains: "failed to read segment table", + }, + { + name: "eof in segment data", + data: func() []byte { + h := make([]byte, 27, 28) + copy(h, "OggS") + h[26] = 1 + return append(h, 100) // expects 100 bytes of data, but none provided + }(), + errContains: "failed to read segment data", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := DecodeOggOpus(bytes.NewReader(tt.data), func(b []byte) error { return nil }) + if tt.name == "short header" { + if err != nil { + t.Errorf("expected no error (io.EOF/ErrUnexpectedEOF swallowed), got %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.errContains) + } + if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("expected error to contain %q, got: %q", tt.errContains, err.Error()) + } + }) + } +} diff --git a/pkg/audio/sentence.go b/pkg/audio/sentence.go new file mode 100644 index 000000000..89b9ac03e --- /dev/null +++ b/pkg/audio/sentence.go @@ -0,0 +1,96 @@ +package audio + +import ( + "strings" + "unicode" +) + +// SplitSentences splits text into sentence-sized chunks suitable for TTS synthesis. +// It splits on sentence-ending punctuation (.!?\n, as well as CJK 。, !, ?) while avoiding false splits +// on decimal numbers. Very short fragments are merged with +// the next sentence to prevent choppy playback. +func SplitSentences(text string) []string { + if text == "" { + return nil + } + + var sentences []string + var current strings.Builder + runes := []rune(text) + + for i := 0; i < len(runes); i++ { + r := runes[i] + if r == '\n' { + s := strings.TrimSpace(current.String()) + if s != "" { + sentences = append(sentences, s) + } + current.Reset() + continue + } + + current.WriteRune(r) + + if r == '.' || r == '!' || r == '?' || r == '。' || r == '!' || r == '?' { + // Avoid splitting on decimal numbers like "3.14" + if r == '.' && i > 0 && unicode.IsDigit(runes[i-1]) && + i+1 < len(runes) && unicode.IsDigit(runes[i+1]) { + continue + } + + // Consume contiguous punctuation clusters (e.g., "..." or "?!"). + for i+1 < len(runes) && (runes[i+1] == '.' || runes[i+1] == '!' || runes[i+1] == '?' || runes[i+1] == '。' || runes[i+1] == '!' || runes[i+1] == '?') { + i++ + current.WriteRune(runes[i]) + } + + s := strings.TrimSpace(current.String()) + if s != "" { + sentences = append(sentences, s) + } + current.Reset() + } + } + + // Flush remaining text + if s := strings.TrimSpace(current.String()); s != "" { + sentences = append(sentences, s) + } + + // Merge very short fragments with the next sentence + return mergeShorties(sentences, 15) +} + +// mergeShorties merges sentences shorter than minLen characters with the following sentence. +func mergeShorties(sentences []string, minLen int) []string { + if len(sentences) <= 1 { + return sentences + } + + var merged []string + var buf string + + for _, s := range sentences { + if buf != "" { + buf += " " + s + if len([]rune(buf)) >= minLen { + merged = append(merged, buf) + buf = "" + } + } else if len([]rune(s)) < minLen { + buf = s + } else { + merged = append(merged, s) + } + } + + if buf != "" { + if len(merged) > 0 { + merged[len(merged)-1] += " " + buf + } else { + merged = append(merged, buf) + } + } + + return merged +} diff --git a/pkg/audio/sentence_test.go b/pkg/audio/sentence_test.go new file mode 100644 index 000000000..54d69e4a6 --- /dev/null +++ b/pkg/audio/sentence_test.go @@ -0,0 +1,69 @@ +package audio + +import ( + "reflect" + "testing" +) + +func TestSplitSentences(t *testing.T) { + tests := []struct { + name string + in string + want []string + }{ + { + name: "empty input", + in: "", + want: nil, + }, + { + name: "single sentence", + in: "Hello world.", + want: []string{"Hello world."}, + }, + { + name: "decimal numbers do not split", + in: "The value is 3.14 today. Keep watching closely.", + want: []string{"The value is 3.14 today.", "Keep watching closely."}, + }, + { + name: "newline boundary", + in: "This is line number one\nThis is line number two", + want: []string{"This is line number one", "This is line number two"}, + }, + { + name: "newline with surrounding spaces", + in: " This is the first line \n This is the second line ", + want: []string{"This is the first line", "This is the second line"}, + }, + { + name: "trailing punctuation consumed", + in: "Please wait a moment... What on earth?! That is perfectly fine.", + want: []string{"Please wait a moment...", "What on earth?!", "That is perfectly fine."}, + }, + { + name: "short leading fragment merges with next", + in: "Hi. This is a longer sentence.", + want: []string{"Hi. This is a longer sentence."}, + }, + { + name: "consecutive short fragments keep merging", + in: "A. B. C. This is the real sentence.", + want: []string{"A. B. C. This is the real sentence."}, + }, + { + name: "short trailing fragment merges back", + in: "This sentence is long enough. End.", + want: []string{"This sentence is long enough. End."}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := SplitSentences(tc.in) + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("SplitSentences(%q) = %#v, want %#v", tc.in, got, tc.want) + } + }) + } +} diff --git a/pkg/audio/tts/README.md b/pkg/audio/tts/README.md new file mode 100644 index 000000000..ab8491da6 --- /dev/null +++ b/pkg/audio/tts/README.md @@ -0,0 +1,137 @@ +# TTS (Text-to-Speech) + +This package handles speech synthesis for PicoClaw. + +If you are new to TTS setup, the simplest workflow is: + +1. Add a TTS-capable entry to `model_list`. +2. Point `voice.tts_model_name` at that entry. +3. Put the API key in `.security.yml`. + +## Quick Recommendation + +For most users, these are the best starting points: + +| Provider | Why start here | +| --- | --- | +| [OpenAI](https://platform.openai.com/docs/guides/text-to-speech) | Best-supported path in PicoClaw today. The current TTS implementation is built around the OpenAI-compatible `/audio/speech` API shape, and OpenAI is the safest default. | +| [Xiaomi MiMo](https://platform.xiaomimimo.com) | A good second option if you want an OpenAI-compatible provider endpoint and are already using MiMo models in the rest of your stack. | + +## How TTS Configuration Works + +PicoClaw does not keep TTS API keys inside `voice`. + +Instead: + +- `voice.tts_model_name` selects a named entry from `model_list`. +- That `model_list` entry provides the provider, model ID, API base, and proxy settings. +- `.security.yml` stores the API key for the same named model entry. + +This is the recommended and supported configuration pattern. + +## Recommended Setup + +### Option A: OpenAI + +`config.json` + +```json +{ + "voice": { + "tts_model_name": "openai-tts" + }, + "model_list": [ + { + "model_name": "openai-tts", + "model": "openai/tts-1" + } + ] +} +``` + +`.security.yml` + +```yaml +model_list: + openai-tts: + api_keys: + - "sk-openai-your-key" +``` + +### Option B: Xiaomi MiMo + +`config.json` + +```json +{ + "voice": { + "tts_model_name": "mimo-tts" + }, + "model_list": [ + { + "model_name": "mimo-tts", + "model": "mimo/mimo-v2-tts" + } + ] +} +``` + +`.security.yml` + +```yaml +model_list: + mimo-tts: + api_keys: + - "your-mimo-key" +``` + +If you use a custom MiMo endpoint, you can also set `api_base` explicitly. Otherwise PicoClaw will use the provider default. + +## What PicoClaw Sends Today + +The current TTS runtime uses an OpenAI-compatible speech request with these defaults: + +- Endpoint: `/audio/speech` +- Response format: `opus` +- Voice: `alloy` +- Model: taken from the selected `model_list` entry + +That means: + +- `openai/tts-1` works naturally. +- Other OpenAI-compatible providers can work if they accept the same request format. +- PicoClaw currently does not expose a user-facing config field for changing the TTS voice from `alloy`. + +## How PicoClaw Chooses a TTS Provider + +`DetectTTS` resolves TTS in this order: + +1. **Preferred path**: resolve `voice.tts_model_name` against `model_list`. +2. If a matching model entry exists and has an API key, PicoClaw creates an OpenAI-compatible TTS provider using that model's settings. +3. **Fallback path**: if `voice.tts_model_name` is not set or cannot be resolved, PicoClaw scans `model_list` for the first entry whose model string contains `tts` and has an API key. + +Fallback scanning exists for compatibility. New configs should set `voice.tts_model_name` explicitly. + +## Notes About API Base Handling + +PicoClaw normalizes the configured base URL for TTS: + +- For OpenAI, a base like `https://api.openai.com` or `https://api.openai.com/v1` becomes `https://api.openai.com/v1/audio/speech`. +- For other OpenAI-compatible providers, PicoClaw preserves the configured base path and ensures it ends with `/audio/speech`. +- If `api_base` is omitted, PicoClaw uses the provider default base when the model prefix is known. + +## Common Mistakes + +- Setting `voice.tts_model_name` to a name that does not exist in `model_list`. +- Adding a TTS model but forgetting to put its API key in `.security.yml`. +- Assuming PicoClaw will automatically use provider-specific custom voices. +- Using a provider endpoint that is not compatible with the OpenAI `/audio/speech` request format. + +## Minimal Checklist + +Before testing `send_tts`, make sure: + +- `voice.tts_model_name` matches a `model_list[].model_name`. +- The matching `.security.yml` entry contains a valid API key. +- The chosen provider supports an OpenAI-compatible speech synthesis endpoint. +- Your selected model is actually a TTS-capable model. diff --git a/pkg/audio/tts/README_zh.md b/pkg/audio/tts/README_zh.md new file mode 100644 index 000000000..a48b612a9 --- /dev/null +++ b/pkg/audio/tts/README_zh.md @@ -0,0 +1,137 @@ +# TTS(文本转语音) + +这个目录负责 PicoClaw 的语音合成能力。 + +如果你是第一次配置 TTS,可以参照下面这个流程: + +1. 在 `model_list` 里添加一个支持 TTS 的模型。 +2. 用 `voice.tts_model_name` 指向这个模型。 +3. 在 `.security.yml` 里配置对应的 API Key。 + +## 快速推荐 + +对于大多数用户,建议优先从下面两种开始: + +| 提供商 | 推荐理由 | +| --- | --- | +| [OpenAI](https://platform.openai.com/docs/guides/text-to-speech) | 这是 PicoClaw 当前最稳定、最直接的 TTS 路径。当前实现就是围绕 OpenAI 兼容的 `/audio/speech` 接口格式构建的,所以 OpenAI 是最稳妥的默认选择。 | +| [Xiaomi MiMo](https://platform.xiaomimimo.com) | 由于响应速度和语音音色对于中国用户更友好,MiMo 是一个不错的第二选择。 | + +## TTS 配置是如何工作的 + +PicoClaw 不会把 TTS 的 API Key 放在 `voice` 配置里。 + +推荐方式是: + +- `voice.tts_model_name` 用来选择 `model_list` 里的某个命名模型。 +- 对应的 `model_list` 条目提供真实的 provider、model ID、`api_base` 和代理配置。 +- `.security.yml` 负责保存该模型条目的 API Key。 + +这是当前推荐且受支持的配置方式。 + +## 推荐配置方式 + +### 方案 A:OpenAI + +`config.json` + +```json +{ + "voice": { + "tts_model_name": "openai-tts" + }, + "model_list": [ + { + "model_name": "openai-tts", + "model": "openai/tts-1" + } + ] +} +``` + +`.security.yml` + +```yaml +model_list: + openai-tts: + api_keys: + - "sk-openai-your-key" +``` + +### 方案 B:Xiaomi MiMo + +`config.json` + +```json +{ + "voice": { + "tts_model_name": "mimo-tts" + }, + "model_list": [ + { + "model_name": "mimo-tts", + "model": "mimo/mimo-v2-tts" + } + ] +} +``` + +`.security.yml` + +```yaml +model_list: + mimo-tts: + api_keys: + - "your-mimo-key" +``` + +如果你使用自定义的 MiMo 接口地址,也可以显式设置 `api_base`。如果不设置,PicoClaw 会自动使用该 provider 的默认地址。 + +## PicoClaw 当前实际发送的 TTS 请求 + +当前 TTS 运行时使用的是 OpenAI 兼容的语音合成请求,并带有以下默认值: + +- Endpoint:`/audio/speech` +- 返回格式:`opus` +- Voice:`alloy` +- Model:来自你所选中的 `model_list` 条目 + +这意味着: + +- `openai/tts-1` 可以自然工作。 +- 其他 OpenAI 兼容 provider 也可能可用,前提是它们接受相同的请求格式。 +- PicoClaw 目前还没有对用户暴露一个配置项来修改 TTS voice,当前固定为 `alloy`。 + +## PicoClaw 如何选择 TTS Provider + +`DetectTTS` 会按下面顺序选择 TTS: + +1. **首选路径**:根据 `voice.tts_model_name` 在 `model_list` 中找到对应模型。 +2. 如果找到了匹配条目,并且它有 API Key,PicoClaw 就会使用这个模型条目的配置创建一个 OpenAI 兼容的 TTS provider。 +3. **回退路径**:如果没有设置 `voice.tts_model_name`,或者该名字无法解析,PicoClaw 会扫描 `model_list`,选中第一个模型字符串里包含 `tts` 且带有 API Key 的条目。 + +回退扫描只是为了兼容旧行为。新配置建议始终显式设置 `voice.tts_model_name`。 + +## 关于 API Base 的处理方式 + +PicoClaw 会对 TTS 的 `api_base` 做规范化处理: + +- 对 OpenAI 来说,像 `https://api.openai.com` 或 `https://api.openai.com/v1` 这样的地址,会自动变成 `https://api.openai.com/v1/audio/speech`。 +- 对其他 OpenAI 兼容 provider,PicoClaw 会尽量保留你提供的基础路径,只确保它最终以 `/audio/speech` 结尾。 +- 如果没有设置 `api_base`,并且模型前缀是已知 provider,PicoClaw 会自动使用该 provider 的默认地址。 + +## 常见错误 + +- `voice.tts_model_name` 指向了一个不存在的 `model_list` 名称。 +- 在 `model_list` 里定义了 TTS 模型,但忘了在 `.security.yml` 中配置对应 API Key。 +- 误以为 PicoClaw 会自动支持 provider 自定义 voice 参数。 +- 使用了不兼容 OpenAI `/audio/speech` 请求格式的接口地址。 + +## 最小检查清单 + +在测试 `send_tts` 之前,请确认: + +- `voice.tts_model_name` 能正确匹配某个 `model_list[].model_name`。 +- `.security.yml` 中对应条目已经配置了有效 API Key。 +- 你所选的 provider 支持 OpenAI 兼容的语音合成接口。 +- 你选择的模型本身确实支持 TTS。 diff --git a/pkg/audio/tts/mimo_tts.go b/pkg/audio/tts/mimo_tts.go new file mode 100644 index 000000000..a8aee6b8c --- /dev/null +++ b/pkg/audio/tts/mimo_tts.go @@ -0,0 +1,162 @@ +package tts + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +type MimoTTSProvider struct { + apiKey string + apiBase string + voice string + format string + model string + httpClient *http.Client +} + +func NewMimoTTSProvider(apiKey string, apiBase string, model string, proxyURL string) *MimoTTSProvider { + if apiBase == "" { + apiBase = "https://api.xiaomimimo.com/v1/chat/completions" + } else { + if u, err := url.Parse(apiBase); err == nil && u.Scheme != "" && u.Host != "" { + path := u.Path + if u.Host == "api.xiaomimimo.com" { + if path == "" || path == "/" || path == "/v1" || path == "/v1/" { + path = "/v1/chat/completions" + } else { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + if !strings.HasPrefix(path, "/v1/") { + path = "/v1" + strings.TrimSuffix(path, "/") + } + if !strings.HasSuffix(path, "/chat/completions") { + path = strings.TrimSuffix(path, "/") + "/chat/completions" + } + } + } else { + if !strings.HasSuffix(path, "/chat/completions") { + path = strings.TrimSuffix(path, "/") + "/chat/completions" + } + } + u.Path = path + apiBase = u.String() + } else { + if apiBase == "https://api.xiaomimimo.com/v1" { + apiBase = "https://api.xiaomimimo.com/v1/chat/completions" + } else if !strings.HasSuffix(apiBase, "/chat/completions") { + apiBase = strings.TrimSuffix(apiBase, "/") + "/chat/completions" + } + } + } + + model = strings.TrimSpace(model) + if model == "" { + model = "mimo-v2-tts" + } + + client := &http.Client{Timeout: 60 * time.Second} + if proxyURL != "" { + if pURL, err := url.Parse(proxyURL); err == nil { + client.Transport = &http.Transport{Proxy: http.ProxyURL(pURL)} + } else { + logger.WarnF( + "NewMimoTTSProvider: invalid proxy URL; proceeding without proxy", + map[string]any{"proxyURL": proxyURL, "error": err}, + ) + } + } + + return &MimoTTSProvider{ + apiKey: apiKey, + apiBase: apiBase, + voice: "default_zh", // mimo_default now seems to be an alias for default_en, which is not working for Chinese TTS. default_zh seems to work fine with both English and Chinese, and is likely the intended default for TTS. + format: "mp3", + model: model, + httpClient: client, + } +} + +func (t *MimoTTSProvider) Name() string { + return "mimo-tts" +} + +func (t *MimoTTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) { + logger.DebugCF("voice-tts", "Starting TTS synthesis", map[string]any{"text_len": len(text), "provider": t.Name()}) + + reqBody := map[string]any{ + "model": t.model, + "messages": []map[string]string{ + {"role": "assistant", "content": text}, + }, + "audio": map[string]string{ + "format": t.format, + "voice": t.voice, + }, + "stream": false, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", t.apiBase, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Api-Key", t.apiKey) + + resp, err := t.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + var payload struct { + Choices []struct { + Message struct { + Audio struct { + Data string `json:"data"` + } `json:"audio"` + } `json:"message"` + } `json:"choices"` + } + + err = json.Unmarshal(body, &payload) + if err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(payload.Choices) == 0 || payload.Choices[0].Message.Audio.Data == "" { + return nil, fmt.Errorf("invalid TTS response: missing audio data") + } + + audioBytes, err := base64.StdEncoding.DecodeString(payload.Choices[0].Message.Audio.Data) + if err != nil { + return nil, fmt.Errorf("failed to decode audio data: %w", err) + } + + return io.NopCloser(bytes.NewReader(audioBytes)), nil +} diff --git a/pkg/audio/tts/openai_tts.go b/pkg/audio/tts/openai_tts.go new file mode 100644 index 000000000..786414873 --- /dev/null +++ b/pkg/audio/tts/openai_tts.go @@ -0,0 +1,126 @@ +package tts + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers/common" +) + +type OpenAITTSProvider struct { + apiKey string + apiBase string + voice string + model string + httpClient *http.Client +} + +func NewOpenAITTSProvider(apiKey string, apiBase string, proxyURL string, model string) *OpenAITTSProvider { + // Normalize apiBase to avoid malformed endpoints like + // "https://api.openai.com/audio/speech" when "/v1" is required. + if apiBase == "" { + apiBase = "https://api.openai.com/v1/audio/speech" + } else { + if u, err := url.Parse(apiBase); err == nil && u.Scheme != "" && u.Host != "" { + path := u.Path + if u.Host == "api.openai.com" { + // For the official OpenAI host, ensure exactly one /v1 prefix and + // that the path ends with /audio/speech. + if path == "" || path == "/" || path == "/v1" { + path = "/v1/audio/speech" + } else { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + if !strings.HasPrefix(path, "/v1/") { + path = "/v1" + strings.TrimSuffix(path, "/") + } + if !strings.HasSuffix(path, "/audio/speech") { + path = strings.TrimSuffix(path, "/") + "/audio/speech" + } + } + } else { + // For non-OpenAI hosts (e.g., proxies), preserve the existing base + // path and only ensure it ends with /audio/speech. + if !strings.HasSuffix(path, "/audio/speech") { + path = strings.TrimSuffix(path, "/") + "/audio/speech" + } + } + u.Path = path + apiBase = u.String() + } else { + // Fallback to the previous string-based behavior if parsing fails. + if apiBase == "https://api.openai.com/v1" { + apiBase = "https://api.openai.com/v1/audio/speech" + } else if !strings.HasSuffix(apiBase, "/audio/speech") { + // Just in case they provide openrouter base or standard base + apiBase = strings.TrimSuffix(apiBase, "/") + "/audio/speech" + } + } + } + + client := common.NewHTTPClient(proxyURL) + client.Timeout = 60 * time.Second + + model = strings.TrimSpace(model) + if model == "" { + model = "tts-1" + } + + return &OpenAITTSProvider{ + apiKey: apiKey, + apiBase: apiBase, + voice: "alloy", + model: model, + httpClient: client, + } +} + +func (t *OpenAITTSProvider) Name() string { + return "openai-tts" +} + +func (t *OpenAITTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) { + logger.DebugCF("voice-tts", "Starting TTS synthesis", map[string]any{"text_len": len(text)}) + + reqBody := map[string]any{ + "model": t.model, + "input": text, + "voice": t.voice, + "response_format": "opus", + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", t.apiBase, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+t.apiKey) + + resp, err := t.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + return resp.Body, nil +} diff --git a/pkg/audio/tts/tts.go b/pkg/audio/tts/tts.go new file mode 100644 index 000000000..99a9ef203 --- /dev/null +++ b/pkg/audio/tts/tts.go @@ -0,0 +1,151 @@ +package tts + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/providers" +) + +type TTSProvider interface { + Name() string + Synthesize(ctx context.Context, text string) (io.ReadCloser, error) +} + +func providerFromModelConfig(mc *config.ModelConfig) TTSProvider { + if mc == nil || mc.APIKey() == "" { + return nil + } + + protocol, modelID := providers.ExtractProtocol(mc.Model) + if modelID == "" { + modelID = strings.TrimSpace(mc.Model) + } + + switch protocol { + case "mimo": + return NewMimoTTSProvider(mc.APIKey(), providers.ResolveAPIBase(mc), modelID, mc.Proxy) + default: + return NewOpenAITTSProvider(mc.APIKey(), providers.ResolveAPIBase(mc), mc.Proxy, modelID) + } +} + +func DetectTTS(cfg *config.Config) TTSProvider { + if cfg == nil { + return nil + } + + if modelName := strings.TrimSpace(cfg.Voice.TTSModelName); modelName != "" { + if mc, err := cfg.GetModelConfig(modelName); err == nil { + if provider := providerFromModelConfig(mc); provider != nil { + return provider + } + } + } + + for _, mc := range cfg.ModelList { + if strings.Contains(strings.ToLower(mc.Model), "tts") && mc.APIKey() != "" { + if provider := providerFromModelConfig(mc); provider != nil { + return provider + } + } + } + return nil +} + +// SynthesizeAndStore synthesizes text to speech and registers it in the media store, returning the media reference. +func SynthesizeAndStore( + ctx context.Context, + provider TTSProvider, + store media.MediaStore, + text string, + filename string, + channel string, + chatID string, +) (string, error) { + if provider == nil { + return "", fmt.Errorf("tts provider is not configured") + } + if store == nil { + return "", fmt.Errorf("media store not configured") + } + if channel == "" || chatID == "" { + return "", fmt.Errorf("no target channel/chat available") + } + if strings.TrimSpace(text) == "" { + return "", fmt.Errorf("text is required") + } + + stream, err := provider.Synthesize(ctx, text) + if err != nil { + return "", fmt.Errorf("tts synthesize failed: %w", err) + } + defer stream.Close() + + err = os.MkdirAll(media.TempDir(), 0o700) + if err != nil { + return "", fmt.Errorf("failed to create media temp dir: %w", err) + } + + fileExt := ".ogg" + contentType := "audio/ogg" + if provider.Name() == "mimo-tts" { + fileExt = ".mp3" + contentType = "audio/mpeg" + } + + file, err := os.CreateTemp(media.TempDir(), "tts-*"+fileExt) + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + + removeTemp := true + defer func() { + if removeTemp { + _ = os.Remove(file.Name()) + } + }() + + _, err = io.Copy(file, stream) + if err != nil { + file.Close() + return "", fmt.Errorf("failed to write tts audio: %w", err) + } + + err = file.Close() + if err != nil { + return "", fmt.Errorf("failed to close tts audio file: %w", err) + } + + filename = strings.TrimSpace(filename) + if filename == "" { + filename = fmt.Sprintf("tts-%d%s", time.Now().Unix(), fileExt) + } + + ext := strings.ToLower(filepath.Ext(filename)) + if ext == "" { + filename += fileExt + } else if ext != fileExt { + filename = strings.TrimSuffix(filename, filepath.Ext(filename)) + fileExt + } + + scope := fmt.Sprintf("tool:send_tts:%s:%s:%d", channel, chatID, time.Now().UnixNano()) + ref, err := store.Store(file.Name(), media.MediaMeta{ + Filename: filename, + ContentType: contentType, + Source: "tool:send_tts", + }, scope) + if err != nil { + return "", fmt.Errorf("failed to register audio: %w", err) + } + removeTemp = false + + return ref, nil +} diff --git a/pkg/audio/tts/tts_test.go b/pkg/audio/tts/tts_test.go new file mode 100644 index 000000000..053aa7220 --- /dev/null +++ b/pkg/audio/tts/tts_test.go @@ -0,0 +1,247 @@ +package tts + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" +) + +func TestNewOpenAITTSProvider_APIBaseNormalization(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + input string + expect string + }{ + { + name: "empty base", + input: "", + expect: "https://api.openai.com/v1/audio/speech", + }, + { + name: "official host no path", + input: "https://api.openai.com", + expect: "https://api.openai.com/v1/audio/speech", + }, + { + name: "official host v1", + input: "https://api.openai.com/v1", + expect: "https://api.openai.com/v1/audio/speech", + }, + { + name: "official host v1 slash", + input: "https://api.openai.com/v1/", + expect: "https://api.openai.com/v1/audio/speech", + }, + { + name: "non-openai host preserves base path", + input: "https://proxy.example.com/base", + expect: "https://proxy.example.com/base/audio/speech", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + provider := NewOpenAITTSProvider("key", tc.input, "", "") + if provider.apiBase != tc.expect { + t.Fatalf("apiBase mismatch: got %q, want %q", provider.apiBase, tc.expect) + } + }) + } +} + +func TestOpenAITTSProvider_SynthesizeSuccess(t *testing.T) { + t.Parallel() + + var gotPath string + var gotAuth string + var gotContentType string + var gotBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + gotContentType = r.Header.Get("Content-Type") + + bodyBytes, _ := io.ReadAll(r.Body) + _ = r.Body.Close() + _ = json.Unmarshal(bodyBytes, &gotBody) + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("audio-bytes")) + })) + defer server.Close() + + provider := NewOpenAITTSProvider("k123", server.URL, "", "") + stream, err := provider.Synthesize(context.Background(), "hello") + if err != nil { + t.Fatalf("Synthesize failed: %v", err) + } + defer stream.Close() + + data, err := io.ReadAll(stream) + if err != nil { + t.Fatalf("read stream failed: %v", err) + } + + if gotPath != "/audio/speech" { + t.Fatalf("request path mismatch: got %q", gotPath) + } + if gotAuth != "Bearer k123" { + t.Fatalf("authorization mismatch: got %q", gotAuth) + } + if gotContentType != "application/json" { + t.Fatalf("content-type mismatch: got %q", gotContentType) + } + if gotBody["model"] != "tts-1" || gotBody["voice"] != "alloy" || gotBody["response_format"] != "opus" || + gotBody["input"] != "hello" { + bodyJSON, _ := json.Marshal(gotBody) + t.Fatalf("request body mismatch: %s", string(bodyJSON)) + } + if string(data) != "audio-bytes" { + t.Fatalf("response body mismatch: got %q", string(data)) + } +} + +func TestOpenAITTSProvider_SynthesizeNon200(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("nope")) + })) + defer server.Close() + + provider := NewOpenAITTSProvider("k123", server.URL, "", "") + _, err := provider.Synthesize(context.Background(), "hello") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "API error (status 500): nope") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestNewOpenAITTSProvider_UsesConfiguredModel(t *testing.T) { + t.Parallel() + + provider := NewOpenAITTSProvider("key", "https://api.xiaomimimo.com/v1", "", "mimo-v2-tts") + if provider.model != "mimo-v2-tts" { + t.Fatalf("model mismatch: got %q, want %q", provider.model, "mimo-v2-tts") + } + if provider.apiBase != "https://api.xiaomimimo.com/v1/audio/speech" { + t.Fatalf("apiBase mismatch: got %q", provider.apiBase) + } +} + +func TestDetectTTS_UsesMimoProviderForMimoModels(t *testing.T) { + t.Parallel() + + provider := DetectTTS(&config.Config{ + Voice: config.VoiceConfig{TTSModelName: "mimo-tts"}, + ModelList: []*config.ModelConfig{ + { + ModelName: "mimo-tts", + Model: "mimo/mimo-v2-tts", + APIKeys: config.SimpleSecureStrings("sk-mimo"), + }, + }, + }) + + ttsProvider, ok := provider.(*MimoTTSProvider) + if !ok { + t.Fatalf("DetectTTS() type = %T, want *MimoTTSProvider", provider) + } + if ttsProvider.model != "mimo-v2-tts" { + t.Fatalf("model mismatch: got %q, want %q", ttsProvider.model, "mimo-v2-tts") + } + if ttsProvider.apiBase != "https://api.xiaomimimo.com/v1/chat/completions" { + t.Fatalf("apiBase mismatch: got %q", ttsProvider.apiBase) + } +} + +type stubTTSProvider struct { + name string +} + +func (s stubTTSProvider) Name() string { + return s.name +} + +func (s stubTTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader("audio")), nil +} + +func TestSynthesizeAndStore_UsesOggMetadataByDefault(t *testing.T) { + t.Parallel() + + store := media.NewFileMediaStore() + ref, err := SynthesizeAndStore( + context.Background(), + stubTTSProvider{name: "openai-tts"}, + store, + "hello", + "", + "discord", + "chat123", + ) + if err != nil { + t.Fatalf("SynthesizeAndStore failed: %v", err) + } + + path, meta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("ResolveWithMeta failed: %v", err) + } + if meta.ContentType != "audio/ogg" { + t.Fatalf("ContentType = %q, want %q", meta.ContentType, "audio/ogg") + } + if filepath.Ext(path) != ".ogg" { + t.Fatalf("stored file extension = %q, want %q", filepath.Ext(path), ".ogg") + } + if filepath.Ext(meta.Filename) != ".ogg" { + t.Fatalf("filename extension = %q, want %q", filepath.Ext(meta.Filename), ".ogg") + } +} + +func TestSynthesizeAndStore_UsesMp3MetadataForMimo(t *testing.T) { + t.Parallel() + + store := media.NewFileMediaStore() + ref, err := SynthesizeAndStore( + context.Background(), + stubTTSProvider{name: "mimo-tts"}, + store, + "hello", + "", + "discord", + "chat123", + ) + if err != nil { + t.Fatalf("SynthesizeAndStore failed: %v", err) + } + + path, meta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("ResolveWithMeta failed: %v", err) + } + if meta.ContentType != "audio/mpeg" { + t.Fatalf("ContentType = %q, want %q", meta.ContentType, "audio/mpeg") + } + if filepath.Ext(path) != ".mp3" { + t.Fatalf("stored file extension = %q, want %q", filepath.Ext(path), ".mp3") + } + if filepath.Ext(meta.Filename) != ".mp3" { + t.Fatalf("filename extension = %q, want %q", filepath.Ext(meta.Filename), ".mp3") + } +} diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 37fcb74c5..a9c74ef90 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -34,6 +34,8 @@ type MessageBus struct { inbound chan InboundMessage outbound chan OutboundMessage outboundMedia chan OutboundMediaMessage + audioChunks chan AudioChunk + voiceControls chan VoiceControl closeOnce sync.Once done chan struct{} @@ -47,6 +49,8 @@ func NewMessageBus() *MessageBus { inbound: make(chan InboundMessage, defaultBusBufferSize), outbound: make(chan OutboundMessage, defaultBusBufferSize), outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize), + audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer + voiceControls: make(chan VoiceControl, defaultBusBufferSize), done: make(chan struct{}), } } @@ -103,6 +107,22 @@ func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage { return mb.outboundMedia } +func (mb *MessageBus) PublishAudioChunk(ctx context.Context, chunk AudioChunk) error { + return publish(ctx, mb, mb.audioChunks, chunk) +} + +func (mb *MessageBus) AudioChunksChan() <-chan AudioChunk { + return mb.audioChunks +} + +func (mb *MessageBus) PublishVoiceControl(ctx context.Context, ctrl VoiceControl) error { + return publish(ctx, mb, mb.voiceControls, ctrl) +} + +func (mb *MessageBus) VoiceControlsChan() <-chan VoiceControl { + return mb.voiceControls +} + // SetStreamDelegate registers a StreamDelegate (typically the channel Manager). func (mb *MessageBus) SetStreamDelegate(d StreamDelegate) { mb.streamDelegate.Store(d) @@ -132,6 +152,8 @@ func (mb *MessageBus) Close() { close(mb.inbound) close(mb.outbound) close(mb.outboundMedia) + close(mb.audioChunks) + close(mb.voiceControls) // clean up any remaining messages in channels drained := 0 @@ -144,6 +166,12 @@ func (mb *MessageBus) Close() { for range mb.outboundMedia { drained++ } + for range mb.audioChunks { + drained++ + } + for range mb.voiceControls { + drained++ + } if drained > 0 { logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{ diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 12da3f1dd..27cf61b5f 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -30,10 +30,11 @@ type InboundMessage struct { } type OutboundMessage struct { - Channel string `json:"channel"` - ChatID string `json:"chat_id"` - Content string `json:"content"` - ReplyToMessageID string `json:"reply_to_message_id,omitempty"` + Channel string `json:"channel"` + ChatID string `json:"chat_id"` + Content string `json:"content"` + ReplyToMessageID string `json:"reply_to_message_id,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` } // MediaPart describes a single media attachment to send. @@ -51,3 +52,25 @@ type OutboundMediaMessage struct { ChatID string `json:"chat_id"` Parts []MediaPart `json:"parts"` } + +// AudioChunk represents a chunk of streaming voice data. +type AudioChunk struct { + SessionID string `json:"session_id"` + SpeakerID string `json:"speaker_id"` // User ID or SSRC + ChatID string `json:"chat_id"` // Where to respond + Channel string `json:"channel"` // Source channel type (e.g. "discord") + Sequence uint64 `json:"sequence"` + Timestamp uint32 `json:"timestamp"` + SampleRate int `json:"sample_rate"` + Channels int `json:"channels"` + Format string `json:"format"` // "opus", "pcm", etc + Data []byte `json:"data"` +} + +// VoiceControl represents state or commands for voice sessions. +type VoiceControl struct { + SessionID string `json:"session_id"` + ChatID string `json:"chat_id"` + Type string `json:"type"` // "state", "command" + Action string `json:"action"` // "idle", "listening", "start", "stop", "leave" +} diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index b3070a822..01b1b4053 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -3,6 +3,7 @@ package discord import ( "context" "fmt" + "io" "net/http" "net/url" "os" @@ -14,6 +15,8 @@ import ( "github.com/bwmarrin/discordgo" "github.com/gorilla/websocket" + "github.com/sipeed/picoclaw/pkg/audio" + "github.com/sipeed/picoclaw/pkg/audio/tts" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" @@ -42,6 +45,15 @@ type DiscordChannel struct { typingMu sync.Mutex typingStop map[string]chan struct{} // chatID → stop signal botUserID string // stored for mention checking + bus *bus.MessageBus + tts tts.TTSProvider + voiceMu sync.RWMutex + voiceSSRC map[string]map[uint32]string // guildID -> ssrc -> userID + + // TTS interruption: cancel active playback when user speaks + ttsMu sync.Mutex + cancelTTS context.CancelFunc + ttsPlayID uint64 } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -73,6 +85,8 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC config: cfg, ctx: context.Background(), typingStop: make(map[string]chan struct{}), + bus: bus, + voiceSSRC: make(map[string]map[uint32]string), }, nil } @@ -90,6 +104,8 @@ func (c *DiscordChannel) Start(ctx context.Context) error { c.session.AddHandler(c.handleMessage) + go c.listenVoiceControl(c.ctx) + if err := c.session.Open(); err != nil { return fmt.Errorf("failed to open discord session: %w", err) } @@ -142,6 +158,25 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]s return nil, nil } + if c.tts != nil { + if ch, err := c.session.State.Channel(channelID); err == nil && ch.GuildID != "" { + if vc, ok := c.session.VoiceConnections[ch.GuildID]; ok && vc != nil { + // Cancel any previous TTS playback + c.ttsMu.Lock() + if c.cancelTTS != nil { + c.cancelTTS() + } + ttsCtx, ttsCancel := context.WithCancel(c.ctx) + c.ttsPlayID++ + playID := c.ttsPlayID + c.cancelTTS = ttsCancel + c.ttsMu.Unlock() + + go c.playTTS(ttsCtx, vc, msg.Content, playID) + } + } + } + msgID, err := c.sendChunk(ctx, channelID, msg.Content, msg.ReplyToMessageID) if err != nil { return nil, err @@ -359,6 +394,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } + if c.handleVoiceCommand(s, m) { + return + } + content := m.Content // In guild (group) channels, apply unified group trigger filtering @@ -630,3 +669,134 @@ func (c *DiscordChannel) stripBotMention(text string) string { text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "") return strings.TrimSpace(text) } + +func (c *DiscordChannel) listenVoiceControl(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case ctrl, ok := <-c.bus.VoiceControlsChan(): + if !ok { + return + } + if ctrl.Type == "command" && ctrl.Action == "leave" { + if strings.HasPrefix(ctrl.SessionID, "discord_vc_") { + guildID := strings.TrimPrefix(ctrl.SessionID, "discord_vc_") + vc, exists := c.session.VoiceConnections[guildID] + if exists && vc != nil { + vc.Disconnect(ctx) + } + } + } + } + } +} + +func (c *DiscordChannel) playTTS(ctx context.Context, vc *discordgo.VoiceConnection, text string, playID uint64) { + // Capture the cancel func associated with this playback (if any). + // Clear cancelTTS when playback finishes (normal or interrupted), + // but only if it still refers to this playback's cancel func. + defer func() { + c.ttsMu.Lock() + if c.ttsPlayID == playID { + c.cancelTTS = nil + } + c.ttsMu.Unlock() + }() + + sentences := audio.SplitSentences(text) + if len(sentences) == 0 { + return + } + + logger.InfoCF("discord", "Starting streamed TTS", map[string]any{"sentences": len(sentences)}) + + // Pipeline: prefetch next sentence's audio while playing current + type ttResult struct { + stream io.ReadCloser + err error + } + + var prefetch chan ttResult + + // Ensure any in-flight prefetch is drained on exit to prevent stream leaks, + // but avoid blocking indefinitely if the prefetch goroutine is stuck or never sends. + defer func() { + if prefetch != nil { + select { + case result := <-prefetch: + if result.stream != nil { + result.stream.Close() + } + case <-time.After(100 * time.Millisecond): + // Timed out waiting for a prefetched result; avoid blocking on exit. + } + } + }() + + for i, sentence := range sentences { + // Check for cancellation (interruption) + select { + case <-ctx.Done(): + logger.InfoCF("discord", "TTS interrupted", map[string]any{"at_sentence": i}) + return + default: + } + + // Start prefetching the NEXT sentence while we process the current one + var nextPrefetch chan ttResult + if i+1 < len(sentences) { + nextPrefetch = make(chan ttResult, 1) + nextSentence := sentences[i+1] + go func() { + s, e := c.tts.Synthesize(ctx, nextSentence) + nextPrefetch <- ttResult{s, e} + }() + } + + // Get the current sentence's audio + var stream io.ReadCloser + var err error + + if prefetch != nil { + // Use prefetched result from previous iteration, but be responsive to cancellation. + var result ttResult + select { + case result = <-prefetch: + stream, err = result.stream, result.err + case <-ctx.Done(): + // Context canceled while waiting for prefetched audio; abort playback. + logger.InfoCF( + "discord", + "TTS interrupted while waiting for prefetched audio", + map[string]any{"at_sentence": i}, + ) + return + } + } else { + // First sentence: synthesize directly + stream, err = c.tts.Synthesize(ctx, sentence) + } + + if err != nil { + if stream != nil { + stream.Close() + } + logger.ErrorCF("discord", "TTS synthesize failed", map[string]any{"error": err.Error(), "sentence": i}) + prefetch = nextPrefetch + continue + } + + if err := streamOggOpusToDiscord(ctx, vc, stream); err != nil { + logger.ErrorCF("discord", "TTS playback failed", map[string]any{"error": err.Error(), "sentence": i}) + } + stream.Close() + + prefetch = nextPrefetch + } +} + +// VoiceCapabilities returns the voice capabilities of the channel. +func (c *DiscordChannel) VoiceCapabilities() channels.VoiceCapabilities { + return channels.VoiceCapabilities{ASR: true, TTS: true} +} diff --git a/pkg/channels/discord/init.go b/pkg/channels/discord/init.go index 15a539804..8381dc9e9 100644 --- a/pkg/channels/discord/init.go +++ b/pkg/channels/discord/init.go @@ -1,6 +1,7 @@ package discord import ( + "github.com/sipeed/picoclaw/pkg/audio/tts" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" @@ -8,6 +9,10 @@ import ( func init() { channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { - return NewDiscordChannel(cfg.Channels.Discord, b) + ch, err := NewDiscordChannel(cfg.Channels.Discord, b) + if err == nil { + ch.tts = tts.DetectTTS(cfg) + } + return ch, err }) } diff --git a/pkg/channels/discord/voice.go b/pkg/channels/discord/voice.go new file mode 100644 index 000000000..5b686b141 --- /dev/null +++ b/pkg/channels/discord/voice.go @@ -0,0 +1,313 @@ +package discord + +import ( + "context" + "fmt" + "io" + "time" + + "github.com/bwmarrin/discordgo" + + "github.com/sipeed/picoclaw/pkg/audio" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" +) + +func (c *DiscordChannel) setVoiceUserID(guildID string, ssrc uint32, userID string) { + if userID == "" { + return + } + + c.voiceMu.Lock() + defer c.voiceMu.Unlock() + + ssrcMap, ok := c.voiceSSRC[guildID] + if !ok { + ssrcMap = make(map[uint32]string) + c.voiceSSRC[guildID] = ssrcMap + } + ssrcMap[ssrc] = userID +} + +func (c *DiscordChannel) voiceUserID(guildID string, ssrc uint32) string { + c.voiceMu.RLock() + defer c.voiceMu.RUnlock() + + ssrcMap, ok := c.voiceSSRC[guildID] + if !ok { + return "" + } + return ssrcMap[ssrc] +} + +func (c *DiscordChannel) handleVoiceCommand(s *discordgo.Session, m *discordgo.MessageCreate) bool { + if m.Content == "!vc join" { + vs, err := s.State.VoiceState(m.GuildID, m.Author.ID) + if err != nil || vs == nil { + if _, sendErr := s.ChannelMessageSend( + m.ChannelID, + "You need to be in a voice channel first!", + ); sendErr != nil { + logger.InfoCF("discord", "Failed to send voice channel requirement message", map[string]any{ + "channel": m.ChannelID, + "error": sendErr, + }) + } + return true + } + + logger.InfoCF("discord", "Joining voice channel", map[string]any{"channel": vs.ChannelID}) + vc, err := s.ChannelVoiceJoin(c.ctx, m.GuildID, vs.ChannelID, false, false) + if err != nil { + if _, sendErr := s.ChannelMessageSend( + m.ChannelID, + fmt.Sprintf("Failed to join voice channel: %v", err), + ); sendErr != nil { + logger.InfoCF("discord", "Failed to send voice join error message", map[string]any{ + "channel": m.ChannelID, + "error": sendErr, + }) + } + return true + } + + go c.receiveVoice(vc, m.GuildID, m.ChannelID) + if _, sendErr := s.ChannelMessageSend( + m.ChannelID, + "Joined Voice Channel! Listening for audio...", + ); sendErr != nil { + logger.InfoCF("discord", "Failed to send voice join success message", map[string]any{ + "channel": m.ChannelID, + "error": sendErr, + }) + } + return true + } else if m.Content == "!vc leave" { + vc, exists := s.VoiceConnections[m.GuildID] + if exists && vc != nil { + if err := vc.Disconnect(c.ctx); err != nil { + logger.InfoCF("discord", "Failed to disconnect from voice channel", map[string]any{ + "guild": m.GuildID, + "error": err, + }) + } + if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Left Voice Channel."); sendErr != nil { + logger.InfoCF("discord", "Failed to send voice leave success message", map[string]any{ + "channel": m.ChannelID, + "error": sendErr, + }) + } + } else { + if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Not in a voice channel."); sendErr != nil { + logger.InfoCF("discord", "Failed to send voice not-in-channel message", map[string]any{ + "channel": m.ChannelID, + "error": sendErr, + }) + } + } + return true + } + return false +} + +func VoiceReceiveActive(vc *discordgo.VoiceConnection) bool { + return vc != nil && vc.OpusRecv != nil +} + +func streamOggOpusToDiscord(ctx context.Context, vc *discordgo.VoiceConnection, r io.Reader) (retErr error) { + // Recover from panic if vc.OpusSend is closed mid-send (e.g. on disconnect) + defer func() { + if rec := recover(); rec != nil { + retErr = fmt.Errorf("voice connection closed during playback") + } + }() + + // Wait for the speaking transition to register + vc.Speaking(true) + defer vc.Speaking(false) + + return audio.DecodeOggOpus(r, func(frame []byte) error { + select { + case <-ctx.Done(): + return ctx.Err() + case vc.OpusSend <- frame: + return nil + } + }) +} + +func (c *DiscordChannel) receiveVoice(vc *discordgo.VoiceConnection, guildID string, chatID string) { + logger.InfoCF("discord", "Started listening for voice", map[string]any{"guild": guildID}) + + vc.AddHandler(func(_ *discordgo.VoiceConnection, vs *discordgo.VoiceSpeakingUpdate) { + if vs == nil { + return + } + c.setVoiceUserID(guildID, uint32(vs.SSRC), vs.UserID) + }) + + defer func() { + c.voiceMu.Lock() + delete(c.voiceSSRC, guildID) + c.voiceMu.Unlock() + }() + + go func(ctx context.Context, vc *discordgo.VoiceConnection) { + // Recover from potential panics if OpusSend is closed mid-send. + defer func() { + if rec := recover(); rec != nil { + logger.WarnCF("discord", "Recovered from panic while sending wake-up frames", map[string]any{ + "error": rec, + "guild": guildID, + }) + } + }() + + // If the voice connection or OpusSend are not available, nothing to do. + if vc == nil || vc.OpusSend == nil { + return + } + + time.Sleep(250 * time.Millisecond) // Wait a bit for connection to settle + + // Abort if the context has already been canceled. + select { + case <-ctx.Done(): + return + default: + } + + vc.Speaking(true) + defer vc.Speaking(false) + + silenceFrame := []byte{0xF8, 0xFF, 0xFE} + for i := 0; i < 5; i++ { + select { + case <-ctx.Done(): + return + case vc.OpusSend <- silenceFrame: + } + time.Sleep(20 * time.Millisecond) + } + + logger.DebugCF("discord", "Sent wake-up silence frames", map[string]any{"guild": guildID}) + }(c.ctx, vc) + sessionID := fmt.Sprintf("discord_vc_%s", guildID) + + c.bus.PublishVoiceControl(c.ctx, bus.VoiceControl{ + SessionID: sessionID, + Type: "state", + Action: "listening", + }) + + var sequence uint64 = 0 + var interruptCount int + var lastInterruptAt time.Time + + for { + select { + case <-c.ctx.Done(): + return + case p, ok := <-vc.OpusRecv: + if !ok { + logger.InfoCF("discord", "Voice channel closed", map[string]any{"guild": guildID}) + // Cancel any TTS that may still be playing + c.ttsMu.Lock() + if c.cancelTTS != nil { + c.cancelTTS() + c.cancelTTS = nil + } + c.ttsMu.Unlock() + return + } + + if p == nil { + logger.DebugCF("discord", "Received nil Opus packet", nil) + continue + } + + if len(p.Opus) == 0 { + logger.DebugCF("discord", "Received empty Opus packet", map[string]any{ + "seq": p.Sequence, + "ssrc": p.SSRC, + }) + continue + } + + logger.DebugCF("discord", "Received Opus packet", map[string]any{ + "seq": p.Sequence, + "len": len(p.Opus), + "ssrc": p.SSRC, + }) + // Interruption detection: if user sends voice while TTS is playing, + // cancel TTS after a short debounce (3 packets in 200ms) + now := time.Now() + if now.Sub(lastInterruptAt) > 500*time.Millisecond { + interruptCount = 0 + } + interruptCount++ + lastInterruptAt = now + + if interruptCount >= 3 { + c.ttsMu.Lock() + if c.cancelTTS != nil { + c.cancelTTS() + c.cancelTTS = nil + logger.InfoCF("discord", "TTS interrupted by user voice", nil) + } + c.ttsMu.Unlock() + interruptCount = 0 + } + + userID := c.voiceUserID(guildID, p.SSRC) + if userID == "" { + logger.DebugCF("discord", "Dropping voice packet without user mapping", map[string]any{ + "ssrc": p.SSRC, + "guild": guildID, + }) + continue + } + + sender := bus.SenderInfo{ + Platform: "discord", + PlatformID: userID, + CanonicalID: identity.BuildCanonicalID("discord", userID), + } + if !c.IsAllowedSender(sender) { + logger.DebugCF("discord", "Voice packet rejected by allowlist", map[string]any{ + "user_id": userID, + "guild": guildID, + }) + continue + } + + sequence++ + + chunk := bus.AudioChunk{ + SessionID: sessionID, + SpeakerID: userID, + ChatID: chatID, + Channel: "discord", + Sequence: sequence, + Timestamp: p.Timestamp, + SampleRate: 48000, + Channels: 2, + Format: "opus", + Data: p.Opus, + } + + ctx, cancel := context.WithTimeout(c.ctx, 100*time.Millisecond) + err := c.bus.PublishAudioChunk(ctx, chunk) + cancel() + if err != nil { + logger.ErrorCF("discord", "Failed to publish audio chunk", map[string]any{ + "guild": guildID, + "sessionID": sessionID, + "sequence": sequence, + "error": err.Error(), + }) + } + } + } +} diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go index 4952394b7..81238460a 100644 --- a/pkg/channels/feishu/common.go +++ b/pkg/channels/feishu/common.go @@ -6,6 +6,8 @@ import ( "strings" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/sipeed/picoclaw/pkg/channels" ) // mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions. @@ -145,3 +147,8 @@ func extractImageKeysRecursive(v any, feishuKeys, externalURLs *[]string) { } } } + +// VoiceCapabilities returns the voice capabilities of the channel. +func (c *FeishuChannel) VoiceCapabilities() channels.VoiceCapabilities { + return channels.VoiceCapabilities{ASR: true, TTS: true} +} diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index e29896389..230983935 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -684,3 +684,8 @@ func (c *LINEChannel) downloadContent(messageID, filename string) string { }, }) } + +// VoiceCapabilities returns the voice capabilities of the channel. +func (c *LINEChannel) VoiceCapabilities() channels.VoiceCapabilities { + return channels.VoiceCapabilities{ASR: true, TTS: true} +} diff --git a/pkg/channels/matrix/matrix.go b/pkg/channels/matrix/matrix.go index 96db964cf..5e975b4f0 100644 --- a/pkg/channels/matrix/matrix.go +++ b/pkg/channels/matrix/matrix.go @@ -1300,3 +1300,8 @@ func stripUserMentionWithRegexp(text string, userID id.UserID, mentionR *regexp. cleaned = strings.TrimLeft(cleaned, ",:; ") return strings.TrimSpace(cleaned) } + +// VoiceCapabilities returns the voice capabilities of the channel. +func (c *MatrixChannel) VoiceCapabilities() channels.VoiceCapabilities { + return channels.VoiceCapabilities{ASR: true, TTS: true} +} diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index a9b95c20f..0c59965c1 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -1104,3 +1104,8 @@ func truncate(s string, n int) string { } return string(runes[:n]) + "..." } + +// VoiceCapabilities returns the voice capabilities of the channel. +func (c *OneBotChannel) VoiceCapabilities() channels.VoiceCapabilities { + return channels.VoiceCapabilities{ASR: true, TTS: true} +} diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 3a8cf9652..f2b70aec9 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -1002,3 +1002,8 @@ func sanitizeURLs(text string) string { return scheme + domain + path }) } + +// VoiceCapabilities returns the voice capabilities of the channel. +func (c *QQChannel) VoiceCapabilities() channels.VoiceCapabilities { + return channels.VoiceCapabilities{ASR: true, TTS: true} +} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 831eb43cc..ccb394a57 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -1133,3 +1133,8 @@ func cryptoRandInt() int { _, _ = rand.Read(b[:]) return int(binary.BigEndian.Uint32(b[:])) | 1 // ensure non-zero } + +// VoiceCapabilities returns the voice capabilities of the channel. +func (c *TelegramChannel) VoiceCapabilities() channels.VoiceCapabilities { + return channels.VoiceCapabilities{ASR: true, TTS: true} +} diff --git a/pkg/channels/voice_capabilities.go b/pkg/channels/voice_capabilities.go new file mode 100644 index 000000000..34fd24269 --- /dev/null +++ b/pkg/channels/voice_capabilities.go @@ -0,0 +1,58 @@ +package channels + +// VoiceCapabilities describes whether ASR (speech-to-text) and TTS (text-to-speech) +// are available for a channel under the current configuration. +type VoiceCapabilities struct { + ASR bool + TTS bool +} + +// VoiceCapabilityProvider is an optional interface for channels that want to +// explicitly declare their ASR/TTS support. +type VoiceCapabilityProvider interface { + VoiceCapabilities() VoiceCapabilities +} + +// Deprecated: Channels should implement VoiceCapabilityProvider instead. +// To be removed once all existing capable channels conform to the interface. +var asrCapableChannels = map[string]bool{ + "discord": true, + "telegram": true, + "matrix": true, + "qq": true, + "weixin": true, + "line": true, + "feishu": true, + "onebot": true, +} + +// DetectVoiceCapabilities returns ASR/TTS availability for a channel, gated by +// whether providers are configured. +func DetectVoiceCapabilities(channelName string, ch Channel, asrAvailable bool, ttsAvailable bool) VoiceCapabilities { + if ch == nil { + return VoiceCapabilities{} + } + + if vcp, ok := ch.(VoiceCapabilityProvider); ok { + caps := vcp.VoiceCapabilities() + if !asrAvailable { + caps.ASR = false + } + if !ttsAvailable { + caps.TTS = false + } + return caps + } + + caps := VoiceCapabilities{} + if asrAvailable { + caps.ASR = asrCapableChannels[channelName] + } + if ttsAvailable { + if _, ok := ch.(MediaSender); ok { + caps.TTS = true + } + } + + return caps +} diff --git a/pkg/channels/weixin/weixin.go b/pkg/channels/weixin/weixin.go index 0e9010131..a0d0c96b5 100644 --- a/pkg/channels/weixin/weixin.go +++ b/pkg/channels/weixin/weixin.go @@ -402,3 +402,8 @@ func (c *WeixinChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st return nil, nil } + +// VoiceCapabilities returns the voice capabilities of the channel. +func (c *WeixinChannel) VoiceCapabilities() channels.VoiceCapabilities { + return channels.VoiceCapabilities{ASR: true, TTS: true} +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 397cd4ab8..7a11d1ab7 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -558,9 +558,9 @@ type DevicesConfig struct { } type VoiceConfig struct { - ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"` - EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"` - ElevenLabsAPIKey string `json:"elevenlabs_api_key,omitempty" env:"PICOCLAW_VOICE_ELEVENLABS_API_KEY"` + ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"` + TTSModelName string `json:"tts_model_name,omitempty" env:"PICOCLAW_VOICE_TTS_MODEL_NAME"` + EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"` } // ModelConfig represents a model-centric provider configuration. @@ -829,6 +829,7 @@ type ToolsConfig struct { Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"` ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"` SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"` + SendTTS ToolConfig `json:"send_tts" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_TTS_"` Spawn ToolConfig `json:"spawn" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_"` SpawnStatus ToolConfig `json:"spawn_status" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_STATUS_"` SPI ToolConfig `json:"spi" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPI_"` @@ -1281,6 +1282,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool { return t.WebFetch.Enabled case "send_file": return t.SendFile.Enabled + case "send_tts": + return t.SendTTS.Enabled case "write_file": return t.WriteFile.Enabled case "mcp": diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index c3845e3e2..6eac5d8b9 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -434,6 +434,9 @@ func DefaultConfig() *Config { SendFile: ToolConfig{ Enabled: true, }, + SendTTS: ToolConfig{ + Enabled: false, + }, MCP: MCPConfig{ ToolConfig: ToolConfig{ Enabled: false, diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 64aed5e8c..b5e8c1f36 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" "path/filepath" + "sort" "strings" "sync" "sync/atomic" @@ -13,6 +14,8 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/audio/asr" + "github.com/sipeed/picoclaw/pkg/audio/tts" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk" @@ -41,7 +44,6 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" - "github.com/sipeed/picoclaw/pkg/voice" ) const ( @@ -61,6 +63,7 @@ type services struct { ChannelManager *channels.Manager DeviceService *devices.Service HealthServer *health.Server + VoiceAgentCancel context.CancelFunc manualReloadChan chan struct{} reloading atomic.Bool authToken string @@ -70,6 +73,27 @@ type startupBlockedProvider struct { reason string } +func logChannelVoiceCapabilities(cm *channels.Manager, asrAvailable bool, ttsAvailable bool) { + if cm == nil { + return + } + + names := cm.GetEnabledChannels() + sort.Strings(names) + for _, name := range names { + ch, ok := cm.GetChannel(name) + if !ok { + continue + } + caps := channels.DetectVoiceCapabilities(name, ch, asrAvailable, ttsAvailable) + logger.InfoCF("voice", "Channel voice capabilities", map[string]any{ + "channel": name, + "asr": caps.ASR, + "tts": caps.TTS, + }) + } +} + func (p *startupBlockedProvider) Chat( _ context.Context, _ []providers.Message, @@ -337,11 +361,14 @@ func setupAndStartServices( agentLoop.SetChannelManager(runningServices.ChannelManager) agentLoop.SetMediaStore(runningServices.MediaStore) - if transcriber := voice.DetectTranscriber(cfg); transcriber != nil { + transcriber := asr.DetectTranscriber(cfg) + if transcriber != nil { agentLoop.SetTranscriber(transcriber) logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()}) } + ttsAvailable := tts.DetectTTS(cfg) != nil + enabledChannels := runningServices.ChannelManager.GetEnabledChannels() if len(enabledChannels) > 0 { fmt.Printf("✓ Channels enabled: %s\n", enabledChannels) @@ -358,6 +385,16 @@ func setupAndStartServices( return nil, fmt.Errorf("error starting channels: %w", err) } + logChannelVoiceCapabilities(runningServices.ChannelManager, transcriber != nil, ttsAvailable) + + if transcriber != nil { + // Start Voice Agent Orchestrator after channels are ready. + vaCtx, vaCancel := context.WithCancel(context.Background()) + runningServices.VoiceAgentCancel = vaCancel + voiceAgent := asr.NewAgent(msgBus, transcriber) + voiceAgent.Start(vaCtx) + } + fmt.Printf( "✓ Health endpoints available at http://%s:%d/health, /ready and /reload (POST)\n", cfg.Gateway.Host, @@ -387,6 +424,9 @@ func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Dura if !isReload && runningServices.ChannelManager != nil { runningServices.ChannelManager.StopAll(shutdownCtx) } + if runningServices.VoiceAgentCancel != nil { + runningServices.VoiceAgentCancel() + } if runningServices.DeviceService != nil { runningServices.DeviceService.Stop() } @@ -563,14 +603,22 @@ func restartServices( fmt.Println(" ✓ Device event service restarted") } - transcriber := voice.DetectTranscriber(cfg) + transcriber := asr.DetectTranscriber(cfg) al.SetTranscriber(transcriber) if transcriber != nil { logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()}) + + // Start Voice Agent Orchestrator on reload + vaCtx, vaCancel := context.WithCancel(context.Background()) + runningServices.VoiceAgentCancel = vaCancel + voiceAgent := asr.NewAgent(msgBus, transcriber) + voiceAgent.Start(vaCtx) } else { logger.InfoCF("voice", "Transcription disabled", nil) } + ttsAvailable := tts.DetectTTS(cfg) != nil + logChannelVoiceCapabilities(runningServices.ChannelManager, transcriber != nil, ttsAvailable) // NOTE: PID file is written once at startup and not updated on reload. // Changing the gateway listen address requires a full restart. diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index e956db209..16b2ead10 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -98,6 +98,19 @@ func ExtractProtocol(model string) (protocol, modelID string) { return protocol, modelID } +// ResolveAPIBase returns the configured API base, or the protocol default when +// the model uses an HTTP-based provider family with a known default endpoint. +func ResolveAPIBase(cfg *config.ModelConfig) string { + if cfg == nil { + return "" + } + if apiBase := strings.TrimSpace(cfg.APIBase); apiBase != "" { + return strings.TrimRight(apiBase, "/") + } + protocol, _ := ExtractProtocol(cfg.Model) + return strings.TrimRight(getDefaultAPIBase(protocol), "/") +} + // CreateProviderFromConfig creates a provider based on the ModelConfig. // It uses the protocol prefix in the Model field to determine which provider to create. // Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini), diff --git a/pkg/tools/tts_send.go b/pkg/tools/tts_send.go new file mode 100644 index 000000000..3d569e3f7 --- /dev/null +++ b/pkg/tools/tts_send.go @@ -0,0 +1,82 @@ +package tools + +import ( + "context" + "strings" + + "github.com/sipeed/picoclaw/pkg/audio/tts" + "github.com/sipeed/picoclaw/pkg/media" +) + +type SendTTSTool struct { + provider tts.TTSProvider + mediaStore media.MediaStore +} + +func NewSendTTSTool(provider tts.TTSProvider, store media.MediaStore) *SendTTSTool { + return &SendTTSTool{ + provider: provider, + mediaStore: store, + } +} + +func (t *SendTTSTool) Name() string { return "send_tts" } + +func (t *SendTTSTool) Description() string { + return "Synthesize speech from text and send it as an audio file to the user." +} + +func (t *SendTTSTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "text": map[string]any{ + "type": "string", + "description": "The text to synthesize into speech. NOTE: Reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally.", + }, + "filename": map[string]any{ + "type": "string", + "description": "Optional filename for the audio file (e.g., response.ogg).", + }, + }, + "required": []string{"text"}, + } +} + +func (t *SendTTSTool) SetMediaStore(store media.MediaStore) { + t.mediaStore = store +} + +func (t *SendTTSTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + text, _ := args["text"].(string) + text = strings.TrimSpace(text) + if text == "" { + return ErrorResult("text is required") + } + + channel := ToolChannel(ctx) + chatID := ToolChatID(ctx) + filename, _ := args["filename"].(string) + + ref, err := tts.SynthesizeAndStore( + ctx, + t.provider, + t.mediaStore, + text, + filename, + channel, + chatID, + ) + if err != nil { + return ErrorResult(err.Error()).WithError(err) + } + + // Return with ForUser set to original text, Media containing the audio ref, + // and mark as ResponseHandled so the audio is sent immediately without LLM intervention. + return &ToolResult{ + ForLLM: "TTS audio sent", + ForUser: text, + Media: []string{ref}, + ResponseHandled: true, + } +} diff --git a/pkg/voice/groq_transcriber.go b/pkg/voice/groq_transcriber.go deleted file mode 100644 index b42e598f7..000000000 --- a/pkg/voice/groq_transcriber.go +++ /dev/null @@ -1,151 +0,0 @@ -package voice - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "mime/multipart" - "net/http" - "os" - "path/filepath" - "time" - - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -type GroqTranscriber struct { - apiKey string - apiBase string - httpClient *http.Client -} - -func NewGroqTranscriber(apiKey string) *GroqTranscriber { - logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""}) - - apiBase := "https://api.groq.com/openai/v1" - return &GroqTranscriber{ - apiKey: apiKey, - apiBase: apiBase, - httpClient: &http.Client{ - Timeout: 60 * time.Second, - }, - } -} - -func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) { - logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath}) - - audioFile, err := os.Open(audioFilePath) - if err != nil { - logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err}) - return nil, fmt.Errorf("failed to open audio file: %w", err) - } - defer audioFile.Close() - - fileInfo, err := audioFile.Stat() - if err != nil { - logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err}) - return nil, fmt.Errorf("failed to get file info: %w", err) - } - - logger.DebugCF("voice", "Audio file details", map[string]any{ - "size_bytes": fileInfo.Size(), - "file_name": filepath.Base(audioFilePath), - }) - - var requestBody bytes.Buffer - writer := multipart.NewWriter(&requestBody) - - part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath)) - if err != nil { - logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to create form file: %w", err) - } - - copied, err := io.Copy(part, audioFile) - if err != nil { - logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to copy file content: %w", err) - } - - logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied}) - - if err = writer.WriteField("model", "whisper-large-v3"); err != nil { - logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to write model field: %w", err) - } - - if err = writer.WriteField("response_format", "json"); err != nil { - logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to write response_format field: %w", err) - } - - if err = writer.Close(); err != nil { - logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to close multipart writer: %w", err) - } - - url := t.apiBase + "/audio/transcriptions" - req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody) - if err != nil { - logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", writer.FormDataContentType()) - req.Header.Set("Authorization", "Bearer "+t.apiKey) - - logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{ - "url": url, - "request_size_bytes": requestBody.Len(), - "file_size_bytes": fileInfo.Size(), - }) - - resp, err := t.httpClient.Do(req) - if err != nil { - logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - logger.ErrorCF("voice", "API error", map[string]any{ - "status_code": resp.StatusCode, - "response": string(body), - }) - return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) - } - - logger.DebugCF("voice", "Received response from Groq API", map[string]any{ - "status_code": resp.StatusCode, - "response_size_bytes": len(body), - }) - - var result TranscriptionResponse - if err := json.Unmarshal(body, &result); err != nil { - logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - logger.InfoCF("voice", "Transcription completed successfully", map[string]any{ - "text_length": len(result.Text), - "language": result.Language, - "duration_seconds": result.Duration, - "transcription_preview": utils.Truncate(result.Text, 50), - }) - - return &result, nil -} - -func (t *GroqTranscriber) Name() string { - return "groq" -} diff --git a/pkg/voice/groq_transcriber_test.go b/pkg/voice/groq_transcriber_test.go deleted file mode 100644 index fdcaa7580..000000000 --- a/pkg/voice/groq_transcriber_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package voice - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" -) - -var _ Transcriber = (*GroqTranscriber)(nil) - -func TestGroqTranscriberName(t *testing.T) { - tr := NewGroqTranscriber("sk-test") - if got := tr.Name(); got != "groq" { - t.Errorf("Name() = %q, want %q", got, "groq") - } -} - -func TestGroqTranscribe(t *testing.T) { - // Write a minimal fake audio file so the transcriber can open and send it. - tmpDir := t.TempDir() - audioPath := filepath.Join(tmpDir, "clip.ogg") - if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil { - t.Fatalf("failed to write fake audio file: %v", err) - } - - t.Run("success", func(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/audio/transcriptions" { - t.Errorf("unexpected path: %s", r.URL.Path) - } - if r.Header.Get("Authorization") != "Bearer sk-test" { - t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization")) - } - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(TranscriptionResponse{ - Text: "hello world", - Language: "en", - Duration: 1.5, - }) - })) - defer srv.Close() - - tr := NewGroqTranscriber("sk-test") - tr.apiBase = srv.URL - - resp, err := tr.Transcribe(context.Background(), audioPath) - if err != nil { - t.Fatalf("Transcribe() error: %v", err) - } - if resp.Text != "hello world" { - t.Errorf("Text = %q, want %q", resp.Text, "hello world") - } - if resp.Language != "en" { - t.Errorf("Language = %q, want %q", resp.Language, "en") - } - }) - - t.Run("api error", func(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized) - })) - defer srv.Close() - - tr := NewGroqTranscriber("sk-bad") - tr.apiBase = srv.URL - - _, err := tr.Transcribe(context.Background(), audioPath) - if err == nil { - t.Fatal("expected error for non-200 response, got nil") - } - }) - - t.Run("missing file", func(t *testing.T) { - tr := NewGroqTranscriber("sk-test") - _, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg")) - if err == nil { - t.Fatal("expected error for missing file, got nil") - } - }) -} diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go deleted file mode 100644 index f56fdeedd..000000000 --- a/pkg/voice/transcriber.go +++ /dev/null @@ -1,68 +0,0 @@ -package voice - -import ( - "context" - "strings" - - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/providers" -) - -type Transcriber interface { - Name() string - Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) -} - -type TranscriptionResponse struct { - Text string `json:"text"` - Language string `json:"language,omitempty"` - Duration float64 `json:"duration,omitempty"` -} - -func supportsAudioTranscription(model string) bool { - protocol, _ := providers.ExtractProtocol(model) - - switch protocol { - case "openai", "azure", "azure-openai", - "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", - "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", - "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl", - "qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita", - "coding-plan", "alibaba-coding", "qwen-coding": - // These protocols all go through the OpenAI-compatible or Azure provider path in - // providers.CreateProviderFromConfig, so they are the only ones that can supply - // the audio media payload shape expected by NewAudioModelTranscriber. - - // TODO: Further restrict this by modelID, since not every model under these - // protocols supports audio transcription. - return true - default: - return false - } -} - -// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or -// nil if no supported transcription provider is configured. -func DetectTranscriber(cfg *config.Config) Transcriber { - if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" { - modelCfg, err := cfg.GetModelConfig(modelName) - if err != nil { - return nil - } - if supportsAudioTranscription(modelCfg.Model) { - return NewAudioModelTranscriber(modelCfg) - } - } - - // ElevenLabs voice config (supports Scribe STT). - if key := strings.TrimSpace(cfg.Voice.ElevenLabsAPIKey); key != "" { - return NewElevenLabsTranscriber(key) - } - // Fall back to any model-list entry that uses the groq/ protocol. - for _, mc := range cfg.ModelList { - if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey() != "" { - return NewGroqTranscriber(mc.APIKey()) - } - } - return nil -}