lookup user default drive for onedrive

instead of enumerating all drives during backup, if we're
processing a onedrive source, opt to look up only the default
drive instead of enumerating all drives.
This commit is contained in:
ryanfkeepers 2023-05-05 14:32:35 -06:00
parent ef2083bc20
commit aadfb1a8d5
8 changed files with 104 additions and 36 deletions

View File

@ -112,7 +112,7 @@ func runDisplayM365JSON(
creds account.M365Config, creds account.M365Config,
user, itemID string, user, itemID string,
) error { ) error {
drive, err := api.GetDriveByID(ctx, srv, user) drive, err := api.GetUsersDefaultDrive(ctx, srv, user)
if err != nil { if err != nil {
return err return err
} }

View File

@ -166,7 +166,7 @@ func purgeOneDriveFolders(
return nil, err return nil, err
} }
cfs, err := onedrive.GetAllFolders(ctx, gs, pager, prefix, fault.New(true)) cfs, err := onedrive.GetAllFolders(ctx, gs, "", onedrive.OneDriveSource, pager, prefix, fault.New(true))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -109,6 +109,13 @@ type userDrivePager struct {
options *users.ItemDrivesRequestBuilderGetRequestConfiguration options *users.ItemDrivesRequestBuilderGetRequestConfiguration
} }
// NewUserDrivePager produces a pager for getting all of a user's drives.
// Use with caution: users *can* have multiple drives, even though onedrive
// docs say that they cannot. This can be manufactured by force, but the
// more likely context is if microsoft is using the second drive behind the
// scenes as an eventually-consistent copy for preservation and recovery.
// Since corso generally only handles the user's default drive you probably
// want to call GetUsersDefaultDrive instead of paging over all of them.
func NewUserDrivePager( func NewUserDrivePager(
gs graph.Servicer, gs graph.Servicer,
userID string, userID string,
@ -211,7 +218,10 @@ type DrivePager interface {
ValuesIn(api.PageLinker) ([]models.Driveable, error) ValuesIn(api.PageLinker) ([]models.Driveable, error)
} }
// GetAllDrives fetches all drives for the given pager // GetAllDrives fetches all drives for the given pager.
// If you're using this to enumerate a User's dries, first check whether
// you need GetUsersDefaultDrive instead. In most cases, we don't want to
// track all of the drives in a user's onedrive; just the default one.
func GetAllDrives( func GetAllDrives(
ctx context.Context, ctx context.Context,
pager DrivePager, pager DrivePager,
@ -308,18 +318,18 @@ func GetItemPermission(
return perm, nil return perm, nil
} }
func GetDriveByID( func GetUsersDefaultDrive(
ctx context.Context, ctx context.Context,
srv graph.Servicer, srv graph.Servicer,
userID string, user string,
) (models.Driveable, error) { ) (models.Driveable, error) {
//revive:enable:context-as-argument //revive:enable:context-as-argument
d, err := srv.Client(). d, err := srv.Client().
UsersById(userID). UsersById(user).
Drive(). Drive().
Get(ctx, nil) Get(ctx, nil)
if err != nil { if err != nil {
return nil, graph.Wrap(ctx, err, "getting drive") return nil, graph.Wrap(ctx, err, "getting user's default drive")
} }
return d, nil return d, nil

View File

@ -69,6 +69,13 @@ type folderMatcher interface {
Matches(string) bool Matches(string) bool
} }
type drivePagerFunc func(
source driveSource,
servicer graph.Servicer,
resourceOwner string,
fields []string,
) (api.DrivePager, error)
// Collections is used to retrieve drive data for a // Collections is used to retrieve drive data for a
// resource owner, which can be either a user or a sharepoint site. // resource owner, which can be either a user or a sharepoint site.
type Collections struct { type Collections struct {
@ -91,7 +98,7 @@ type Collections struct {
// Not the most ideal, but allows us to change the pager function for testing // Not the most ideal, but allows us to change the pager function for testing
// as needed. This will allow us to mock out some scenarios during testing. // as needed. This will allow us to mock out some scenarios during testing.
drivePagerFunc func( dpf func(
source driveSource, source driveSource,
servicer graph.Servicer, servicer graph.Servicer,
resourceOwner string, resourceOwner string,
@ -119,17 +126,17 @@ func NewCollections(
ctrlOpts control.Options, ctrlOpts control.Options,
) *Collections { ) *Collections {
return &Collections{ return &Collections{
itemClient: itemClient, itemClient: itemClient,
tenant: tenant, tenant: tenant,
resourceOwner: resourceOwner, resourceOwner: resourceOwner,
source: source, source: source,
matcher: matcher, matcher: matcher,
CollectionMap: map[string]map[string]*Collection{}, CollectionMap: map[string]map[string]*Collection{},
drivePagerFunc: PagerForSource, dpf: PagerForSource,
itemPagerFunc: defaultItemPager, itemPagerFunc: defaultItemPager,
service: service, service: service,
statusUpdater: statusUpdater, statusUpdater: statusUpdater,
ctrl: ctrlOpts, ctrl: ctrlOpts,
} }
} }
@ -285,14 +292,19 @@ func (c *Collections) Get(
defer close(driveComplete) defer close(driveComplete)
// Enumerate drives for the specified resourceOwner // Enumerate drives for the specified resourceOwner
pager, err := c.drivePagerFunc(c.source, c.service, c.resourceOwner, nil) pager, err := c.dpf(c.source, c.service, c.resourceOwner, nil)
if err != nil { if err != nil {
return nil, graph.Stack(ctx, err) return nil, graph.Stack(ctx, err)
} }
drives, err := api.GetAllDrives(ctx, pager, true, maxDrivesRetries) drives, err := getDrivesBySource(
ctx,
c.service,
c.resourceOwner,
c.source,
pager)
if err != nil { if err != nil {
return nil, err return nil, clues.Wrap(err, "enumerating drives")
} }
var ( var (
@ -924,3 +936,32 @@ func updatePath(paths map[string]string, id, newPath string) {
paths[folderID] = strings.Replace(p, oldPath, newPath, 1) paths[folderID] = strings.Replace(p, oldPath, newPath, 1)
} }
} }
// gets either the user's default drive (if source is onedrive) or
// enumerates all drives for the provided pager.
func getDrivesBySource(
ctx context.Context,
gs graph.Servicer,
resourceOwner string,
source driveSource,
adp api.DrivePager,
) ([]models.Driveable, error) {
// onedrive users *can* have multiple drives, but we want to ignore all
// except the default drive.
switch source {
case OneDriveSource:
dd, err := api.GetUsersDefaultDrive(ctx, gs, resourceOwner)
if err != nil {
return nil, err
}
return []models.Driveable{dd}, nil
default:
drives, err := api.GetAllDrives(ctx, adp, true, maxDrivesRetries)
if err != nil {
return nil, err
}
return drives, nil
}
}

View File

@ -2242,7 +2242,7 @@ func (suite *OneDriveCollectionsUnitSuite) TestGet() {
func(*support.ConnectorOperationStatus) {}, func(*support.ConnectorOperationStatus) {},
control.Options{ToggleFeatures: control.Toggles{}}, control.Options{ToggleFeatures: control.Toggles{}},
) )
c.drivePagerFunc = drivePagerFunc c.dpf = drivePagerFunc
c.itemPagerFunc = itemPagerFunc c.itemPagerFunc = itemPagerFunc
prevDelta := "prev-delta" prevDelta := "prev-delta"

View File

@ -234,19 +234,21 @@ func (op *Displayable) GetDisplayName() *string {
return op.GetName() return op.GetName()
} }
// GetAllFolders returns all folders in all drives for the given user. If a // GetAllFolders returns all folders in tracked drives for the given user. If a
// prefix is given, returns all folders with that prefix, regardless of if they // prefix is given, returns all folders with that prefix, regardless of if they
// are a subfolder or top-level folder in the hierarchy. // are a subfolder or top-level folder in the hierarchy.
func GetAllFolders( func GetAllFolders(
ctx context.Context, ctx context.Context,
gs graph.Servicer, gs graph.Servicer,
resourceOwner string,
source driveSource,
pager api.DrivePager, pager api.DrivePager,
prefix string, prefix string,
errs *fault.Bus, errs *fault.Bus,
) ([]*Displayable, error) { ) ([]*Displayable, error) {
drives, err := api.GetAllDrives(ctx, pager, true, maxDrivesRetries) drives, err := getDrivesBySource(ctx, gs, resourceOwner, source, pager)
if err != nil { if err != nil {
return nil, clues.Wrap(err, "getting OneDrive folders") return nil, clues.Wrap(err, "getting folders across all drives")
} }
var ( var (

View File

@ -308,10 +308,7 @@ func (suite *OneDriveSuite) TestCreateGetDeleteFolder() {
gs = loadTestService(t) gs = loadTestService(t)
) )
pager, err := PagerForSource(OneDriveSource, gs, suite.userID, nil) drives, err := getDrivesBySource(ctx, gs, suite.userID, OneDriveSource, nil)
require.NoError(t, err, clues.ToCore(err))
drives, err := api.GetAllDrives(ctx, pager, true, maxDrivesRetries)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
require.NotEmpty(t, drives) require.NotEmpty(t, drives)
@ -367,10 +364,14 @@ func (suite *OneDriveSuite) TestCreateGetDeleteFolder() {
suite.Run(test.name, func() { suite.Run(test.name, func() {
t := suite.T() t := suite.T()
pager, err := PagerForSource(OneDriveSource, gs, suite.userID, nil) allFolders, err := GetAllFolders(
require.NoError(t, err, clues.ToCore(err)) ctx,
gs,
allFolders, err := GetAllFolders(ctx, gs, pager, test.prefix, fault.New(true)) suite.userID,
OneDriveSource,
nil,
test.prefix,
fault.New(true))
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
foundFolderIDs := []string{} foundFolderIDs := []string{}

View File

@ -8,15 +8,29 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/graph"
"github.com/alcionai/corso/src/internal/connector/graph/mock"
"github.com/alcionai/corso/src/internal/connector/support" "github.com/alcionai/corso/src/internal/connector/support"
"github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/internal/tester"
"github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/account"
"github.com/alcionai/corso/src/pkg/logger"
) )
type MockGraphService struct{} type MockGraphService struct {
useMockClient bool
creds account.M365Config // only required if useMockClient=true
}
func (ms *MockGraphService) Client() *msgraphsdk.GraphServiceClient { func (ms *MockGraphService) Client() *msgraphsdk.GraphServiceClient {
return nil if !ms.useMockClient {
return nil
}
s, err := mock.NewService(ms.creds)
if err != nil {
logger.Ctx(nil).Error("mocking client", err)
}
return s.Client()
} }
func (ms *MockGraphService) Adapter() *msgraphsdk.GraphRequestAdapter { func (ms *MockGraphService) Adapter() *msgraphsdk.GraphRequestAdapter {