From 8762e62b2f7470c6e0efb710c334e521960071d7 Mon Sep 17 00:00:00 2001 From: HiteshRepo Date: Wed, 31 Jan 2024 21:58:27 +0530 Subject: [PATCH] WIP: fix COR-239 --- .../m365/collection/drive/collection.go | 9 ++++++++ .../m365/collection/drive/collection_test.go | 21 +++++++++++++------ .../m365/collection/drive/handlers.go | 5 +++++ .../m365/collection/drive/helper_test.go | 20 ++++++++++++++++-- .../m365/collection/drive/site_handler.go | 7 +++++++ .../collection/drive/user_drive_handler.go | 7 +++++++ .../m365/service/onedrive/mock/handlers.go | 11 ++++++++++ src/pkg/services/m365/api/drive.go | 16 ++++++++++++++ 8 files changed, 88 insertions(+), 8 deletions(-) diff --git a/src/internal/m365/collection/drive/collection.go b/src/internal/m365/collection/drive/collection.go index 7ef897fed..c74ede4a7 100644 --- a/src/internal/m365/collection/drive/collection.go +++ b/src/internal/m365/collection/drive/collection.go @@ -1,6 +1,7 @@ package drive import ( + "bytes" "context" "io" "net/http" @@ -366,6 +367,14 @@ func downloadContent( itemID := ptr.Val(item.GetId()) ctx = clues.Add(ctx, "item_id", itemID) + // attempt to fetch item content directly via API + // before falling back to fetch content in chunks + contentBytes, err := iaag.GetItemContent(ctx, driveID, ptr.Val(item.GetId())) + if err == nil { + reader := bytes.NewReader(contentBytes) + return io.NopCloser(reader), nil + } + content, err := downloadItem(ctx, iaag, item) if err == nil { return content, nil diff --git a/src/internal/m365/collection/drive/collection_test.go b/src/internal/m365/collection/drive/collection_test.go index ed4d72151..86513f942 100644 --- a/src/internal/m365/collection/drive/collection_test.go +++ b/src/internal/m365/collection/drive/collection_test.go @@ -733,7 +733,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { }, { name: "expired url redownloads", - mgi: getsItem{Item: itemWID, Err: nil}, + mgi: getsItem{Item: itemWID, Err: nil, ContentErr: assert.AnError}, itemInfo: details.ItemInfo{}, respBody: []io.ReadCloser{nil, iorc}, getErr: []error{errUnauth, nil}, @@ -744,6 +744,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { { name: "immediate error", itemInfo: details.ItemInfo{}, + mgi: getsItem{ContentErr: assert.AnError}, getErr: []error{assert.AnError}, expectErr: require.Error, expect: require.Nil, @@ -753,14 +754,14 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { name: "re-fetching the item fails", itemInfo: details.ItemInfo{}, getErr: []error{errUnauth}, - mgi: getsItem{Item: nil, Err: assert.AnError}, + mgi: getsItem{Item: nil, Err: assert.AnError, ContentErr: assert.AnError}, expectErr: require.Error, expect: require.Nil, muc: m, }, { name: "expired url fails redownload", - mgi: getsItem{Item: itemWID, Err: nil}, + mgi: getsItem{Item: itemWID, Err: nil, ContentErr: assert.AnError}, itemInfo: details.ItemInfo{}, respBody: []io.ReadCloser{nil, nil}, getErr: []error{errUnauth, assert.AnError}, @@ -770,7 +771,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { }, { name: "url refreshed from cache", - mgi: getsItem{Item: itemWID, Err: nil}, + mgi: getsItem{Item: itemWID, Err: nil, ContentErr: assert.AnError}, itemInfo: details.ItemInfo{}, respBody: []io.ReadCloser{nil, iorc}, getErr: []error{errUnauth, nil}, @@ -788,7 +789,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { }, { name: "url refreshed from cache but item deleted", - mgi: getsItem{Item: itemWID, Err: core.ErrNotFound}, + mgi: getsItem{Item: itemWID, Err: core.ErrNotFound, ContentErr: assert.AnError}, itemInfo: details.ItemInfo{}, respBody: []io.ReadCloser{nil, nil, nil}, getErr: []error{errUnauth, core.ErrNotFound, core.ErrNotFound}, @@ -806,7 +807,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { }, { name: "fallback to item fetch on any cache error", - mgi: getsItem{Item: itemWID, Err: nil}, + mgi: getsItem{Item: itemWID, Err: nil, ContentErr: assert.AnError}, itemInfo: details.ItemInfo{}, respBody: []io.ReadCloser{nil, iorc}, getErr: []error{errUnauth, nil}, @@ -818,6 +819,14 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { }, }, }, + { + name: "fetches item content via direct API call", + mgi: getsItem{Item: itemWID, Err: nil, ContentErr: nil}, + itemInfo: details.ItemInfo{}, + respBody: []io.ReadCloser{nil, iorc}, + expectErr: require.NoError, + expect: require.NotNil, + }, } for _, test := range table { suite.Run(test.name, func() { diff --git a/src/internal/m365/collection/drive/handlers.go b/src/internal/m365/collection/drive/handlers.go index 1a2b98479..b2d4fc991 100644 --- a/src/internal/m365/collection/drive/handlers.go +++ b/src/internal/m365/collection/drive/handlers.go @@ -83,6 +83,11 @@ type GetItemer interface { ctx context.Context, driveID, itemID string, ) (models.DriveItemable, error) + + GetItemContent( + ctx context.Context, + driveID, itemID string, + ) ([]byte, error) } type EnumerateDriveItemsDeltaer interface { diff --git a/src/internal/m365/collection/drive/helper_test.go b/src/internal/m365/collection/drive/helper_test.go index d00a5edf4..b3809f48e 100644 --- a/src/internal/m365/collection/drive/helper_test.go +++ b/src/internal/m365/collection/drive/helper_test.go @@ -869,6 +869,10 @@ func (h mockBackupHandler[T]) GetItem(ctx context.Context, _, _ string) (models. return h.GI.GetItem(ctx, "", "") } +func (h mockBackupHandler[T]) GetItemContent(ctx context.Context, _, _ string) ([]byte, error) { + return h.GI.GetItemContent(ctx, "", "") +} + func (h mockBackupHandler[T]) GetItemPermission( ctx context.Context, _, _ string, @@ -976,8 +980,9 @@ func (h mockBackupHandler[T]) GetRootFolder(context.Context, string) (models.Dri // --------------------------------------------------------------------------- type getsItem struct { - Item models.DriveItemable - Err error + Item models.DriveItemable + Err error + ContentErr error } func (m getsItem) GetItem( @@ -987,6 +992,17 @@ func (m getsItem) GetItem( return m.Item, m.Err } +func (m getsItem) GetItemContent( + _ context.Context, + _, _ string, +) ([]byte, error) { + if m.ContentErr != nil { + return nil, m.ContentErr + } + + return []byte("fnords"), nil +} + // --------------------------------------------------------------------------- // Drive Item Enummerator // --------------------------------------------------------------------------- diff --git a/src/internal/m365/collection/drive/site_handler.go b/src/internal/m365/collection/drive/site_handler.go index 189131e04..d3277f7ac 100644 --- a/src/internal/m365/collection/drive/site_handler.go +++ b/src/internal/m365/collection/drive/site_handler.go @@ -167,6 +167,13 @@ func (h siteBackupHandler) GetItem( return h.ac.GetItem(ctx, driveID, itemID) } +func (h siteBackupHandler) GetItemContent( + ctx context.Context, + driveID, itemID string, +) ([]byte, error) { + return h.ac.GetItemContent(ctx, driveID, itemID) +} + func (h siteBackupHandler) IsAllPass() bool { return h.scope.IsAny(selectors.SharePointLibraryFolder) } diff --git a/src/internal/m365/collection/drive/user_drive_handler.go b/src/internal/m365/collection/drive/user_drive_handler.go index fcf3943d0..fc11704fe 100644 --- a/src/internal/m365/collection/drive/user_drive_handler.go +++ b/src/internal/m365/collection/drive/user_drive_handler.go @@ -172,6 +172,13 @@ func (h userDriveBackupHandler) GetItem( return h.ac.GetItem(ctx, driveID, itemID) } +func (h userDriveBackupHandler) GetItemContent( + ctx context.Context, + driveID, itemID string, +) ([]byte, error) { + return h.ac.GetItemContent(ctx, driveID, itemID) +} + func (h userDriveBackupHandler) IsAllPass() bool { return h.scope.IsAny(selectors.OneDriveFolder) } diff --git a/src/internal/m365/service/onedrive/mock/handlers.go b/src/internal/m365/service/onedrive/mock/handlers.go index e56c7bab5..b28e0bac1 100644 --- a/src/internal/m365/service/onedrive/mock/handlers.go +++ b/src/internal/m365/service/onedrive/mock/handlers.go @@ -225,6 +225,10 @@ func (h BackupHandler[T]) GetItem(ctx context.Context, _, _ string) (models.Driv return h.GI.GetItem(ctx, "", "") } +func (h BackupHandler[T]) GetItemContent(ctx context.Context, _, _ string) ([]byte, error) { + return h.GI.GetItemContent(ctx, "", "") +} + func (h BackupHandler[T]) GetItemPermission( ctx context.Context, _, _ string, @@ -343,6 +347,13 @@ func (m GetsItem) GetItem( return m.Item, m.Err } +func (m GetsItem) GetItemContent( + _ context.Context, + _, _ string, +) ([]byte, error) { + return nil, m.Err +} + // --------------------------------------------------------------------------- // Drive Items Enumerator // --------------------------------------------------------------------------- diff --git a/src/pkg/services/m365/api/drive.go b/src/pkg/services/m365/api/drive.go index 5394ad309..225b3c61e 100644 --- a/src/pkg/services/m365/api/drive.go +++ b/src/pkg/services/m365/api/drive.go @@ -136,6 +136,22 @@ func (c Drives) GetItem( return di, nil } +func (c Drives) GetItemContent( + ctx context.Context, + driveID, itemID string, +) ([]byte, error) { + + dic, err := c.Stable. + Client(). + Drives(). + ByDriveId(driveID). + Items(). + ByDriveItemId(itemID).Content().Get(ctx, nil) + + return dic, graph.Wrap(ctx, err, "getting item content").OrNil() + +} + func (c Drives) NewItemContentUpload( ctx context.Context, driveID, itemID string,