mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(channels): unify Start/Stop lifecycle and fix goroutine/context leaks
- OneBot: remove close(ch) race in Stop() pending cleanup; add WriteDeadline to Send/sendAPIRequest - Telegram: add cancelCtx; Stop() now calls bh.Stop(), cancel(), and cleans up thinking CancelFuncs - Discord: add cancelCtx via WithCancel; Stop() calls cancel(); remove unused getContext() - WhatsApp: add cancelCtx; Send() adds WriteDeadline; replace stdlib log with project logger - MaixCam: add cancelCtx; Send() adds WriteDeadline; Stop() calls cancel() before closing
This commit is contained in:
@@ -29,6 +29,7 @@ type DiscordChannel struct {
|
||||
config config.DiscordConfig
|
||||
transcriber *voice.GroqTranscriber
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
typingMu sync.Mutex
|
||||
typingStop map[string]chan struct{} // chatID → stop signal
|
||||
botUserID string // stored for mention checking
|
||||
@@ -56,17 +57,10 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
c.transcriber = transcriber
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) getContext() context.Context {
|
||||
if c.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
return c.ctx
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("discord", "Starting Discord bot")
|
||||
|
||||
c.ctx = ctx
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Get bot user ID before opening session to avoid race condition
|
||||
botUser, err := c.session.User("@me")
|
||||
@@ -103,6 +97,11 @@ func (c *DiscordChannel) Stop(ctx context.Context) error {
|
||||
}
|
||||
c.typingMu.Unlock()
|
||||
|
||||
// Cancel our context so typing goroutines using c.ctx.Done() exit
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
if err := c.session.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close discord session: %w", err)
|
||||
}
|
||||
@@ -236,7 +235,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
|
||||
transcribedText := ""
|
||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
|
||||
ctx, cancel := context.WithTimeout(c.ctx, transcriptionTimeout)
|
||||
result, err := c.transcriber.Transcribe(ctx, localPath)
|
||||
cancel() // Release context resources immediately to avoid leaks in for loop
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
@@ -17,6 +18,8 @@ type MaixCamChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.MaixCamConfig
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
clients map[net.Conn]bool
|
||||
clientsMux sync.RWMutex
|
||||
}
|
||||
@@ -41,9 +44,12 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC
|
||||
func (c *MaixCamChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("maixcam", "Starting MaixCam channel server")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
return fmt.Errorf("failed to listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
@@ -55,17 +61,17 @@ func (c *MaixCamChannel) Start(ctx context.Context) error {
|
||||
"port": c.config.Port,
|
||||
})
|
||||
|
||||
go c.acceptConnections(ctx)
|
||||
go c.acceptConnections()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) acceptConnections(ctx context.Context) {
|
||||
func (c *MaixCamChannel) acceptConnections() {
|
||||
logger.DebugC("maixcam", "Starting connection acceptor")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-c.ctx.Done():
|
||||
logger.InfoC("maixcam", "Stopping connection acceptor")
|
||||
return
|
||||
default:
|
||||
@@ -87,12 +93,12 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) {
|
||||
c.clients[conn] = true
|
||||
c.clientsMux.Unlock()
|
||||
|
||||
go c.handleConnection(conn, ctx)
|
||||
go c.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) {
|
||||
func (c *MaixCamChannel) handleConnection(conn net.Conn) {
|
||||
logger.DebugC("maixcam", "Handling MaixCam connection")
|
||||
|
||||
defer func() {
|
||||
@@ -107,7 +113,7 @@ func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
var msg MaixCamMessage
|
||||
@@ -186,6 +192,11 @@ func (c *MaixCamChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("maixcam", "Stopping MaixCam channel")
|
||||
c.SetRunning(false)
|
||||
|
||||
// Cancel context first to signal goroutines to exit
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
if c.listener != nil {
|
||||
c.listener.Close()
|
||||
}
|
||||
@@ -229,6 +240,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
|
||||
var sendErr error
|
||||
for conn := range c.clients {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{
|
||||
"client": conn.RemoteAddr().String(),
|
||||
@@ -236,6 +248,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
})
|
||||
sendErr = err
|
||||
}
|
||||
_ = conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
|
||||
return sendErr
|
||||
|
||||
@@ -298,7 +298,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
err = conn.WriteMessage(websocket.TextMessage, data)
|
||||
_ = conn.SetWriteDeadline(time.Time{})
|
||||
c.writeMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
@@ -354,8 +356,7 @@ func (c *OneBotChannel) Stop(ctx context.Context) error {
|
||||
}
|
||||
|
||||
c.pendingMu.Lock()
|
||||
for echo, ch := range c.pending {
|
||||
close(ch)
|
||||
for echo := range c.pending {
|
||||
delete(c.pending, echo)
|
||||
}
|
||||
c.pendingMu.Unlock()
|
||||
@@ -402,7 +403,9 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
err = conn.WriteMessage(websocket.TextMessage, data)
|
||||
_ = conn.SetWriteDeadline(time.Time{})
|
||||
c.writeMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -27,10 +27,13 @@ import (
|
||||
type TelegramChannel struct {
|
||||
*channels.BaseChannel
|
||||
bot *telego.Bot
|
||||
bh *telegohandler.BotHandler
|
||||
commands TelegramCommander
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
transcriber *voice.GroqTranscriber
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
placeholders sync.Map // chatID -> messageID
|
||||
stopThinking sync.Map // chatID -> thinkingCancel
|
||||
}
|
||||
@@ -94,17 +97,22 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
|
||||
|
||||
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
|
||||
Timeout: 30,
|
||||
})
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
return fmt.Errorf("failed to start long polling: %w", err)
|
||||
}
|
||||
|
||||
bh, err := telegohandler.NewBotHandler(c.bot, updates)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
return fmt.Errorf("failed to create bot handler: %w", err)
|
||||
}
|
||||
c.bh = bh
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
c.commands.Help(ctx, message)
|
||||
@@ -133,17 +141,32 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
|
||||
go bh.Start()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
bh.Stop()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("telegram", "Stopping Telegram bot...")
|
||||
c.SetRunning(false)
|
||||
|
||||
// Clean up all thinking cancel functions to avoid context leaks
|
||||
c.stopThinking.Range(func(key, value any) bool {
|
||||
if cf, ok := value.(*thinkingCancel); ok && cf != nil {
|
||||
cf.Cancel()
|
||||
}
|
||||
c.stopThinking.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
// Stop the bot handler
|
||||
if c.bh != nil {
|
||||
c.bh.Stop()
|
||||
}
|
||||
|
||||
// Cancel our context (stops long polling)
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +12,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,8 @@ type WhatsAppChannel struct {
|
||||
conn *websocket.Conn
|
||||
config config.WhatsAppConfig
|
||||
url string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.Mutex
|
||||
connected bool
|
||||
}
|
||||
@@ -37,13 +39,18 @@ func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsA
|
||||
}
|
||||
|
||||
func (c *WhatsAppChannel) Start(ctx context.Context) error {
|
||||
log.Printf("Starting WhatsApp channel connecting to %s...", c.url)
|
||||
logger.InfoCF("whatsapp", "Starting WhatsApp channel", map[string]any{
|
||||
"bridge_url": c.url,
|
||||
})
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
dialer := websocket.DefaultDialer
|
||||
dialer.HandshakeTimeout = 10 * time.Second
|
||||
|
||||
conn, _, err := dialer.Dial(c.url, nil)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err)
|
||||
}
|
||||
|
||||
@@ -53,22 +60,29 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error {
|
||||
c.mu.Unlock()
|
||||
|
||||
c.SetRunning(true)
|
||||
log.Println("WhatsApp channel connected")
|
||||
logger.InfoC("whatsapp", "WhatsApp channel connected")
|
||||
|
||||
go c.listen(ctx)
|
||||
go c.listen()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WhatsAppChannel) Stop(ctx context.Context) error {
|
||||
log.Println("Stopping WhatsApp channel...")
|
||||
logger.InfoC("whatsapp", "Stopping WhatsApp channel...")
|
||||
|
||||
// Cancel context first to signal listen goroutine to exit
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
log.Printf("Error closing WhatsApp connection: %v", err)
|
||||
logger.ErrorCF("whatsapp", "Error closing WhatsApp connection", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
c.conn = nil
|
||||
}
|
||||
@@ -98,17 +112,20 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
_ = c.conn.SetWriteDeadline(time.Time{})
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
_ = c.conn.SetWriteDeadline(time.Time{})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WhatsAppChannel) listen(ctx context.Context) {
|
||||
func (c *WhatsAppChannel) listen() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
c.mu.Lock()
|
||||
@@ -122,14 +139,18 @@ func (c *WhatsAppChannel) listen(ctx context.Context) {
|
||||
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
log.Printf("WhatsApp read error: %v", err)
|
||||
logger.ErrorCF("whatsapp", "WhatsApp read error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
var msg map[string]any
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
log.Printf("Failed to unmarshal WhatsApp message: %v", err)
|
||||
logger.ErrorCF("whatsapp", "Failed to unmarshal WhatsApp message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -187,7 +208,10 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
}
|
||||
|
||||
log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
|
||||
logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{
|
||||
"sender": senderID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
c.HandleMessage(peer, messageID, senderID, chatID, content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user