fix: address PR review feedback across channel system

- MediaStore: use full UUID to prevent ref collisions, preserve and
  expose metadata via ResolveWithMeta, include underlying OS errors
- Agent loop: populate MediaPart Type/Filename/ContentType from
  MediaStore metadata so channels can dispatch media correctly
- SplitMessage: fix byte-vs-rune index mixup in code block header
  parsing, remove dead candidateStr variable
- Pico auth: restrict query-param token behind AllowTokenQuery config
  flag (default false) to prevent token leakage via logs/referer
- HandleMessage: replace context.TODO with caller-propagated ctx,
  log PublishInbound failures instead of silently discarding
- Gateway shutdown: use fresh 15s timeout context for StopAll so
  graceful shutdown is not short-circuited by the cancelled parent ctx
This commit is contained in:
Hoshina
2026-02-23 06:03:23 +08:00
parent f4b0f080e2
commit db3c1e011f
20 changed files with 187 additions and 47 deletions
+7 -1
View File
@@ -190,7 +190,13 @@ func gatewayCmd() {
fmt.Println("\nShutting down...")
cancel()
msgBus.Close()
channelManager.StopAll(ctx)
// Use a fresh context with timeout for graceful shutdown,
// since the original ctx is already cancelled.
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second)
defer shutdownCancel()
channelManager.StopAll(shutdownCtx)
deviceService.Stop()
heartbeatService.Stop()
cronService.Stop()
+41 -1
View File
@@ -10,6 +10,7 @@ import (
"context"
"encoding/json"
"fmt"
"path/filepath"
"strings"
"sync"
"sync/atomic"
@@ -237,6 +238,36 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
al.mediaStore = s
}
// inferMediaType determines the media type ("image", "audio", "video", "file")
// from a filename and MIME content type.
func inferMediaType(filename, contentType string) string {
ct := strings.ToLower(contentType)
fn := strings.ToLower(filename)
if strings.HasPrefix(ct, "image/") {
return "image"
}
if strings.HasPrefix(ct, "audio/") || ct == "application/ogg" {
return "audio"
}
if strings.HasPrefix(ct, "video/") {
return "video"
}
// Fallback: infer from extension
ext := filepath.Ext(fn)
switch ext {
case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg":
return "image"
case ".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma", ".opus":
return "audio"
case ".mp4", ".avi", ".mov", ".webm", ".mkv":
return "video"
}
return "file"
}
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
@@ -731,7 +762,16 @@ func (al *AgentLoop) runLLMIteration(
if len(toolResult.Media) > 0 && opts.SendResponse {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
parts = append(parts, bus.MediaPart{Ref: ref})
part := bus.MediaPart{Ref: ref}
// Populate metadata from MediaStore when available
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
part.Filename = meta.Filename
part.ContentType = meta.ContentType
part.Type = inferMediaType(meta.Filename, meta.ContentType)
}
}
parts = append(parts, part)
}
al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{
Channel: opts.Channel,
+9 -1
View File
@@ -9,6 +9,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
)
@@ -168,6 +169,7 @@ func (c *BaseChannel) IsAllowed(senderID string) bool {
}
func (c *BaseChannel) HandleMessage(
ctx context.Context,
peer bus.Peer,
messageID, senderID, chatID, content string,
media []string,
@@ -191,7 +193,13 @@ func (c *BaseChannel) HandleMessage(
Metadata: metadata,
}
c.bus.PublishInbound(context.TODO(), msg)
if err := c.bus.PublishInbound(ctx, msg); err != nil {
logger.ErrorCF("channels", "Failed to publish inbound message", map[string]any{
"channel": c.name,
"chat_id": chatID,
"error": err.Error(),
})
}
}
func (c *BaseChannel) SetRunning(running bool) {
+1 -1
View File
@@ -183,7 +183,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
})
// Handle the message through the base channel
c.HandleMessage(peer, "", senderID, chatID, content, nil, metadata)
c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata)
// Return nil to indicate we've handled the message asynchronously
// The response will be sent through the message bus
+1 -1
View File
@@ -381,7 +381,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
}
c.HandleMessage(peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata)
c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata)
}
// startTyping starts a continuous typing indicator loop for the given chatID.
+2 -2
View File
@@ -131,7 +131,7 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
return nil
}
func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2MessageReceiveV1) error {
func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
if event == nil || event.Event == nil || event.Event.Message == nil {
return nil
}
@@ -189,7 +189,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2
"preview": utils.Truncate(content, 80),
})
c.HandleMessage(peer, messageID, senderID, chatID, content, nil, metadata)
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata)
return nil
}
+1 -1
View File
@@ -370,7 +370,7 @@ func (c *LINEChannel) processEvent(event lineEvent) {
// Show typing/loading indicator (requires user ID, not group ID)
c.sendLoading(senderID)
c.HandleMessage(peer, msg.ID, senderID, chatID, content, mediaPaths, metadata)
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, mediaPaths, metadata)
}
// isBotMentioned checks if the bot is mentioned in the message.
+10 -1
View File
@@ -179,7 +179,16 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) {
"h": fmt.Sprintf("%.0f", h),
}
c.HandleMessage(bus.Peer{Kind: "channel", ID: "default"}, "", senderID, chatID, content, []string{}, metadata)
c.HandleMessage(
c.ctx,
bus.Peer{Kind: "channel", ID: "default"},
"",
senderID,
chatID,
content,
[]string{},
metadata,
)
}
func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) {
+1 -1
View File
@@ -1040,7 +1040,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
}
}
c.HandleMessage(peer, messageID, senderID, chatID, content, parsed.Media, metadata)
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata)
}
func (c *OneBotChannel) isDuplicate(messageID string) bool {
+8 -5
View File
@@ -255,7 +255,8 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
go c.readLoop(pc)
}
// authenticate checks the Bearer token from header or query parameter.
// authenticate checks the Bearer token from the Authorization header.
// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled.
func (c *PicoChannel) authenticate(r *http.Request) bool {
token := c.config.Token
if token == "" {
@@ -270,9 +271,11 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
}
}
// Check query parameter
if r.URL.Query().Get("token") == token {
return true
// Check query parameter only when explicitly allowed
if c.config.AllowTokenQuery {
if r.URL.Query().Get("token") == token {
return true
}
}
return false
@@ -417,7 +420,7 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
}
}
c.HandleMessage(peer, msg.ID, senderID, chatID, content, nil, metadata)
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata)
}
// truncate truncates a string to maxLen runes.
+2 -2
View File
@@ -168,7 +168,7 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
// 转发到消息总线
metadata := map[string]string{}
c.HandleMessage(
c.HandleMessage(c.ctx,
bus.Peer{Kind: "direct", ID: senderID},
data.ID,
senderID,
@@ -224,7 +224,7 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
"group_id": data.GroupID,
}
c.HandleMessage(
c.HandleMessage(c.ctx,
bus.Peer{Kind: "group", ID: data.GroupID},
data.ID,
senderID,
+3 -3
View File
@@ -360,7 +360,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
"has_thread": threadTS != "",
})
c.HandleMessage(peer, messageTS, senderID, chatID, content, mediaPaths, metadata)
c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata)
}
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
@@ -433,7 +433,7 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
"team_id": c.teamID,
}
c.HandleMessage(mentionPeer, messageTS, senderID, chatID, content, nil, metadata)
c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
@@ -476,7 +476,7 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
"text": utils.Truncate(content, 50),
})
c.HandleMessage(bus.Peer{Kind: "channel", ID: channelID}, "", senderID, chatID, content, nil, metadata)
c.HandleMessage(c.ctx, bus.Peer{Kind: "channel", ID: channelID}, "", senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
+13 -5
View File
@@ -66,9 +66,8 @@ func SplitMessage(content string, maxLen int) []string {
} else {
// Code block is too long to fit in one chunk or missing closing fence.
// Try to split inside by injecting closing and reopening fences.
candidateStr := string(candidate)
unclosedStr := string(runes[unclosedIdx:])
headerEnd := strings.Index(unclosedStr, "\n")
fenceRunes := runes[unclosedIdx:]
headerEnd := findNewlineInRunes(fenceRunes)
var header string
if headerEnd == -1 {
header = strings.TrimSpace(string(runes[unclosedIdx : unclosedIdx+3]))
@@ -80,8 +79,6 @@ func SplitMessage(content string, maxLen int) []string {
headerEndIdx = unclosedIdx + headerEnd
}
_ = candidateStr // used above for context
// If we have a reasonable amount of content after the header, split inside
if msgEnd > headerEndIdx+20 {
// Find a better split point closer to maxLen
@@ -170,6 +167,17 @@ func findNextClosingCodeBlockRunes(runes []rune, startIdx int) int {
return -1
}
// findNewlineInRunes finds the first newline character in a rune slice.
// Returns the rune index of the newline or -1 if not found.
func findNewlineInRunes(runes []rune) int {
for i, r := range runes {
if r == '\n' {
return i
}
}
return -1
}
// findLastNewlineRunes finds the last newline character within the last N runes
// Returns the rune position of the newline or -1 if not found
func findLastNewlineRunes(runes []rune, searchWindow int) int {
+1 -1
View File
@@ -448,7 +448,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
}
c.HandleMessage(
c.HandleMessage(c.ctx,
peer,
messageID,
fmt.Sprintf("%d", user.ID),
+1 -1
View File
@@ -630,7 +630,7 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag
})
// Handle the message through the base channel
c.HandleMessage(peer, messageID, senderID, chatID, content, nil, metadata)
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata)
}
// tokenRefreshLoop periodically refreshes the access token
+1 -1
View File
@@ -399,7 +399,7 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag
})
// Handle the message through the base channel
c.HandleMessage(peer, msg.MsgID, senderID, chatID, content, nil, metadata)
c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata)
}
// sendWebhookReply sends a reply using the webhook URL
+1 -1
View File
@@ -224,5 +224,5 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
"preview": utils.Truncate(content, 50),
})
c.HandleMessage(peer, messageID, senderID, chatID, content, mediaPaths, metadata)
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata)
}
+9 -8
View File
@@ -335,14 +335,15 @@ type WeComAppConfig struct {
}
type PicoConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"`
AllowOrigins []string `json:"allow_origins,omitempty"`
PingInterval int `json:"ping_interval,omitempty"` // seconds, default 30
ReadTimeout int `json:"read_timeout,omitempty"` // seconds, default 60
WriteTimeout int `json:"write_timeout,omitempty"` // seconds, default 10
MaxConnections int `json:"max_connections,omitempty"` // default 100
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"`
AllowTokenQuery bool `json:"allow_token_query,omitempty"`
AllowOrigins []string `json:"allow_origins,omitempty"`
PingInterval int `json:"ping_interval,omitempty"`
ReadTimeout int `json:"read_timeout,omitempty"`
WriteTimeout int `json:"write_timeout,omitempty"`
MaxConnections int `json:"max_connections,omitempty"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_ALLOW_FROM"`
}
type HeartbeatConfig struct {
+31 -10
View File
@@ -25,23 +25,32 @@ type MediaStore interface {
// Resolve returns the local file path for a given ref.
Resolve(ref string) (localPath string, err error)
// ResolveWithMeta returns the local file path and metadata for a given ref.
ResolveWithMeta(ref string) (localPath string, meta MediaMeta, err error)
// ReleaseAll deletes all files registered under the given scope
// and removes the mapping entries. File-not-exist errors are ignored.
ReleaseAll(scope string) error
}
// mediaEntry holds the path and metadata for a stored media file.
type mediaEntry struct {
path string
meta MediaMeta
}
// FileMediaStore is a pure in-memory implementation of MediaStore.
// Files are expected to already exist on disk (e.g. in /tmp/picoclaw_media/).
type FileMediaStore struct {
mu sync.RWMutex
refToPath map[string]string
refs map[string]mediaEntry
scopeToRefs map[string]map[string]struct{}
}
// NewFileMediaStore creates a new FileMediaStore.
func NewFileMediaStore() *FileMediaStore {
return &FileMediaStore{
refToPath: make(map[string]string),
refs: make(map[string]mediaEntry),
scopeToRefs: make(map[string]map[string]struct{}),
}
}
@@ -49,15 +58,15 @@ func NewFileMediaStore() *FileMediaStore {
// Store registers a local file under the given scope. The file must exist.
func (s *FileMediaStore) Store(localPath string, meta MediaMeta, scope string) (string, error) {
if _, err := os.Stat(localPath); err != nil {
return "", fmt.Errorf("media store: file does not exist: %s", localPath)
return "", fmt.Errorf("media store: %s: %w", localPath, err)
}
ref := "media://" + uuid.New().String()[:8]
ref := "media://" + uuid.New().String()
s.mu.Lock()
defer s.mu.Unlock()
s.refToPath[ref] = localPath
s.refs[ref] = mediaEntry{path: localPath, meta: meta}
if s.scopeToRefs[scope] == nil {
s.scopeToRefs[scope] = make(map[string]struct{})
}
@@ -71,11 +80,23 @@ func (s *FileMediaStore) Resolve(ref string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
path, ok := s.refToPath[ref]
entry, ok := s.refs[ref]
if !ok {
return "", fmt.Errorf("media store: unknown ref: %s", ref)
}
return path, nil
return entry.path, nil
}
// ResolveWithMeta returns the local path and metadata for the given ref.
func (s *FileMediaStore) ResolveWithMeta(ref string) (string, MediaMeta, error) {
s.mu.RLock()
defer s.mu.RUnlock()
entry, ok := s.refs[ref]
if !ok {
return "", MediaMeta{}, fmt.Errorf("media store: unknown ref: %s", ref)
}
return entry.path, entry.meta, nil
}
// ReleaseAll removes all files under the given scope and cleans up mappings.
@@ -89,11 +110,11 @@ func (s *FileMediaStore) ReleaseAll(scope string) error {
}
for ref := range refs {
if path, exists := s.refToPath[ref]; exists {
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
if entry, exists := s.refs[ref]; exists {
if err := os.Remove(entry.path); err != nil && !os.IsNotExist(err) {
// Log but continue — best effort cleanup
}
delete(s.refToPath, ref)
delete(s.refs, ref)
}
}
+44
View File
@@ -139,6 +139,50 @@ func TestStoreNonexistentFile(t *testing.T) {
if err == nil {
t.Error("Store should fail for nonexistent file")
}
// Error message should include the underlying os error, not just "file does not exist"
if !strings.Contains(err.Error(), "no such file or directory") {
t.Errorf("Error should contain OS error detail, got: %v", err)
}
}
func TestResolveWithMeta(t *testing.T) {
dir := t.TempDir()
store := NewFileMediaStore()
path := createTempFile(t, dir, "image.png")
meta := MediaMeta{
Filename: "image.png",
ContentType: "image/png",
Source: "telegram",
}
ref, err := store.Store(path, meta, "scope1")
if err != nil {
t.Fatalf("Store failed: %v", err)
}
resolvedPath, resolvedMeta, err := store.ResolveWithMeta(ref)
if err != nil {
t.Fatalf("ResolveWithMeta failed: %v", err)
}
if resolvedPath != path {
t.Errorf("ResolveWithMeta path = %q, want %q", resolvedPath, path)
}
if resolvedMeta.Filename != meta.Filename {
t.Errorf("ResolveWithMeta Filename = %q, want %q", resolvedMeta.Filename, meta.Filename)
}
if resolvedMeta.ContentType != meta.ContentType {
t.Errorf("ResolveWithMeta ContentType = %q, want %q", resolvedMeta.ContentType, meta.ContentType)
}
if resolvedMeta.Source != meta.Source {
t.Errorf("ResolveWithMeta Source = %q, want %q", resolvedMeta.Source, meta.Source)
}
// Unknown ref should fail
_, _, err = store.ResolveWithMeta("media://nonexistent")
if err == nil {
t.Error("ResolveWithMeta should fail for unknown ref")
}
}
func TestConcurrentSafety(t *testing.T) {