diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index c62c868e3..3c2cb021d 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -33,6 +33,7 @@ import ( "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" @@ -123,14 +124,18 @@ func gatewayCmd() { return tools.SilentResult(response) }) - channelManager, err := channels.NewManager(cfg, msgBus) + // Create media store for file lifecycle management + mediaStore := media.NewFileMediaStore() + + channelManager, err := channels.NewManager(cfg, msgBus, mediaStore) if err != nil { fmt.Printf("Error creating channel manager: %v\n", err) os.Exit(1) } - // Inject channel manager into agent loop for command handling + // Inject channel manager and media store into agent loop agentLoop.SetChannelManager(channelManager) + agentLoop.SetMediaStore(mediaStore) var transcriber *voice.GroqTranscriber groqAPIKey := cfg.Providers.Groq.APIKey diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index d8ea3b091..97569bef7 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -21,6 +21,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/skills" @@ -38,6 +39,7 @@ type AgentLoop struct { summarizing sync.Map fallback *providers.FallbackChain channelManager *channels.Manager + mediaStore media.MediaStore } // processOptions configures how a message is processed @@ -167,33 +169,47 @@ func (al *AgentLoop) Run(ctx context.Context) error { continue } - response, err := al.processMessage(ctx, msg) - if err != nil { - response = fmt.Sprintf("Error processing message: %v", err) - } - - if response != "" { - // Check if the message tool already sent a response during this round. - // If so, skip publishing to avoid duplicate messages to the user. - // Use default agent's tools to check (message tool is shared). - alreadySent := false - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent != nil { - if tool, ok := defaultAgent.Tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() + // Process message and ensure media is released afterward + func() { + defer func() { + if al.mediaStore != nil && msg.MediaScope != "" { + if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { + logger.WarnCF("agent", "Failed to release media", map[string]any{ + "scope": msg.MediaScope, + "error": releaseErr.Error(), + }) } } + }() + + response, err := al.processMessage(ctx, msg) + if err != nil { + response = fmt.Sprintf("Error processing message: %v", err) } - if !alreadySent { - al.bus.PublishOutbound(bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) + if response != "" { + // Check if the message tool already sent a response during this round. + // If so, skip publishing to avoid duplicate messages to the user. + // Use default agent's tools to check (message tool is shared). + alreadySent := false + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } + } + } + + if !alreadySent { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) + } } - } + }() } } @@ -216,6 +232,11 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { al.channelManager = cm } +// SetMediaStore injects a MediaStore for media lifecycle management. +func (al *AgentLoop) SetMediaStore(s media.MediaStore) { + al.mediaStore = s +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 081f13a0b..e49713eb8 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -12,8 +12,9 @@ type InboundMessage struct { ChatID string `json:"chat_id"` Content string `json:"content"` Media []string `json:"media,omitempty"` - Peer Peer `json:"peer"` // routing peer - MessageID string `json:"message_id,omitempty"` // platform message ID + Peer Peer `json:"peer"` // routing peer + MessageID string `json:"message_id,omitempty"` // platform message ID + MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope SessionKey string `json:"session_key"` Metadata map[string]string `json:"metadata,omitempty"` } diff --git a/pkg/channels/base.go b/pkg/channels/base.go index f70145981..d967d9e91 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -5,7 +5,10 @@ import ( "strings" "sync/atomic" + "github.com/google/uuid" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/media" ) type Channel interface { @@ -41,6 +44,7 @@ type BaseChannel struct { name string allowList []string maxMessageLength int + mediaStore media.MediaStore } func NewBaseChannel( @@ -125,15 +129,18 @@ func (c *BaseChannel) HandleMessage( return } + scope := BuildMediaScope(c.name, chatID, messageID) + msg := bus.InboundMessage{ - Channel: c.name, - SenderID: senderID, - ChatID: chatID, - Content: content, - Media: media, - Peer: peer, - MessageID: messageID, - Metadata: metadata, + Channel: c.name, + SenderID: senderID, + ChatID: chatID, + Content: content, + Media: media, + Peer: peer, + MessageID: messageID, + MediaScope: scope, + Metadata: metadata, } c.bus.PublishInbound(msg) @@ -142,3 +149,18 @@ func (c *BaseChannel) HandleMessage( func (c *BaseChannel) SetRunning(running bool) { c.running.Store(running) } + +// SetMediaStore injects a MediaStore into the channel. +func (c *BaseChannel) SetMediaStore(s media.MediaStore) { c.mediaStore = s } + +// GetMediaStore returns the injected MediaStore (may be nil). +func (c *BaseChannel) GetMediaStore() media.MediaStore { return c.mediaStore } + +// BuildMediaScope constructs a scope key for media lifecycle tracking. +func BuildMediaScope(channel, chatID, messageID string) string { + id := messageID + if id == "" { + id = uuid.New().String() + } + return channel + ":" + chatID + ":" + id +} diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 623bc9f48..7977d32e1 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -3,7 +3,6 @@ package discord import ( "context" "fmt" - "os" "strings" "sync" "time" @@ -14,6 +13,7 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -202,19 +202,22 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content := m.Content content = c.stripBotMention(content) mediaPaths := make([]string, 0, len(m.Attachments)) - localFiles := make([]string, 0, len(m.Attachments)) - // Ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + scope := channels.BuildMediaScope("discord", m.ChannelID, m.ID) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "discord", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } for _, attachment := range m.Attachments { isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) @@ -222,8 +225,6 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { - localFiles = append(localFiles, localPath) - transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { ctx, cancel := context.WithTimeout(c.ctx, transcriptionTimeout) @@ -245,6 +246,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) } + mediaPaths = append(mediaPaths, storeMedia(localPath, attachment.Filename)) content = appendContent(content, transcribedText) } else { logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 9744e1848..272a53c6e 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -10,7 +10,6 @@ import ( "fmt" "io" "net/http" - "os" "strings" "sync" "time" @@ -19,6 +18,7 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -308,18 +308,22 @@ func (c *LINEChannel) processEvent(event lineEvent) { var content string var mediaPaths []string - localFiles := []string{} - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("line", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + scope := channels.BuildMediaScope("line", chatID, msg.ID) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "line", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } switch msg.Type { case "text": @@ -331,22 +335,19 @@ func (c *LINEChannel) processEvent(event lineEvent) { case "image": localPath := c.downloadContent(msg.ID, "image.jpg") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "image.jpg")) content = "[image]" } case "audio": localPath := c.downloadContent(msg.ID, "audio.m4a") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "audio.m4a")) content = "[audio]" } case "video": localPath := c.downloadContent(msg.ID, "video.mp4") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "video.mp4")) content = "[video]" } case "file": diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 081d616da..37af01796 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -15,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -31,6 +32,7 @@ type Manager struct { workers map[string]*channelWorker bus *bus.MessageBus config *config.Config + mediaStore media.MediaStore dispatchTask *asyncTask mu sync.RWMutex } @@ -39,12 +41,13 @@ type asyncTask struct { cancel context.CancelFunc } -func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error) { +func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) { m := &Manager{ - channels: make(map[string]Channel), - workers: make(map[string]*channelWorker), - bus: messageBus, - config: cfg, + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + bus: messageBus, + config: cfg, + mediaStore: store, } if err := m.initChannels(); err != nil { @@ -73,6 +76,12 @@ func (m *Manager) initChannel(name, displayName string) { "error": err.Error(), }) } else { + // Inject MediaStore if channel supports it + if m.mediaStore != nil { + if setter, ok := ch.(interface{ SetMediaStore(s media.MediaStore) }); ok { + setter.SetMediaStore(m.mediaStore) + } + } m.channels[name] = ch m.workers[name] = &channelWorker{ ch: ch, diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index 4f35888ca..e2fe541f1 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "os" "strconv" "strings" "sync" @@ -17,6 +16,7 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -575,11 +575,15 @@ type parseMessageResult struct { Text string IsBotMentioned bool Media []string - LocalFiles []string ReplyTo string } -func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult { +func (c *OneBotChannel) parseMessageSegments( + raw json.RawMessage, + selfID int64, + store media.MediaStore, + scope string, +) parseMessageResult { if len(raw) == 0 { return parseMessageResult{} } @@ -606,10 +610,23 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) var textParts []string mentioned := false selfIDStr := strconv.FormatInt(selfID, 10) - var media []string - var localFiles []string + var mediaRefs []string var replyTo string + // Helper to register a local file with the media store + storeFile := func(localPath, filename string) string { + if store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "onebot", + }, scope) + if err == nil { + return ref + } + } + return localPath // fallback + } + for _, seg := range segments { segType, _ := seg["type"].(string) data, _ := seg["data"].(map[string]any) @@ -645,8 +662,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) LoggerPrefix: "onebot", }) if localPath != "" { - media = append(media, localPath) - localFiles = append(localFiles, localPath) + mediaRefs = append(mediaRefs, storeFile(localPath, filename)) textParts = append(textParts, fmt.Sprintf("[%s]", segType)) } } @@ -660,7 +676,6 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) LoggerPrefix: "onebot", }) if localPath != "" { - localFiles = append(localFiles, localPath) if c.transcriber != nil && c.transcriber.IsAvailable() { tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second) result, err := c.transcriber.Transcribe(tctx, localPath) @@ -670,13 +685,15 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) "error": err.Error(), }) textParts = append(textParts, "[voice (transcription failed)]") - media = append(media, localPath) + mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr")) } else { textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text)) + // Still store the file so it can be released later + storeFile(localPath, "voice.amr") } } else { textParts = append(textParts, "[voice]") - media = append(media, localPath) + mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr")) } } } @@ -706,8 +723,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) return parseMessageResult{ Text: strings.TrimSpace(strings.Join(textParts, "")), IsBotMentioned: mentioned, - Media: media, - LocalFiles: localFiles, + Media: mediaRefs, ReplyTo: replyTo, } } @@ -799,7 +815,17 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { selfID = atomic.LoadInt64(&c.selfID) } - parsed := c.parseMessageSegments(raw.Message, selfID) + // Compute scope for media store before parsing (parsing may download files) + var chatIDForScope string + switch raw.MessageType { + case "group": + chatIDForScope = "group:" + strconv.FormatInt(groupID, 10) + default: + chatIDForScope = "private:" + strconv.FormatInt(userID, 10) + } + scope := channels.BuildMediaScope("onebot", chatIDForScope, messageID) + + parsed := c.parseMessageSegments(raw.Message, selfID, c.GetMediaStore(), scope) isBotMentioned := parsed.IsBotMentioned content := raw.RawMessage @@ -828,20 +854,6 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { } } - // Clean up temp files when done - if len(parsed.LocalFiles) > 0 { - defer func() { - for _, f := range parsed.LocalFiles { - if err := os.Remove(f); err != nil { - logger.DebugCF("onebot", "Failed to remove temp file", map[string]any{ - "path": f, - "error": err.Error(), - }) - } - } - }() - } - if c.isDuplicate(messageID) { logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{ "message_id": messageID, diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index fc0bee505..53d7c0609 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -3,7 +3,6 @@ package slack import ( "context" "fmt" - "os" "strings" "sync" "time" @@ -16,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -233,19 +233,22 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { content = c.stripBotMention(content) var mediaPaths []string - localFiles := []string{} // 跟踪需要清理的本地文件 - // 确保临时文件在函数返回时被清理 - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("slack", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + scope := channels.BuildMediaScope("slack", chatID, messageTS) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "slack", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } if ev.Message != nil && len(ev.Message.Files) > 0 { for _, file := range ev.Message.Files { @@ -253,8 +256,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { if localPath == "" { continue } - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, file.Name)) if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 578e3c51e..af7155799 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -20,6 +20,7 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -251,19 +252,24 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes content := "" mediaPaths := []string{} - localFiles := []string{} // 跟踪需要清理的本地文件 - // 确保临时文件在函数返回时被清理 - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + chatIDStr := fmt.Sprintf("%d", chatID) + messageIDStr := fmt.Sprintf("%d", message.MessageID) + scope := channels.BuildMediaScope("telegram", chatIDStr, messageIDStr) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "telegram", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback: use raw path + } if message.Text != "" { content += message.Text @@ -280,8 +286,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes photo := message.Photo[len(message.Photo)-1] photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { - localFiles = append(localFiles, photoPath) - mediaPaths = append(mediaPaths, photoPath) + mediaPaths = append(mediaPaths, storeMedia(photoPath, "photo.jpg")) if content != "" { content += "\n" } @@ -292,8 +297,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if message.Voice != nil { voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") if voicePath != "" { - localFiles = append(localFiles, voicePath) - mediaPaths = append(mediaPaths, voicePath) + mediaPaths = append(mediaPaths, storeMedia(voicePath, "voice.ogg")) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { @@ -327,8 +331,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if message.Audio != nil { audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") if audioPath != "" { - localFiles = append(localFiles, audioPath) - mediaPaths = append(mediaPaths, audioPath) + mediaPaths = append(mediaPaths, storeMedia(audioPath, "audio.mp3")) if content != "" { content += "\n" } @@ -339,8 +342,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if message.Document != nil { docPath := c.downloadFile(ctx, message.Document.FileID, "") if docPath != "" { - localFiles = append(localFiles, docPath) - mediaPaths = append(mediaPaths, docPath) + mediaPaths = append(mediaPaths, storeMedia(docPath, "document")) if content != "" { content += "\n" } @@ -367,7 +369,6 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes } // Stop any previous thinking animation - chatIDStr := fmt.Sprintf("%d", chatID) if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { cf.Cancel() diff --git a/pkg/media/store.go b/pkg/media/store.go new file mode 100644 index 000000000..8d03c03ef --- /dev/null +++ b/pkg/media/store.go @@ -0,0 +1,102 @@ +package media + +import ( + "fmt" + "os" + "sync" + + "github.com/google/uuid" +) + +// MediaMeta holds metadata about a stored media file. +type MediaMeta struct { + Filename string + ContentType string + Source string // "telegram", "discord", "tool:image-gen", etc. +} + +// MediaStore manages the lifecycle of media files associated with processing scopes. +type MediaStore interface { + // Store registers an existing local file under the given scope. + // Returns a ref identifier (e.g. "media://"). + // Store does not move or copy the file; it only records the mapping. + Store(localPath string, meta MediaMeta, scope string) (ref string, err error) + + // Resolve returns the local file path for a given ref. + Resolve(ref string) (localPath string, err error) + + // ReleaseAll deletes all files registered under the given scope + // and removes the mapping entries. File-not-exist errors are ignored. + ReleaseAll(scope string) error +} + +// FileMediaStore is a pure in-memory implementation of MediaStore. +// Files are expected to already exist on disk (e.g. in /tmp/picoclaw_media/). +type FileMediaStore struct { + mu sync.RWMutex + refToPath map[string]string + scopeToRefs map[string]map[string]struct{} +} + +// NewFileMediaStore creates a new FileMediaStore. +func NewFileMediaStore() *FileMediaStore { + return &FileMediaStore{ + refToPath: make(map[string]string), + scopeToRefs: make(map[string]map[string]struct{}), + } +} + +// Store registers a local file under the given scope. The file must exist. +func (s *FileMediaStore) Store(localPath string, meta MediaMeta, scope string) (string, error) { + if _, err := os.Stat(localPath); err != nil { + return "", fmt.Errorf("media store: file does not exist: %s", localPath) + } + + ref := "media://" + uuid.New().String()[:8] + + s.mu.Lock() + defer s.mu.Unlock() + + s.refToPath[ref] = localPath + if s.scopeToRefs[scope] == nil { + s.scopeToRefs[scope] = make(map[string]struct{}) + } + s.scopeToRefs[scope][ref] = struct{}{} + + return ref, nil +} + +// Resolve returns the local path for the given ref. +func (s *FileMediaStore) Resolve(ref string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + path, ok := s.refToPath[ref] + if !ok { + return "", fmt.Errorf("media store: unknown ref: %s", ref) + } + return path, nil +} + +// ReleaseAll removes all files under the given scope and cleans up mappings. +func (s *FileMediaStore) ReleaseAll(scope string) error { + s.mu.Lock() + defer s.mu.Unlock() + + refs, ok := s.scopeToRefs[scope] + if !ok { + return nil + } + + for ref := range refs { + if path, exists := s.refToPath[ref]; exists { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + // Log but continue — best effort cleanup + } + delete(s.refToPath, ref) + } + } + + delete(s.scopeToRefs, scope) + return nil +} diff --git a/pkg/media/store_test.go b/pkg/media/store_test.go new file mode 100644 index 000000000..361582307 --- /dev/null +++ b/pkg/media/store_test.go @@ -0,0 +1,179 @@ +package media + +import ( + "os" + "path/filepath" + "strings" + "sync" + "testing" +) + +func createTempFile(t *testing.T, dir, name string) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("test content"), 0o644); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + return path +} + +func TestStoreAndResolve(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + path := createTempFile(t, dir, "photo.jpg") + + ref, err := store.Store(path, MediaMeta{Filename: "photo.jpg", Source: "telegram"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + if !strings.HasPrefix(ref, "media://") { + t.Errorf("ref should start with media://, got %q", ref) + } + + resolved, err := store.Resolve(ref) + if err != nil { + t.Fatalf("Resolve failed: %v", err) + } + if resolved != path { + t.Errorf("Resolve returned %q, want %q", resolved, path) + } +} + +func TestReleaseAll(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + paths := make([]string, 3) + refs := make([]string, 3) + for i := 0; i < 3; i++ { + paths[i] = createTempFile(t, dir, strings.Repeat("a", i+1)+".jpg") + var err error + refs[i], err = store.Store(paths[i], MediaMeta{Source: "test"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + } + + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("ReleaseAll failed: %v", err) + } + + // Files should be deleted + for _, p := range paths { + if _, err := os.Stat(p); !os.IsNotExist(err) { + t.Errorf("file %q should have been deleted", p) + } + } + + // Refs should be unresolvable + for _, ref := range refs { + if _, err := store.Resolve(ref); err == nil { + t.Errorf("Resolve(%q) should fail after ReleaseAll", ref) + } + } +} + +func TestMultiScopeIsolation(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + pathA := createTempFile(t, dir, "fileA.jpg") + pathB := createTempFile(t, dir, "fileB.jpg") + + refA, _ := store.Store(pathA, MediaMeta{Source: "test"}, "scopeA") + refB, _ := store.Store(pathB, MediaMeta{Source: "test"}, "scopeB") + + // Release only scopeA + if err := store.ReleaseAll("scopeA"); err != nil { + t.Fatalf("ReleaseAll(scopeA) failed: %v", err) + } + + // scopeA file should be gone + if _, err := os.Stat(pathA); !os.IsNotExist(err) { + t.Error("file A should have been deleted") + } + if _, err := store.Resolve(refA); err == nil { + t.Error("refA should be unresolvable after release") + } + + // scopeB file should still exist + if _, err := os.Stat(pathB); err != nil { + t.Error("file B should still exist") + } + resolved, err := store.Resolve(refB) + if err != nil { + t.Fatalf("refB should still resolve: %v", err) + } + if resolved != pathB { + t.Errorf("resolved %q, want %q", resolved, pathB) + } +} + +func TestReleaseAllIdempotent(t *testing.T) { + store := NewFileMediaStore() + + // ReleaseAll on non-existent scope should not error + if err := store.ReleaseAll("nonexistent"); err != nil { + t.Fatalf("ReleaseAll on empty scope should not error: %v", err) + } + + // Create and release, then release again + dir := t.TempDir() + path := createTempFile(t, dir, "file.jpg") + _, _ = store.Store(path, MediaMeta{Source: "test"}, "scope1") + + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("first ReleaseAll failed: %v", err) + } + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("second ReleaseAll should not error: %v", err) + } +} + +func TestStoreNonexistentFile(t *testing.T) { + store := NewFileMediaStore() + + _, err := store.Store("/nonexistent/path/file.jpg", MediaMeta{Source: "test"}, "scope1") + if err == nil { + t.Error("Store should fail for nonexistent file") + } +} + +func TestConcurrentSafety(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + const goroutines = 20 + const filesPerGoroutine = 5 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for g := 0; g < goroutines; g++ { + go func(gIdx int) { + defer wg.Done() + scope := strings.Repeat("s", gIdx+1) + + for i := 0; i < filesPerGoroutine; i++ { + path := createTempFile(t, dir, strings.Repeat("f", gIdx*filesPerGoroutine+i+1)+".tmp") + ref, err := store.Store(path, MediaMeta{Source: "test"}, scope) + if err != nil { + t.Errorf("Store failed: %v", err) + return + } + + if _, err := store.Resolve(ref); err != nil { + t.Errorf("Resolve failed: %v", err) + } + } + + if err := store.ReleaseAll(scope); err != nil { + t.Errorf("ReleaseAll failed: %v", err) + } + }(g) + } + + wg.Wait() +}