feat(wecom-aibot): add context management for stream tasks to improve agent cancellation

This commit is contained in:
Zhang Rui
2026-02-28 15:38:49 +08:00
parent 0b6d913dfc
commit 4e09c91dda
+23 -6
View File
@@ -48,7 +48,9 @@ type streamTask struct {
StreamClosed bool // stream returned finish:true; waiting for agent to reply via response_url
Finished bool // fully done
mu sync.Mutex
answerCh chan string // receives agent reply from Send()
answerCh chan string // receives agent reply from Send()
ctx context.Context // canceled when task is removed; used to interrupt the agent goroutine
cancel context.CancelFunc // call on task removal to cancel ctx
}
// WeComAIBotMessage represents the decrypted JSON message from WeCom AI Bot
@@ -109,7 +111,7 @@ type WeComAIBotStreamInfo struct {
// WeComAIBotStreamResponse represents the streaming response format
type WeComAIBotStreamResponse struct {
MsgType string `json:"msgtype"`
MsgType string `json:"msgtype"`
Stream WeComAIBotStreamInfo `json:"stream"`
}
@@ -237,6 +239,9 @@ func (c *WeComAIBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) e
// Stream still open: deliver via answerCh for the next poll response.
select {
case task.answerCh <- msg.Content:
case <-task.ctx.Done():
// Task was canceled (cleanup removed it); silently drop the reply.
return nil
case <-ctx.Done():
return ctx.Err()
}
@@ -490,6 +495,10 @@ func (c *WeComAIBotChannel) handleTextMessage(
// Set a slightly shorter deadline so we can send a timeout notice before it gives up.
deadline := time.Now().Add(30 * time.Second)
// Each task gets its own context derived from the channel lifetime context.
// Canceling taskCancel interrupts the agent goroutine when the task is removed.
taskCtx, taskCancel := context.WithCancel(c.ctx)
task := &streamTask{
StreamID: streamID,
ChatID: chatID,
@@ -499,6 +508,8 @@ func (c *WeComAIBotChannel) handleTextMessage(
Deadline: deadline,
Finished: false,
answerCh: make(chan string, 1),
ctx: taskCtx,
cancel: taskCancel,
}
c.taskMu.Lock()
@@ -506,8 +517,8 @@ func (c *WeComAIBotChannel) handleTextMessage(
c.chatTasks[chatID] = append(c.chatTasks[chatID], task)
c.taskMu.Unlock()
// Publish to agent asynchronously; agent will call Send() with reply
// Use c.ctx (channel lifetime) instead of r.Context() which is canceled when the HTTP handler returns.
// Publish to agent asynchronously; agent will call Send() with reply.
// Use task.ctx (not c.ctx) so the agent goroutine is canceled when the task is removed.
go func() {
sender := bus.SenderInfo{
Platform: "wecom_aibot",
@@ -529,7 +540,7 @@ func (c *WeComAIBotChannel) handleTextMessage(
"stream_id": streamID,
"response_url": msg.ResponseURL,
}
c.HandleMessage(c.ctx, peer, msg.MsgID, userID, chatID,
c.HandleMessage(task.ctx, peer, msg.MsgID, userID, chatID,
content, nil, metadata, sender)
}()
@@ -800,11 +811,13 @@ func (c *WeComAIBotChannel) getStreamResponse(task *streamTask, timestamp, nonce
return c.encryptResponse(task.StreamID, timestamp, nonce, response)
}
// removeTask removes a task from both streamTasks and chatTasks and marks it finished.
// removeTask removes a task from both streamTasks and chatTasks, marks it finished,
// and cancels its context to interrupt the associated agent goroutine.
func (c *WeComAIBotChannel) removeTask(task *streamTask) {
task.mu.Lock()
task.Finished = true
task.mu.Unlock()
task.cancel() // interrupt agent goroutine bound to this task
c.taskMu.Lock()
delete(c.streamTasks, task.StreamID)
@@ -1114,6 +1127,7 @@ func (c *WeComAIBotChannel) cleanupOldTasks() {
for id, task := range c.streamTasks {
if task.CreatedTime.Before(cutoff) {
delete(c.streamTasks, id)
task.cancel() // interrupt agent goroutine still waiting for LLM
queue := c.chatTasks[task.ChatID]
for i, t := range queue {
if t == task {
@@ -1130,11 +1144,14 @@ func (c *WeComAIBotChannel) cleanupOldTasks() {
}
}
// Also clean up StreamClosed tasks from chatTasks that are older than 1 hour.
// These were removed from streamTasks earlier but kept alive for response_url delivery.
for chatID, queue := range c.chatTasks {
filtered := queue[:0]
for _, t := range queue {
if !t.Finished && t.CreatedTime.After(cutoff) {
filtered = append(filtered, t)
} else if !t.Finished {
t.cancel() // cancel any lingering agent goroutine
}
}
if len(filtered) == 0 {