refactor(channels): unify Start/Stop lifecycle and fix goroutine/context leaks

- OneBot: remove close(ch) race in Stop() pending cleanup; add WriteDeadline to Send/sendAPIRequest
- Telegram: add cancelCtx; Stop() now calls bh.Stop(), cancel(), and cleans up thinking CancelFuncs
- Discord: add cancelCtx via WithCancel; Stop() calls cancel(); remove unused getContext()
- WhatsApp: add cancelCtx; Send() adds WriteDeadline; replace stdlib log with project logger
- MaixCam: add cancelCtx; Send() adds WriteDeadline; Stop() calls cancel() before closing
This commit is contained in:
Hoshina
2026-02-22 22:25:07 +08:00
parent 153198e0f3
commit b6161aec3f
5 changed files with 96 additions and 34 deletions
+8 -9
View File
@@ -29,6 +29,7 @@ type DiscordChannel struct {
config config.DiscordConfig
transcriber *voice.GroqTranscriber
ctx context.Context
cancel context.CancelFunc
typingMu sync.Mutex
typingStop map[string]chan struct{} // chatID → stop signal
botUserID string // stored for mention checking
@@ -56,17 +57,10 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *DiscordChannel) getContext() context.Context {
if c.ctx == nil {
return context.Background()
}
return c.ctx
}
func (c *DiscordChannel) Start(ctx context.Context) error {
logger.InfoC("discord", "Starting Discord bot")
c.ctx = ctx
c.ctx, c.cancel = context.WithCancel(ctx)
// Get bot user ID before opening session to avoid race condition
botUser, err := c.session.User("@me")
@@ -103,6 +97,11 @@ func (c *DiscordChannel) Stop(ctx context.Context) error {
}
c.typingMu.Unlock()
// Cancel our context so typing goroutines using c.ctx.Done() exit
if c.cancel != nil {
c.cancel()
}
if err := c.session.Close(); err != nil {
return fmt.Errorf("failed to close discord session: %w", err)
}
@@ -236,7 +235,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
ctx, cancel := context.WithTimeout(c.ctx, transcriptionTimeout)
result, err := c.transcriber.Transcribe(ctx, localPath)
cancel() // Release context resources immediately to avoid leaks in for loop
+19 -6
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -17,6 +18,8 @@ type MaixCamChannel struct {
*channels.BaseChannel
config config.MaixCamConfig
listener net.Listener
ctx context.Context
cancel context.CancelFunc
clients map[net.Conn]bool
clientsMux sync.RWMutex
}
@@ -41,9 +44,12 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC
func (c *MaixCamChannel) Start(ctx context.Context) error {
logger.InfoC("maixcam", "Starting MaixCam channel server")
c.ctx, c.cancel = context.WithCancel(ctx)
addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
listener, err := net.Listen("tcp", addr)
if err != nil {
c.cancel()
return fmt.Errorf("failed to listen on %s: %w", addr, err)
}
@@ -55,17 +61,17 @@ func (c *MaixCamChannel) Start(ctx context.Context) error {
"port": c.config.Port,
})
go c.acceptConnections(ctx)
go c.acceptConnections()
return nil
}
func (c *MaixCamChannel) acceptConnections(ctx context.Context) {
func (c *MaixCamChannel) acceptConnections() {
logger.DebugC("maixcam", "Starting connection acceptor")
for {
select {
case <-ctx.Done():
case <-c.ctx.Done():
logger.InfoC("maixcam", "Stopping connection acceptor")
return
default:
@@ -87,12 +93,12 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) {
c.clients[conn] = true
c.clientsMux.Unlock()
go c.handleConnection(conn, ctx)
go c.handleConnection(conn)
}
}
}
func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) {
func (c *MaixCamChannel) handleConnection(conn net.Conn) {
logger.DebugC("maixcam", "Handling MaixCam connection")
defer func() {
@@ -107,7 +113,7 @@ func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) {
for {
select {
case <-ctx.Done():
case <-c.ctx.Done():
return
default:
var msg MaixCamMessage
@@ -186,6 +192,11 @@ func (c *MaixCamChannel) Stop(ctx context.Context) error {
logger.InfoC("maixcam", "Stopping MaixCam channel")
c.SetRunning(false)
// Cancel context first to signal goroutines to exit
if c.cancel != nil {
c.cancel()
}
if c.listener != nil {
c.listener.Close()
}
@@ -229,6 +240,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
var sendErr error
for conn := range c.clients {
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err := conn.Write(data); err != nil {
logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{
"client": conn.RemoteAddr().String(),
@@ -236,6 +248,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
})
sendErr = err
}
_ = conn.SetWriteDeadline(time.Time{})
}
return sendErr
+5 -2
View File
@@ -298,7 +298,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D
}
c.writeMu.Lock()
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
err = conn.WriteMessage(websocket.TextMessage, data)
_ = conn.SetWriteDeadline(time.Time{})
c.writeMu.Unlock()
if err != nil {
@@ -354,8 +356,7 @@ func (c *OneBotChannel) Stop(ctx context.Context) error {
}
c.pendingMu.Lock()
for echo, ch := range c.pending {
close(ch)
for echo := range c.pending {
delete(c.pending, echo)
}
c.pendingMu.Unlock()
@@ -402,7 +403,9 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
}
c.writeMu.Lock()
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
err = conn.WriteMessage(websocket.TextMessage, data)
_ = conn.SetWriteDeadline(time.Time{})
c.writeMu.Unlock()
if err != nil {
+29 -6
View File
@@ -27,10 +27,13 @@ import (
type TelegramChannel struct {
*channels.BaseChannel
bot *telego.Bot
bh *telegohandler.BotHandler
commands TelegramCommander
config *config.Config
chatIDs map[string]int64
transcriber *voice.GroqTranscriber
ctx context.Context
cancel context.CancelFunc
placeholders sync.Map // chatID -> messageID
stopThinking sync.Map // chatID -> thinkingCancel
}
@@ -94,17 +97,22 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
func (c *TelegramChannel) Start(ctx context.Context) error {
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
c.ctx, c.cancel = context.WithCancel(ctx)
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
Timeout: 30,
})
if err != nil {
c.cancel()
return fmt.Errorf("failed to start long polling: %w", err)
}
bh, err := telegohandler.NewBotHandler(c.bot, updates)
if err != nil {
c.cancel()
return fmt.Errorf("failed to create bot handler: %w", err)
}
c.bh = bh
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
c.commands.Help(ctx, message)
@@ -133,17 +141,32 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
go bh.Start()
go func() {
<-ctx.Done()
bh.Stop()
}()
return nil
}
func (c *TelegramChannel) Stop(ctx context.Context) error {
logger.InfoC("telegram", "Stopping Telegram bot...")
c.SetRunning(false)
// Clean up all thinking cancel functions to avoid context leaks
c.stopThinking.Range(func(key, value any) bool {
if cf, ok := value.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
c.stopThinking.Delete(key)
return true
})
// Stop the bot handler
if c.bh != nil {
c.bh.Stop()
}
// Cancel our context (stops long polling)
if c.cancel != nil {
c.cancel()
}
return nil
}
+35 -11
View File
@@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"time"
@@ -13,6 +12,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -21,6 +21,8 @@ type WhatsAppChannel struct {
conn *websocket.Conn
config config.WhatsAppConfig
url string
ctx context.Context
cancel context.CancelFunc
mu sync.Mutex
connected bool
}
@@ -37,13 +39,18 @@ func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsA
}
func (c *WhatsAppChannel) Start(ctx context.Context) error {
log.Printf("Starting WhatsApp channel connecting to %s...", c.url)
logger.InfoCF("whatsapp", "Starting WhatsApp channel", map[string]any{
"bridge_url": c.url,
})
c.ctx, c.cancel = context.WithCancel(ctx)
dialer := websocket.DefaultDialer
dialer.HandshakeTimeout = 10 * time.Second
conn, _, err := dialer.Dial(c.url, nil)
if err != nil {
c.cancel()
return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err)
}
@@ -53,22 +60,29 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error {
c.mu.Unlock()
c.SetRunning(true)
log.Println("WhatsApp channel connected")
logger.InfoC("whatsapp", "WhatsApp channel connected")
go c.listen(ctx)
go c.listen()
return nil
}
func (c *WhatsAppChannel) Stop(ctx context.Context) error {
log.Println("Stopping WhatsApp channel...")
logger.InfoC("whatsapp", "Stopping WhatsApp channel...")
// Cancel context first to signal listen goroutine to exit
if c.cancel != nil {
c.cancel()
}
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Printf("Error closing WhatsApp connection: %v", err)
logger.ErrorCF("whatsapp", "Error closing WhatsApp connection", map[string]any{
"error": err.Error(),
})
}
c.conn = nil
}
@@ -98,17 +112,20 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("failed to marshal message: %w", err)
}
_ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
_ = c.conn.SetWriteDeadline(time.Time{})
return fmt.Errorf("failed to send message: %w", err)
}
_ = c.conn.SetWriteDeadline(time.Time{})
return nil
}
func (c *WhatsAppChannel) listen(ctx context.Context) {
func (c *WhatsAppChannel) listen() {
for {
select {
case <-ctx.Done():
case <-c.ctx.Done():
return
default:
c.mu.Lock()
@@ -122,14 +139,18 @@ func (c *WhatsAppChannel) listen(ctx context.Context) {
_, message, err := conn.ReadMessage()
if err != nil {
log.Printf("WhatsApp read error: %v", err)
logger.ErrorCF("whatsapp", "WhatsApp read error", map[string]any{
"error": err.Error(),
})
time.Sleep(2 * time.Second)
continue
}
var msg map[string]any
if err := json.Unmarshal(message, &msg); err != nil {
log.Printf("Failed to unmarshal WhatsApp message: %v", err)
logger.ErrorCF("whatsapp", "Failed to unmarshal WhatsApp message", map[string]any{
"error": err.Error(),
})
continue
}
@@ -187,7 +208,10 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
peer = bus.Peer{Kind: "group", ID: chatID}
}
log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{
"sender": senderID,
"preview": utils.Truncate(content, 50),
})
c.HandleMessage(peer, messageID, senderID, chatID, content, mediaPaths, metadata)
}