fix(matrix): bound room cache and align temp media dir

This commit is contained in:
horsley
2026-03-08 09:23:02 +00:00
parent cd955d730b
commit 6e16ac7f68
2 changed files with 234 additions and 32 deletions
+180 -32
View File
@@ -26,9 +26,13 @@ import (
)
const (
typingRefreshInterval = 20 * time.Second
typingServerTTL = 30 * time.Second
roomKindCacheTTL = 5 * time.Minute
typingRefreshInterval = 20 * time.Second
typingServerTTL = 30 * time.Second
roomKindCacheTTL = 5 * time.Minute
roomKindCacheCleanupPeriod = 1 * time.Minute
roomKindCacheMaxEntries = 2048
matrixMediaTempDirName = "picoclaw_media"
)
var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)<a[^>]+href=["']([^"']+)["']`)
@@ -36,6 +40,109 @@ var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)<a[^>]+href=["']([^"']+)["
type roomKindCacheEntry struct {
isGroup bool
expiresAt time.Time
touchedAt time.Time
}
type roomKindCache struct {
mu sync.Mutex
entries map[string]roomKindCacheEntry
maxEntries int
ttl time.Duration
}
func newRoomKindCache(maxEntries int, ttl time.Duration) *roomKindCache {
if maxEntries <= 0 {
maxEntries = roomKindCacheMaxEntries
}
if ttl <= 0 {
ttl = roomKindCacheTTL
}
return &roomKindCache{
entries: make(map[string]roomKindCacheEntry),
maxEntries: maxEntries,
ttl: ttl,
}
}
func (c *roomKindCache) get(roomID string, now time.Time) (bool, bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, ok := c.entries[roomID]
if !ok {
return false, false
}
if !entry.expiresAt.After(now) {
delete(c.entries, roomID)
return false, false
}
return entry.isGroup, true
}
func (c *roomKindCache) set(roomID string, isGroup bool, now time.Time) {
c.mu.Lock()
defer c.mu.Unlock()
if entry, ok := c.entries[roomID]; ok {
entry.isGroup = isGroup
entry.expiresAt = now.Add(c.ttl)
entry.touchedAt = now
c.entries[roomID] = entry
return
}
c.cleanupExpiredLocked(now)
for len(c.entries) >= c.maxEntries {
if !c.evictOldestLocked() {
break
}
}
c.entries[roomID] = roomKindCacheEntry{
isGroup: isGroup,
expiresAt: now.Add(c.ttl),
touchedAt: now,
}
}
func (c *roomKindCache) cleanupExpired(now time.Time) int {
c.mu.Lock()
defer c.mu.Unlock()
return c.cleanupExpiredLocked(now)
}
func (c *roomKindCache) cleanupExpiredLocked(now time.Time) int {
removed := 0
for roomID, entry := range c.entries {
if !entry.expiresAt.After(now) {
delete(c.entries, roomID)
removed++
}
}
return removed
}
func (c *roomKindCache) evictOldestLocked() bool {
if len(c.entries) == 0 {
return false
}
var (
oldestRoomID string
oldestAt time.Time
)
for roomID, entry := range c.entries {
if oldestRoomID == "" || entry.touchedAt.Before(oldestAt) {
oldestRoomID = roomID
oldestAt = entry.touchedAt
}
}
delete(c.entries, oldestRoomID)
return true
}
type typingSession struct {
@@ -70,7 +177,8 @@ type MatrixChannel struct {
typingMu sync.Mutex
typingSessions map[string]*typingSession // roomID -> session
roomKindCache sync.Map // roomID -> roomKindCacheEntry
roomKindCache *roomKindCache
localpartMentionR *regexp.Regexp
}
func NewMatrixChannel(cfg config.MatrixConfig, messageBus *bus.MessageBus) (*MatrixChannel, error) {
@@ -111,14 +219,15 @@ func NewMatrixChannel(cfg config.MatrixConfig, messageBus *bus.MessageBus) (*Mat
)
return &MatrixChannel{
BaseChannel: base,
client: client,
config: cfg,
syncer: syncer,
typingSessions: make(map[string]*typingSession),
startTime: time.Now(),
roomKindCache: sync.Map{},
typingMu: sync.Mutex{},
BaseChannel: base,
client: client,
config: cfg,
syncer: syncer,
typingSessions: make(map[string]*typingSession),
startTime: time.Now(),
roomKindCache: newRoomKindCache(roomKindCacheMaxEntries, roomKindCacheTTL),
localpartMentionR: localpartMentionRegexp(matrixLocalpart(client.UserID)),
typingMu: sync.Mutex{},
}, nil
}
@@ -132,6 +241,7 @@ func (c *MatrixChannel) Start(ctx context.Context) error {
c.syncer.OnEventType(event.StateMember, c.handleMemberEvent)
c.SetRunning(true)
go c.runRoomKindCacheJanitor(c.ctx)
go func() {
if err := c.client.SyncWithContext(c.ctx); err != nil && c.ctx.Err() == nil {
@@ -469,7 +579,7 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
if isGroup {
isMentioned := c.isBotMentioned(msgEvt)
if isMentioned {
content = stripUserMention(content, c.client.UserID)
content = c.stripSelfMention(content)
}
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
if !respond {
@@ -483,7 +593,7 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
}
content = cleaned
} else {
content = stripUserMention(content, c.client.UserID)
content = c.stripSelfMention(content)
}
content = strings.TrimSpace(content)
@@ -619,7 +729,11 @@ func (c *MatrixChannel) downloadMedia(
label := matrixMediaLabel(msgEvt, mediaKind)
ext := matrixMediaExt(label, matrixContentType(msgEvt), mediaKind)
tmp, err := os.CreateTemp("", "matrix-media-*"+ext)
mediaDir, err := matrixMediaTempDir()
if err != nil {
return "", fmt.Errorf("create matrix media directory: %w", err)
}
tmp, err := os.CreateTemp(mediaDir, "matrix-media-*"+ext)
if err != nil {
return "", err
}
@@ -777,11 +891,8 @@ func matrixMediaExt(filename, contentType, mediaKind string) string {
func (c *MatrixChannel) isGroupRoom(ctx context.Context, roomID id.RoomID) bool {
now := time.Now()
if cached, ok := c.roomKindCache.Load(roomID.String()); ok {
entry := cached.(roomKindCacheEntry)
if now.Before(entry.expiresAt) {
return entry.isGroup
}
if isGroup, ok := c.roomKindCache.get(roomID.String(), now); ok {
return isGroup
}
qctx := c.baseContext()
@@ -801,10 +912,7 @@ func (c *MatrixChannel) isGroupRoom(ctx context.Context, roomID id.RoomID) bool
}
isGroup := len(resp.Joined) > 2
c.roomKindCache.Store(roomID.String(), roomKindCacheEntry{
isGroup: isGroup,
expiresAt: now.Add(roomKindCacheTTL),
})
c.roomKindCache.set(roomID.String(), isGroup, now)
return isGroup
}
@@ -825,13 +933,17 @@ func (c *MatrixChannel) isBotMentioned(msgEvt *event.MessageEventContent) bool {
return true
}
localpart := matrixLocalpart(c.client.UserID)
if localpart == "" {
mentionR := c.localpartMentionR
if mentionR == nil {
mentionR = localpartMentionRegexp(matrixLocalpart(c.client.UserID))
}
if mentionR == nil {
return false
}
re := localpartMentionRegexp(localpart)
return re.MatchString(msgEvt.Body) || re.MatchString(msgEvt.FormattedBody)
// Matrix users are addressed as MXID "@localpart:server", but many clients
// emit plain-text mentions as "@localpart". Both forms are handled here.
return mentionR.MatchString(msgEvt.Body) || mentionR.MatchString(msgEvt.FormattedBody)
}
func mentionsUserInFormattedBody(formattedBody string, userID id.UserID) bool {
@@ -941,6 +1053,32 @@ func (c *MatrixChannel) baseContext() context.Context {
return context.Background()
}
func (c *MatrixChannel) runRoomKindCacheJanitor(ctx context.Context) {
ticker := time.NewTicker(roomKindCacheCleanupPeriod)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case now := <-ticker.C:
c.roomKindCache.cleanupExpired(now)
}
}
}
func (c *MatrixChannel) stripSelfMention(text string) string {
return stripUserMentionWithRegexp(text, c.client.UserID, c.localpartMentionR)
}
func matrixMediaTempDir() (string, error) {
mediaDir := filepath.Join(os.TempDir(), matrixMediaTempDirName)
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
return "", err
}
return mediaDir, nil
}
func matrixLocalpart(userID id.UserID) string {
s := strings.TrimPrefix(userID.String(), "@")
localpart, _, _ := strings.Cut(s, ":")
@@ -948,17 +1086,27 @@ func matrixLocalpart(userID id.UserID) string {
}
func localpartMentionRegexp(localpart string) *regexp.Regexp {
localpart = strings.TrimSpace(localpart)
if localpart == "" {
return nil
}
// Match Matrix mentions in plain text while avoiding false positives:
// "@picoclaw" and "@picoclaw:matrix.org" should match,
// "test@example.com" and "hellopicoclawworld" should not.
pattern := `(?i)(^|[^[:alnum:]_])@` + regexp.QuoteMeta(localpart) + `(?::[A-Za-z0-9._:-]+)?([^[:alnum:]_]|$)`
return regexp.MustCompile(pattern)
}
func stripUserMention(text string, userID id.UserID) string {
return stripUserMentionWithRegexp(text, userID, localpartMentionRegexp(matrixLocalpart(userID)))
}
func stripUserMentionWithRegexp(text string, userID id.UserID, mentionR *regexp.Regexp) string {
cleaned := strings.ReplaceAll(text, userID.String(), "")
localpart := matrixLocalpart(userID)
if localpart != "" {
re := localpartMentionRegexp(localpart)
cleaned = re.ReplaceAllString(cleaned, "$1$2")
if mentionR != nil {
cleaned = mentionR.ReplaceAllString(cleaned, "$1$2")
}
cleaned = strings.TrimSpace(cleaned)
+54
View File
@@ -2,7 +2,10 @@ package matrix
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
@@ -116,6 +119,57 @@ func TestIsBotMentioned(t *testing.T) {
}
}
func TestRoomKindCache_ExpiresEntries(t *testing.T) {
cache := newRoomKindCache(4, 5*time.Second)
now := time.Unix(100, 0)
cache.set("!room:matrix.org", true, now)
if got, ok := cache.get("!room:matrix.org", now.Add(2*time.Second)); !ok || !got {
t.Fatalf("expected cached group room before ttl, got ok=%v group=%v", ok, got)
}
if _, ok := cache.get("!room:matrix.org", now.Add(6*time.Second)); ok {
t.Fatal("expected cache miss after ttl expiry")
}
}
func TestRoomKindCache_EvictsOldestWhenFull(t *testing.T) {
cache := newRoomKindCache(2, time.Minute)
now := time.Unix(200, 0)
cache.set("!room1:matrix.org", false, now)
cache.set("!room2:matrix.org", false, now.Add(1*time.Second))
cache.set("!room3:matrix.org", true, now.Add(2*time.Second))
if _, ok := cache.get("!room1:matrix.org", now.Add(2*time.Second)); ok {
t.Fatal("expected oldest cache entry to be evicted")
}
if got, ok := cache.get("!room2:matrix.org", now.Add(2*time.Second)); !ok || got {
t.Fatalf("expected room2 to remain and be direct, got ok=%v group=%v", ok, got)
}
if got, ok := cache.get("!room3:matrix.org", now.Add(2*time.Second)); !ok || !got {
t.Fatalf("expected room3 to remain and be group, got ok=%v group=%v", ok, got)
}
}
func TestMatrixMediaTempDir(t *testing.T) {
dir, err := matrixMediaTempDir()
if err != nil {
t.Fatalf("matrixMediaTempDir failed: %v", err)
}
if filepath.Base(dir) != matrixMediaTempDirName {
t.Fatalf("unexpected media dir base: %q", filepath.Base(dir))
}
info, err := os.Stat(dir)
if err != nil {
t.Fatalf("media dir not created: %v", err)
}
if !info.IsDir() {
t.Fatalf("expected directory, got mode=%v", info.Mode())
}
}
func TestMatrixMediaExt(t *testing.T) {
if got := matrixMediaExt("photo.png", "", "image"); got != ".png" {
t.Fatalf("filename extension mismatch: got=%q", got)