From 6ebdc306edd9b1ee0d853bdad63c0fb418382eb7 Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Wed, 21 Dec 2022 11:17:43 +0100 Subject: [PATCH] [bugfix] Close reader gracefully when streaming recache of remote media to fileserver api caller (#1281) * close pipereader on failed data function * gently slurp the bytes * readability updates * go fmt * tidy up file server tests + add more cases * start moving io wrappers to separate iotools package. Remove use of buffering while piping recache stream Signed-off-by: kim * add license text Signed-off-by: kim Co-authored-by: kim --- .../api/client/fileserver/fileserver_test.go | 109 +++++ internal/api/client/fileserver/servefile.go | 13 +- .../api/client/fileserver/servefile_test.go | 381 ++++++++++-------- internal/iotools/io.go | 121 ++++++ internal/media/manager_test.go | 2 +- internal/processing/media/getfile.go | 66 +-- internal/processing/media/getfile_test.go | 11 +- internal/processing/media/util.go | 14 - 8 files changed, 503 insertions(+), 214 deletions(-) create mode 100644 internal/api/client/fileserver/fileserver_test.go create mode 100644 internal/iotools/io.go diff --git a/internal/api/client/fileserver/fileserver_test.go b/internal/api/client/fileserver/fileserver_test.go new file mode 100644 index 00000000..f1fab567 --- /dev/null +++ b/internal/api/client/fileserver/fileserver_test.go @@ -0,0 +1,109 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package fileserver_test + +import ( + "context" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/email" + "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type FileserverTestSuite struct { + // standard suite interfaces + suite.Suite + db db.DB + storage *storage.Driver + federator federation.Federator + tc typeutils.TypeConverter + processor processing.Processor + mediaManager media.Manager + oauthServer oauth.Server + emailSender email.Sender + + // standard suite models + testTokens map[string]*gtsmodel.Token + testClients map[string]*gtsmodel.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + testAttachments map[string]*gtsmodel.MediaAttachment + + // item being tested + fileServer *fileserver.FileServer +} + +/* + TEST INFRASTRUCTURE +*/ + +func (suite *FileserverTestSuite) SetupSuite() { + testrig.InitTestConfig() + testrig.InitTestLog() + + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + + suite.db = testrig.NewTestDB() + suite.storage = testrig.NewInMemoryStorage() + suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) + + suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, testrig.NewTestMediaManager(suite.db, suite.storage), clientWorker, fedWorker) + suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) + suite.oauthServer = testrig.NewTestOauthServer(suite.db) + + suite.fileServer = fileserver.New(suite.processor).(*fileserver.FileServer) +} + +func (suite *FileserverTestSuite) SetupTest() { + testrig.StandardDBSetup(suite.db, nil) + testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() +} + +func (suite *FileserverTestSuite) TearDownSuite() { + if err := suite.db.Stop(context.Background()); err != nil { + log.Panicf("error closing db connection: %s", err) + } +} + +func (suite *FileserverTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) + testrig.StandardStorageTeardown(suite.storage) +} diff --git a/internal/api/client/fileserver/servefile.go b/internal/api/client/fileserver/servefile.go index e4eca770..d2328a5f 100644 --- a/internal/api/client/fileserver/servefile.go +++ b/internal/api/client/fileserver/servefile.go @@ -19,7 +19,9 @@ package fileserver import ( + "bytes" "fmt" + "io" "net/http" "strconv" @@ -120,5 +122,14 @@ func (m *FileServer) ServeFile(c *gin.Context) { return } - c.DataFromReader(http.StatusOK, content.ContentLength, format, content.Content, nil) + // try to slurp the first few bytes to make sure we have something + b := bytes.NewBuffer(make([]byte, 0, 64)) + if _, err := io.CopyN(b, content.Content, 64); err != nil { + err = fmt.Errorf("ServeFile: error reading from content: %w", err) + api.ErrorHandler(c, gtserror.NewErrorNotFound(err, err.Error()), m.processor.InstanceGet) + return + } + + // we're good, return the slurped bytes + the rest of the content + c.DataFromReader(http.StatusOK, content.ContentLength, format, io.MultiReader(b, content.Content), nil) } diff --git a/internal/api/client/fileserver/servefile_test.go b/internal/api/client/fileserver/servefile_test.go index a6c46e23..1ca0c60d 100644 --- a/internal/api/client/fileserver/servefile_test.go +++ b/internal/api/client/fileserver/servefile_test.go @@ -20,196 +20,251 @@ package fileserver_test import ( "context" - "fmt" "io/ioutil" "net/http" "net/http/httptest" "testing" - "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/email" - "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/oauth" - "github.com/superseriousbusiness/gotosocial/internal/processing" - "github.com/superseriousbusiness/gotosocial/internal/storage" - "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" ) type ServeFileTestSuite struct { - // standard suite interfaces - suite.Suite - db db.DB - storage *storage.Driver - federator federation.Federator - tc typeutils.TypeConverter - processor processing.Processor - mediaManager media.Manager - oauthServer oauth.Server - emailSender email.Sender - - // standard suite models - testTokens map[string]*gtsmodel.Token - testClients map[string]*gtsmodel.Client - testApplications map[string]*gtsmodel.Application - testUsers map[string]*gtsmodel.User - testAccounts map[string]*gtsmodel.Account - testAttachments map[string]*gtsmodel.MediaAttachment - - // item being tested - fileServer *fileserver.FileServer + FileserverTestSuite } -/* - TEST INFRASTRUCTURE -*/ - -func (suite *ServeFileTestSuite) SetupSuite() { - // setup standard items - testrig.InitTestConfig() - testrig.InitTestLog() - - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() - suite.storage = testrig.NewInMemoryStorage() - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) - suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, testrig.NewTestMediaManager(suite.db, suite.storage), clientWorker, fedWorker) - suite.tc = testrig.NewTestTypeConverter(suite.db) - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.oauthServer = testrig.NewTestOauthServer(suite.db) - - // setup module being tested - suite.fileServer = fileserver.New(suite.processor).(*fileserver.FileServer) -} - -func (suite *ServeFileTestSuite) TearDownSuite() { - if err := suite.db.Stop(context.Background()); err != nil { - log.Panicf("error closing db connection: %s", err) - } -} - -func (suite *ServeFileTestSuite) SetupTest() { - testrig.StandardDBSetup(suite.db, nil) - testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - suite.testTokens = testrig.NewTestTokens() - suite.testClients = testrig.NewTestClients() - suite.testApplications = testrig.NewTestApplications() - suite.testUsers = testrig.NewTestUsers() - suite.testAccounts = testrig.NewTestAccounts() - suite.testAttachments = testrig.NewTestAttachments() -} - -func (suite *ServeFileTestSuite) TearDownTest() { - testrig.StandardDBTeardown(suite.db) - testrig.StandardStorageTeardown(suite.storage) -} - -/* - ACTUAL TESTS -*/ - -func (suite *ServeFileTestSuite) TestServeOriginalFileSuccessful() { - targetAttachment, ok := suite.testAttachments["admin_account_status_1_attachment_1"] - suite.True(ok) - suite.NotNil(targetAttachment) - +// GetFile is just a convenience function to save repetition in this test suite. +// It takes the required params to serve a file, calls the handler, and returns +// the http status code, the response headers, and the parsed body bytes. +func (suite *ServeFileTestSuite) GetFile( + accountID string, + mediaType media.Type, + mediaSize media.Size, + filename string, +) (code int, headers http.Header, body []byte) { recorder := httptest.NewRecorder() - ctx, _ := testrig.CreateGinTestContext(recorder, nil) - ctx.Request = httptest.NewRequest(http.MethodGet, targetAttachment.URL, nil) - ctx.Request.Header.Set("accept", "*/*") - // normally the router would populate these params from the path values, - // but because we're calling the ServeFile function directly, we need to set them manually. - ctx.Params = gin.Params{ - gin.Param{ - Key: fileserver.AccountIDKey, - Value: targetAttachment.AccountID, - }, - gin.Param{ - Key: fileserver.MediaTypeKey, - Value: string(media.TypeAttachment), - }, - gin.Param{ - Key: fileserver.MediaSizeKey, - Value: string(media.SizeOriginal), - }, - gin.Param{ - Key: fileserver.FileNameKey, - Value: fmt.Sprintf("%s.jpeg", targetAttachment.ID), - }, + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Request = httptest.NewRequest(http.MethodGet, "http://localhost:8080/whatever", nil) + ctx.Request.Header.Set("accept", "*/*") + ctx.AddParam(fileserver.AccountIDKey, accountID) + ctx.AddParam(fileserver.MediaTypeKey, string(mediaType)) + ctx.AddParam(fileserver.MediaSizeKey, string(mediaSize)) + ctx.AddParam(fileserver.FileNameKey, filename) + + suite.fileServer.ServeFile(ctx) + code = recorder.Code + headers = recorder.Result().Header + + var err error + body, err = ioutil.ReadAll(recorder.Body) + if err != nil { + suite.FailNow(err.Error()) } - // call the function we're testing and check status code - suite.fileServer.ServeFile(ctx) - suite.EqualValues(http.StatusOK, recorder.Code) - suite.EqualValues("image/jpeg", recorder.Header().Get("content-type")) - - b, err := ioutil.ReadAll(recorder.Body) - suite.NoError(err) - suite.NotNil(b) - - fileInStorage, err := suite.storage.Get(ctx, targetAttachment.File.Path) - suite.NoError(err) - suite.NotNil(fileInStorage) - suite.Equal(b, fileInStorage) + return } -func (suite *ServeFileTestSuite) TestServeSmallFileSuccessful() { - targetAttachment, ok := suite.testAttachments["admin_account_status_1_attachment_1"] - suite.True(ok) - suite.NotNil(targetAttachment) +// UncacheAttachment is a convenience function that uncaches the targetAttachment by +// removing its associated files from storage, and updating the database. +func (suite *ServeFileTestSuite) UncacheAttachment(targetAttachment *gtsmodel.MediaAttachment) { + ctx := context.Background() - recorder := httptest.NewRecorder() - ctx, _ := testrig.CreateGinTestContext(recorder, nil) - ctx.Request = httptest.NewRequest(http.MethodGet, targetAttachment.Thumbnail.URL, nil) - ctx.Request.Header.Set("accept", "*/*") + cached := false + targetAttachment.Cached = &cached - // normally the router would populate these params from the path values, - // but because we're calling the ServeFile function directly, we need to set them manually. - ctx.Params = gin.Params{ - gin.Param{ - Key: fileserver.AccountIDKey, - Value: targetAttachment.AccountID, - }, - gin.Param{ - Key: fileserver.MediaTypeKey, - Value: string(media.TypeAttachment), - }, - gin.Param{ - Key: fileserver.MediaSizeKey, - Value: string(media.SizeSmall), - }, - gin.Param{ - Key: fileserver.FileNameKey, - Value: fmt.Sprintf("%s.jpeg", targetAttachment.ID), - }, + if err := suite.db.UpdateByID(ctx, targetAttachment, targetAttachment.ID, "cached"); err != nil { + suite.FailNow(err.Error()) + } + if err := suite.storage.Delete(ctx, targetAttachment.File.Path); err != nil { + suite.FailNow(err.Error()) + } + if err := suite.storage.Delete(ctx, targetAttachment.Thumbnail.Path); err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *ServeFileTestSuite) TestServeOriginalLocalFileOK() { + targetAttachment := >smodel.MediaAttachment{} + *targetAttachment = *suite.testAttachments["admin_account_status_1_attachment_1"] + fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.File.Path) + if err != nil { + suite.FailNow(err.Error()) } - // call the function we're testing and check status code - suite.fileServer.ServeFile(ctx) - suite.EqualValues(http.StatusOK, recorder.Code) - suite.EqualValues("image/jpeg", recorder.Header().Get("content-type")) + code, headers, body := suite.GetFile( + targetAttachment.AccountID, + media.TypeAttachment, + media.SizeOriginal, + targetAttachment.ID+".jpeg", + ) - b, err := ioutil.ReadAll(recorder.Body) - suite.NoError(err) - suite.NotNil(b) + suite.Equal(http.StatusOK, code) + suite.Equal("image/jpeg", headers.Get("content-type")) + suite.Equal(fileInStorage, body) +} - fileInStorage, err := suite.storage.Get(ctx, targetAttachment.Thumbnail.Path) - suite.NoError(err) - suite.NotNil(fileInStorage) - suite.Equal(b, fileInStorage) +func (suite *ServeFileTestSuite) TestServeSmallLocalFileOK() { + targetAttachment := >smodel.MediaAttachment{} + *targetAttachment = *suite.testAttachments["admin_account_status_1_attachment_1"] + fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.Thumbnail.Path) + if err != nil { + suite.FailNow(err.Error()) + } + + code, headers, body := suite.GetFile( + targetAttachment.AccountID, + media.TypeAttachment, + media.SizeSmall, + targetAttachment.ID+".jpeg", + ) + + suite.Equal(http.StatusOK, code) + suite.Equal("image/jpeg", headers.Get("content-type")) + suite.Equal(fileInStorage, body) +} + +func (suite *ServeFileTestSuite) TestServeOriginalRemoteFileOK() { + targetAttachment := >smodel.MediaAttachment{} + *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] + fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.File.Path) + if err != nil { + suite.FailNow(err.Error()) + } + + code, headers, body := suite.GetFile( + targetAttachment.AccountID, + media.TypeAttachment, + media.SizeOriginal, + targetAttachment.ID+".jpeg", + ) + + suite.Equal(http.StatusOK, code) + suite.Equal("image/jpeg", headers.Get("content-type")) + suite.Equal(fileInStorage, body) +} + +func (suite *ServeFileTestSuite) TestServeSmallRemoteFileOK() { + targetAttachment := >smodel.MediaAttachment{} + *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] + fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.Thumbnail.Path) + if err != nil { + suite.FailNow(err.Error()) + } + + code, headers, body := suite.GetFile( + targetAttachment.AccountID, + media.TypeAttachment, + media.SizeSmall, + targetAttachment.ID+".jpeg", + ) + + suite.Equal(http.StatusOK, code) + suite.Equal("image/jpeg", headers.Get("content-type")) + suite.Equal(fileInStorage, body) +} + +func (suite *ServeFileTestSuite) TestServeOriginalRemoteFileRecache() { + targetAttachment := >smodel.MediaAttachment{} + *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] + fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.File.Path) + if err != nil { + suite.FailNow(err.Error()) + } + + // uncache the attachment so we'll have to refetch it from the 'remote' instance + suite.UncacheAttachment(targetAttachment) + + code, headers, body := suite.GetFile( + targetAttachment.AccountID, + media.TypeAttachment, + media.SizeOriginal, + targetAttachment.ID+".jpeg", + ) + + suite.Equal(http.StatusOK, code) + suite.Equal("image/jpeg", headers.Get("content-type")) + suite.Equal(fileInStorage, body) +} + +func (suite *ServeFileTestSuite) TestServeSmallRemoteFileRecache() { + targetAttachment := >smodel.MediaAttachment{} + *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] + fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.Thumbnail.Path) + if err != nil { + suite.FailNow(err.Error()) + } + + // uncache the attachment so we'll have to refetch it from the 'remote' instance + suite.UncacheAttachment(targetAttachment) + + code, headers, body := suite.GetFile( + targetAttachment.AccountID, + media.TypeAttachment, + media.SizeSmall, + targetAttachment.ID+".jpeg", + ) + + suite.Equal(http.StatusOK, code) + suite.Equal("image/jpeg", headers.Get("content-type")) + suite.Equal(fileInStorage, body) +} + +func (suite *ServeFileTestSuite) TestServeOriginalRemoteFileRecacheNotFound() { + targetAttachment := >smodel.MediaAttachment{} + *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] + + // uncache the attachment *and* set the remote URL to something that will return a 404 + suite.UncacheAttachment(targetAttachment) + targetAttachment.RemoteURL = "http://nothing.at.this.url/weeeeeeeee" + if err := suite.db.UpdateByID(context.Background(), targetAttachment, targetAttachment.ID, "remote_url"); err != nil { + suite.FailNow(err.Error()) + } + + code, _, _ := suite.GetFile( + targetAttachment.AccountID, + media.TypeAttachment, + media.SizeOriginal, + targetAttachment.ID+".jpeg", + ) + + suite.Equal(http.StatusNotFound, code) +} + +func (suite *ServeFileTestSuite) TestServeSmallRemoteFileRecacheNotFound() { + targetAttachment := >smodel.MediaAttachment{} + *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"] + + // uncache the attachment *and* set the remote URL to something that will return a 404 + suite.UncacheAttachment(targetAttachment) + targetAttachment.RemoteURL = "http://nothing.at.this.url/weeeeeeeee" + if err := suite.db.UpdateByID(context.Background(), targetAttachment, targetAttachment.ID, "remote_url"); err != nil { + suite.FailNow(err.Error()) + } + + code, _, _ := suite.GetFile( + targetAttachment.AccountID, + media.TypeAttachment, + media.SizeSmall, + targetAttachment.ID+".jpeg", + ) + + suite.Equal(http.StatusNotFound, code) +} + +// Callers trying to get some random-ass file that doesn't exist should just get a 404 +func (suite *ServeFileTestSuite) TestServeFileNotFound() { + code, _, _ := suite.GetFile( + "01GMMY4G9B0QEG0PQK5Q5JGJWZ", + media.TypeAttachment, + media.SizeOriginal, + "01GMMY68Y7E5DJ3CA3Y9SS8524.jpeg", + ) + + suite.Equal(http.StatusNotFound, code) } func TestServeFileTestSuite(t *testing.T) { diff --git a/internal/iotools/io.go b/internal/iotools/io.go new file mode 100644 index 00000000..d16a4ce9 --- /dev/null +++ b/internal/iotools/io.go @@ -0,0 +1,121 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package iotools + +import ( + "io" +) + +// ReadFnCloser takes an io.Reader and wraps it to use the provided function to implement io.Closer. +func ReadFnCloser(r io.Reader, close func() error) io.ReadCloser { + return &readFnCloser{ + Reader: r, + close: close, + } +} + +type readFnCloser struct { + io.Reader + close func() error +} + +func (r *readFnCloser) Close() error { + return r.close() +} + +// WriteFnCloser takes an io.Writer and wraps it to use the provided function to implement io.Closer. +func WriteFnCloser(w io.Writer, close func() error) io.WriteCloser { + return &writeFnCloser{ + Writer: w, + close: close, + } +} + +type writeFnCloser struct { + io.Writer + close func() error +} + +func (r *writeFnCloser) Close() error { + return r.close() +} + +// SilentReader wraps an io.Reader to silence any +// error output during reads. Instead they are stored +// and accessible (not concurrency safe!) via .Error(). +type SilentReader struct { + io.Reader + err error +} + +// SilenceReader wraps an io.Reader within SilentReader{}. +func SilenceReader(r io.Reader) *SilentReader { + return &SilentReader{Reader: r} +} + +func (r *SilentReader) Read(b []byte) (int, error) { + n, err := r.Reader.Read(b) + if err != nil { + // Store error for now + if r.err == nil { + r.err = err + } + + // Pretend we're happy + // to continue reading. + n = len(b) + } + return n, nil +} + +func (r *SilentReader) Error() error { + return r.err +} + +// SilentWriter wraps an io.Writer to silence any +// error output during writes. Instead they are stored +// and accessible (not concurrency safe!) via .Error(). +type SilentWriter struct { + io.Writer + err error +} + +// SilenceWriter wraps an io.Writer within SilentWriter{}. +func SilenceWriter(w io.Writer) *SilentWriter { + return &SilentWriter{Writer: w} +} + +func (w *SilentWriter) Write(b []byte) (int, error) { + n, err := w.Writer.Write(b) + if err != nil { + // Store error for now + if w.err == nil { + w.err = err + } + + // Pretend we're happy + // to continue writing. + n = len(b) + } + return n, nil +} + +func (w *SilentWriter) Error() error { + return w.err +} diff --git a/internal/media/manager_test.go b/internal/media/manager_test.go index a8912bde..f9361a83 100644 --- a/internal/media/manager_test.go +++ b/internal/media/manager_test.go @@ -440,7 +440,7 @@ func (suite *ManagerTestSuite) TestSlothVineProcessBlocking() { processedThumbnailBytes, err := suite.storage.Get(ctx, attachment.Thumbnail.Path) suite.NoError(err) suite.NotEmpty(processedThumbnailBytes) - + processedThumbnailBytesExpected, err := os.ReadFile("./test/test-mp4-thumbnail.jpg") suite.NoError(err) suite.NotEmpty(processedThumbnailBytesExpected) diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go index ddc14479..eba3fdb7 100644 --- a/internal/processing/media/getfile.go +++ b/internal/processing/media/getfile.go @@ -19,7 +19,6 @@ package media import ( - "bufio" "context" "fmt" "io" @@ -29,7 +28,7 @@ import ( apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/iotools" "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/uris" @@ -135,7 +134,6 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount } var data media.DataFunc - var postDataCallback media.PostDataCallbackFunc if mediaSize == media.SizeSmall { // if it's the thumbnail that's requested then the user will have to wait a bit while we process the @@ -155,7 +153,7 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount // // this looks a bit like this: // - // http fetch buffered pipe + // http fetch pipe // remote server ------------> data function ----------------> api caller // | // | tee @@ -163,54 +161,58 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount // ▼ // instance storage - // Buffer each end of the pipe, so that if the caller drops the connection during the flow, the tee - // reader can continue without having to worry about tee-ing into a closed or blocked pipe. + // This pipe will connect the caller to the in-process media retrieval... pipeReader, pipeWriter := io.Pipe() - bufferedWriter := bufio.NewWriterSize(pipeWriter, int(attachmentContent.ContentLength)) - bufferedReader := bufio.NewReaderSize(pipeReader, int(attachmentContent.ContentLength)) - // the caller will read from the buffered reader, so it doesn't matter if they drop out without reading everything - attachmentContent.Content = io.NopCloser(bufferedReader) + // Wrap the output pipe to silence any errors during the actual media + // streaming process. We catch the error later but they must be silenced + // during stream to prevent interruptions to storage of the actual media. + silencedWriter := iotools.SilenceWriter(pipeWriter) + // Pass the reader side of the pipe to the caller to slurp from. + attachmentContent.Content = pipeReader + + // Create a data function which injects the writer end of the pipe + // into the data retrieval process. If something goes wrong while + // doing the data retrieval, we hang up the underlying pipeReader + // to indicate to the caller that no data is available. It's up to + // the caller of this processor function to handle that gracefully. data = func(innerCtx context.Context) (io.ReadCloser, int64, error) { t, err := p.transportController.NewTransportForUsername(innerCtx, requestingUsername) if err != nil { + // propagate the transport error to read end of pipe. + _ = pipeWriter.CloseWithError(fmt.Errorf("error getting transport for user: %w", err)) return nil, 0, err } readCloser, fileSize, err := t.DereferenceMedia(transport.WithFastfail(innerCtx), remoteMediaIRI) if err != nil { + // propagate the dereference error to read end of pipe. + _ = pipeWriter.CloseWithError(fmt.Errorf("error dereferencing media: %w", err)) return nil, 0, err } - // Make a TeeReader so that everything read from the readCloser by the media manager will be written into the bufferedWriter. - // We wrap this in a teeReadCloser which implements io.ReadCloser, so that whoever uses the teeReader can close the readCloser - // when they're done with it. - trc := teeReadCloser{ - teeReader: io.TeeReader(readCloser, bufferedWriter), - close: readCloser.Close, - } + // Make a TeeReader so that everything read from the readCloser, + // aka the remote instance, will also be written into the pipe. + teeReader := io.TeeReader(readCloser, silencedWriter) - return trc, fileSize, nil - } + // Wrap teereader to implement original readcloser's close, + // and also ensuring that we close the pipe from write end. + return iotools.ReadFnCloser(teeReader, func() error { + defer func() { + // We use the error (if any) encountered by the + // silenced writer to close connection to make sure it + // gets propagated to the attachment.Content reader. + _ = pipeWriter.CloseWithError(silencedWriter.Error()) + }() - // close the pipewriter after data has been piped into it, so the reader on the other side doesn't block; - // we don't need to close the reader here because that's the caller's responsibility - postDataCallback = func(innerCtx context.Context) error { - // close the underlying pipe writer when we're done with it - defer func() { - if err := pipeWriter.Close(); err != nil { - log.Errorf("getAttachmentContent: error closing pipeWriter: %s", err) - } - }() - - // and flush the buffered writer into the buffer of the reader - return bufferedWriter.Flush() + return readCloser.Close() + }), fileSize, nil } } // put the media recached in the queue - processingMedia, err := p.mediaManager.RecacheMedia(ctx, data, postDataCallback, wantedMediaID) + processingMedia, err := p.mediaManager.RecacheMedia(ctx, data, nil, wantedMediaID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error recaching media: %s", err)) } diff --git a/internal/processing/media/getfile_test.go b/internal/processing/media/getfile_test.go index ba726953..7b978691 100644 --- a/internal/processing/media/getfile_test.go +++ b/internal/processing/media/getfile_test.go @@ -19,6 +19,7 @@ package media_test import ( + "bytes" "context" "io" "path" @@ -143,9 +144,13 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncachedInterrupted() { suite.NotNil(content) // only read the first kilobyte and then stop - b := make([]byte, 1024) - _, err = content.Content.Read(b) - suite.NoError(err) + b := make([]byte, 0, 1024) + if !testrig.WaitFor(func() bool { + read, err := io.CopyN(bytes.NewBuffer(b), content.Content, 1024) + return err == nil && read == 1024 + }) { + suite.FailNow("timed out trying to read first 1024 bytes") + } // close the reader suite.NoError(content.Content.Close()) diff --git a/internal/processing/media/util.go b/internal/processing/media/util.go index 9739e70b..37dc8797 100644 --- a/internal/processing/media/util.go +++ b/internal/processing/media/util.go @@ -20,7 +20,6 @@ package media import ( "fmt" - "io" "strconv" "strings" ) @@ -62,16 +61,3 @@ func parseFocus(focus string) (focusx, focusy float32, err error) { focusy = float32(fy) return } - -type teeReadCloser struct { - teeReader io.Reader - close func() error -} - -func (t teeReadCloser) Read(p []byte) (n int, err error) { - return t.teeReader.Read(p) -} - -func (t teeReadCloser) Close() error { - return t.close() -}