diff --git a/src/internal/connector/onedrive/collections.go b/src/internal/connector/onedrive/collections.go index cc27f4fb2..aabee1bcf 100644 --- a/src/internal/connector/onedrive/collections.go +++ b/src/internal/connector/onedrive/collections.go @@ -96,10 +96,7 @@ type Collections struct { resourceOwner string, fields []string, ) (api.DrivePager, error) - itemPagerFunc func( - servicer graph.Servicer, - driveID, link string, - ) itemPager + itemPagerFunc driveItemPagerFunc servicePathPfxFunc pathPrefixerFunc // Track stats from drive enumeration. Represents the items backed up. diff --git a/src/internal/connector/onedrive/drive.go b/src/internal/connector/onedrive/drive.go index e75b38cd9..2e0017ace 100644 --- a/src/internal/connector/onedrive/drive.go +++ b/src/internal/connector/onedrive/drive.go @@ -88,6 +88,11 @@ type itemCollector func( errs *fault.Bus, ) error +type driveItemPagerFunc func( + servicer graph.Servicer, + driveID, link string, +) itemPager + type itemPager interface { GetPage(context.Context) (api.DeltaPageLinker, error) SetNext(nextLink string) diff --git a/src/internal/connector/onedrive/url_cache.go b/src/internal/connector/onedrive/url_cache.go new file mode 100644 index 000000000..23c8f84f1 --- /dev/null +++ b/src/internal/connector/onedrive/url_cache.go @@ -0,0 +1,263 @@ +package onedrive + +import ( + "context" + "sync" + "time" + + "github.com/alcionai/clues" + "github.com/microsoftgraph/msgraph-sdk-go/models" + + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/internal/connector/graph" + "github.com/alcionai/corso/src/pkg/fault" + "github.com/alcionai/corso/src/pkg/logger" +) + +type itemProps struct { + downloadURL string + isDeleted bool +} + +// urlCache caches download URLs for drive items +type urlCache struct { + driveID string + idToProps map[string]itemProps + lastRefreshTime time.Time + refreshInterval time.Duration + // cacheMu protects idToProps and lastRefreshTime + cacheMu sync.RWMutex + // refreshMu serializes cache refresh attempts by potential writers + refreshMu sync.Mutex + deltaQueryCount int + + svc graph.Servicer + itemPagerFunc driveItemPagerFunc + + errors *fault.Bus +} + +// newURLache creates a new URL cache for the specified drive ID +func newURLCache( + driveID string, + refreshInterval time.Duration, + svc graph.Servicer, + errors *fault.Bus, + itemPagerFunc driveItemPagerFunc, +) (*urlCache, error) { + err := validateCacheParams( + driveID, + refreshInterval, + svc, + itemPagerFunc) + if err != nil { + return nil, clues.Wrap(err, "cache params") + } + + return &urlCache{ + idToProps: make(map[string]itemProps), + lastRefreshTime: time.Time{}, + driveID: driveID, + refreshInterval: refreshInterval, + svc: svc, + itemPagerFunc: itemPagerFunc, + errors: errors, + }, + nil +} + +// validateCacheParams validates input params +func validateCacheParams( + driveID string, + refreshInterval time.Duration, + svc graph.Servicer, + itemPagerFunc driveItemPagerFunc, +) error { + if len(driveID) == 0 { + return clues.New("drive id is empty") + } + + if refreshInterval <= 1*time.Second { + return clues.New("invalid refresh interval") + } + + if svc == nil { + return clues.New("nil graph servicer") + } + + if itemPagerFunc == nil { + return clues.New("nil item pager") + } + + return nil +} + +// getItemProps returns the item properties for the specified drive item ID +func (uc *urlCache) getItemProperties( + ctx context.Context, + itemID string, +) (itemProps, error) { + if len(itemID) == 0 { + return itemProps{}, clues.New("item id is empty") + } + + ctx = clues.Add(ctx, "drive_id", uc.driveID) + + // Lazy refresh + if uc.needsRefresh() { + err := uc.refreshCache(ctx) + if err != nil { + return itemProps{}, err + } + } + + props, err := uc.readCache(ctx, itemID) + if err != nil { + return itemProps{}, err + } + + return props, nil +} + +// needsRefresh returns true if the cache is empty or if refresh interval has +// elapsed +func (uc *urlCache) needsRefresh() bool { + uc.cacheMu.RLock() + defer uc.cacheMu.RUnlock() + + return len(uc.idToProps) == 0 || + time.Since(uc.lastRefreshTime) > uc.refreshInterval +} + +// refreshCache refreshes the URL cache by performing a delta query. +func (uc *urlCache) refreshCache( + ctx context.Context, +) error { + // Acquire mutex to prevent multiple threads from refreshing the + // cache at the same time + uc.refreshMu.Lock() + defer uc.refreshMu.Unlock() + + // If the cache was refreshed by another thread while we were waiting + // to acquire mutex, return + if !uc.needsRefresh() { + return nil + } + + // Hold cache lock in write mode for the entire duration of the refresh. + // This is to prevent other threads from reading the cache while it is + // being updated page by page + uc.cacheMu.Lock() + defer uc.cacheMu.Unlock() + + // Issue a delta query to graph + logger.Ctx(ctx).Info("refreshing url cache") + + err := uc.deltaQuery(ctx) + if err != nil { + return err + } + + logger.Ctx(ctx).Info("url cache refreshed") + + // Update last refresh time + uc.lastRefreshTime = time.Now() + + return nil +} + +// deltaQuery performs a delta query on the drive and update the cache +func (uc *urlCache) deltaQuery( + ctx context.Context, +) error { + logger.Ctx(ctx).Debug("starting delta query") + + _, _, _, err := collectItems( + ctx, + uc.itemPagerFunc(uc.svc, uc.driveID, ""), + uc.driveID, + "", + uc.updateCache, + map[string]string{}, + "", + uc.errors) + if err != nil { + return clues.Wrap(err, "delta query") + } + + uc.deltaQueryCount++ + + return nil +} + +// readCache returns the item properties for the specified item +func (uc *urlCache) readCache( + ctx context.Context, + itemID string, +) (itemProps, error) { + uc.cacheMu.RLock() + defer uc.cacheMu.RUnlock() + + ctx = clues.Add(ctx, "item_id", itemID) + + props, ok := uc.idToProps[itemID] + if !ok { + return itemProps{}, clues.New("item not found in cache").WithClues(ctx) + } + + return props, nil +} + +// updateCache consumes a slice of drive items and updates the url cache. +// It assumes that cacheMu is held by caller in write mode +func (uc *urlCache) updateCache( + ctx context.Context, + _, _ string, + items []models.DriveItemable, + _ map[string]string, + _ map[string]string, + _ map[string]struct{}, + _ map[string]map[string]string, + _ bool, + errs *fault.Bus, +) error { + el := errs.Local() + + for _, item := range items { + if el.Failure() != nil { + break + } + + // Skip if not a file + if item.GetFile() == nil { + continue + } + + var url string + + for _, key := range downloadURLKeys { + tmp, ok := item.GetAdditionalData()[key].(*string) + if ok { + url = ptr.Val(tmp) + break + } + } + + itemID := ptr.Val(item.GetId()) + + uc.idToProps[itemID] = itemProps{ + downloadURL: url, + isDeleted: false, + } + + // Mark deleted items in cache + if item.GetDeleted() != nil { + uc.idToProps[itemID] = itemProps{ + downloadURL: "", + isDeleted: true, + } + } + } + + return el.Failure() +} diff --git a/src/internal/connector/onedrive/url_cache_test.go b/src/internal/connector/onedrive/url_cache_test.go new file mode 100644 index 000000000..987ed77d9 --- /dev/null +++ b/src/internal/connector/onedrive/url_cache_test.go @@ -0,0 +1,162 @@ +package onedrive + +import ( + "net/http" + "sync" + "testing" + "time" + + "github.com/alcionai/clues" + "github.com/microsoftgraph/msgraph-sdk-go/models" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/common/dttm" + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/internal/connector/graph" + "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/pkg/fault" + "github.com/alcionai/corso/src/pkg/logger" + "github.com/alcionai/corso/src/pkg/services/m365/api" +) + +type URLCacheIntegrationSuite struct { + tester.Suite + service graph.Servicer + user string + driveID string +} + +func TestURLCacheIntegrationSuite(t *testing.T) { + suite.Run(t, &URLCacheIntegrationSuite{ + Suite: tester.NewIntegrationSuite( + t, + [][]string{tester.M365AcctCredEnvs}), + }) +} + +func (suite *URLCacheIntegrationSuite) SetupSuite() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + suite.service = loadTestService(t) + suite.user = tester.SecondaryM365UserID(t) + + pager, err := PagerForSource(OneDriveSource, suite.service, suite.user, nil) + require.NoError(t, err, clues.ToCore(err)) + + odDrives, err := api.GetAllDrives(ctx, pager, true, maxDrivesRetries) + require.NoError(t, err, clues.ToCore(err)) + require.Greaterf(t, len(odDrives), 0, "user %s does not have a drive", suite.user) + suite.driveID = ptr.Val(odDrives[0].GetId()) +} + +// Basic test for urlCache. Create some files in onedrive, then access them via +// url cache +func (suite *URLCacheIntegrationSuite) TestURLCacheBasic() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + svc := suite.service + driveID := suite.driveID + + // Create a new test folder + root, err := svc.Client().Drives().ByDriveId(driveID).Root().Get(ctx, nil) + require.NoError(t, err, clues.ToCore(err)) + + newFolderName := tester.DefaultTestRestoreDestination("folder").ContainerName + + newFolder, err := CreateItem( + ctx, + svc, + driveID, + ptr.Val(root.GetId()), + newItem(newFolderName, true)) + require.NoError(t, err, clues.ToCore(err)) + require.NotNil(t, newFolder.GetId()) + + // Delete folder on exit + defer func() { + ictx := clues.Add(ctx, "folder_id", ptr.Val(newFolder.GetId())) + + err := api.DeleteDriveItem( + ictx, + loadTestService(t), + driveID, + ptr.Val(newFolder.GetId())) + if err != nil { + logger.CtxErr(ictx, err).Errorw("deleting folder") + } + }() + + // Create a bunch of files in the new folder + var items []models.DriveItemable + + for i := 0; i < 10; i++ { + newItemName := "testItem_" + dttm.FormatNow(dttm.SafeForTesting) + + item, err := CreateItem( + ctx, + svc, + driveID, + ptr.Val(newFolder.GetId()), + newItem(newItemName, false)) + if err != nil { + // Something bad happened, skip this item + continue + } + + items = append(items, item) + } + + // Create a new URL cache with a long TTL + cache, err := newURLCache( + suite.driveID, + 1*time.Hour, + svc, + fault.New(true), + defaultItemPager) + + require.NoError(t, err, clues.ToCore(err)) + + // Launch parallel requests to the cache, one per item + var wg sync.WaitGroup + for i := 0; i < len(items); i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + // Read item from URL cache + props, err := cache.getItemProperties( + ctx, + ptr.Val(items[i].GetId())) + + require.NoError(t, err, clues.ToCore(err)) + require.NotNil(t, props) + require.NotEmpty(t, props.downloadURL) + require.Equal(t, false, props.isDeleted) + + // Validate download URL + c := graph.NewNoTimeoutHTTPWrapper() + + resp, err := c.Request( + ctx, + http.MethodGet, + props.downloadURL, + nil, + nil) + + require.NoError(t, err, clues.ToCore(err)) + require.Equal(t, http.StatusOK, resp.StatusCode) + }(i) + } + wg.Wait() + + // Validate that <= 1 delta queries were made + require.LessOrEqual(t, cache.deltaQueryCount, 1) +}