diff --git a/src/cmd/sanity_test/restore/exchange.go b/src/cmd/sanity_test/restore/exchange.go new file mode 100644 index 000000000..e484550de --- /dev/null +++ b/src/cmd/sanity_test/restore/exchange.go @@ -0,0 +1,219 @@ +package restore + +import ( + "context" + "fmt" + stdpath "path" + "strings" + "time" + + "github.com/alcionai/clues" + msgraphsdk "github.com/microsoftgraph/msgraph-sdk-go" + "github.com/microsoftgraph/msgraph-sdk-go/models" + "github.com/microsoftgraph/msgraph-sdk-go/users" + + "github.com/alcionai/corso/src/cmd/sanity_test/utils" + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/pkg/filters" +) + +// CheckEmailRestoration verifies that the emails count in restored folder is equivalent to +// emails in actual m365 account +func CheckEmailRestoration( + ctx context.Context, + client *msgraphsdk.GraphServiceClient, + testUser, folderName, dataFolder, baseBackupFolder string, + startTime time.Time, +) { + var ( + restoreFolder models.MailFolderable + itemCount = make(map[string]int32) + restoreItemCount = make(map[string]int32) + builder = client.Users().ByUserId(testUser).MailFolders() + ) + + for { + result, err := builder.Get(ctx, nil) + if err != nil { + utils.Fatal(ctx, "getting mail folders", err) + } + + values := result.GetValue() + + for _, v := range values { + itemName := ptr.Val(v.GetDisplayName()) + + if itemName == folderName { + restoreFolder = v + continue + } + + if itemName == dataFolder || itemName == baseBackupFolder { + // otherwise, recursively aggregate all child folders. + getAllMailSubFolders(ctx, client, testUser, v, itemName, dataFolder, itemCount) + + itemCount[itemName] = ptr.Val(v.GetTotalItemCount()) + } + } + + link, ok := ptr.ValOK(result.GetOdataNextLink()) + if !ok { + break + } + + builder = users.NewItemMailFoldersRequestBuilder(link, client.GetAdapter()) + } + + folderID := ptr.Val(restoreFolder.GetId()) + folderName = ptr.Val(restoreFolder.GetDisplayName()) + ctx = clues.Add( + ctx, + "restore_folder_id", folderID, + "restore_folder_name", folderName) + + childFolder, err := client. + Users(). + ByUserId(testUser). + MailFolders(). + ByMailFolderId(folderID). + ChildFolders(). + Get(ctx, nil) + if err != nil { + utils.Fatal(ctx, "getting restore folder child folders", err) + } + + for _, fld := range childFolder.GetValue() { + restoreDisplayName := ptr.Val(fld.GetDisplayName()) + + // check if folder is the data folder we loaded or the base backup to verify + // the incremental backup worked fine + if strings.EqualFold(restoreDisplayName, dataFolder) || strings.EqualFold(restoreDisplayName, baseBackupFolder) { + count, _ := ptr.ValOK(fld.GetTotalItemCount()) + + restoreItemCount[restoreDisplayName] = count + checkAllSubFolder(ctx, client, fld, testUser, restoreDisplayName, dataFolder, restoreItemCount) + } + } + + verifyEmailData(ctx, restoreItemCount, itemCount) +} + +func verifyEmailData(ctx context.Context, restoreMessageCount, messageCount map[string]int32) { + for fldName, expected := range messageCount { + got := restoreMessageCount[fldName] + + utils.Assert( + ctx, + func() bool { return expected == got }, + fmt.Sprintf("Restore item counts do not match: %s", fldName), + expected, + got) + } +} + +// getAllSubFolder will recursively check for all subfolders and get the corresponding +// email count. +func getAllMailSubFolders( + ctx context.Context, + client *msgraphsdk.GraphServiceClient, + testUser string, + r models.MailFolderable, + parentFolder, + dataFolder string, + messageCount map[string]int32, +) { + var ( + folderID = ptr.Val(r.GetId()) + count int32 = 99 + options = &users.ItemMailFoldersItemChildFoldersRequestBuilderGetRequestConfiguration{ + QueryParameters: &users.ItemMailFoldersItemChildFoldersRequestBuilderGetQueryParameters{ + Top: &count, + }, + } + ) + + ctx = clues.Add(ctx, "parent_folder_id", folderID) + + childFolder, err := client. + Users(). + ByUserId(testUser). + MailFolders(). + ByMailFolderId(folderID). + ChildFolders(). + Get(ctx, options) + if err != nil { + utils.Fatal(ctx, "getting mail subfolders", err) + } + + for _, child := range childFolder.GetValue() { + var ( + childDisplayName = ptr.Val(child.GetDisplayName()) + childFolderCount = ptr.Val(child.GetChildFolderCount()) + //nolint:forbidigo + fullFolderName = stdpath.Join(parentFolder, childDisplayName) + ) + + if filters.PathContains([]string{dataFolder}).Compare(fullFolderName) { + messageCount[fullFolderName] = ptr.Val(child.GetTotalItemCount()) + // recursively check for subfolders + if childFolderCount > 0 { + parentFolder := fullFolderName + + getAllMailSubFolders(ctx, client, testUser, child, parentFolder, dataFolder, messageCount) + } + } + } +} + +// checkAllSubFolder will recursively traverse inside the restore folder and +// verify that data matched in all subfolders +func checkAllSubFolder( + ctx context.Context, + client *msgraphsdk.GraphServiceClient, + r models.MailFolderable, + testUser, + parentFolder, + dataFolder string, + restoreMessageCount map[string]int32, +) { + var ( + folderID = ptr.Val(r.GetId()) + count int32 = 99 + options = &users.ItemMailFoldersItemChildFoldersRequestBuilderGetRequestConfiguration{ + QueryParameters: &users.ItemMailFoldersItemChildFoldersRequestBuilderGetQueryParameters{ + Top: &count, + }, + } + ) + + childFolder, err := client. + Users(). + ByUserId(testUser). + MailFolders(). + ByMailFolderId(folderID). + ChildFolders(). + Get(ctx, options) + if err != nil { + utils.Fatal(ctx, "getting mail subfolders", err) + } + + for _, child := range childFolder.GetValue() { + var ( + childDisplayName = ptr.Val(child.GetDisplayName()) + //nolint:forbidigo + fullFolderName = stdpath.Join(parentFolder, childDisplayName) + ) + + if filters.PathContains([]string{dataFolder}).Compare(fullFolderName) { + childTotalCount, _ := ptr.ValOK(child.GetTotalItemCount()) + restoreMessageCount[fullFolderName] = childTotalCount + } + + childFolderCount := ptr.Val(child.GetChildFolderCount()) + + if childFolderCount > 0 { + parentFolder := fullFolderName + checkAllSubFolder(ctx, client, child, testUser, parentFolder, dataFolder, restoreMessageCount) + } + } +} diff --git a/src/cmd/sanity_test/restore/onedrive.go b/src/cmd/sanity_test/restore/onedrive.go new file mode 100644 index 000000000..ca332d608 --- /dev/null +++ b/src/cmd/sanity_test/restore/onedrive.go @@ -0,0 +1,353 @@ +package restore + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/alcionai/clues" + msgraphsdk "github.com/microsoftgraph/msgraph-sdk-go" + "golang.org/x/exp/slices" + + "github.com/alcionai/corso/src/cmd/sanity_test/utils" + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/pkg/path" +) + +const ( + owner = "owner" +) + +type permissionInfo struct { + entityID string + roles []string +} + +func CheckOneDriveRestoration( + ctx context.Context, + client *msgraphsdk.GraphServiceClient, + userID, folderName, dataFolder string, + startTime time.Time, +) { + drive, err := client. + Users(). + ByUserId(userID). + Drive(). + Get(ctx, nil) + if err != nil { + utils.Fatal(ctx, "getting the drive:", err) + } + + checkDriveRestoration( + ctx, + client, + path.OneDriveService, + folderName, + ptr.Val(drive.GetId()), + ptr.Val(drive.GetName()), + dataFolder, + startTime, + false) +} + +func checkDriveRestoration( + ctx context.Context, + client *msgraphsdk.GraphServiceClient, + service path.ServiceType, + folderName, + driveID, + driveName, + dataFolder string, + startTime time.Time, + skipPermissionTest bool, +) { + var ( + // map itemID -> item size + fileSizes = make(map[string]int64) + // map itemID -> permission id -> []permission roles + folderPermissions = make(map[string][]permissionInfo) + restoreFile = make(map[string]int64) + restoredFolderPermissions = make(map[string][]permissionInfo) + ) + + var restoreFolderID string + + ctx = clues.Add(ctx, "drive_id", driveID, "drive_name", driveName) + + response, err := client. + Drives(). + ByDriveId(driveID). + Items(). + ByDriveItemId("root"). + Children(). + Get(ctx, nil) + if err != nil { + utils.Fatal(ctx, "getting drive by id", err) + } + + for _, driveItem := range response.GetValue() { + var ( + itemID = ptr.Val(driveItem.GetId()) + itemName = ptr.Val(driveItem.GetName()) + ) + + if itemName == folderName { + restoreFolderID = itemID + continue + } + + if itemName != dataFolder { + utils.LogAndPrint(ctx, "test data for folder: %s", dataFolder) + continue + } + + // if it's a file check the size + if driveItem.GetFile() != nil { + fileSizes[itemName] = ptr.Val(driveItem.GetSize()) + } + + if driveItem.GetFolder() == nil && driveItem.GetPackage() == nil { + continue + } + + // currently we don't restore blank folders. + // skip permission check for empty folders + if ptr.Val(driveItem.GetFolder().GetChildCount()) == 0 { + utils.LogAndPrint(ctx, "skipped empty folder: %s", itemName) + continue + } + + folderPermissions[itemName] = permissionIn(ctx, client, driveID, itemID) + getOneDriveChildFolder(ctx, client, driveID, itemID, itemName, fileSizes, folderPermissions, startTime) + } + + getRestoredDrive(ctx, client, driveID, restoreFolderID, restoreFile, restoredFolderPermissions, startTime) + + checkRestoredDriveItemPermissions( + ctx, + service, + skipPermissionTest, + folderPermissions, + restoredFolderPermissions) + + for fileName, expected := range fileSizes { + utils.LogAndPrint(ctx, "checking for file: %s", fileName) + + got := restoreFile[fileName] + + utils.Assert( + ctx, + func() bool { return expected == got }, + fmt.Sprintf("different file size: %s", fileName), + expected, + got) + } + + fmt.Println("Success") +} + +func checkRestoredDriveItemPermissions( + ctx context.Context, + service path.ServiceType, + skip bool, + folderPermissions map[string][]permissionInfo, + restoredFolderPermissions map[string][]permissionInfo, +) { + if skip { + return + } + + /** + TODO: replace this check with testElementsMatch + from internal/connecter/graph_connector_helper_test.go + **/ + + for folderName, permissions := range folderPermissions { + utils.LogAndPrint(ctx, "checking for folder: %s", folderName) + + restoreFolderPerm := restoredFolderPermissions[folderName] + + if len(permissions) < 1 { + utils.LogAndPrint(ctx, "no permissions found in: %s", folderName) + continue + } + + permCheck := func() bool { return len(permissions) == len(restoreFolderPerm) } + + if service == path.SharePointService { + permCheck = func() bool { return len(permissions) <= len(restoreFolderPerm) } + } + + utils.Assert( + ctx, + permCheck, + fmt.Sprintf("wrong number of restored permissions: %s", folderName), + permissions, + restoreFolderPerm) + + for _, perm := range permissions { + eqID := func(pi permissionInfo) bool { return strings.EqualFold(pi.entityID, perm.entityID) } + i := slices.IndexFunc(restoreFolderPerm, eqID) + + utils.Assert( + ctx, + func() bool { return i >= 0 }, + fmt.Sprintf("permission was restored in: %s", folderName), + perm.entityID, + restoreFolderPerm) + + // permissions should be sorted, so a by-index comparison works + restored := restoreFolderPerm[i] + + utils.Assert( + ctx, + func() bool { return slices.Equal(perm.roles, restored.roles) }, + fmt.Sprintf("different roles restored: %s", folderName), + perm.roles, + restored.roles) + } + } +} + +func getOneDriveChildFolder( + ctx context.Context, + client *msgraphsdk.GraphServiceClient, + driveID, itemID, parentName string, + fileSizes map[string]int64, + folderPermission map[string][]permissionInfo, + startTime time.Time, +) { + response, err := client.Drives().ByDriveId(driveID).Items().ByDriveItemId(itemID).Children().Get(ctx, nil) + if err != nil { + utils.Fatal(ctx, "getting child folder", err) + } + + for _, driveItem := range response.GetValue() { + var ( + itemID = ptr.Val(driveItem.GetId()) + itemName = ptr.Val(driveItem.GetName()) + fullName = parentName + "/" + itemName + ) + + folderTime, hasTime := utils.MustGetTimeFromName(ctx, itemName) + if !utils.IsWithinTimeBound(ctx, startTime, folderTime, hasTime) { + continue + } + + // if it's a file check the size + if driveItem.GetFile() != nil { + fileSizes[fullName] = ptr.Val(driveItem.GetSize()) + } + + if driveItem.GetFolder() == nil && driveItem.GetPackage() == nil { + continue + } + + // currently we don't restore blank folders. + // skip permission check for empty folders + if ptr.Val(driveItem.GetFolder().GetChildCount()) == 0 { + utils.LogAndPrint(ctx, "skipped empty folder: %s", fullName) + + continue + } + + folderPermission[fullName] = permissionIn(ctx, client, driveID, itemID) + getOneDriveChildFolder(ctx, client, driveID, itemID, fullName, fileSizes, folderPermission, startTime) + } +} + +func getRestoredDrive( + ctx context.Context, + client *msgraphsdk.GraphServiceClient, + driveID, restoreFolderID string, + restoreFile map[string]int64, + restoreFolder map[string][]permissionInfo, + startTime time.Time, +) { + restored, err := client. + Drives(). + ByDriveId(driveID). + Items(). + ByDriveItemId(restoreFolderID). + Children(). + Get(ctx, nil) + if err != nil { + utils.Fatal(ctx, "getting child folder", err) + } + + for _, item := range restored.GetValue() { + var ( + itemID = ptr.Val(item.GetId()) + itemName = ptr.Val(item.GetName()) + itemSize = ptr.Val(item.GetSize()) + ) + + if item.GetFile() != nil { + restoreFile[itemName] = itemSize + continue + } + + if item.GetFolder() == nil && item.GetPackage() == nil { + continue + } + + restoreFolder[itemName] = permissionIn(ctx, client, driveID, itemID) + getOneDriveChildFolder(ctx, client, driveID, itemID, itemName, restoreFile, restoreFolder, startTime) + } +} + +// --------------------------------------------------------------------------- +// permission helpers +// --------------------------------------------------------------------------- + +func permissionIn( + ctx context.Context, + client *msgraphsdk.GraphServiceClient, + driveID, itemID string, +) []permissionInfo { + pi := []permissionInfo{} + + pcr, err := client. + Drives(). + ByDriveId(driveID). + Items(). + ByDriveItemId(itemID). + Permissions(). + Get(ctx, nil) + if err != nil { + utils.Fatal(ctx, "getting permission", err) + } + + for _, perm := range pcr.GetValue() { + if perm.GetGrantedToV2() == nil { + continue + } + + var ( + gv2 = perm.GetGrantedToV2() + permInfo = permissionInfo{} + entityID string + ) + + // TODO: replace with filterUserPermissions in onedrive item.go + if gv2.GetUser() != nil { + entityID = ptr.Val(gv2.GetUser().GetId()) + } else if gv2.GetGroup() != nil { + entityID = ptr.Val(gv2.GetGroup().GetId()) + } + + roles := utils.FilterSlice(perm.GetRoles(), owner) + for _, role := range roles { + permInfo.entityID = entityID + permInfo.roles = append(permInfo.roles, role) + } + + if len(roles) > 0 { + slices.Sort(permInfo.roles) + pi = append(pi, permInfo) + } + } + + return pi +} diff --git a/src/cmd/sanity_test/restore/sharepoint.go b/src/cmd/sanity_test/restore/sharepoint.go new file mode 100644 index 000000000..d84e28f6c --- /dev/null +++ b/src/cmd/sanity_test/restore/sharepoint.go @@ -0,0 +1,39 @@ +package restore + +import ( + "context" + "time" + + msgraphsdk "github.com/microsoftgraph/msgraph-sdk-go" + + "github.com/alcionai/corso/src/cmd/sanity_test/utils" + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/pkg/path" +) + +func CheckSharePointRestoration( + ctx context.Context, + client *msgraphsdk.GraphServiceClient, + siteID, userID, folderName, dataFolder string, + startTime time.Time, +) { + drive, err := client. + Sites(). + BySiteId(siteID). + Drive(). + Get(ctx, nil) + if err != nil { + utils.Fatal(ctx, "getting the drive:", err) + } + + checkDriveRestoration( + ctx, + client, + path.SharePointService, + folderName, + ptr.Val(drive.GetId()), + ptr.Val(drive.GetName()), + dataFolder, + startTime, + true) +} diff --git a/src/cmd/sanity_test/sanity_tests.go b/src/cmd/sanity_test/sanity_tests.go index ab89586b4..a4d455edd 100644 --- a/src/cmd/sanity_test/sanity_tests.go +++ b/src/cmd/sanity_test/sanity_tests.go @@ -2,45 +2,20 @@ package main import ( "context" - "errors" - "fmt" "os" - stdpath "path" "strings" "time" "github.com/alcionai/clues" msgraphsdk "github.com/microsoftgraph/msgraph-sdk-go" - "github.com/microsoftgraph/msgraph-sdk-go/models" - "github.com/microsoftgraph/msgraph-sdk-go/users" - "golang.org/x/exp/slices" - "github.com/alcionai/corso/src/internal/common/dttm" - "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/cmd/sanity_test/restore" + "github.com/alcionai/corso/src/cmd/sanity_test/utils" "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/internal/tester/tconfig" - "github.com/alcionai/corso/src/pkg/filters" "github.com/alcionai/corso/src/pkg/logger" - "github.com/alcionai/corso/src/pkg/path" ) -// --------------------------------------------------------------------------- -// types, consts, etc -// --------------------------------------------------------------------------- - -type permissionInfo struct { - entityID string - roles []string -} - -const ( - owner = "owner" -) - -// --------------------------------------------------------------------------- -// main -// --------------------------------------------------------------------------- - func main() { ls := logger.Settings{ File: logger.GetLogFile(""), @@ -60,7 +35,7 @@ func main() { os.Getenv("AZURE_CLIENT_ID"), os.Getenv("AZURE_CLIENT_SECRET")) if err != nil { - fatal(ctx, "creating adapter", err) + utils.Fatal(ctx, "creating adapter", err) } var ( @@ -69,7 +44,7 @@ func main() { testSite = tconfig.GetM365SiteID(ctx) testService = os.Getenv("SANITY_RESTORE_SERVICE") folder = strings.TrimSpace(os.Getenv("SANITY_RESTORE_FOLDER")) - startTime, _ = mustGetTimeFromName(ctx, folder) + startTime, _ = utils.MustGetTimeFromName(ctx, folder) dataFolder = os.Getenv("TEST_DATA") baseBackupFolder = os.Getenv("BASE_BACKUP") ) @@ -85,657 +60,12 @@ func main() { switch testService { case "exchange": - checkEmailRestoration(ctx, client, testUser, folder, dataFolder, baseBackupFolder, startTime) + restore.CheckEmailRestoration(ctx, client, testUser, folder, dataFolder, baseBackupFolder, startTime) case "onedrive": - checkOneDriveRestoration(ctx, client, testUser, folder, dataFolder, startTime) + restore.CheckOneDriveRestoration(ctx, client, testUser, folder, dataFolder, startTime) case "sharepoint": - checkSharePointRestoration(ctx, client, testSite, testUser, folder, dataFolder, startTime) + restore.CheckSharePointRestoration(ctx, client, testSite, testUser, folder, dataFolder, startTime) default: - fatal(ctx, "no service specified", nil) + utils.Fatal(ctx, "unknown service for restore sanity tests", nil) } } - -// --------------------------------------------------------------------------- -// exchange -// --------------------------------------------------------------------------- - -// checkEmailRestoration verifies that the emails count in restored folder is equivalent to -// emails in actual m365 account -func checkEmailRestoration( - ctx context.Context, - client *msgraphsdk.GraphServiceClient, - testUser, folderName, dataFolder, baseBackupFolder string, - startTime time.Time, -) { - var ( - restoreFolder models.MailFolderable - itemCount = make(map[string]int32) - restoreItemCount = make(map[string]int32) - builder = client.Users().ByUserId(testUser).MailFolders() - ) - - for { - result, err := builder.Get(ctx, nil) - if err != nil { - fatal(ctx, "getting mail folders", err) - } - - values := result.GetValue() - - for _, v := range values { - itemName := ptr.Val(v.GetDisplayName()) - - if itemName == folderName { - restoreFolder = v - continue - } - - if itemName == dataFolder || itemName == baseBackupFolder { - // otherwise, recursively aggregate all child folders. - getAllMailSubFolders(ctx, client, testUser, v, itemName, dataFolder, itemCount) - - itemCount[itemName] = ptr.Val(v.GetTotalItemCount()) - } - } - - link, ok := ptr.ValOK(result.GetOdataNextLink()) - if !ok { - break - } - - builder = users.NewItemMailFoldersRequestBuilder(link, client.GetAdapter()) - } - - folderID := ptr.Val(restoreFolder.GetId()) - folderName = ptr.Val(restoreFolder.GetDisplayName()) - ctx = clues.Add( - ctx, - "restore_folder_id", folderID, - "restore_folder_name", folderName) - - childFolder, err := client. - Users(). - ByUserId(testUser). - MailFolders(). - ByMailFolderId(folderID). - ChildFolders(). - Get(ctx, nil) - if err != nil { - fatal(ctx, "getting restore folder child folders", err) - } - - for _, fld := range childFolder.GetValue() { - restoreDisplayName := ptr.Val(fld.GetDisplayName()) - - // check if folder is the data folder we loaded or the base backup to verify - // the incremental backup worked fine - if strings.EqualFold(restoreDisplayName, dataFolder) || strings.EqualFold(restoreDisplayName, baseBackupFolder) { - count, _ := ptr.ValOK(fld.GetTotalItemCount()) - - restoreItemCount[restoreDisplayName] = count - checkAllSubFolder(ctx, client, fld, testUser, restoreDisplayName, dataFolder, restoreItemCount) - } - } - - verifyEmailData(ctx, restoreItemCount, itemCount) -} - -func verifyEmailData(ctx context.Context, restoreMessageCount, messageCount map[string]int32) { - for fldName, expected := range messageCount { - got := restoreMessageCount[fldName] - - assert( - ctx, - func() bool { return expected == got }, - fmt.Sprintf("Restore item counts do not match: %s", fldName), - expected, - got) - } -} - -// getAllSubFolder will recursively check for all subfolders and get the corresponding -// email count. -func getAllMailSubFolders( - ctx context.Context, - client *msgraphsdk.GraphServiceClient, - testUser string, - r models.MailFolderable, - parentFolder, - dataFolder string, - messageCount map[string]int32, -) { - var ( - folderID = ptr.Val(r.GetId()) - count int32 = 99 - options = &users.ItemMailFoldersItemChildFoldersRequestBuilderGetRequestConfiguration{ - QueryParameters: &users.ItemMailFoldersItemChildFoldersRequestBuilderGetQueryParameters{ - Top: &count, - }, - } - ) - - ctx = clues.Add(ctx, "parent_folder_id", folderID) - - childFolder, err := client. - Users(). - ByUserId(testUser). - MailFolders(). - ByMailFolderId(folderID). - ChildFolders(). - Get(ctx, options) - if err != nil { - fatal(ctx, "getting mail subfolders", err) - } - - for _, child := range childFolder.GetValue() { - var ( - childDisplayName = ptr.Val(child.GetDisplayName()) - childFolderCount = ptr.Val(child.GetChildFolderCount()) - //nolint:forbidigo - fullFolderName = stdpath.Join(parentFolder, childDisplayName) - ) - - if filters.PathContains([]string{dataFolder}).Compare(fullFolderName) { - messageCount[fullFolderName] = ptr.Val(child.GetTotalItemCount()) - // recursively check for subfolders - if childFolderCount > 0 { - parentFolder := fullFolderName - - getAllMailSubFolders(ctx, client, testUser, child, parentFolder, dataFolder, messageCount) - } - } - } -} - -// checkAllSubFolder will recursively traverse inside the restore folder and -// verify that data matched in all subfolders -func checkAllSubFolder( - ctx context.Context, - client *msgraphsdk.GraphServiceClient, - r models.MailFolderable, - testUser, - parentFolder, - dataFolder string, - restoreMessageCount map[string]int32, -) { - var ( - folderID = ptr.Val(r.GetId()) - count int32 = 99 - options = &users.ItemMailFoldersItemChildFoldersRequestBuilderGetRequestConfiguration{ - QueryParameters: &users.ItemMailFoldersItemChildFoldersRequestBuilderGetQueryParameters{ - Top: &count, - }, - } - ) - - childFolder, err := client. - Users(). - ByUserId(testUser). - MailFolders(). - ByMailFolderId(folderID). - ChildFolders(). - Get(ctx, options) - if err != nil { - fatal(ctx, "getting mail subfolders", err) - } - - for _, child := range childFolder.GetValue() { - var ( - childDisplayName = ptr.Val(child.GetDisplayName()) - //nolint:forbidigo - fullFolderName = stdpath.Join(parentFolder, childDisplayName) - ) - - if filters.PathContains([]string{dataFolder}).Compare(fullFolderName) { - childTotalCount, _ := ptr.ValOK(child.GetTotalItemCount()) - restoreMessageCount[fullFolderName] = childTotalCount - } - - childFolderCount := ptr.Val(child.GetChildFolderCount()) - - if childFolderCount > 0 { - parentFolder := fullFolderName - checkAllSubFolder(ctx, client, child, testUser, parentFolder, dataFolder, restoreMessageCount) - } - } -} - -// --------------------------------------------------------------------------- -// oneDrive -// --------------------------------------------------------------------------- - -func checkOneDriveRestoration( - ctx context.Context, - client *msgraphsdk.GraphServiceClient, - userID, folderName, dataFolder string, - startTime time.Time, -) { - drive, err := client. - Users(). - ByUserId(userID). - Drive(). - Get(ctx, nil) - if err != nil { - fatal(ctx, "getting the drive:", err) - } - - checkDriveRestoration( - ctx, - client, - path.OneDriveService, - folderName, - ptr.Val(drive.GetId()), - ptr.Val(drive.GetName()), - dataFolder, - startTime, - false) -} - -// --------------------------------------------------------------------------- -// sharePoint -// --------------------------------------------------------------------------- - -func checkSharePointRestoration( - ctx context.Context, - client *msgraphsdk.GraphServiceClient, - siteID, userID, folderName, dataFolder string, - startTime time.Time, -) { - drive, err := client. - Sites(). - BySiteId(siteID). - Drive(). - Get(ctx, nil) - if err != nil { - fatal(ctx, "getting the drive:", err) - } - - checkDriveRestoration( - ctx, - client, - path.SharePointService, - folderName, - ptr.Val(drive.GetId()), - ptr.Val(drive.GetName()), - dataFolder, - startTime, - true) -} - -// --------------------------------------------------------------------------- -// shared drive tests -// --------------------------------------------------------------------------- - -func checkDriveRestoration( - ctx context.Context, - client *msgraphsdk.GraphServiceClient, - service path.ServiceType, - folderName, - driveID, - driveName, - dataFolder string, - startTime time.Time, - skipPermissionTest bool, -) { - var ( - // map itemID -> item size - fileSizes = make(map[string]int64) - // map itemID -> permission id -> []permission roles - folderPermissions = make(map[string][]permissionInfo) - restoreFile = make(map[string]int64) - restoredFolderPermissions = make(map[string][]permissionInfo) - ) - - var restoreFolderID string - - ctx = clues.Add(ctx, "drive_id", driveID, "drive_name", driveName) - - response, err := client. - Drives(). - ByDriveId(driveID). - Items(). - ByDriveItemId("root"). - Children(). - Get(ctx, nil) - if err != nil { - fatal(ctx, "getting drive by id", err) - } - - for _, driveItem := range response.GetValue() { - var ( - itemID = ptr.Val(driveItem.GetId()) - itemName = ptr.Val(driveItem.GetName()) - ) - - if itemName == folderName { - restoreFolderID = itemID - continue - } - - if itemName != dataFolder { - logAndPrint(ctx, "test data for folder: %s", dataFolder) - continue - } - - // if it's a file check the size - if driveItem.GetFile() != nil { - fileSizes[itemName] = ptr.Val(driveItem.GetSize()) - } - - if driveItem.GetFolder() == nil && driveItem.GetPackage() == nil { - continue - } - - // currently we don't restore blank folders. - // skip permission check for empty folders - if ptr.Val(driveItem.GetFolder().GetChildCount()) == 0 { - logAndPrint(ctx, "skipped empty folder: %s", itemName) - continue - } - - folderPermissions[itemName] = permissionIn(ctx, client, driveID, itemID) - getOneDriveChildFolder(ctx, client, driveID, itemID, itemName, fileSizes, folderPermissions, startTime) - } - - getRestoredDrive(ctx, client, driveID, restoreFolderID, restoreFile, restoredFolderPermissions, startTime) - - checkRestoredDriveItemPermissions( - ctx, - service, - skipPermissionTest, - folderPermissions, - restoredFolderPermissions) - - for fileName, expected := range fileSizes { - logAndPrint(ctx, "checking for file: %s", fileName) - - got := restoreFile[fileName] - - assert( - ctx, - func() bool { return expected == got }, - fmt.Sprintf("different file size: %s", fileName), - expected, - got) - } - - fmt.Println("Success") -} - -func checkRestoredDriveItemPermissions( - ctx context.Context, - service path.ServiceType, - skip bool, - folderPermissions map[string][]permissionInfo, - restoredFolderPermissions map[string][]permissionInfo, -) { - if skip { - return - } - - /** - TODO: replace this check with testElementsMatch - from internal/connecter/graph_connector_helper_test.go - **/ - - for folderName, permissions := range folderPermissions { - logAndPrint(ctx, "checking for folder: %s", folderName) - - restoreFolderPerm := restoredFolderPermissions[folderName] - - if len(permissions) < 1 { - logAndPrint(ctx, "no permissions found in: %s", folderName) - continue - } - - permCheck := func() bool { return len(permissions) == len(restoreFolderPerm) } - - if service == path.SharePointService { - permCheck = func() bool { return len(permissions) <= len(restoreFolderPerm) } - } - - assert( - ctx, - permCheck, - fmt.Sprintf("wrong number of restored permissions: %s", folderName), - permissions, - restoreFolderPerm) - - for _, perm := range permissions { - eqID := func(pi permissionInfo) bool { return strings.EqualFold(pi.entityID, perm.entityID) } - i := slices.IndexFunc(restoreFolderPerm, eqID) - - assert( - ctx, - func() bool { return i >= 0 }, - fmt.Sprintf("permission was restored in: %s", folderName), - perm.entityID, - restoreFolderPerm) - - // permissions should be sorted, so a by-index comparison works - restored := restoreFolderPerm[i] - - assert( - ctx, - func() bool { return slices.Equal(perm.roles, restored.roles) }, - fmt.Sprintf("different roles restored: %s", folderName), - perm.roles, - restored.roles) - } - } -} - -func getOneDriveChildFolder( - ctx context.Context, - client *msgraphsdk.GraphServiceClient, - driveID, itemID, parentName string, - fileSizes map[string]int64, - folderPermission map[string][]permissionInfo, - startTime time.Time, -) { - response, err := client.Drives().ByDriveId(driveID).Items().ByDriveItemId(itemID).Children().Get(ctx, nil) - if err != nil { - fatal(ctx, "getting child folder", err) - } - - for _, driveItem := range response.GetValue() { - var ( - itemID = ptr.Val(driveItem.GetId()) - itemName = ptr.Val(driveItem.GetName()) - fullName = parentName + "/" + itemName - ) - - folderTime, hasTime := mustGetTimeFromName(ctx, itemName) - if !isWithinTimeBound(ctx, startTime, folderTime, hasTime) { - continue - } - - // if it's a file check the size - if driveItem.GetFile() != nil { - fileSizes[fullName] = ptr.Val(driveItem.GetSize()) - } - - if driveItem.GetFolder() == nil && driveItem.GetPackage() == nil { - continue - } - - // currently we don't restore blank folders. - // skip permission check for empty folders - if ptr.Val(driveItem.GetFolder().GetChildCount()) == 0 { - logAndPrint(ctx, "skipped empty folder: %s", fullName) - - continue - } - - folderPermission[fullName] = permissionIn(ctx, client, driveID, itemID) - getOneDriveChildFolder(ctx, client, driveID, itemID, fullName, fileSizes, folderPermission, startTime) - } -} - -func getRestoredDrive( - ctx context.Context, - client *msgraphsdk.GraphServiceClient, - driveID, restoreFolderID string, - restoreFile map[string]int64, - restoreFolder map[string][]permissionInfo, - startTime time.Time, -) { - restored, err := client. - Drives(). - ByDriveId(driveID). - Items(). - ByDriveItemId(restoreFolderID). - Children(). - Get(ctx, nil) - if err != nil { - fatal(ctx, "getting child folder", err) - } - - for _, item := range restored.GetValue() { - var ( - itemID = ptr.Val(item.GetId()) - itemName = ptr.Val(item.GetName()) - itemSize = ptr.Val(item.GetSize()) - ) - - if item.GetFile() != nil { - restoreFile[itemName] = itemSize - continue - } - - if item.GetFolder() == nil && item.GetPackage() == nil { - continue - } - - restoreFolder[itemName] = permissionIn(ctx, client, driveID, itemID) - getOneDriveChildFolder(ctx, client, driveID, itemID, itemName, restoreFile, restoreFolder, startTime) - } -} - -// --------------------------------------------------------------------------- -// permission helpers -// --------------------------------------------------------------------------- - -func permissionIn( - ctx context.Context, - client *msgraphsdk.GraphServiceClient, - driveID, itemID string, -) []permissionInfo { - pi := []permissionInfo{} - - pcr, err := client. - Drives(). - ByDriveId(driveID). - Items(). - ByDriveItemId(itemID). - Permissions(). - Get(ctx, nil) - if err != nil { - fatal(ctx, "getting permission", err) - } - - for _, perm := range pcr.GetValue() { - if perm.GetGrantedToV2() == nil { - continue - } - - var ( - gv2 = perm.GetGrantedToV2() - permInfo = permissionInfo{} - entityID string - ) - - // TODO: replace with filterUserPermissions in onedrive item.go - if gv2.GetUser() != nil { - entityID = ptr.Val(gv2.GetUser().GetId()) - } else if gv2.GetGroup() != nil { - entityID = ptr.Val(gv2.GetGroup().GetId()) - } - - roles := filterSlice(perm.GetRoles(), owner) - for _, role := range roles { - permInfo.entityID = entityID - permInfo.roles = append(permInfo.roles, role) - } - - if len(roles) > 0 { - slices.Sort(permInfo.roles) - pi = append(pi, permInfo) - } - } - - return pi -} - -// --------------------------------------------------------------------------- -// Helpers -// --------------------------------------------------------------------------- - -func fatal(ctx context.Context, msg string, err error) { - logger.CtxErr(ctx, err).Error("test failure: " + msg) - fmt.Println(msg+": ", err) - os.Exit(1) -} - -func mustGetTimeFromName(ctx context.Context, name string) (time.Time, bool) { - t, err := dttm.ExtractTime(name) - if err != nil && !errors.Is(err, dttm.ErrNoTimeString) { - fatal(ctx, "extracting time from name: "+name, err) - } - - return t, !errors.Is(err, dttm.ErrNoTimeString) -} - -func isWithinTimeBound(ctx context.Context, bound, check time.Time, hasTime bool) bool { - if hasTime { - if bound.Before(check) { - logger.Ctx(ctx). - With("boundary_time", bound, "check_time", check). - Info("skipping restore folder: not older than time bound") - - return false - } - } - - return true -} - -func filterSlice(sl []string, remove string) []string { - r := []string{} - - for _, s := range sl { - if !strings.EqualFold(s, remove) { - r = append(r, s) - } - } - - return r -} - -func assert( - ctx context.Context, - passes func() bool, - header string, - expect, current any, -) { - if passes() { - return - } - - header = "Error: " + header - expected := fmt.Sprintf("* Expected: %+v", expect) - got := fmt.Sprintf("* Current: %+v", current) - - logger.Ctx(ctx).Info(strings.Join([]string{header, expected, got}, " ")) - - fmt.Println(header) - fmt.Println(expected) - fmt.Println(got) - - os.Exit(1) -} - -func logAndPrint(ctx context.Context, tmpl string, vs ...any) { - logger.Ctx(ctx).Infof(tmpl, vs...) - fmt.Printf(tmpl+"\n", vs...) -} diff --git a/src/cmd/sanity_test/utils/utils.go b/src/cmd/sanity_test/utils/utils.go new file mode 100644 index 000000000..d0a877d61 --- /dev/null +++ b/src/cmd/sanity_test/utils/utils.go @@ -0,0 +1,82 @@ +package utils + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/alcionai/corso/src/internal/common/dttm" + "github.com/alcionai/corso/src/pkg/logger" +) + +func Assert( + ctx context.Context, + passes func() bool, + header string, + expect, current any, +) { + if passes() { + return + } + + header = "Error: " + header + expected := fmt.Sprintf("* Expected: %+v", expect) + got := fmt.Sprintf("* Current: %+v", current) + + logger.Ctx(ctx).Info(strings.Join([]string{header, expected, got}, " ")) + + fmt.Println(header) + fmt.Println(expected) + fmt.Println(got) + + os.Exit(1) +} + +func Fatal(ctx context.Context, msg string, err error) { + logger.CtxErr(ctx, err).Error("test failure: " + msg) + fmt.Println(msg+": ", err) + os.Exit(1) +} + +func MustGetTimeFromName(ctx context.Context, name string) (time.Time, bool) { + t, err := dttm.ExtractTime(name) + if err != nil && !errors.Is(err, dttm.ErrNoTimeString) { + Fatal(ctx, "extracting time from name: "+name, err) + } + + return t, !errors.Is(err, dttm.ErrNoTimeString) +} + +func IsWithinTimeBound(ctx context.Context, bound, check time.Time, hasTime bool) bool { + if hasTime { + if bound.Before(check) { + logger.Ctx(ctx). + With("boundary_time", bound, "check_time", check). + Info("skipping restore folder: not older than time bound") + + return false + } + } + + return true +} + +func FilterSlice(sl []string, remove string) []string { + r := []string{} + + for _, s := range sl { + if !strings.EqualFold(s, remove) { + r = append(r, s) + } + } + + return r +} + +func LogAndPrint(ctx context.Context, tmpl string, vs ...any) { + logger.Ctx(ctx).Infof(tmpl, vs...) + fmt.Printf(tmpl+"\n", vs...) +}