WIP: fix COR-239

This commit is contained in:
HiteshRepo 2024-01-31 21:58:27 +05:30
parent 7e2b9dab62
commit 8762e62b2f
8 changed files with 88 additions and 8 deletions

View File

@ -1,6 +1,7 @@
package drive package drive
import ( import (
"bytes"
"context" "context"
"io" "io"
"net/http" "net/http"
@ -366,6 +367,14 @@ func downloadContent(
itemID := ptr.Val(item.GetId()) itemID := ptr.Val(item.GetId())
ctx = clues.Add(ctx, "item_id", itemID) ctx = clues.Add(ctx, "item_id", itemID)
// attempt to fetch item content directly via API
// before falling back to fetch content in chunks
contentBytes, err := iaag.GetItemContent(ctx, driveID, ptr.Val(item.GetId()))
if err == nil {
reader := bytes.NewReader(contentBytes)
return io.NopCloser(reader), nil
}
content, err := downloadItem(ctx, iaag, item) content, err := downloadItem(ctx, iaag, item)
if err == nil { if err == nil {
return content, nil return content, nil

View File

@ -733,7 +733,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() {
}, },
{ {
name: "expired url redownloads", name: "expired url redownloads",
mgi: getsItem{Item: itemWID, Err: nil}, mgi: getsItem{Item: itemWID, Err: nil, ContentErr: assert.AnError},
itemInfo: details.ItemInfo{}, itemInfo: details.ItemInfo{},
respBody: []io.ReadCloser{nil, iorc}, respBody: []io.ReadCloser{nil, iorc},
getErr: []error{errUnauth, nil}, getErr: []error{errUnauth, nil},
@ -744,6 +744,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() {
{ {
name: "immediate error", name: "immediate error",
itemInfo: details.ItemInfo{}, itemInfo: details.ItemInfo{},
mgi: getsItem{ContentErr: assert.AnError},
getErr: []error{assert.AnError}, getErr: []error{assert.AnError},
expectErr: require.Error, expectErr: require.Error,
expect: require.Nil, expect: require.Nil,
@ -753,14 +754,14 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() {
name: "re-fetching the item fails", name: "re-fetching the item fails",
itemInfo: details.ItemInfo{}, itemInfo: details.ItemInfo{},
getErr: []error{errUnauth}, getErr: []error{errUnauth},
mgi: getsItem{Item: nil, Err: assert.AnError}, mgi: getsItem{Item: nil, Err: assert.AnError, ContentErr: assert.AnError},
expectErr: require.Error, expectErr: require.Error,
expect: require.Nil, expect: require.Nil,
muc: m, muc: m,
}, },
{ {
name: "expired url fails redownload", name: "expired url fails redownload",
mgi: getsItem{Item: itemWID, Err: nil}, mgi: getsItem{Item: itemWID, Err: nil, ContentErr: assert.AnError},
itemInfo: details.ItemInfo{}, itemInfo: details.ItemInfo{},
respBody: []io.ReadCloser{nil, nil}, respBody: []io.ReadCloser{nil, nil},
getErr: []error{errUnauth, assert.AnError}, getErr: []error{errUnauth, assert.AnError},
@ -770,7 +771,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() {
}, },
{ {
name: "url refreshed from cache", name: "url refreshed from cache",
mgi: getsItem{Item: itemWID, Err: nil}, mgi: getsItem{Item: itemWID, Err: nil, ContentErr: assert.AnError},
itemInfo: details.ItemInfo{}, itemInfo: details.ItemInfo{},
respBody: []io.ReadCloser{nil, iorc}, respBody: []io.ReadCloser{nil, iorc},
getErr: []error{errUnauth, nil}, getErr: []error{errUnauth, nil},
@ -788,7 +789,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() {
}, },
{ {
name: "url refreshed from cache but item deleted", name: "url refreshed from cache but item deleted",
mgi: getsItem{Item: itemWID, Err: core.ErrNotFound}, mgi: getsItem{Item: itemWID, Err: core.ErrNotFound, ContentErr: assert.AnError},
itemInfo: details.ItemInfo{}, itemInfo: details.ItemInfo{},
respBody: []io.ReadCloser{nil, nil, nil}, respBody: []io.ReadCloser{nil, nil, nil},
getErr: []error{errUnauth, core.ErrNotFound, core.ErrNotFound}, getErr: []error{errUnauth, core.ErrNotFound, core.ErrNotFound},
@ -806,7 +807,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() {
}, },
{ {
name: "fallback to item fetch on any cache error", name: "fallback to item fetch on any cache error",
mgi: getsItem{Item: itemWID, Err: nil}, mgi: getsItem{Item: itemWID, Err: nil, ContentErr: assert.AnError},
itemInfo: details.ItemInfo{}, itemInfo: details.ItemInfo{},
respBody: []io.ReadCloser{nil, iorc}, respBody: []io.ReadCloser{nil, iorc},
getErr: []error{errUnauth, nil}, getErr: []error{errUnauth, nil},
@ -818,6 +819,14 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() {
}, },
}, },
}, },
{
name: "fetches item content via direct API call",
mgi: getsItem{Item: itemWID, Err: nil, ContentErr: nil},
itemInfo: details.ItemInfo{},
respBody: []io.ReadCloser{nil, iorc},
expectErr: require.NoError,
expect: require.NotNil,
},
} }
for _, test := range table { for _, test := range table {
suite.Run(test.name, func() { suite.Run(test.name, func() {

View File

@ -83,6 +83,11 @@ type GetItemer interface {
ctx context.Context, ctx context.Context,
driveID, itemID string, driveID, itemID string,
) (models.DriveItemable, error) ) (models.DriveItemable, error)
GetItemContent(
ctx context.Context,
driveID, itemID string,
) ([]byte, error)
} }
type EnumerateDriveItemsDeltaer interface { type EnumerateDriveItemsDeltaer interface {

View File

@ -869,6 +869,10 @@ func (h mockBackupHandler[T]) GetItem(ctx context.Context, _, _ string) (models.
return h.GI.GetItem(ctx, "", "") return h.GI.GetItem(ctx, "", "")
} }
func (h mockBackupHandler[T]) GetItemContent(ctx context.Context, _, _ string) ([]byte, error) {
return h.GI.GetItemContent(ctx, "", "")
}
func (h mockBackupHandler[T]) GetItemPermission( func (h mockBackupHandler[T]) GetItemPermission(
ctx context.Context, ctx context.Context,
_, _ string, _, _ string,
@ -976,8 +980,9 @@ func (h mockBackupHandler[T]) GetRootFolder(context.Context, string) (models.Dri
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
type getsItem struct { type getsItem struct {
Item models.DriveItemable Item models.DriveItemable
Err error Err error
ContentErr error
} }
func (m getsItem) GetItem( func (m getsItem) GetItem(
@ -987,6 +992,17 @@ func (m getsItem) GetItem(
return m.Item, m.Err return m.Item, m.Err
} }
func (m getsItem) GetItemContent(
_ context.Context,
_, _ string,
) ([]byte, error) {
if m.ContentErr != nil {
return nil, m.ContentErr
}
return []byte("fnords"), nil
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Drive Item Enummerator // Drive Item Enummerator
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------

View File

@ -167,6 +167,13 @@ func (h siteBackupHandler) GetItem(
return h.ac.GetItem(ctx, driveID, itemID) return h.ac.GetItem(ctx, driveID, itemID)
} }
func (h siteBackupHandler) GetItemContent(
ctx context.Context,
driveID, itemID string,
) ([]byte, error) {
return h.ac.GetItemContent(ctx, driveID, itemID)
}
func (h siteBackupHandler) IsAllPass() bool { func (h siteBackupHandler) IsAllPass() bool {
return h.scope.IsAny(selectors.SharePointLibraryFolder) return h.scope.IsAny(selectors.SharePointLibraryFolder)
} }

View File

@ -172,6 +172,13 @@ func (h userDriveBackupHandler) GetItem(
return h.ac.GetItem(ctx, driveID, itemID) return h.ac.GetItem(ctx, driveID, itemID)
} }
func (h userDriveBackupHandler) GetItemContent(
ctx context.Context,
driveID, itemID string,
) ([]byte, error) {
return h.ac.GetItemContent(ctx, driveID, itemID)
}
func (h userDriveBackupHandler) IsAllPass() bool { func (h userDriveBackupHandler) IsAllPass() bool {
return h.scope.IsAny(selectors.OneDriveFolder) return h.scope.IsAny(selectors.OneDriveFolder)
} }

View File

@ -225,6 +225,10 @@ func (h BackupHandler[T]) GetItem(ctx context.Context, _, _ string) (models.Driv
return h.GI.GetItem(ctx, "", "") return h.GI.GetItem(ctx, "", "")
} }
func (h BackupHandler[T]) GetItemContent(ctx context.Context, _, _ string) ([]byte, error) {
return h.GI.GetItemContent(ctx, "", "")
}
func (h BackupHandler[T]) GetItemPermission( func (h BackupHandler[T]) GetItemPermission(
ctx context.Context, ctx context.Context,
_, _ string, _, _ string,
@ -343,6 +347,13 @@ func (m GetsItem) GetItem(
return m.Item, m.Err return m.Item, m.Err
} }
func (m GetsItem) GetItemContent(
_ context.Context,
_, _ string,
) ([]byte, error) {
return nil, m.Err
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Drive Items Enumerator // Drive Items Enumerator
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------

View File

@ -136,6 +136,22 @@ func (c Drives) GetItem(
return di, nil return di, nil
} }
func (c Drives) GetItemContent(
ctx context.Context,
driveID, itemID string,
) ([]byte, error) {
dic, err := c.Stable.
Client().
Drives().
ByDriveId(driveID).
Items().
ByDriveItemId(itemID).Content().Get(ctx, nil)
return dic, graph.Wrap(ctx, err, "getting item content").OrNil()
}
func (c Drives) NewItemContentUpload( func (c Drives) NewItemContentUpload(
ctx context.Context, ctx context.Context,
driveID, itemID string, driveID, itemID string,