diff --git a/pkg/channels/matrix/matrix.go b/pkg/channels/matrix/matrix.go index 7d361bcf8..d51eee8fb 100644 --- a/pkg/channels/matrix/matrix.go +++ b/pkg/channels/matrix/matrix.go @@ -26,9 +26,13 @@ import ( ) const ( - typingRefreshInterval = 20 * time.Second - typingServerTTL = 30 * time.Second - roomKindCacheTTL = 5 * time.Minute + typingRefreshInterval = 20 * time.Second + typingServerTTL = 30 * time.Second + roomKindCacheTTL = 5 * time.Minute + roomKindCacheCleanupPeriod = 1 * time.Minute + roomKindCacheMaxEntries = 2048 + + matrixMediaTempDirName = "picoclaw_media" ) var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)]+href=["']([^"']+)["']`) @@ -36,6 +40,109 @@ var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)]+href=["']([^"']+)[" type roomKindCacheEntry struct { isGroup bool expiresAt time.Time + touchedAt time.Time +} + +type roomKindCache struct { + mu sync.Mutex + entries map[string]roomKindCacheEntry + maxEntries int + ttl time.Duration +} + +func newRoomKindCache(maxEntries int, ttl time.Duration) *roomKindCache { + if maxEntries <= 0 { + maxEntries = roomKindCacheMaxEntries + } + if ttl <= 0 { + ttl = roomKindCacheTTL + } + + return &roomKindCache{ + entries: make(map[string]roomKindCacheEntry), + maxEntries: maxEntries, + ttl: ttl, + } +} + +func (c *roomKindCache) get(roomID string, now time.Time) (bool, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + entry, ok := c.entries[roomID] + if !ok { + return false, false + } + if !entry.expiresAt.After(now) { + delete(c.entries, roomID) + return false, false + } + + return entry.isGroup, true +} + +func (c *roomKindCache) set(roomID string, isGroup bool, now time.Time) { + c.mu.Lock() + defer c.mu.Unlock() + + if entry, ok := c.entries[roomID]; ok { + entry.isGroup = isGroup + entry.expiresAt = now.Add(c.ttl) + entry.touchedAt = now + c.entries[roomID] = entry + return + } + + c.cleanupExpiredLocked(now) + for len(c.entries) >= c.maxEntries { + if !c.evictOldestLocked() { + break + } + } + + c.entries[roomID] = roomKindCacheEntry{ + isGroup: isGroup, + expiresAt: now.Add(c.ttl), + touchedAt: now, + } +} + +func (c *roomKindCache) cleanupExpired(now time.Time) int { + c.mu.Lock() + defer c.mu.Unlock() + return c.cleanupExpiredLocked(now) +} + +func (c *roomKindCache) cleanupExpiredLocked(now time.Time) int { + removed := 0 + for roomID, entry := range c.entries { + if !entry.expiresAt.After(now) { + delete(c.entries, roomID) + removed++ + } + } + return removed +} + +func (c *roomKindCache) evictOldestLocked() bool { + if len(c.entries) == 0 { + return false + } + + var ( + oldestRoomID string + oldestAt time.Time + ) + + for roomID, entry := range c.entries { + if oldestRoomID == "" || entry.touchedAt.Before(oldestAt) { + oldestRoomID = roomID + oldestAt = entry.touchedAt + } + } + + delete(c.entries, oldestRoomID) + return true } type typingSession struct { @@ -70,7 +177,8 @@ type MatrixChannel struct { typingMu sync.Mutex typingSessions map[string]*typingSession // roomID -> session - roomKindCache sync.Map // roomID -> roomKindCacheEntry + roomKindCache *roomKindCache + localpartMentionR *regexp.Regexp } func NewMatrixChannel(cfg config.MatrixConfig, messageBus *bus.MessageBus) (*MatrixChannel, error) { @@ -111,14 +219,15 @@ func NewMatrixChannel(cfg config.MatrixConfig, messageBus *bus.MessageBus) (*Mat ) return &MatrixChannel{ - BaseChannel: base, - client: client, - config: cfg, - syncer: syncer, - typingSessions: make(map[string]*typingSession), - startTime: time.Now(), - roomKindCache: sync.Map{}, - typingMu: sync.Mutex{}, + BaseChannel: base, + client: client, + config: cfg, + syncer: syncer, + typingSessions: make(map[string]*typingSession), + startTime: time.Now(), + roomKindCache: newRoomKindCache(roomKindCacheMaxEntries, roomKindCacheTTL), + localpartMentionR: localpartMentionRegexp(matrixLocalpart(client.UserID)), + typingMu: sync.Mutex{}, }, nil } @@ -132,6 +241,7 @@ func (c *MatrixChannel) Start(ctx context.Context) error { c.syncer.OnEventType(event.StateMember, c.handleMemberEvent) c.SetRunning(true) + go c.runRoomKindCacheJanitor(c.ctx) go func() { if err := c.client.SyncWithContext(c.ctx); err != nil && c.ctx.Err() == nil { @@ -469,7 +579,7 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event if isGroup { isMentioned := c.isBotMentioned(msgEvt) if isMentioned { - content = stripUserMention(content, c.client.UserID) + content = c.stripSelfMention(content) } respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) if !respond { @@ -483,7 +593,7 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event } content = cleaned } else { - content = stripUserMention(content, c.client.UserID) + content = c.stripSelfMention(content) } content = strings.TrimSpace(content) @@ -619,7 +729,11 @@ func (c *MatrixChannel) downloadMedia( label := matrixMediaLabel(msgEvt, mediaKind) ext := matrixMediaExt(label, matrixContentType(msgEvt), mediaKind) - tmp, err := os.CreateTemp("", "matrix-media-*"+ext) + mediaDir, err := matrixMediaTempDir() + if err != nil { + return "", fmt.Errorf("create matrix media directory: %w", err) + } + tmp, err := os.CreateTemp(mediaDir, "matrix-media-*"+ext) if err != nil { return "", err } @@ -777,11 +891,8 @@ func matrixMediaExt(filename, contentType, mediaKind string) string { func (c *MatrixChannel) isGroupRoom(ctx context.Context, roomID id.RoomID) bool { now := time.Now() - if cached, ok := c.roomKindCache.Load(roomID.String()); ok { - entry := cached.(roomKindCacheEntry) - if now.Before(entry.expiresAt) { - return entry.isGroup - } + if isGroup, ok := c.roomKindCache.get(roomID.String(), now); ok { + return isGroup } qctx := c.baseContext() @@ -801,10 +912,7 @@ func (c *MatrixChannel) isGroupRoom(ctx context.Context, roomID id.RoomID) bool } isGroup := len(resp.Joined) > 2 - c.roomKindCache.Store(roomID.String(), roomKindCacheEntry{ - isGroup: isGroup, - expiresAt: now.Add(roomKindCacheTTL), - }) + c.roomKindCache.set(roomID.String(), isGroup, now) return isGroup } @@ -825,13 +933,17 @@ func (c *MatrixChannel) isBotMentioned(msgEvt *event.MessageEventContent) bool { return true } - localpart := matrixLocalpart(c.client.UserID) - if localpart == "" { + mentionR := c.localpartMentionR + if mentionR == nil { + mentionR = localpartMentionRegexp(matrixLocalpart(c.client.UserID)) + } + if mentionR == nil { return false } - re := localpartMentionRegexp(localpart) - return re.MatchString(msgEvt.Body) || re.MatchString(msgEvt.FormattedBody) + // Matrix users are addressed as MXID "@localpart:server", but many clients + // emit plain-text mentions as "@localpart". Both forms are handled here. + return mentionR.MatchString(msgEvt.Body) || mentionR.MatchString(msgEvt.FormattedBody) } func mentionsUserInFormattedBody(formattedBody string, userID id.UserID) bool { @@ -941,6 +1053,32 @@ func (c *MatrixChannel) baseContext() context.Context { return context.Background() } +func (c *MatrixChannel) runRoomKindCacheJanitor(ctx context.Context) { + ticker := time.NewTicker(roomKindCacheCleanupPeriod) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case now := <-ticker.C: + c.roomKindCache.cleanupExpired(now) + } + } +} + +func (c *MatrixChannel) stripSelfMention(text string) string { + return stripUserMentionWithRegexp(text, c.client.UserID, c.localpartMentionR) +} + +func matrixMediaTempDir() (string, error) { + mediaDir := filepath.Join(os.TempDir(), matrixMediaTempDirName) + if err := os.MkdirAll(mediaDir, 0o700); err != nil { + return "", err + } + return mediaDir, nil +} + func matrixLocalpart(userID id.UserID) string { s := strings.TrimPrefix(userID.String(), "@") localpart, _, _ := strings.Cut(s, ":") @@ -948,17 +1086,27 @@ func matrixLocalpart(userID id.UserID) string { } func localpartMentionRegexp(localpart string) *regexp.Regexp { + localpart = strings.TrimSpace(localpart) + if localpart == "" { + return nil + } + + // Match Matrix mentions in plain text while avoiding false positives: + // "@picoclaw" and "@picoclaw:matrix.org" should match, + // "test@example.com" and "hellopicoclawworld" should not. pattern := `(?i)(^|[^[:alnum:]_])@` + regexp.QuoteMeta(localpart) + `(?::[A-Za-z0-9._:-]+)?([^[:alnum:]_]|$)` return regexp.MustCompile(pattern) } func stripUserMention(text string, userID id.UserID) string { + return stripUserMentionWithRegexp(text, userID, localpartMentionRegexp(matrixLocalpart(userID))) +} + +func stripUserMentionWithRegexp(text string, userID id.UserID, mentionR *regexp.Regexp) string { cleaned := strings.ReplaceAll(text, userID.String(), "") - localpart := matrixLocalpart(userID) - if localpart != "" { - re := localpartMentionRegexp(localpart) - cleaned = re.ReplaceAllString(cleaned, "$1$2") + if mentionR != nil { + cleaned = mentionR.ReplaceAllString(cleaned, "$1$2") } cleaned = strings.TrimSpace(cleaned) diff --git a/pkg/channels/matrix/matrix_test.go b/pkg/channels/matrix/matrix_test.go index 6a0ad03b8..4eb5ac083 100644 --- a/pkg/channels/matrix/matrix_test.go +++ b/pkg/channels/matrix/matrix_test.go @@ -2,7 +2,10 @@ package matrix import ( "context" + "os" + "path/filepath" "testing" + "time" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -116,6 +119,57 @@ func TestIsBotMentioned(t *testing.T) { } } +func TestRoomKindCache_ExpiresEntries(t *testing.T) { + cache := newRoomKindCache(4, 5*time.Second) + now := time.Unix(100, 0) + cache.set("!room:matrix.org", true, now) + + if got, ok := cache.get("!room:matrix.org", now.Add(2*time.Second)); !ok || !got { + t.Fatalf("expected cached group room before ttl, got ok=%v group=%v", ok, got) + } + + if _, ok := cache.get("!room:matrix.org", now.Add(6*time.Second)); ok { + t.Fatal("expected cache miss after ttl expiry") + } +} + +func TestRoomKindCache_EvictsOldestWhenFull(t *testing.T) { + cache := newRoomKindCache(2, time.Minute) + now := time.Unix(200, 0) + + cache.set("!room1:matrix.org", false, now) + cache.set("!room2:matrix.org", false, now.Add(1*time.Second)) + cache.set("!room3:matrix.org", true, now.Add(2*time.Second)) + + if _, ok := cache.get("!room1:matrix.org", now.Add(2*time.Second)); ok { + t.Fatal("expected oldest cache entry to be evicted") + } + if got, ok := cache.get("!room2:matrix.org", now.Add(2*time.Second)); !ok || got { + t.Fatalf("expected room2 to remain and be direct, got ok=%v group=%v", ok, got) + } + if got, ok := cache.get("!room3:matrix.org", now.Add(2*time.Second)); !ok || !got { + t.Fatalf("expected room3 to remain and be group, got ok=%v group=%v", ok, got) + } +} + +func TestMatrixMediaTempDir(t *testing.T) { + dir, err := matrixMediaTempDir() + if err != nil { + t.Fatalf("matrixMediaTempDir failed: %v", err) + } + if filepath.Base(dir) != matrixMediaTempDirName { + t.Fatalf("unexpected media dir base: %q", filepath.Base(dir)) + } + + info, err := os.Stat(dir) + if err != nil { + t.Fatalf("media dir not created: %v", err) + } + if !info.IsDir() { + t.Fatalf("expected directory, got mode=%v", info.Mode()) + } +} + func TestMatrixMediaExt(t *testing.T) { if got := matrixMediaExt("photo.png", "", "image"); got != ".png" { t.Fatalf("filename extension mismatch: got=%q", got)