From b63f45acb41d769e967c8a16708850fae8de1a97 Mon Sep 17 00:00:00 2001 From: ryanfkeepers Date: Fri, 27 Jan 2023 16:23:57 -0700 Subject: [PATCH] separate, test streamItems In oneDrive, breaks the code for streaming each item out into its own function. Adds a unit test that uses mocks to test the new stream func. --- src/internal/connector/onedrive/collection.go | 509 ++++++++++-------- .../connector/onedrive/collection_test.go | 215 +++++++- 2 files changed, 500 insertions(+), 224 deletions(-) diff --git a/src/internal/connector/onedrive/collection.go b/src/internal/connector/onedrive/collection.go index 343a8911e..78849e0be 100644 --- a/src/internal/connector/onedrive/collection.go +++ b/src/internal/connector/onedrive/collection.go @@ -61,13 +61,16 @@ type Collection struct { // M365 IDs of file items within this collection driveItems map[string]models.DriveItemable // M365 ID of the drive this collection was created from - driveID string - source driveSource - service graph.Servicer - statusUpdater support.StatusUpdater + driveID string + source driveSource + service graph.Servicer + statusUpdater support.StatusUpdater + ctrl control.Options + + // TODO: these should be interfaces, not funcs itemReader itemReaderFunc itemMetaReader itemMetaReaderFunc - ctrl control.Options + itemGetter itemGetterFunc // should only be true if the old delta token expired doNotMergeItems bool @@ -88,6 +91,12 @@ type itemMetaReaderFunc func( item models.DriveItemable, ) (io.ReadCloser, int, error) +type itemGetterFunc func( + ctx context.Context, + srv graph.Servicer, + driveID, itemID string, +) (models.DriveItemable, error) + // NewCollection creates a Collection func NewCollection( itemClient *http.Client, @@ -99,15 +108,16 @@ func NewCollection( ctrlOpts control.Options, ) *Collection { c := &Collection{ - itemClient: itemClient, - folderPath: folderPath, - driveItems: map[string]models.DriveItemable{}, - driveID: driveID, - source: source, - service: service, - data: make(chan data.Stream, collectionChannelBufferSize), - statusUpdater: statusUpdater, ctrl: ctrlOpts, + data: make(chan data.Stream, collectionChannelBufferSize), + driveID: driveID, + driveItems: map[string]models.DriveItemable{}, + folderPath: folderPath, + itemClient: itemClient, + itemGetter: getDriveItem, + service: service, + source: source, + statusUpdater: statusUpdater, } // Allows tests to set a mock populator @@ -203,6 +213,13 @@ func (od *Item) ModTime() time.Time { // populateItems iterates through items added to the collection // and uses the collection `itemReader` to read the item func (oc *Collection) populateItems(ctx context.Context) { + // Retrieve the OneDrive folder path to set later in `details.OneDriveInfo` + parentPathString, err := path.GetDriveFolderPath(oc.folderPath) + if err != nil { + oc.reportAsCompleted(ctx, 0, 0, 0, err) + return + } + var ( errs error byteCount int64 @@ -210,17 +227,23 @@ func (oc *Collection) populateItems(ctx context.Context) { dirsRead int64 itemsFound int64 dirsFound int64 - wg sync.WaitGroup - m sync.Mutex - ) - // Retrieve the OneDrive folder path to set later in - // `details.OneDriveInfo` - parentPathString, err := path.GetDriveFolderPath(oc.folderPath) - if err != nil { - oc.reportAsCompleted(ctx, 0, 0, 0, err) - return - } + wg sync.WaitGroup + m sync.Mutex + + errUpdater = func(id string, err error) { + m.Lock() + defer m.Unlock() + errs = support.WrapAndAppend(id, err, errs) + } + countUpdater = func(size, dirs, items, dReads, iReads int64) { + atomic.AddInt64(&dirsRead, dReads) + atomic.AddInt64(&itemsRead, iReads) + atomic.AddInt64(&byteCount, size) + atomic.AddInt64(&dirsFound, dirs) + atomic.AddInt64(&itemsFound, dirs) + } + ) folderProgress, colCloser := observe.ProgressWithCount( ctx, @@ -233,12 +256,6 @@ func (oc *Collection) populateItems(ctx context.Context) { semaphoreCh := make(chan struct{}, urlPrefetchChannelBufferSize) defer close(semaphoreCh) - errUpdater := func(id string, err error) { - m.Lock() - errs = support.WrapAndAppend(id, err, errs) - m.Unlock() - } - for _, item := range oc.driveItems { if oc.ctrl.FailFast && errs != nil { break @@ -248,198 +265,270 @@ func (oc *Collection) populateItems(ctx context.Context) { wg.Add(1) - go func(item models.DriveItemable) { - defer wg.Done() - defer func() { <-semaphoreCh }() - - // Read the item - var ( - itemID = *item.GetId() - itemName = *item.GetName() - itemSize = *item.GetSize() - itemInfo details.ItemInfo - itemMeta io.ReadCloser - itemMetaSize int - metaSuffix string - err error - ) - - isFile := item.GetFile() != nil - - if isFile { - atomic.AddInt64(&itemsFound, 1) - - metaSuffix = MetaFileSuffix - } else { - atomic.AddInt64(&dirsFound, 1) - - metaSuffix = DirMetaFileSuffix - } - - if oc.source == OneDriveSource { - // Fetch metadata for the file - for i := 1; i <= maxRetries; i++ { - if !oc.ctrl.ToggleFeatures.EnablePermissionsBackup { - // We are still writing the metadata file but with - // empty permissions as we don't have a way to - // signify that the permissions was explicitly - // not added. - itemMeta = io.NopCloser(strings.NewReader("{}")) - itemMetaSize = 2 - - break - } - - itemMeta, itemMetaSize, err = oc.itemMetaReader(ctx, oc.service, oc.driveID, item) - - // retry on Timeout type errors, break otherwise. - if err == nil || - !graph.IsErrTimeout(err) || - !graph.IsInternalServerError(err) { - break - } - - if i < maxRetries { - time.Sleep(1 * time.Second) - } - } - - if err != nil { - errUpdater(*item.GetId(), errors.Wrap(err, "failed to get item permissions")) - return - } - } - - switch oc.source { - case SharePointSource: - itemInfo.SharePoint = sharePointItemInfo(item, itemSize) - itemInfo.SharePoint.ParentPath = parentPathString - default: - itemInfo.OneDrive = oneDriveItemInfo(item, itemSize) - itemInfo.OneDrive.ParentPath = parentPathString - } - - if isFile { - dataSuffix := "" - if oc.source == OneDriveSource { - dataSuffix = DataFileSuffix - } - - // Construct a new lazy readCloser to feed to the collection consumer. - // This ensures that downloads won't be attempted unless that consumer - // attempts to read bytes. Assumption is that kopia will check things - // like file modtimes before attempting to read. - itemReader := lazy.NewLazyReadCloser(func() (io.ReadCloser, error) { - // Read the item - var ( - itemData io.ReadCloser - err error - ) - - for i := 1; i <= maxRetries; i++ { - _, itemData, err = oc.itemReader(oc.itemClient, item) - if err == nil { - break - } - - if graph.IsErrUnauthorized(err) { - // assume unauthorized requests are a sign of an expired - // jwt token, and that we've overrun the available window - // to download the actual file. Re-downloading the item - // will refresh that download url. - di, diErr := getDriveItem(ctx, oc.service, oc.driveID, itemID) - if diErr != nil { - err = errors.Wrap(diErr, "retrieving expired item") - break - } - - item = di - - continue - - } else if !graph.IsErrTimeout(err) && - !graph.IsInternalServerError(err) { - // Don't retry for non-timeout, on-unauth, as - // we are already retrying it in the default - // retry middleware - break - } - - if i < maxRetries { - time.Sleep(1 * time.Second) - } - } - - // check for errors following retries - if err != nil { - errUpdater(itemID, err) - return nil, err - } - - // display/log the item download - progReader, closer := observe.ItemProgress( - ctx, - itemData, - observe.ItemBackupMsg, - observe.PII(itemName+dataSuffix), - itemSize, - ) - go closer() - - return progReader, nil - }) - - oc.data <- &Item{ - id: itemName + dataSuffix, - data: itemReader, - info: itemInfo, - } - } - - if oc.source == OneDriveSource { - metaReader := lazy.NewLazyReadCloser(func() (io.ReadCloser, error) { - progReader, closer := observe.ItemProgress( - ctx, itemMeta, observe.ItemBackupMsg, - observe.PII(itemName+metaSuffix), int64(itemMetaSize)) - go closer() - return progReader, nil - }) - - oc.data <- &Item{ - id: itemName + metaSuffix, - data: metaReader, - info: itemInfo, - } - } - - // Item read successfully, add to collection - if isFile { - atomic.AddInt64(&itemsRead, 1) - } else { - atomic.AddInt64(&dirsRead, 1) - } - - // byteCount iteration - atomic.AddInt64(&byteCount, itemSize) - - folderProgress <- struct{}{} - }(item) + // fetch the item, and stream it into the collection's data channel + go oc.streamItem( + ctx, + &wg, + semaphoreCh, + folderProgress, + errUpdater, + countUpdater, + item, + parentPathString) } wg.Wait() - oc.reportAsCompleted(ctx, int(itemsFound), int(itemsRead), byteCount, errs) + oc.reportAsCompleted(ctx, itemsFound, itemsRead, byteCount, errs) } -func (oc *Collection) reportAsCompleted(ctx context.Context, itemsFound, itemsRead int, byteCount int64, errs error) { +func (oc *Collection) streamItem( + ctx context.Context, + wg *sync.WaitGroup, + semaphore <-chan struct{}, + progress chan<- struct{}, + errUpdater func(string, error), + countUpdater func(int64, int64, int64, int64, int64), + item models.DriveItemable, + parentPath string, +) { + defer wg.Done() + defer func() { <-semaphore }() + + var ( + itemID = *item.GetId() + itemName = *item.GetName() + itemSize = *item.GetSize() + isFile = item.GetFile() != nil + + itemInfo details.ItemInfo + itemMeta io.ReadCloser + + dataSuffix string + metaSuffix string + + dirsFound int64 + itemsFound int64 + dirsRead int64 + itemsRead int64 + itemMetaSize int + + err error + ) + + if isFile { + itemsFound++ + itemsRead++ + metaSuffix = MetaFileSuffix + } else { + dirsFound++ + dirsRead++ + metaSuffix = DirMetaFileSuffix + } + + switch oc.source { + case SharePointSource: + itemInfo.SharePoint = sharePointItemInfo(item, itemSize) + itemInfo.SharePoint.ParentPath = parentPath + default: + itemInfo.OneDrive = oneDriveItemInfo(item, itemSize) + itemInfo.OneDrive.ParentPath = parentPath + dataSuffix = DataFileSuffix + } + + // directory handling + if !isFile { + // Construct a new lazy readCloser to feed to the collection consumer. + // This ensures that downloads won't be attempted unless that consumer + // attempts to read bytes. Assumption is that kopia will check things + // like file modtimes before attempting to read. + metaReader := lazy.NewLazyReadCloser(func() (io.ReadCloser, error) { + if oc.source == OneDriveSource { + itemMeta, itemMetaSize, err = getItemMeta( + ctx, + oc.service, + oc.driveID, + item, + maxRetries, + oc.ctrl.ToggleFeatures.EnablePermissionsBackup, + oc.itemMetaReader) + if err != nil { + errUpdater(itemID, err) + return nil, err + } + } + + progReader, closer := observe.ItemProgress( + ctx, itemMeta, observe.ItemBackupMsg, + observe.PII(itemName+metaSuffix), int64(itemMetaSize)) + go closer() + return progReader, nil + }) + } else { + // Construct a new lazy readCloser to feed to the collection consumer. + // This ensures that downloads won't be attempted unless that consumer + // attempts to read bytes. Assumption is that kopia will check things + // like file modtimes before attempting to read. + lazyData := lazy.NewLazyReadCloser(func() (io.ReadCloser, error) { + // Read the item + var ( + itemData io.ReadCloser + err error + ) + + itemData, item, err = readDriveItem( + ctx, + oc.service, + oc.itemClient, + oc.driveID, itemID, + item, + oc.itemReader, + oc.itemGetter) + if err != nil { + errUpdater(itemID, err) + return nil, err + } + + // display/log the item download + progReader, closer := observe.ItemProgress( + ctx, + itemData, + observe.ItemBackupMsg, + observe.PII(itemName+dataSuffix), + itemSize) + go closer() + + return progReader, nil + }) + } + + // Item read successfully, record its addition. + // + // Note: this can cause inaccurate counts. Right now it counts all + // the items we intend to read. Errors within the lazy readCloser + // will create a conflict: an item is both successful and erroneous. + // But the async control to fix that is more error-prone than helpful. + // + // TODO: transform this into a stats bus so that async control of stats + // aggregation is handled at the backup level, not at the item iteration + // level. + countUpdater(itemSize, dirsFound, itemsFound, dirsRead, itemsRead) + + if hasMeta { + oc.data <- &Item{ + id: itemName + metaSuffix, + data: metaReader, + info: itemInfo, + } + } + + // stream the item to the data consumer. + oc.data <- &Item{ + id: itemName + dataSuffix, + data: lazyData, + info: itemInfo, + } + + progress <- struct{}{} +} + +func getItemMeta( + ctx context.Context, + service graph.Servicer, + driveID string, + item models.DriveItemable, + maxRetries int, + enablePermissionsBackup bool, + read itemMetaReaderFunc, +) (io.ReadCloser, int, error) { + var ( + rc io.ReadCloser + size int + ) + + for i := 1; i <= maxRetries; i++ { + if !enablePermissionsBackup { + // We are still writing the metadata file but with + // empty permissions as we don't have a way to + // signify that the permissions was explicitly + // not added. + return io.NopCloser(strings.NewReader("{}")), 2, nil + } + + var err error + rc, size, err = read(ctx, service, driveID, item) + if err == nil || + !graph.IsErrTimeout(err) || + !graph.IsInternalServerError(err) { + return nil, 0, errors.Wrap(err, "getting item metadata") + } + + if i < maxRetries { + time.Sleep(1 * time.Second) + } + } + + return rc, size, nil +} + +func readDriveItem( + ctx context.Context, + service graph.Servicer, + itemClient *http.Client, + driveID, itemID string, + original models.DriveItemable, + read itemReaderFunc, + get itemGetterFunc, +) (io.ReadCloser, models.DriveItemable, error) { + var ( + err error + rc io.ReadCloser + item = original + ) + + for i := 0; i < maxRetries; i++ { + _, rc, err = read(itemClient, item) + if err == nil { + return nil, nil, errors.Wrap(err, "reading drive item") + } + + if graph.IsErrUnauthorized(err) { + // assume unauthorized requests are a sign of an expired + // jwt token, and that we've overrun the available window + // to download the actual file. Re-downloading the item + // will refresh that download url. + di, diErr := get(ctx, service, driveID, itemID) + if diErr != nil { + return nil, nil, errors.Wrap(diErr, "retrieving item to refresh download url") + } + + item = di + + continue + + } else if !graph.IsErrTimeout(err) && + !graph.IsInternalServerError(err) { + // for all non-timeout, non-internal errors, do not retry + break + } + + if i < maxRetries { + time.Sleep(1 * time.Second) + } + } + + return rc, item, err +} + +func (oc *Collection) reportAsCompleted(ctx context.Context, itemsRead, itemsFound, byteCount int64, errs error) { close(oc.data) status := support.CreateStatus(ctx, support.Backup, 1, // num folders (always 1) support.CollectionMetrics{ - Objects: itemsFound, // items to read, - Successes: itemsRead, // items read successfully, - TotalBytes: byteCount, // Number of bytes read in the operation, + Objects: int(itemsFound), // items to read, + Successes: int(itemsRead), // items read successfully, + TotalBytes: byteCount, // Number of bytes read in the operation, }, errs, oc.folderPath.Folder(), // Additional details diff --git a/src/internal/connector/onedrive/collection_test.go b/src/internal/connector/onedrive/collection_test.go index 734009d72..c80f2ca4e 100644 --- a/src/internal/connector/onedrive/collection_test.go +++ b/src/internal/connector/onedrive/collection_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - msgraphsdk "github.com/microsoftgraph/msgraph-sdk-go" + "github.com/hashicorp/go-multierror" "github.com/microsoftgraph/msgraph-sdk-go/models" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,6 +20,7 @@ import ( "github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/support" "github.com/alcionai/corso/src/internal/data" + "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/pkg/backup/details" "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/path" @@ -29,17 +30,6 @@ type CollectionUnitTestSuite struct { suite.Suite } -// Allows `*CollectionUnitTestSuite` to be used as a graph.Servicer -// TODO: Implement these methods - -func (suite *CollectionUnitTestSuite) Client() *msgraphsdk.GraphServiceClient { - return nil -} - -func (suite *CollectionUnitTestSuite) Adapter() *msgraphsdk.GraphRequestAdapter { - return nil -} - func TestCollectionUnitTestSuite(t *testing.T) { suite.Run(t, new(CollectionUnitTestSuite)) } @@ -165,7 +155,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() { graph.HTTPClient(graph.NoTimeout()), folderPath, "drive-id", - suite, + &MockGraphService{}, suite.testStatusUpdater(&wg, &collStatus), test.source, control.Options{ToggleFeatures: control.Toggles{EnablePermissionsBackup: true}}) @@ -298,7 +288,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadError() { graph.HTTPClient(graph.NoTimeout()), folderPath, "fakeDriveID", - suite, + &MockGraphService{}, suite.testStatusUpdater(&wg, &collStatus), test.source, control.Options{ToggleFeatures: control.Toggles{EnablePermissionsBackup: true}}) @@ -422,3 +412,200 @@ func (suite *CollectionUnitTestSuite) TestCollectionDisablePermissionsBackup() { }) } } + +func (suite *CollectionUnitTestSuite) TestStreamItem() { + var ( + id = "id" + name = "name" + size int64 = 42 + now = time.Now() + ) + + mockItem := models.NewDriveItem() + mockItem.SetId(&id) + mockItem.SetName(&name) + mockItem.SetSize(&size) + mockItem.SetCreatedDateTime(&now) + mockItem.SetLastModifiedDateTime(&now) + + mockReader := func(v string, e error) itemReaderFunc { + return func(*http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + return details.ItemInfo{}, io.NopCloser(strings.NewReader(v)), e + } + } + + mockGetter := func(e error) itemGetterFunc { + return func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) { + return mockItem, e + } + } + + mockDataChan := func() chan data.Stream { + return make(chan data.Stream, 1) + } + + table := []struct { + name string + coll *Collection + expectData string + errsIs func(*testing.T, error, int) + readErrIs func(*testing.T, error) + }{ + { + name: "happy", + expectData: "happy", + coll: &Collection{ + data: mockDataChan(), + itemReader: mockReader("happy", nil), + itemGetter: mockGetter(nil), + }, + errsIs: func(t *testing.T, e error, count int) { + assert.NoError(t, e, "no errors") + assert.Zero(t, count, "zero errors") + }, + readErrIs: func(t *testing.T, e error) { + assert.NoError(t, e, "no reader error") + }, + }, + { + name: "reader err", + expectData: "", + coll: &Collection{ + data: mockDataChan(), + itemReader: mockReader("foo", assert.AnError), + itemGetter: mockGetter(nil), + }, + errsIs: func(t *testing.T, e error, count int) { + assert.ErrorIs(t, e, assert.AnError) + assert.Equal(t, 1, count, "one errors") + }, + readErrIs: func(t *testing.T, e error) { + assert.Error(t, e, "basic error") + }, + }, + { + name: "iteration err", + expectData: "", + coll: &Collection{ + data: mockDataChan(), + itemReader: mockReader("foo", graph.Err401Unauthorized), + itemGetter: mockGetter(assert.AnError), + }, + errsIs: func(t *testing.T, e error, count int) { + assert.True(t, graph.IsErrUnauthorized(e), "is unauthorized error") + assert.ErrorIs(t, e, graph.Err401Unauthorized) + assert.Equal(t, 2, count, "count of errors aggregated") + }, + readErrIs: func(t *testing.T, e error) { + assert.True(t, graph.IsErrUnauthorized(e), "is unauthorized error") + assert.ErrorIs(t, e, graph.Err401Unauthorized) + }, + }, + { + name: "timeout errors", + expectData: "", + coll: &Collection{ + data: mockDataChan(), + itemReader: mockReader("foo", context.DeadlineExceeded), + itemGetter: mockGetter(nil), + }, + errsIs: func(t *testing.T, e error, count int) { + assert.True(t, graph.IsErrTimeout(e), "is timeout error") + assert.ErrorIs(t, e, context.DeadlineExceeded) + assert.Equal(t, 1, count, "one errors") + }, + readErrIs: func(t *testing.T, e error) { + assert.True(t, graph.IsErrTimeout(e), "is timeout error") + assert.ErrorIs(t, e, context.DeadlineExceeded) + }, + }, + { + name: "throttled errors", + expectData: "", + coll: &Collection{ + data: mockDataChan(), + itemReader: mockReader("foo", graph.Err429TooManyRequests), + itemGetter: mockGetter(nil), + }, + errsIs: func(t *testing.T, e error, count int) { + assert.True(t, graph.IsErrThrottled(e), "is throttled error") + assert.ErrorIs(t, e, graph.Err429TooManyRequests) + assert.Equal(t, 1, count, "one errors") + }, + readErrIs: func(t *testing.T, e error) { + assert.True(t, graph.IsErrThrottled(e), "is throttled error") + assert.ErrorIs(t, e, graph.Err429TooManyRequests) + }, + }, + { + name: "service unavailable errors", + expectData: "", + coll: &Collection{ + data: mockDataChan(), + itemReader: mockReader("foo", graph.Err503ServiceUnavailable), + itemGetter: mockGetter(nil), + }, + errsIs: func(t *testing.T, e error, count int) { + assert.True(t, graph.IsSericeUnavailable(e), "is unavailable error") + assert.ErrorIs(t, e, graph.Err503ServiceUnavailable) + assert.Equal(t, 1, count, "one errors") + }, + readErrIs: func(t *testing.T, e error) { + assert.True(t, graph.IsSericeUnavailable(e), "is unavailable error") + assert.ErrorIs(t, e, graph.Err503ServiceUnavailable) + }, + }, + } + for _, test := range table { + suite.T().Run(test.name, func(t *testing.T) { + ctx, flush := tester.NewContext() + defer flush() + + var ( + wg sync.WaitGroup + errs error + errCount int + size int64 + + countUpdater = func(sz int64) { size = sz } + errUpdater = func(s string, e error) { + errs = multierror.Append(errs, e) + errCount++ + } + + semaphore = make(chan struct{}, 1) + progress = make(chan struct{}, 1) + ) + + wg.Add(1) + semaphore <- struct{}{} + + go test.coll.streamItem( + ctx, + &wg, + semaphore, + progress, + errUpdater, + countUpdater, + mockItem, + "parentPath", + ) + + // wait for the func to run + wg.Wait() + + assert.Zero(t, len(semaphore), "semaphore was released") + assert.NotNil(t, <-progress, "progress was communicated") + assert.NotZero(t, size, "countUpdater was called") + + data, ok := <-test.coll.data + assert.True(t, ok, "data channel survived") + + bs, err := io.ReadAll(data.ToReader()) + + test.readErrIs(t, err) + test.errsIs(t, errs, errCount) + assert.Equal(t, test.expectData, string(bs), "streamed item bytes") + }) + } +}