diff --git a/pkg/channels/matrix/matrix.go b/pkg/channels/matrix/matrix.go index a45207f12..bec5dfdac 100644 --- a/pkg/channels/matrix/matrix.go +++ b/pkg/channels/matrix/matrix.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "html" + "io" "mime" "net/url" "os" @@ -726,17 +727,23 @@ func (c *MatrixChannel) downloadMedia( reqCtx, cancel := context.WithTimeout(dlCtx, 20*time.Second) defer cancel() - data, err := c.client.DownloadBytes(reqCtx, parsed) + resp, err := c.client.Download(reqCtx, parsed) if err != nil { return "", err } + defer resp.Body.Close() + + reader := resp.Body + readerClose := func() error { return nil } // Encrypted attachments put URL in msgEvt.File and require client-side decryption. if msgEvt != nil && msgEvt.File != nil && msgEvt.URL == "" { - err = msgEvt.File.DecryptInPlace(data) - if err != nil { + if err = msgEvt.File.PrepareForDecryption(); err != nil { return "", fmt.Errorf("decrypt matrix media: %w", err) } + decryptReader := msgEvt.File.DecryptStream(resp.Body) + reader = decryptReader + readerClose = decryptReader.Close } label := matrixMediaLabel(msgEvt, mediaKind) @@ -749,14 +756,28 @@ func (c *MatrixChannel) downloadMedia( if err != nil { return "", err } - defer tmp.Close() + tmpPath := tmp.Name() + cleanup := true + defer func() { + _ = tmp.Close() + if cleanup { + _ = os.Remove(tmpPath) + } + }() - if _, err = tmp.Write(data); err != nil { - _ = os.Remove(tmp.Name()) + _, err = io.Copy(tmp, reader) + if err != nil { + return "", err + } + if err = readerClose(); err != nil { + return "", fmt.Errorf("decrypt matrix media: %w", err) + } + if err = tmp.Close(); err != nil { return "", err } - return tmp.Name(), nil + cleanup = false + return tmpPath, nil } func matrixContentType(msgEvt *event.MessageEventContent) string { diff --git a/pkg/channels/matrix/matrix_test.go b/pkg/channels/matrix/matrix_test.go index 806a98739..07a35c021 100644 --- a/pkg/channels/matrix/matrix_test.go +++ b/pkg/channels/matrix/matrix_test.go @@ -2,6 +2,8 @@ package matrix import ( "context" + "net/http" + "net/http/httptest" "os" "path/filepath" "strings" @@ -197,6 +199,50 @@ func TestMatrixMediaExt(t *testing.T) { } } +func TestDownloadMedia_WritesResponseToTempFile(t *testing.T) { + const wantBody = "matrix-media-payload" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasSuffix(r.URL.Path, "/_matrix/client/v1/media/download/matrix.test/abc123") { + t.Fatalf("unexpected download path: %s", r.URL.Path) + } + w.Header().Set("Content-Type", "image/png") + _, _ = w.Write([]byte(wantBody)) + })) + defer server.Close() + + client, err := mautrix.NewClient(server.URL, id.UserID("@picoclaw:matrix.test"), "") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + + ch := &MatrixChannel{client: client} + msg := &event.MessageEventContent{ + MsgType: event.MsgImage, + Body: "image.png", + URL: id.ContentURIString("mxc://matrix.test/abc123"), + Info: &event.FileInfo{MimeType: "image/png"}, + } + + path, err := ch.downloadMedia(context.Background(), msg, "image") + if err != nil { + t.Fatalf("downloadMedia: %v", err) + } + defer os.Remove(path) + + if ext := filepath.Ext(path); ext != ".png" { + t.Fatalf("temp file extension=%q want=.png", ext) + } + + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != wantBody { + t.Fatalf("file contents=%q want=%q", string(got), wantBody) + } +} + func TestExtractInboundContent_ImageNoURLFallback(t *testing.T) { ch := &MatrixChannel{} msg := &event.MessageEventContent{