diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 416a94710..faf1e1358 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -29,6 +29,7 @@ type DiscordChannel struct { config config.DiscordConfig transcriber *voice.GroqTranscriber ctx context.Context + cancel context.CancelFunc typingMu sync.Mutex typingStop map[string]chan struct{} // chatID → stop signal botUserID string // stored for mention checking @@ -56,17 +57,10 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { c.transcriber = transcriber } -func (c *DiscordChannel) getContext() context.Context { - if c.ctx == nil { - return context.Background() - } - return c.ctx -} - func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") - c.ctx = ctx + c.ctx, c.cancel = context.WithCancel(ctx) // Get bot user ID before opening session to avoid race condition botUser, err := c.session.User("@me") @@ -103,6 +97,11 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } c.typingMu.Unlock() + // Cancel our context so typing goroutines using c.ctx.Done() exit + if c.cancel != nil { + c.cancel() + } + if err := c.session.Close(); err != nil { return fmt.Errorf("failed to close discord session: %w", err) } @@ -236,7 +235,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) + ctx, cancel := context.WithTimeout(c.ctx, transcriptionTimeout) result, err := c.transcriber.Transcribe(ctx, localPath) cancel() // Release context resources immediately to avoid leaks in for loop diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index 280098dda..05213b095 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sync" + "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -17,6 +18,8 @@ type MaixCamChannel struct { *channels.BaseChannel config config.MaixCamConfig listener net.Listener + ctx context.Context + cancel context.CancelFunc clients map[net.Conn]bool clientsMux sync.RWMutex } @@ -41,9 +44,12 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC func (c *MaixCamChannel) Start(ctx context.Context) error { logger.InfoC("maixcam", "Starting MaixCam channel server") + c.ctx, c.cancel = context.WithCancel(ctx) + addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) listener, err := net.Listen("tcp", addr) if err != nil { + c.cancel() return fmt.Errorf("failed to listen on %s: %w", addr, err) } @@ -55,17 +61,17 @@ func (c *MaixCamChannel) Start(ctx context.Context) error { "port": c.config.Port, }) - go c.acceptConnections(ctx) + go c.acceptConnections() return nil } -func (c *MaixCamChannel) acceptConnections(ctx context.Context) { +func (c *MaixCamChannel) acceptConnections() { logger.DebugC("maixcam", "Starting connection acceptor") for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): logger.InfoC("maixcam", "Stopping connection acceptor") return default: @@ -87,12 +93,12 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) { c.clients[conn] = true c.clientsMux.Unlock() - go c.handleConnection(conn, ctx) + go c.handleConnection(conn) } } } -func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { +func (c *MaixCamChannel) handleConnection(conn net.Conn) { logger.DebugC("maixcam", "Handling MaixCam connection") defer func() { @@ -107,7 +113,7 @@ func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): return default: var msg MaixCamMessage @@ -186,6 +192,11 @@ func (c *MaixCamChannel) Stop(ctx context.Context) error { logger.InfoC("maixcam", "Stopping MaixCam channel") c.SetRunning(false) + // Cancel context first to signal goroutines to exit + if c.cancel != nil { + c.cancel() + } + if c.listener != nil { c.listener.Close() } @@ -229,6 +240,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro var sendErr error for conn := range c.clients { + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if _, err := conn.Write(data); err != nil { logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{ "client": conn.RemoteAddr().String(), @@ -236,6 +248,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro }) sendErr = err } + _ = conn.SetWriteDeadline(time.Time{}) } return sendErr diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index 642eebd1d..4f35888ca 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -298,7 +298,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D } c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) c.writeMu.Unlock() if err != nil { @@ -354,8 +356,7 @@ func (c *OneBotChannel) Stop(ctx context.Context) error { } c.pendingMu.Lock() - for echo, ch := range c.pending { - close(ch) + for echo := range c.pending { delete(c.pending, echo) } c.pendingMu.Unlock() @@ -402,7 +403,9 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error } c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) c.writeMu.Unlock() if err != nil { diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 5703000b4..af825ddc9 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -27,10 +27,13 @@ import ( type TelegramChannel struct { *channels.BaseChannel bot *telego.Bot + bh *telegohandler.BotHandler commands TelegramCommander config *config.Config chatIDs map[string]int64 transcriber *voice.GroqTranscriber + ctx context.Context + cancel context.CancelFunc placeholders sync.Map // chatID -> messageID stopThinking sync.Map // chatID -> thinkingCancel } @@ -94,17 +97,22 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { func (c *TelegramChannel) Start(ctx context.Context) error { logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") - updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ + c.ctx, c.cancel = context.WithCancel(ctx) + + updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{ Timeout: 30, }) if err != nil { + c.cancel() return fmt.Errorf("failed to start long polling: %w", err) } bh, err := telegohandler.NewBotHandler(c.bot, updates) if err != nil { + c.cancel() return fmt.Errorf("failed to create bot handler: %w", err) } + c.bh = bh bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { c.commands.Help(ctx, message) @@ -133,17 +141,32 @@ func (c *TelegramChannel) Start(ctx context.Context) error { go bh.Start() - go func() { - <-ctx.Done() - bh.Stop() - }() - return nil } func (c *TelegramChannel) Stop(ctx context.Context) error { logger.InfoC("telegram", "Stopping Telegram bot...") c.SetRunning(false) + + // Clean up all thinking cancel functions to avoid context leaks + c.stopThinking.Range(func(key, value any) bool { + if cf, ok := value.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + c.stopThinking.Delete(key) + return true + }) + + // Stop the bot handler + if c.bh != nil { + c.bh.Stop() + } + + // Cancel our context (stops long polling) + if c.cancel != nil { + c.cancel() + } + return nil } diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index 1a5401172..cbc82fd09 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "log" "sync" "time" @@ -13,6 +12,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -21,6 +21,8 @@ type WhatsAppChannel struct { conn *websocket.Conn config config.WhatsAppConfig url string + ctx context.Context + cancel context.CancelFunc mu sync.Mutex connected bool } @@ -37,13 +39,18 @@ func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsA } func (c *WhatsAppChannel) Start(ctx context.Context) error { - log.Printf("Starting WhatsApp channel connecting to %s...", c.url) + logger.InfoCF("whatsapp", "Starting WhatsApp channel", map[string]any{ + "bridge_url": c.url, + }) + + c.ctx, c.cancel = context.WithCancel(ctx) dialer := websocket.DefaultDialer dialer.HandshakeTimeout = 10 * time.Second conn, _, err := dialer.Dial(c.url, nil) if err != nil { + c.cancel() return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err) } @@ -53,22 +60,29 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error { c.mu.Unlock() c.SetRunning(true) - log.Println("WhatsApp channel connected") + logger.InfoC("whatsapp", "WhatsApp channel connected") - go c.listen(ctx) + go c.listen() return nil } func (c *WhatsAppChannel) Stop(ctx context.Context) error { - log.Println("Stopping WhatsApp channel...") + logger.InfoC("whatsapp", "Stopping WhatsApp channel...") + + // Cancel context first to signal listen goroutine to exit + if c.cancel != nil { + c.cancel() + } c.mu.Lock() defer c.mu.Unlock() if c.conn != nil { if err := c.conn.Close(); err != nil { - log.Printf("Error closing WhatsApp connection: %v", err) + logger.ErrorCF("whatsapp", "Error closing WhatsApp connection", map[string]any{ + "error": err.Error(), + }) } c.conn = nil } @@ -98,17 +112,20 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("failed to marshal message: %w", err) } + _ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { + _ = c.conn.SetWriteDeadline(time.Time{}) return fmt.Errorf("failed to send message: %w", err) } + _ = c.conn.SetWriteDeadline(time.Time{}) return nil } -func (c *WhatsAppChannel) listen(ctx context.Context) { +func (c *WhatsAppChannel) listen() { for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): return default: c.mu.Lock() @@ -122,14 +139,18 @@ func (c *WhatsAppChannel) listen(ctx context.Context) { _, message, err := conn.ReadMessage() if err != nil { - log.Printf("WhatsApp read error: %v", err) + logger.ErrorCF("whatsapp", "WhatsApp read error", map[string]any{ + "error": err.Error(), + }) time.Sleep(2 * time.Second) continue } var msg map[string]any if err := json.Unmarshal(message, &msg); err != nil { - log.Printf("Failed to unmarshal WhatsApp message: %v", err) + logger.ErrorCF("whatsapp", "Failed to unmarshal WhatsApp message", map[string]any{ + "error": err.Error(), + }) continue } @@ -187,7 +208,10 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { peer = bus.Peer{Kind: "group", ID: chatID} } - log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) + logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{ + "sender": senderID, + "preview": utils.Truncate(content, 50), + }) c.HandleMessage(peer, messageID, senderID, chatID, content, mediaPaths, metadata) }