diff --git a/CHANGELOG.md b/CHANGELOG.md index d528b18c8..3ecc05a1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] (beta) +### Added + +### Fixed +- Return a ServiceNotEnabled error when a tenant has no active SharePoint license. + +## [v0.10.0] (beta) - 2023-06-26 + +### Added +- Exceptions and cancellations for recurring events are now backed up and restored +- Introduced a URL cache for OneDrive that helps reduce Graph API calls for long running (>1hr) backups +- Improve incremental backup behavior by leveraging information from incomplete backups +- Improve restore performance and memory use for Exchange and OneDrive + +### Fixed +- Handle OLE conversion errors when trying to fetch attachments +- Fix uploading large attachments for emails and calendar +- Fixed high memory use in OneDrive backup related to logging + +### Changed +- Switched to Go 1.20 + ## [v0.9.0] (beta) - 2023-06-05 ### Added @@ -18,7 +39,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix Exchange folder cache population error when parent folder isn't found. - Fix Exchange backup issue caused by incorrect json serialization - Fix issues with details model containing duplicate entry for api consumers -- Handle OLE conversion errors when trying to fetch attachments ### Changed - Do not display all the items that we restored at the end if there are more than 15. You can override this with `--verbose`. diff --git a/src/cli/backup/backup.go b/src/cli/backup/backup.go index c721e4c3f..f43cd6474 100644 --- a/src/cli/backup/backup.go +++ b/src/cli/backup/backup.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/errors" "github.com/spf13/cobra" + "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/idname" @@ -58,31 +59,21 @@ func AddCommands(cmd *cobra.Command) { // common flags and flag attachers for commands // --------------------------------------------------------------------------- -// list output filter flags -var ( - failedItemsFN = "failed-items" - listFailedItems string - skippedItemsFN = "skipped-items" - listSkippedItems string - recoveredErrorsFN = "recovered-errors" - listRecoveredErrors string -) - func addFailedItemsFN(cmd *cobra.Command) { cmd.Flags().StringVar( - &listFailedItems, failedItemsFN, "show", + &flags.ListFailedItemsFV, flags.FailedItemsFN, "show", "Toggles showing or hiding the list of items that failed.") } func addSkippedItemsFN(cmd *cobra.Command) { cmd.Flags().StringVar( - &listSkippedItems, skippedItemsFN, "show", + &flags.ListSkippedItemsFV, flags.SkippedItemsFN, "show", "Toggles showing or hiding the list of items that were skipped.") } func addRecoveredErrorsFN(cmd *cobra.Command) { cmd.Flags().StringVar( - &listRecoveredErrors, recoveredErrorsFN, "show", + &flags.ListRecoveredErrorsFV, flags.RecoveredErrorsFN, "show", "Toggles showing or hiding the list of errors which corso recovered from.") } @@ -318,7 +309,11 @@ func genericListCommand(cmd *cobra.Command, bID string, service path.ServiceType } b.Print(ctx) - fe.PrintItems(ctx, !ifShow(listFailedItems), !ifShow(listSkippedItems), !ifShow(listRecoveredErrors)) + fe.PrintItems( + ctx, + !ifShow(flags.ListFailedItemsFV), + !ifShow(flags.ListSkippedItemsFV), + !ifShow(flags.ListRecoveredErrorsFV)) return nil } diff --git a/src/cli/backup/exchange.go b/src/cli/backup/exchange.go index af71c6a30..06a231a3d 100644 --- a/src/cli/backup/exchange.go +++ b/src/cli/backup/exchange.go @@ -8,7 +8,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" - "github.com/alcionai/corso/src/cli/options" + "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/data" @@ -31,7 +31,7 @@ const ( const ( exchangeServiceCommand = "exchange" - exchangeServiceCommandCreateUseSuffix = "--mailbox | '" + utils.Wildcard + "'" + exchangeServiceCommandCreateUseSuffix = "--mailbox | '" + flags.Wildcard + "'" exchangeServiceCommandDeleteUseSuffix = "--backup " exchangeServiceCommandDetailsUseSuffix = "--backup " ) @@ -82,20 +82,20 @@ func addExchangeCommands(cmd *cobra.Command) *cobra.Command { // Flags addition ordering should follow the order we want them to appear in help and docs: // More generic (ex: --user) and more frequently used flags take precedence. - utils.AddMailBoxFlag(c) - utils.AddDataFlag(c, []string{dataEmail, dataContacts, dataEvents}, false) - options.AddFetchParallelismFlag(c) - options.AddFailFastFlag(c) - options.AddDisableIncrementalsFlag(c) - options.AddDisableDeltaFlag(c) - options.AddEnableImmutableIDFlag(c) - options.AddDisableConcurrencyLimiterFlag(c) + flags.AddMailBoxFlag(c) + flags.AddDataFlag(c, []string{dataEmail, dataContacts, dataEvents}, false) + flags.AddFetchParallelismFlag(c) + flags.AddFailFastFlag(c) + flags.AddDisableIncrementalsFlag(c) + flags.AddDisableDeltaFlag(c) + flags.AddEnableImmutableIDFlag(c) + flags.AddDisableConcurrencyLimiterFlag(c) case listCommand: c, fs = utils.AddCommand(cmd, exchangeListCmd()) fs.SortFlags = false - utils.AddBackupIDFlag(c, false) + flags.AddBackupIDFlag(c, false) addFailedItemsFN(c) addSkippedItemsFN(c) addRecoveredErrorsFN(c) @@ -107,12 +107,12 @@ func addExchangeCommands(cmd *cobra.Command) *cobra.Command { c.Use = c.Use + " " + exchangeServiceCommandDetailsUseSuffix c.Example = exchangeServiceCommandDetailsExamples - options.AddSkipReduceFlag(c) + flags.AddSkipReduceFlag(c) // Flags addition ordering should follow the order we want them to appear in help and docs: // More generic (ex: --user) and more frequently used flags take precedence. - utils.AddBackupIDFlag(c, true) - utils.AddExchangeDetailsAndRestoreFlags(c) + flags.AddBackupIDFlag(c, true) + flags.AddExchangeDetailsAndRestoreFlags(c) case deleteCommand: c, fs = utils.AddCommand(cmd, exchangeDeleteCmd()) @@ -121,7 +121,7 @@ func addExchangeCommands(cmd *cobra.Command) *cobra.Command { c.Use = c.Use + " " + exchangeServiceCommandDeleteUseSuffix c.Example = exchangeServiceCommandDeleteExamples - utils.AddBackupIDFlag(c, true) + flags.AddBackupIDFlag(c, true) } return c @@ -149,7 +149,7 @@ func createExchangeCmd(cmd *cobra.Command, args []string) error { return nil } - if err := validateExchangeBackupCreateFlags(utils.UserFV, utils.CategoryDataFV); err != nil { + if err := validateExchangeBackupCreateFlags(flags.UserFV, flags.CategoryDataFV); err != nil { return err } @@ -160,7 +160,7 @@ func createExchangeCmd(cmd *cobra.Command, args []string) error { defer utils.CloseRepo(ctx, r) - sel := exchangeBackupCreateSelectors(utils.UserFV, utils.CategoryDataFV) + sel := exchangeBackupCreateSelectors(flags.UserFV, flags.CategoryDataFV) ins, err := utils.UsersMap(ctx, *acct, fault.New(true)) if err != nil { @@ -235,7 +235,7 @@ func exchangeListCmd() *cobra.Command { // lists the history of backup operations func listExchangeCmd(cmd *cobra.Command, args []string) error { - return genericListCommand(cmd, utils.BackupIDFV, path.ExchangeService, args) + return genericListCommand(cmd, flags.BackupIDFV, path.ExchangeService, args) } // ------------------------------------------------------------------------------------------------ @@ -269,9 +269,9 @@ func detailsExchangeCmd(cmd *cobra.Command, args []string) error { defer utils.CloseRepo(ctx, r) - ctrlOpts := options.Control() + ctrlOpts := utils.Control() - ds, err := runDetailsExchangeCmd(ctx, r, utils.BackupIDFV, opts, ctrlOpts.SkipReduce) + ds, err := runDetailsExchangeCmd(ctx, r, flags.BackupIDFV, opts, ctrlOpts.SkipReduce) if err != nil { return Only(ctx, err) } @@ -340,5 +340,5 @@ func exchangeDeleteCmd() *cobra.Command { // deletes an exchange service backup. func deleteExchangeCmd(cmd *cobra.Command, args []string) error { - return genericDeleteCommand(cmd, utils.BackupIDFV, "Exchange", args) + return genericDeleteCommand(cmd, flags.BackupIDFV, "Exchange", args) } diff --git a/src/cli/backup/exchange_e2e_test.go b/src/cli/backup/exchange_e2e_test.go index 9400f0d90..517f42e88 100644 --- a/src/cli/backup/exchange_e2e_test.go +++ b/src/cli/backup/exchange_e2e_test.go @@ -16,8 +16,8 @@ import ( "github.com/alcionai/corso/src/cli" "github.com/alcionai/corso/src/cli/config" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/print" - "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/m365/exchange" "github.com/alcionai/corso/src/internal/operations" @@ -469,7 +469,7 @@ func runExchangeDetailsCmdTest(suite *PreparedBackupExchangeE2ESuite, category p cmd := tester.StubRootCmd( "backup", "details", "exchange", "--config-file", suite.cfgFP, - "--"+utils.BackupFN, string(bID)) + "--"+flags.BackupFN, string(bID)) cli.BuildCommandTree(cmd) cmd.SetOut(&suite.recorder) @@ -568,7 +568,7 @@ func (suite *BackupDeleteExchangeE2ESuite) TestExchangeBackupDeleteCmd() { cmd := tester.StubRootCmd( "backup", "delete", "exchange", "--config-file", suite.cfgFP, - "--"+utils.BackupFN, string(suite.backupOp.Results.BackupID)) + "--"+flags.BackupFN, string(suite.backupOp.Results.BackupID)) cli.BuildCommandTree(cmd) // run the command @@ -597,7 +597,7 @@ func (suite *BackupDeleteExchangeE2ESuite) TestExchangeBackupDeleteCmd_UnknownID cmd := tester.StubRootCmd( "backup", "delete", "exchange", "--config-file", suite.cfgFP, - "--"+utils.BackupFN, uuid.NewString()) + "--"+flags.BackupFN, uuid.NewString()) cli.BuildCommandTree(cmd) // unknown backupIDs should error since the modelStore can't find the backup @@ -617,8 +617,8 @@ func buildExchangeBackupCmd( cmd := tester.StubRootCmd( "backup", "create", "exchange", "--config-file", configFile, - "--"+utils.UserFN, user, - "--"+utils.CategoryDataFN, category) + "--"+flags.UserFN, user, + "--"+flags.CategoryDataFN, category) cli.BuildCommandTree(cmd) cmd.SetOut(recorder) diff --git a/src/cli/backup/exchange_test.go b/src/cli/backup/exchange_test.go index ec78978a2..6bd078797 100644 --- a/src/cli/backup/exchange_test.go +++ b/src/cli/backup/exchange_test.go @@ -10,8 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/alcionai/corso/src/cli/options" - "github.com/alcionai/corso/src/cli/utils" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/utils/testdata" "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/internal/version" @@ -43,14 +42,14 @@ func (suite *ExchangeUnitSuite) TestAddExchangeCommands() { expectUse + " " + exchangeServiceCommandCreateUseSuffix, exchangeCreateCmd().Short, []string{ - utils.UserFN, - utils.CategoryDataFN, - options.DisableIncrementalsFN, - options.DisableDeltaFN, - options.FailFastFN, - options.FetchParallelismFN, - options.SkipReduceFN, - options.NoStatsFN, + flags.UserFN, + flags.CategoryDataFN, + flags.DisableIncrementalsFN, + flags.DisableDeltaFN, + flags.FailFastFN, + flags.FetchParallelismFN, + flags.SkipReduceFN, + flags.NoStatsFN, }, createExchangeCmd, }, @@ -60,10 +59,10 @@ func (suite *ExchangeUnitSuite) TestAddExchangeCommands() { expectUse, exchangeListCmd().Short, []string{ - utils.BackupFN, - failedItemsFN, - skippedItemsFN, - recoveredErrorsFN, + flags.BackupFN, + flags.FailedItemsFN, + flags.SkippedItemsFN, + flags.RecoveredErrorsFN, }, listExchangeCmd, }, @@ -73,23 +72,23 @@ func (suite *ExchangeUnitSuite) TestAddExchangeCommands() { expectUse + " " + exchangeServiceCommandDetailsUseSuffix, exchangeDetailsCmd().Short, []string{ - utils.BackupFN, - utils.ContactFN, - utils.ContactFolderFN, - utils.ContactNameFN, - utils.EmailFN, - utils.EmailFolderFN, - utils.EmailReceivedAfterFN, - utils.EmailReceivedBeforeFN, - utils.EmailSenderFN, - utils.EmailSubjectFN, - utils.EventFN, - utils.EventCalendarFN, - utils.EventOrganizerFN, - utils.EventRecursFN, - utils.EventStartsAfterFN, - utils.EventStartsBeforeFN, - utils.EventSubjectFN, + flags.BackupFN, + flags.ContactFN, + flags.ContactFolderFN, + flags.ContactNameFN, + flags.EmailFN, + flags.EmailFolderFN, + flags.EmailReceivedAfterFN, + flags.EmailReceivedBeforeFN, + flags.EmailSenderFN, + flags.EmailSubjectFN, + flags.EventFN, + flags.EventCalendarFN, + flags.EventOrganizerFN, + flags.EventRecursFN, + flags.EventStartsAfterFN, + flags.EventStartsBeforeFN, + flags.EventSubjectFN, }, detailsExchangeCmd, }, @@ -98,7 +97,7 @@ func (suite *ExchangeUnitSuite) TestAddExchangeCommands() { deleteCommand, expectUse + " " + exchangeServiceCommandDeleteUseSuffix, exchangeDeleteCmd().Short, - []string{utils.BackupFN}, + []string{flags.BackupFN}, deleteExchangeCmd, }, } @@ -171,7 +170,7 @@ func (suite *ExchangeUnitSuite) TestExchangeBackupCreateSelectors() { }, { name: "any users, no data", - user: []string{utils.Wildcard}, + user: []string{flags.Wildcard}, expectIncludeLen: 3, }, { @@ -181,7 +180,7 @@ func (suite *ExchangeUnitSuite) TestExchangeBackupCreateSelectors() { }, { name: "any users, contacts", - user: []string{utils.Wildcard}, + user: []string{flags.Wildcard}, data: []string{dataContacts}, expectIncludeLen: 1, }, @@ -193,7 +192,7 @@ func (suite *ExchangeUnitSuite) TestExchangeBackupCreateSelectors() { }, { name: "any users, email", - user: []string{utils.Wildcard}, + user: []string{flags.Wildcard}, data: []string{dataEmail}, expectIncludeLen: 1, }, @@ -205,7 +204,7 @@ func (suite *ExchangeUnitSuite) TestExchangeBackupCreateSelectors() { }, { name: "any users, events", - user: []string{utils.Wildcard}, + user: []string{flags.Wildcard}, data: []string{dataEvents}, expectIncludeLen: 1, }, @@ -217,7 +216,7 @@ func (suite *ExchangeUnitSuite) TestExchangeBackupCreateSelectors() { }, { name: "any users, contacts + email", - user: []string{utils.Wildcard}, + user: []string{flags.Wildcard}, data: []string{dataContacts, dataEmail}, expectIncludeLen: 2, }, @@ -229,7 +228,7 @@ func (suite *ExchangeUnitSuite) TestExchangeBackupCreateSelectors() { }, { name: "any users, email + events", - user: []string{utils.Wildcard}, + user: []string{flags.Wildcard}, data: []string{dataEmail, dataEvents}, expectIncludeLen: 2, }, @@ -241,7 +240,7 @@ func (suite *ExchangeUnitSuite) TestExchangeBackupCreateSelectors() { }, { name: "any users, events + contacts", - user: []string{utils.Wildcard}, + user: []string{flags.Wildcard}, data: []string{dataEvents, dataContacts}, expectIncludeLen: 2, }, diff --git a/src/cli/backup/onedrive.go b/src/cli/backup/onedrive.go index b47acd496..11efd93fe 100644 --- a/src/cli/backup/onedrive.go +++ b/src/cli/backup/onedrive.go @@ -8,7 +8,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" - "github.com/alcionai/corso/src/cli/options" + "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/data" @@ -25,7 +25,7 @@ import ( const ( oneDriveServiceCommand = "onedrive" - oneDriveServiceCommandCreateUseSuffix = "--user | '" + utils.Wildcard + "'" + oneDriveServiceCommandCreateUseSuffix = "--user | '" + flags.Wildcard + "'" oneDriveServiceCommandDeleteUseSuffix = "--backup " oneDriveServiceCommandDetailsUseSuffix = "--backup " ) @@ -70,15 +70,15 @@ func addOneDriveCommands(cmd *cobra.Command) *cobra.Command { c.Use = c.Use + " " + oneDriveServiceCommandCreateUseSuffix c.Example = oneDriveServiceCommandCreateExamples - utils.AddUserFlag(c) - options.AddFailFastFlag(c) - options.AddDisableIncrementalsFlag(c) + flags.AddUserFlag(c) + flags.AddFailFastFlag(c) + flags.AddDisableIncrementalsFlag(c) case listCommand: c, fs = utils.AddCommand(cmd, oneDriveListCmd()) fs.SortFlags = false - utils.AddBackupIDFlag(c, false) + flags.AddBackupIDFlag(c, false) addFailedItemsFN(c) addSkippedItemsFN(c) addRecoveredErrorsFN(c) @@ -90,9 +90,9 @@ func addOneDriveCommands(cmd *cobra.Command) *cobra.Command { c.Use = c.Use + " " + oneDriveServiceCommandDetailsUseSuffix c.Example = oneDriveServiceCommandDetailsExamples - options.AddSkipReduceFlag(c) - utils.AddBackupIDFlag(c, true) - utils.AddOneDriveDetailsAndRestoreFlags(c) + flags.AddSkipReduceFlag(c) + flags.AddBackupIDFlag(c, true) + flags.AddOneDriveDetailsAndRestoreFlags(c) case deleteCommand: c, fs = utils.AddCommand(cmd, oneDriveDeleteCmd()) @@ -101,7 +101,7 @@ func addOneDriveCommands(cmd *cobra.Command) *cobra.Command { c.Use = c.Use + " " + oneDriveServiceCommandDeleteUseSuffix c.Example = oneDriveServiceCommandDeleteExamples - utils.AddBackupIDFlag(c, true) + flags.AddBackupIDFlag(c, true) } return c @@ -130,7 +130,7 @@ func createOneDriveCmd(cmd *cobra.Command, args []string) error { return nil } - if err := validateOneDriveBackupCreateFlags(utils.UserFV); err != nil { + if err := validateOneDriveBackupCreateFlags(flags.UserFV); err != nil { return err } @@ -141,7 +141,7 @@ func createOneDriveCmd(cmd *cobra.Command, args []string) error { defer utils.CloseRepo(ctx, r) - sel := oneDriveBackupCreateSelectors(utils.UserFV) + sel := oneDriveBackupCreateSelectors(flags.UserFV) ins, err := utils.UsersMap(ctx, *acct, fault.New(true)) if err != nil { @@ -193,7 +193,7 @@ func oneDriveListCmd() *cobra.Command { // lists the history of backup operations func listOneDriveCmd(cmd *cobra.Command, args []string) error { - return genericListCommand(cmd, utils.BackupIDFV, path.OneDriveService, args) + return genericListCommand(cmd, flags.BackupIDFV, path.OneDriveService, args) } // ------------------------------------------------------------------------------------------------ @@ -227,9 +227,9 @@ func detailsOneDriveCmd(cmd *cobra.Command, args []string) error { defer utils.CloseRepo(ctx, r) - ctrlOpts := options.Control() + ctrlOpts := utils.Control() - ds, err := runDetailsOneDriveCmd(ctx, r, utils.BackupIDFV, opts, ctrlOpts.SkipReduce) + ds, err := runDetailsOneDriveCmd(ctx, r, flags.BackupIDFV, opts, ctrlOpts.SkipReduce) if err != nil { return Only(ctx, err) } @@ -295,5 +295,5 @@ func oneDriveDeleteCmd() *cobra.Command { // deletes a oneDrive service backup. func deleteOneDriveCmd(cmd *cobra.Command, args []string) error { - return genericDeleteCommand(cmd, utils.BackupIDFV, "OneDrive", args) + return genericDeleteCommand(cmd, flags.BackupIDFV, "OneDrive", args) } diff --git a/src/cli/backup/onedrive_e2e_test.go b/src/cli/backup/onedrive_e2e_test.go index 6ae96f368..07024f612 100644 --- a/src/cli/backup/onedrive_e2e_test.go +++ b/src/cli/backup/onedrive_e2e_test.go @@ -14,8 +14,8 @@ import ( "github.com/alcionai/corso/src/cli" "github.com/alcionai/corso/src/cli/config" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/print" - "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/operations" "github.com/alcionai/corso/src/internal/tester" @@ -108,7 +108,7 @@ func (suite *NoBackupOneDriveE2ESuite) TestOneDriveBackupCmd_UserNotInTenant() { cmd := tester.StubRootCmd( "backup", "create", "onedrive", "--config-file", suite.cfgFP, - "--"+utils.UserFN, "foo@nothere.com") + "--"+flags.UserFN, "foo@nothere.com") cli.BuildCommandTree(cmd) cmd.SetOut(&recorder) @@ -200,7 +200,7 @@ func (suite *BackupDeleteOneDriveE2ESuite) TestOneDriveBackupDeleteCmd() { cmd := tester.StubRootCmd( "backup", "delete", "onedrive", "--config-file", suite.cfgFP, - "--"+utils.BackupFN, string(suite.backupOp.Results.BackupID)) + "--"+flags.BackupFN, string(suite.backupOp.Results.BackupID)) cli.BuildCommandTree(cmd) cmd.SetErr(&suite.recorder) @@ -240,7 +240,7 @@ func (suite *BackupDeleteOneDriveE2ESuite) TestOneDriveBackupDeleteCmd_unknownID cmd := tester.StubRootCmd( "backup", "delete", "onedrive", "--config-file", suite.cfgFP, - "--"+utils.BackupFN, uuid.NewString()) + "--"+flags.BackupFN, uuid.NewString()) cli.BuildCommandTree(cmd) // unknown backupIDs should error since the modelStore can't find the backup diff --git a/src/cli/backup/onedrive_test.go b/src/cli/backup/onedrive_test.go index caa52561b..3ac476aa7 100644 --- a/src/cli/backup/onedrive_test.go +++ b/src/cli/backup/onedrive_test.go @@ -10,8 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/alcionai/corso/src/cli/options" - "github.com/alcionai/corso/src/cli/utils" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/utils/testdata" "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/internal/version" @@ -43,9 +42,9 @@ func (suite *OneDriveUnitSuite) TestAddOneDriveCommands() { expectUse + " " + oneDriveServiceCommandCreateUseSuffix, oneDriveCreateCmd().Short, []string{ - utils.UserFN, - options.DisableIncrementalsFN, - options.FailFastFN, + flags.UserFN, + flags.DisableIncrementalsFN, + flags.FailFastFN, }, createOneDriveCmd, }, @@ -55,10 +54,10 @@ func (suite *OneDriveUnitSuite) TestAddOneDriveCommands() { expectUse, oneDriveListCmd().Short, []string{ - utils.BackupFN, - failedItemsFN, - skippedItemsFN, - recoveredErrorsFN, + flags.BackupFN, + flags.FailedItemsFN, + flags.SkippedItemsFN, + flags.RecoveredErrorsFN, }, listOneDriveCmd, }, @@ -68,13 +67,13 @@ func (suite *OneDriveUnitSuite) TestAddOneDriveCommands() { expectUse + " " + oneDriveServiceCommandDetailsUseSuffix, oneDriveDetailsCmd().Short, []string{ - utils.BackupFN, - utils.FolderFN, - utils.FileFN, - utils.FileCreatedAfterFN, - utils.FileCreatedBeforeFN, - utils.FileModifiedAfterFN, - utils.FileModifiedBeforeFN, + flags.BackupFN, + flags.FolderFN, + flags.FileFN, + flags.FileCreatedAfterFN, + flags.FileCreatedBeforeFN, + flags.FileModifiedAfterFN, + flags.FileModifiedBeforeFN, }, detailsOneDriveCmd, }, @@ -83,7 +82,7 @@ func (suite *OneDriveUnitSuite) TestAddOneDriveCommands() { deleteCommand, expectUse + " " + oneDriveServiceCommandDeleteUseSuffix, oneDriveDeleteCmd().Short, - []string{utils.BackupFN}, + []string{flags.BackupFN}, deleteOneDriveCmd, }, } diff --git a/src/cli/backup/sharepoint.go b/src/cli/backup/sharepoint.go index 2197252ea..2d730e51c 100644 --- a/src/cli/backup/sharepoint.go +++ b/src/cli/backup/sharepoint.go @@ -9,7 +9,7 @@ import ( "github.com/spf13/pflag" "golang.org/x/exp/slices" - "github.com/alcionai/corso/src/cli/options" + "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/idname" @@ -34,7 +34,7 @@ const ( const ( sharePointServiceCommand = "sharepoint" - sharePointServiceCommandCreateUseSuffix = "--site | '" + utils.Wildcard + "'" + sharePointServiceCommandCreateUseSuffix = "--site | '" + flags.Wildcard + "'" sharePointServiceCommandDeleteUseSuffix = "--backup " sharePointServiceCommandDetailsUseSuffix = "--backup " ) @@ -84,17 +84,17 @@ func addSharePointCommands(cmd *cobra.Command) *cobra.Command { c.Use = c.Use + " " + sharePointServiceCommandCreateUseSuffix c.Example = sharePointServiceCommandCreateExamples - utils.AddSiteFlag(c) - utils.AddSiteIDFlag(c) - utils.AddDataFlag(c, []string{dataLibraries}, true) - options.AddFailFastFlag(c) - options.AddDisableIncrementalsFlag(c) + flags.AddSiteFlag(c) + flags.AddSiteIDFlag(c) + flags.AddDataFlag(c, []string{dataLibraries}, true) + flags.AddFailFastFlag(c) + flags.AddDisableIncrementalsFlag(c) case listCommand: c, fs = utils.AddCommand(cmd, sharePointListCmd()) fs.SortFlags = false - utils.AddBackupIDFlag(c, false) + flags.AddBackupIDFlag(c, false) addFailedItemsFN(c) addSkippedItemsFN(c) addRecoveredErrorsFN(c) @@ -106,9 +106,9 @@ func addSharePointCommands(cmd *cobra.Command) *cobra.Command { c.Use = c.Use + " " + sharePointServiceCommandDetailsUseSuffix c.Example = sharePointServiceCommandDetailsExamples - options.AddSkipReduceFlag(c) - utils.AddBackupIDFlag(c, true) - utils.AddSharePointDetailsAndRestoreFlags(c) + flags.AddSkipReduceFlag(c) + flags.AddBackupIDFlag(c, true) + flags.AddSharePointDetailsAndRestoreFlags(c) case deleteCommand: c, fs = utils.AddCommand(cmd, sharePointDeleteCmd()) @@ -117,7 +117,7 @@ func addSharePointCommands(cmd *cobra.Command) *cobra.Command { c.Use = c.Use + " " + sharePointServiceCommandDeleteUseSuffix c.Example = sharePointServiceCommandDeleteExamples - utils.AddBackupIDFlag(c, true) + flags.AddBackupIDFlag(c, true) } return c @@ -146,7 +146,7 @@ func createSharePointCmd(cmd *cobra.Command, args []string) error { return nil } - if err := validateSharePointBackupCreateFlags(utils.SiteIDFV, utils.WebURLFV, utils.CategoryDataFV); err != nil { + if err := validateSharePointBackupCreateFlags(flags.SiteIDFV, flags.WebURLFV, flags.CategoryDataFV); err != nil { return err } @@ -165,7 +165,7 @@ func createSharePointCmd(cmd *cobra.Command, args []string) error { return Only(ctx, clues.Wrap(err, "Failed to retrieve M365 sites")) } - sel, err := sharePointBackupCreateSelectors(ctx, ins, utils.SiteIDFV, utils.WebURLFV, utils.CategoryDataFV) + sel, err := sharePointBackupCreateSelectors(ctx, ins, flags.SiteIDFV, flags.WebURLFV, flags.CategoryDataFV) if err != nil { return Only(ctx, clues.Wrap(err, "Retrieving up sharepoint sites by ID and URL")) } @@ -188,8 +188,8 @@ func validateSharePointBackupCreateFlags(sites, weburls, cats []string) error { if len(sites) == 0 && len(weburls) == 0 { return clues.New( "requires one or more --" + - utils.SiteFN + " urls, or the wildcard --" + - utils.SiteFN + " *", + flags.SiteFN + " urls, or the wildcard --" + + flags.SiteFN + " *", ) } @@ -214,11 +214,11 @@ func sharePointBackupCreateSelectors( return selectors.NewSharePointBackup(selectors.None()), nil } - if filters.PathContains(sites).Compare(utils.Wildcard) { + if filters.PathContains(sites).Compare(flags.Wildcard) { return includeAllSitesWithCategories(ins, cats), nil } - if filters.PathContains(weburls).Compare(utils.Wildcard) { + if filters.PathContains(weburls).Compare(flags.Wildcard) { return includeAllSitesWithCategories(ins, cats), nil } @@ -265,7 +265,7 @@ func sharePointListCmd() *cobra.Command { // lists the history of backup operations func listSharePointCmd(cmd *cobra.Command, args []string) error { - return genericListCommand(cmd, utils.BackupIDFV, path.SharePointService, args) + return genericListCommand(cmd, flags.BackupIDFV, path.SharePointService, args) } // ------------------------------------------------------------------------------------------------ @@ -285,7 +285,7 @@ func sharePointDeleteCmd() *cobra.Command { // deletes a sharePoint service backup. func deleteSharePointCmd(cmd *cobra.Command, args []string) error { - return genericDeleteCommand(cmd, utils.BackupIDFV, "SharePoint", args) + return genericDeleteCommand(cmd, flags.BackupIDFV, "SharePoint", args) } // ------------------------------------------------------------------------------------------------ @@ -319,9 +319,9 @@ func detailsSharePointCmd(cmd *cobra.Command, args []string) error { defer utils.CloseRepo(ctx, r) - ctrlOpts := options.Control() + ctrlOpts := utils.Control() - ds, err := runDetailsSharePointCmd(ctx, r, utils.BackupIDFV, opts, ctrlOpts.SkipReduce) + ds, err := runDetailsSharePointCmd(ctx, r, flags.BackupIDFV, opts, ctrlOpts.SkipReduce) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/sharepoint_e2e_test.go b/src/cli/backup/sharepoint_e2e_test.go index e3c3c5570..25afc1f8e 100644 --- a/src/cli/backup/sharepoint_e2e_test.go +++ b/src/cli/backup/sharepoint_e2e_test.go @@ -14,8 +14,8 @@ import ( "github.com/alcionai/corso/src/cli" "github.com/alcionai/corso/src/cli/config" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/print" - "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/operations" "github.com/alcionai/corso/src/internal/tester" @@ -164,7 +164,7 @@ func (suite *BackupDeleteSharePointE2ESuite) TestSharePointBackupDeleteCmd() { cmd := tester.StubRootCmd( "backup", "delete", "sharepoint", "--config-file", suite.cfgFP, - "--"+utils.BackupFN, string(suite.backupOp.Results.BackupID)) + "--"+flags.BackupFN, string(suite.backupOp.Results.BackupID)) cli.BuildCommandTree(cmd) cmd.SetErr(&suite.recorder) @@ -205,7 +205,7 @@ func (suite *BackupDeleteSharePointE2ESuite) TestSharePointBackupDeleteCmd_unkno cmd := tester.StubRootCmd( "backup", "delete", "sharepoint", "--config-file", suite.cfgFP, - "--"+utils.BackupFN, uuid.NewString()) + "--"+flags.BackupFN, uuid.NewString()) cli.BuildCommandTree(cmd) // unknown backupIDs should error since the modelStore can't find the backup diff --git a/src/cli/backup/sharepoint_test.go b/src/cli/backup/sharepoint_test.go index 0469a40ef..648d3e8c4 100644 --- a/src/cli/backup/sharepoint_test.go +++ b/src/cli/backup/sharepoint_test.go @@ -10,8 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/alcionai/corso/src/cli/options" - "github.com/alcionai/corso/src/cli/utils" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/utils/testdata" "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/tester" @@ -45,9 +44,9 @@ func (suite *SharePointUnitSuite) TestAddSharePointCommands() { expectUse + " " + sharePointServiceCommandCreateUseSuffix, sharePointCreateCmd().Short, []string{ - utils.SiteFN, - options.DisableIncrementalsFN, - options.FailFastFN, + flags.SiteFN, + flags.DisableIncrementalsFN, + flags.FailFastFN, }, createSharePointCmd, }, @@ -57,10 +56,10 @@ func (suite *SharePointUnitSuite) TestAddSharePointCommands() { expectUse, sharePointListCmd().Short, []string{ - utils.BackupFN, - failedItemsFN, - skippedItemsFN, - recoveredErrorsFN, + flags.BackupFN, + flags.FailedItemsFN, + flags.SkippedItemsFN, + flags.RecoveredErrorsFN, }, listSharePointCmd, }, @@ -70,14 +69,14 @@ func (suite *SharePointUnitSuite) TestAddSharePointCommands() { expectUse + " " + sharePointServiceCommandDetailsUseSuffix, sharePointDetailsCmd().Short, []string{ - utils.BackupFN, - utils.LibraryFN, - utils.FolderFN, - utils.FileFN, - utils.FileCreatedAfterFN, - utils.FileCreatedBeforeFN, - utils.FileModifiedAfterFN, - utils.FileModifiedBeforeFN, + flags.BackupFN, + flags.LibraryFN, + flags.FolderFN, + flags.FileFN, + flags.FileCreatedAfterFN, + flags.FileCreatedBeforeFN, + flags.FileModifiedAfterFN, + flags.FileModifiedBeforeFN, }, detailsSharePointCmd, }, @@ -86,7 +85,7 @@ func (suite *SharePointUnitSuite) TestAddSharePointCommands() { deleteCommand, expectUse + " " + sharePointServiceCommandDeleteUseSuffix, sharePointDeleteCmd().Short, - []string{utils.BackupFN}, + []string{flags.BackupFN}, deleteSharePointCmd, }, } @@ -183,13 +182,13 @@ func (suite *SharePointUnitSuite) TestSharePointBackupCreateSelectors() { }, { name: "site wildcard", - site: []string{utils.Wildcard}, + site: []string{flags.Wildcard}, expect: bothIDs, expectScopesLen: 2, }, { name: "url wildcard", - weburl: []string{utils.Wildcard}, + weburl: []string{flags.Wildcard}, expect: bothIDs, expectScopesLen: 2, }, @@ -221,7 +220,7 @@ func (suite *SharePointUnitSuite) TestSharePointBackupCreateSelectors() { }, { name: "unnecessary site wildcard", - site: []string{id1, utils.Wildcard}, + site: []string{id1, flags.Wildcard}, weburl: []string{url1, url2}, expect: bothIDs, expectScopesLen: 2, @@ -229,7 +228,7 @@ func (suite *SharePointUnitSuite) TestSharePointBackupCreateSelectors() { { name: "unnecessary url wildcard", site: []string{id1}, - weburl: []string{url1, utils.Wildcard}, + weburl: []string{url1, flags.Wildcard}, expect: bothIDs, expectScopesLen: 2, }, diff --git a/src/cli/cli.go b/src/cli/cli.go index eb1276cb5..e69d89eb5 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -11,8 +11,8 @@ import ( "github.com/alcionai/corso/src/cli/backup" "github.com/alcionai/corso/src/cli/config" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/help" - "github.com/alcionai/corso/src/cli/options" "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/restore" @@ -44,11 +44,11 @@ func preRun(cc *cobra.Command, args []string) error { ctx := cc.Context() log := logger.Ctx(ctx) - flags := utils.GetPopulatedFlags(cc) - flagSl := make([]string, 0, len(flags)) + fs := flags.GetPopulatedFlags(cc) + flagSl := make([]string, 0, len(fs)) // currently only tracking flag names to avoid pii leakage. - for f := range flags { + for f := range fs { flagSl = append(flagSl, f) } @@ -87,7 +87,7 @@ func preRun(cc *cobra.Command, args []string) error { cfg.Account.ID(), map[string]any{"command": cc.CommandPath()}, cfg.RepoID, - options.Control()) + utils.Control()) } // handle deprecated user flag in Backup exchange command @@ -138,7 +138,7 @@ func CorsoCommand() *cobra.Command { func BuildCommandTree(cmd *cobra.Command) { // want to order flags explicitly cmd.PersistentFlags().SortFlags = false - utils.AddRunModeFlag(cmd, true) + flags.AddRunModeFlag(cmd, true) cmd.Flags().BoolP("version", "v", false, "current version info") cmd.PersistentPreRunE = preRun @@ -146,7 +146,7 @@ func BuildCommandTree(cmd *cobra.Command) { logger.AddLoggingFlags(cmd) observe.AddProgressBarFlags(cmd) print.AddOutputFlag(cmd) - options.AddGlobalOperationFlags(cmd) + flags.AddGlobalOperationFlags(cmd) cmd.SetUsageTemplate(indentExamplesTemplate(corsoCmd.UsageTemplate())) cmd.CompletionOptions.DisableDefaultCmd = true diff --git a/src/cli/flags/exchange.go b/src/cli/flags/exchange.go new file mode 100644 index 000000000..c859e84a1 --- /dev/null +++ b/src/cli/flags/exchange.go @@ -0,0 +1,124 @@ +package flags + +import ( + "github.com/spf13/cobra" +) + +const ( + ContactFN = "contact" + ContactFolderFN = "contact-folder" + ContactNameFN = "contact-name" + + EmailFN = "email" + EmailFolderFN = "email-folder" + EmailReceivedAfterFN = "email-received-after" + EmailReceivedBeforeFN = "email-received-before" + EmailSenderFN = "email-sender" + EmailSubjectFN = "email-subject" + + EventFN = "event" + EventCalendarFN = "event-calendar" + EventOrganizerFN = "event-organizer" + EventRecursFN = "event-recurs" + EventStartsAfterFN = "event-starts-after" + EventStartsBeforeFN = "event-starts-before" + EventSubjectFN = "event-subject" +) + +// flag values (ie: FV) +var ( + ContactFV []string + ContactFolderFV []string + ContactNameFV string + + EmailFV []string + EmailFolderFV []string + EmailReceivedAfterFV string + EmailReceivedBeforeFV string + EmailSenderFV string + EmailSubjectFV string + + EventFV []string + EventCalendarFV []string + EventOrganizerFV string + EventRecursFV string + EventStartsAfterFV string + EventStartsBeforeFV string + EventSubjectFV string +) + +// AddExchangeDetailsAndRestoreFlags adds flags that are common to both the +// details and restore commands. +func AddExchangeDetailsAndRestoreFlags(cmd *cobra.Command) { + fs := cmd.Flags() + + // email flags + fs.StringSliceVar( + &EmailFV, + EmailFN, nil, + "Select email messages by ID; accepts '"+Wildcard+"' to select all emails.") + fs.StringSliceVar( + &EmailFolderFV, + EmailFolderFN, nil, + "Select emails within a folder; accepts '"+Wildcard+"' to select all email folders.") + fs.StringVar( + &EmailSubjectFV, + EmailSubjectFN, "", + "Select emails with a subject containing this value.") + fs.StringVar( + &EmailSenderFV, + EmailSenderFN, "", + "Select emails from a specific sender.") + fs.StringVar( + &EmailReceivedAfterFV, + EmailReceivedAfterFN, "", + "Select emails received after this datetime.") + fs.StringVar( + &EmailReceivedBeforeFV, + EmailReceivedBeforeFN, "", + "Select emails received before this datetime.") + + // event flags + fs.StringSliceVar( + &EventFV, + EventFN, nil, + "Select events by event ID; accepts '"+Wildcard+"' to select all events.") + fs.StringSliceVar( + &EventCalendarFV, + EventCalendarFN, nil, + "Select events under a calendar; accepts '"+Wildcard+"' to select all events.") + fs.StringVar( + &EventSubjectFV, + EventSubjectFN, "", + "Select events with a subject containing this value.") + fs.StringVar( + &EventOrganizerFV, + EventOrganizerFN, "", + "Select events from a specific organizer.") + fs.StringVar( + &EventRecursFV, + EventRecursFN, "", + "Select recurring events. Use `--event-recurs false` to select non-recurring events.") + fs.StringVar( + &EventStartsAfterFV, + EventStartsAfterFN, "", + "Select events starting after this datetime.") + fs.StringVar( + &EventStartsBeforeFV, + EventStartsBeforeFN, "", + "Select events starting before this datetime.") + + // contact flags + fs.StringSliceVar( + &ContactFV, + ContactFN, nil, + "Select contacts by contact ID; accepts '"+Wildcard+"' to select all contacts.") + fs.StringSliceVar( + &ContactFolderFV, + ContactFolderFN, nil, + "Select contacts within a folder; accepts '"+Wildcard+"' to select all contact folders.") + fs.StringVar( + &ContactNameFV, + ContactNameFN, "", + "Select contacts whose contact name contains this value.") +} diff --git a/src/cli/flags/flags.go b/src/cli/flags/flags.go new file mode 100644 index 000000000..dd9ac1ddc --- /dev/null +++ b/src/cli/flags/flags.go @@ -0,0 +1,36 @@ +package flags + +import ( + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +const Wildcard = "*" + +type PopulatedFlags map[string]struct{} + +func (fs PopulatedFlags) populate(pf *pflag.Flag) { + if pf == nil { + return + } + + if pf.Changed { + fs[pf.Name] = struct{}{} + } +} + +// GetPopulatedFlags returns a map of flags that have been +// populated by the user. Entry keys match the flag's long +// name. Values are empty. +func GetPopulatedFlags(cmd *cobra.Command) PopulatedFlags { + pop := PopulatedFlags{} + + fs := cmd.Flags() + if fs == nil { + return pop + } + + fs.VisitAll(pop.populate) + + return pop +} diff --git a/src/cli/flags/m365_common.go b/src/cli/flags/m365_common.go new file mode 100644 index 000000000..d4a0c0231 --- /dev/null +++ b/src/cli/flags/m365_common.go @@ -0,0 +1,42 @@ +package flags + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" +) + +var CategoryDataFV []string + +const CategoryDataFN = "data" + +func AddDataFlag(cmd *cobra.Command, allowed []string, hide bool) { + var ( + allowedMsg string + fs = cmd.Flags() + ) + + switch len(allowed) { + case 0: + return + case 1: + allowedMsg = allowed[0] + case 2: + allowedMsg = fmt.Sprintf("%s or %s", allowed[0], allowed[1]) + default: + allowedMsg = fmt.Sprintf( + "%s or %s", + strings.Join(allowed[:len(allowed)-1], ", "), + allowed[len(allowed)-1]) + } + + fs.StringSliceVar( + &CategoryDataFV, + CategoryDataFN, nil, + "Select one or more types of data to backup: "+allowedMsg+".") + + if hide { + cobra.CheckErr(fs.MarkHidden(CategoryDataFN)) + } +} diff --git a/src/cli/flags/m365_resource.go b/src/cli/flags/m365_resource.go new file mode 100644 index 000000000..d00897cf2 --- /dev/null +++ b/src/cli/flags/m365_resource.go @@ -0,0 +1,40 @@ +package flags + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +const ( + UserFN = "user" + MailBoxFN = "mailbox" +) + +var UserFV []string + +// AddUserFlag adds the --user flag. +func AddUserFlag(cmd *cobra.Command) { + cmd.Flags().StringSliceVar( + &UserFV, + UserFN, nil, + "Backup a specific user's data; accepts '"+Wildcard+"' to select all users.") + cobra.CheckErr(cmd.MarkFlagRequired(UserFN)) +} + +// AddMailBoxFlag adds the --user and --mailbox flag. +func AddMailBoxFlag(cmd *cobra.Command) { + flags := cmd.Flags() + + flags.StringSliceVar( + &UserFV, + UserFN, nil, + "Backup a specific user's data; accepts '"+Wildcard+"' to select all users.") + + cobra.CheckErr(flags.MarkDeprecated(UserFN, fmt.Sprintf("use --%s instead", MailBoxFN))) + + flags.StringSliceVar( + &UserFV, + MailBoxFN, nil, + "Backup a specific mailbox's data; accepts '"+Wildcard+"' to select all mailbox.") +} diff --git a/src/cli/flags/maintenance.go b/src/cli/flags/maintenance.go new file mode 100644 index 000000000..2c512603a --- /dev/null +++ b/src/cli/flags/maintenance.go @@ -0,0 +1,41 @@ +package flags + +import ( + "github.com/spf13/cobra" + + "github.com/alcionai/corso/src/pkg/control/repository" +) + +const ( + MaintenanceModeFN = "mode" + ForceMaintenanceFN = "force" +) + +var ( + MaintenanceModeFV string + ForceMaintenanceFV bool +) + +func AddMaintenanceModeFlag(cmd *cobra.Command) { + fs := cmd.Flags() + fs.StringVar( + &MaintenanceModeFV, + MaintenanceModeFN, + repository.CompleteMaintenance.String(), + "Type of maintenance operation to run. Pass '"+ + repository.MetadataMaintenance.String()+"' to run a faster maintenance "+ + "that does minimal clean-up and optimization. Pass '"+ + repository.CompleteMaintenance.String()+"' to fully compact existing "+ + "data and delete unused data.") + cobra.CheckErr(fs.MarkHidden(MaintenanceModeFN)) +} + +func AddForceMaintenanceFlag(cmd *cobra.Command) { + fs := cmd.Flags() + fs.BoolVar( + &ForceMaintenanceFV, + ForceMaintenanceFN, + false, + "Force maintenance. Caution: user must ensure this is not run concurrently on a single repo") + cobra.CheckErr(fs.MarkHidden(ForceMaintenanceFN)) +} diff --git a/src/cli/flags/onedrive.go b/src/cli/flags/onedrive.go new file mode 100644 index 000000000..62f69f0b7 --- /dev/null +++ b/src/cli/flags/onedrive.go @@ -0,0 +1,60 @@ +package flags + +import ( + "github.com/spf13/cobra" +) + +const ( + FileFN = "file" + FolderFN = "folder" + + FileCreatedAfterFN = "file-created-after" + FileCreatedBeforeFN = "file-created-before" + FileModifiedAfterFN = "file-modified-after" + FileModifiedBeforeFN = "file-modified-before" +) + +var ( + FolderPathFV []string + FileNameFV []string + + FileCreatedAfterFV string + FileCreatedBeforeFV string + FileModifiedAfterFV string + FileModifiedBeforeFV string +) + +// AddOneDriveDetailsAndRestoreFlags adds flags that are common to both the +// details and restore commands. +func AddOneDriveDetailsAndRestoreFlags(cmd *cobra.Command) { + fs := cmd.Flags() + + fs.StringSliceVar( + &FolderPathFV, + FolderFN, nil, + "Select files by OneDrive folder; defaults to root.") + + fs.StringSliceVar( + &FileNameFV, + FileFN, nil, + "Select files by name.") + + fs.StringVar( + &FileCreatedAfterFV, + FileCreatedAfterFN, "", + "Select files created after this datetime.") + fs.StringVar( + &FileCreatedBeforeFV, + FileCreatedBeforeFN, "", + "Select files created before this datetime.") + + fs.StringVar( + &FileModifiedAfterFV, + FileModifiedAfterFN, "", + "Select files modified after this datetime.") + + fs.StringVar( + &FileModifiedBeforeFV, + FileModifiedBeforeFN, "", + "Select files modified before this datetime.") +} diff --git a/src/cli/options/options.go b/src/cli/flags/options.go similarity index 66% rename from src/cli/options/options.go rename to src/cli/flags/options.go index ac76b41b8..046d3c8d7 100644 --- a/src/cli/options/options.go +++ b/src/cli/flags/options.go @@ -1,65 +1,59 @@ -package options +package flags import ( "github.com/spf13/cobra" - - "github.com/alcionai/corso/src/pkg/control" ) -// Control produces the control options based on the user's flags. -func Control() control.Options { - opt := control.Defaults() - - if failFastFV { - opt.FailureHandling = control.FailFast - } - - opt.DisableMetrics = noStatsFV - opt.RestorePermissions = restorePermissionsFV - opt.SkipReduce = skipReduceFV - opt.ToggleFeatures.DisableIncrementals = disableIncrementalsFV - opt.ToggleFeatures.DisableDelta = disableDeltaFV - opt.ToggleFeatures.ExchangeImmutableIDs = enableImmutableID - opt.ToggleFeatures.DisableConcurrencyLimiter = disableConcurrencyLimiterFV - opt.Parallelism.ItemFetch = fetchParallelismFV - - return opt -} - -// --------------------------------------------------------------------------- -// Operations Flags -// --------------------------------------------------------------------------- - const ( - FailFastFN = "fail-fast" - FetchParallelismFN = "fetch-parallelism" - NoStatsFN = "no-stats" - RestorePermissionsFN = "restore-permissions" - SkipReduceFN = "skip-reduce" + DisableConcurrencyLimiterFN = "disable-concurrency-limiter" DisableDeltaFN = "disable-delta" DisableIncrementalsFN = "disable-incrementals" EnableImmutableIDFN = "enable-immutable-id" - DisableConcurrencyLimiterFN = "disable-concurrency-limiter" + FailFastFN = "fail-fast" + FailedItemsFN = "failed-items" + FetchParallelismFN = "fetch-parallelism" + NoStatsFN = "no-stats" + RecoveredErrorsFN = "recovered-errors" + RestorePermissionsFN = "restore-permissions" + RunModeFN = "run-mode" + SkippedItemsFN = "skipped-items" + SkipReduceFN = "skip-reduce" ) var ( - failFastFV bool - fetchParallelismFV int - noStatsFV bool - restorePermissionsFV bool - skipReduceFV bool + DisableConcurrencyLimiterFV bool + DisableDeltaFV bool + DisableIncrementalsFV bool + EnableImmutableIDFV bool + FailFastFV bool + FetchParallelismFV int + ListFailedItemsFV string + ListSkippedItemsFV string + ListRecoveredErrorsFV string + NoStatsFV bool + // RunMode describes the type of run, such as: + // flagtest, dry, run. Should default to 'run'. + RunModeFV string + RestorePermissionsFV bool + SkipReduceFV bool +) + +// well-known flag values +const ( + RunModeFlagTest = "flag-test" + RunModeRun = "run" ) // AddGlobalOperationFlags adds the global operations flag set. func AddGlobalOperationFlags(cmd *cobra.Command) { fs := cmd.PersistentFlags() - fs.BoolVar(&noStatsFV, NoStatsFN, false, "disable anonymous usage statistics gathering") + fs.BoolVar(&NoStatsFV, NoStatsFN, false, "disable anonymous usage statistics gathering") } // AddFailFastFlag adds a flag to toggle fail-fast error handling behavior. func AddFailFastFlag(cmd *cobra.Command) { fs := cmd.Flags() - fs.BoolVar(&failFastFV, FailFastFN, false, "stop processing immediately if any error occurs") + fs.BoolVar(&FailFastFV, FailFastFN, false, "stop processing immediately if any error occurs") // TODO: reveal this flag when fail-fast support is implemented cobra.CheckErr(fs.MarkHidden(FailFastFN)) } @@ -67,14 +61,14 @@ func AddFailFastFlag(cmd *cobra.Command) { // AddRestorePermissionsFlag adds OneDrive flag for restoring permissions func AddRestorePermissionsFlag(cmd *cobra.Command) { fs := cmd.Flags() - fs.BoolVar(&restorePermissionsFV, RestorePermissionsFN, false, "Restore permissions for files and folders") + fs.BoolVar(&RestorePermissionsFV, RestorePermissionsFN, false, "Restore permissions for files and folders") } // AddSkipReduceFlag adds a hidden flag that allows callers to skip the selector // reduction step. Currently only intended for details commands, not restore. func AddSkipReduceFlag(cmd *cobra.Command) { fs := cmd.Flags() - fs.BoolVar(&skipReduceFV, SkipReduceFN, false, "Skip the selector reduce filtering") + fs.BoolVar(&SkipReduceFV, SkipReduceFN, false, "Skip the selector reduce filtering") cobra.CheckErr(fs.MarkHidden(SkipReduceFN)) } @@ -83,28 +77,19 @@ func AddSkipReduceFlag(cmd *cobra.Command) { func AddFetchParallelismFlag(cmd *cobra.Command) { fs := cmd.Flags() fs.IntVar( - &fetchParallelismFV, + &FetchParallelismFV, FetchParallelismFN, 4, "Control the number of concurrent data fetches for Exchange. Valid range is [1-4]. Default: 4") cobra.CheckErr(fs.MarkHidden(FetchParallelismFN)) } -// --------------------------------------------------------------------------- -// Feature Flags -// --------------------------------------------------------------------------- - -var ( - disableIncrementalsFV bool - disableDeltaFV bool -) - // Adds the hidden '--disable-incrementals' cli flag which, when set, disables // incremental backups. func AddDisableIncrementalsFlag(cmd *cobra.Command) { fs := cmd.Flags() fs.BoolVar( - &disableIncrementalsFV, + &DisableIncrementalsFV, DisableIncrementalsFN, false, "Disable incremental data retrieval in backups.") @@ -116,38 +101,45 @@ func AddDisableIncrementalsFlag(cmd *cobra.Command) { func AddDisableDeltaFlag(cmd *cobra.Command) { fs := cmd.Flags() fs.BoolVar( - &disableDeltaFV, + &DisableDeltaFV, DisableDeltaFN, false, "Disable delta based data retrieval in backups.") cobra.CheckErr(fs.MarkHidden(DisableDeltaFN)) } -var enableImmutableID bool - // Adds the hidden '--enable-immutable-id' cli flag which, when set, enables // immutable IDs for Exchange func AddEnableImmutableIDFlag(cmd *cobra.Command) { fs := cmd.Flags() fs.BoolVar( - &enableImmutableID, + &EnableImmutableIDFV, EnableImmutableIDFN, false, "Enable exchange immutable ID.") cobra.CheckErr(fs.MarkHidden(EnableImmutableIDFN)) } -var disableConcurrencyLimiterFV bool - // AddDisableConcurrencyLimiterFlag adds a hidden cli flag which, when set, // removes concurrency limits when communicating with graph API. This // flag is only relevant for exchange backups for now func AddDisableConcurrencyLimiterFlag(cmd *cobra.Command) { fs := cmd.Flags() fs.BoolVar( - &disableConcurrencyLimiterFV, + &DisableConcurrencyLimiterFV, DisableConcurrencyLimiterFN, false, "Disable concurrency limiter middleware. Default: false") cobra.CheckErr(fs.MarkHidden(DisableConcurrencyLimiterFN)) } + +// AddRunModeFlag adds the hidden --run-mode flag. +func AddRunModeFlag(cmd *cobra.Command, persistent bool) { + fs := cmd.Flags() + if persistent { + fs = cmd.PersistentFlags() + } + + fs.StringVar(&RunModeFV, RunModeFN, "run", "What mode to run: dry, test, run. Defaults to run.") + cobra.CheckErr(fs.MarkHidden(RunModeFN)) +} diff --git a/src/cli/flags/repo.go b/src/cli/flags/repo.go new file mode 100644 index 000000000..67bf6b0db --- /dev/null +++ b/src/cli/flags/repo.go @@ -0,0 +1,18 @@ +package flags + +import ( + "github.com/spf13/cobra" +) + +const BackupFN = "backup" + +var BackupIDFV string + +// AddBackupIDFlag adds the --backup flag. +func AddBackupIDFlag(cmd *cobra.Command, require bool) { + cmd.Flags().StringVar(&BackupIDFV, BackupFN, "", "ID of the backup to retrieve.") + + if require { + cobra.CheckErr(cmd.MarkFlagRequired(BackupFN)) + } +} diff --git a/src/cli/flags/sharepoint.go b/src/cli/flags/sharepoint.go new file mode 100644 index 000000000..31ba29bff --- /dev/null +++ b/src/cli/flags/sharepoint.go @@ -0,0 +1,113 @@ +package flags + +import ( + "github.com/spf13/cobra" +) + +const ( + LibraryFN = "library" + ListFolderFN = "list" + ListItemFN = "list-item" + PageFolderFN = "page-folder" + PageFN = "page" + SiteFN = "site" // site only accepts WebURL values + SiteIDFN = "site-id" // site-id accepts actual site ids +) + +var ( + LibraryFV string + ListFolderFV []string + ListItemFV []string + PageFolderFV []string + PageFV []string + SiteIDFV []string + WebURLFV []string +) + +// AddSharePointDetailsAndRestoreFlags adds flags that are common to both the +// details and restore commands. +func AddSharePointDetailsAndRestoreFlags(cmd *cobra.Command) { + fs := cmd.Flags() + + // libraries + + fs.StringVar( + &LibraryFV, + LibraryFN, "", + "Select only this library; defaults to all libraries.") + fs.StringSliceVar( + &FolderPathFV, + FolderFN, nil, + "Select by folder; defaults to root.") + fs.StringSliceVar( + &FileNameFV, + FileFN, nil, + "Select by file name.") + fs.StringVar( + &FileCreatedAfterFV, + FileCreatedAfterFN, "", + "Select files created after this datetime.") + fs.StringVar( + &FileCreatedBeforeFV, + FileCreatedBeforeFN, "", + "Select files created before this datetime.") + fs.StringVar( + &FileModifiedAfterFV, + FileModifiedAfterFN, "", + "Select files modified after this datetime.") + fs.StringVar( + &FileModifiedBeforeFV, + FileModifiedBeforeFN, "", + "Select files modified before this datetime.") + + // lists + + fs.StringSliceVar( + &ListFolderFV, + ListFolderFN, nil, + "Select lists by name; accepts '"+Wildcard+"' to select all lists.") + cobra.CheckErr(fs.MarkHidden(ListFolderFN)) + fs.StringSliceVar( + &ListItemFV, + ListItemFN, nil, + "Select lists by item name; accepts '"+Wildcard+"' to select all lists.") + cobra.CheckErr(fs.MarkHidden(ListItemFN)) + + // pages + + fs.StringSliceVar( + &PageFolderFV, + PageFolderFN, nil, + "Select pages by folder name; accepts '"+Wildcard+"' to select all pages.") + cobra.CheckErr(fs.MarkHidden(PageFolderFN)) + fs.StringSliceVar( + &PageFV, + PageFN, nil, + "Select pages by item name; accepts '"+Wildcard+"' to select all pages.") + cobra.CheckErr(fs.MarkHidden(PageFN)) +} + +// AddSiteIDFlag adds the --site-id flag, which accepts site ID values. +// This flag is hidden, since we expect users to prefer the --site url +// and do not want to encourage confusion. +func AddSiteIDFlag(cmd *cobra.Command) { + fs := cmd.Flags() + + // note string ARRAY var. IDs naturally contain commas, so we cannot accept + // duplicate values within a flag declaration. ie: --site-id a,b,c does not + // work. Users must call --site-id a --site-id b --site-id c. + fs.StringArrayVar( + &SiteIDFV, + SiteIDFN, nil, + //nolint:lll + "Backup data by site ID; accepts '"+Wildcard+"' to select all sites. Args cannot be comma-delimited and must use multiple flags.") + cobra.CheckErr(fs.MarkHidden(SiteIDFN)) +} + +// AddSiteFlag adds the --site flag, which accepts webURL values. +func AddSiteFlag(cmd *cobra.Command) { + cmd.Flags().StringSliceVar( + &WebURLFV, + SiteFN, nil, + "Backup data by site URL; accepts '"+Wildcard+"' to select all sites.") +} diff --git a/src/cli/options/options_test.go b/src/cli/options/options_test.go deleted file mode 100644 index 8538e3441..000000000 --- a/src/cli/options/options_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package options - -import ( - "testing" - - "github.com/alcionai/clues" - "github.com/spf13/cobra" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - - "github.com/alcionai/corso/src/internal/tester" -) - -type OptionsUnitSuite struct { - tester.Suite -} - -func TestOptionsUnitSuite(t *testing.T) { - suite.Run(t, &OptionsUnitSuite{Suite: tester.NewUnitSuite(t)}) -} - -func (suite *OptionsUnitSuite) TestAddExchangeCommands() { - t := suite.T() - - cmd := &cobra.Command{ - Use: "test", - Run: func(cmd *cobra.Command, args []string) { - assert.True(t, failFastFV, FailFastFN) - assert.True(t, disableIncrementalsFV, DisableIncrementalsFN) - assert.True(t, disableDeltaFV, DisableDeltaFN) - assert.True(t, noStatsFV, NoStatsFN) - assert.True(t, restorePermissionsFV, RestorePermissionsFN) - assert.True(t, skipReduceFV, SkipReduceFN) - assert.Equal(t, 2, fetchParallelismFV, FetchParallelismFN) - assert.True(t, disableConcurrencyLimiterFV, DisableConcurrencyLimiterFN) - }, - } - - // adds no-stats - AddGlobalOperationFlags(cmd) - - AddFailFastFlag(cmd) - AddDisableIncrementalsFlag(cmd) - AddDisableDeltaFlag(cmd) - AddRestorePermissionsFlag(cmd) - AddSkipReduceFlag(cmd) - AddFetchParallelismFlag(cmd) - AddDisableConcurrencyLimiterFlag(cmd) - - // Test arg parsing for few args - cmd.SetArgs([]string{ - "test", - "--" + FailFastFN, - "--" + DisableIncrementalsFN, - "--" + DisableDeltaFN, - "--" + NoStatsFN, - "--" + RestorePermissionsFN, - "--" + SkipReduceFN, - "--" + FetchParallelismFN, "2", - "--" + DisableConcurrencyLimiterFN, - }) - - err := cmd.Execute() - require.NoError(t, err, clues.ToCore(err)) -} diff --git a/src/cli/repo/repo.go b/src/cli/repo/repo.go index 6d36d1608..c6cba55be 100644 --- a/src/cli/repo/repo.go +++ b/src/cli/repo/repo.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" "golang.org/x/exp/maps" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/pkg/control/repository" @@ -42,8 +43,8 @@ func AddCommands(cmd *cobra.Command) { maintenanceCmd, utils.HideCommand(), utils.MarkPreReleaseCommand()) - utils.AddMaintenanceModeFlag(maintenanceCmd) - utils.AddForceMaintenanceFlag(maintenanceCmd) + flags.AddMaintenanceModeFlag(maintenanceCmd) + flags.AddForceMaintenanceFlag(maintenanceCmd) for _, addRepoTo := range repoCommands { addRepoTo(initCmd) @@ -116,7 +117,7 @@ func maintenanceCmd() *cobra.Command { func handleMaintenanceCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - t, err := getMaintenanceType(utils.MaintenanceModeFV) + t, err := getMaintenanceType(flags.MaintenanceModeFV) if err != nil { return err } @@ -133,7 +134,7 @@ func handleMaintenanceCmd(cmd *cobra.Command, args []string) error { repository.Maintenance{ Type: t, Safety: repository.FullMaintenanceSafety, - Force: utils.ForceMaintenanceFV, + Force: flags.ForceMaintenanceFV, }) if err != nil { return print.Only(ctx, err) diff --git a/src/cli/repo/s3.go b/src/cli/repo/s3.go index feba087a8..2480cf0fa 100644 --- a/src/cli/repo/s3.go +++ b/src/cli/repo/s3.go @@ -10,7 +10,6 @@ import ( "github.com/spf13/pflag" "github.com/alcionai/corso/src/cli/config" - "github.com/alcionai/corso/src/cli/options" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/events" @@ -124,7 +123,7 @@ func initS3Cmd(cmd *cobra.Command, args []string) error { cfg.Account.ID(), map[string]any{"command": "init repo"}, cfg.Account.ID(), - options.Control()) + utils.Control()) s3Cfg, err := cfg.Storage.S3Config() if err != nil { @@ -143,7 +142,7 @@ func initS3Cmd(cmd *cobra.Command, args []string) error { return Only(ctx, clues.Wrap(err, "Failed to parse m365 account config")) } - r, err := repository.Initialize(ctx, cfg.Account, cfg.Storage, options.Control()) + r, err := repository.Initialize(ctx, cfg.Account, cfg.Storage, utils.Control()) if err != nil { if succeedIfExists && errors.Is(err, repository.ErrorRepoAlreadyExists) { return nil @@ -214,7 +213,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { return Only(ctx, clues.New(invalidEndpointErr)) } - r, err := repository.ConnectAndSendConnectEvent(ctx, cfg.Account, cfg.Storage, repoID, options.Control()) + r, err := repository.ConnectAndSendConnectEvent(ctx, cfg.Account, cfg.Storage, repoID, utils.Control()) if err != nil { return Only(ctx, clues.Wrap(err, "Failed to connect to the S3 repository")) } diff --git a/src/cli/restore/exchange.go b/src/cli/restore/exchange.go index 514e6102c..be5b83dfc 100644 --- a/src/cli/restore/exchange.go +++ b/src/cli/restore/exchange.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" - "github.com/alcionai/corso/src/cli/options" + "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/dttm" @@ -32,9 +32,9 @@ func addExchangeCommands(cmd *cobra.Command) *cobra.Command { // general flags fs.SortFlags = false - utils.AddBackupIDFlag(c, true) - utils.AddExchangeDetailsAndRestoreFlags(c) - options.AddFailFastFlag(c) + flags.AddBackupIDFlag(c, true) + flags.AddExchangeDetailsAndRestoreFlags(c) + flags.AddFailFastFlag(c) } return c @@ -81,11 +81,11 @@ func restoreExchangeCmd(cmd *cobra.Command, args []string) error { opts := utils.MakeExchangeOpts(cmd) - if utils.RunModeFV == utils.RunModeFlagTest { + if flags.RunModeFV == flags.RunModeFlagTest { return nil } - if err := utils.ValidateExchangeRestoreFlags(utils.BackupIDFV, opts); err != nil { + if err := utils.ValidateExchangeRestoreFlags(flags.BackupIDFV, opts); err != nil { return err } @@ -102,7 +102,7 @@ func restoreExchangeCmd(cmd *cobra.Command, args []string) error { sel := utils.IncludeExchangeRestoreDataSelectors(opts) utils.FilterExchangeRestoreInfoSelectors(sel, opts) - ro, err := r.NewRestore(ctx, utils.BackupIDFV, sel.Selector, restoreCfg) + ro, err := r.NewRestore(ctx, flags.BackupIDFV, sel.Selector, restoreCfg) if err != nil { return Only(ctx, clues.Wrap(err, "Failed to initialize Exchange restore")) } @@ -110,7 +110,7 @@ func restoreExchangeCmd(cmd *cobra.Command, args []string) error { ds, err := ro.Run(ctx) if err != nil { if errors.Is(err, data.ErrNotFound) { - return Only(ctx, clues.New("Backup or backup details missing for id "+utils.BackupIDFV)) + return Only(ctx, clues.New("Backup or backup details missing for id "+flags.BackupIDFV)) } return Only(ctx, clues.Wrap(err, "Failed to run Exchange restore")) diff --git a/src/cli/restore/exchange_e2e_test.go b/src/cli/restore/exchange_e2e_test.go index 1f4f93601..5512dca4f 100644 --- a/src/cli/restore/exchange_e2e_test.go +++ b/src/cli/restore/exchange_e2e_test.go @@ -12,7 +12,7 @@ import ( "github.com/alcionai/corso/src/cli" "github.com/alcionai/corso/src/cli/config" - "github.com/alcionai/corso/src/cli/utils" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/m365/exchange" "github.com/alcionai/corso/src/internal/operations" @@ -135,7 +135,7 @@ func (suite *RestoreExchangeE2ESuite) TestExchangeRestoreCmd() { cmd := tester.StubRootCmd( "restore", "exchange", "--config-file", suite.cfgFP, - "--"+utils.BackupFN, string(suite.backupOps[set].Results.BackupID)) + "--"+flags.BackupFN, string(suite.backupOps[set].Results.BackupID)) cli.BuildCommandTree(cmd) // run the command @@ -162,15 +162,15 @@ func (suite *RestoreExchangeE2ESuite) TestExchangeRestoreCmd_badTimeFlags() { var timeFilter string switch set { case email: - timeFilter = "--" + utils.EmailReceivedAfterFN + timeFilter = "--" + flags.EmailReceivedAfterFN case events: - timeFilter = "--" + utils.EventStartsAfterFN + timeFilter = "--" + flags.EventStartsAfterFN } cmd := tester.StubRootCmd( "restore", "exchange", "--config-file", suite.cfgFP, - "--"+utils.BackupFN, string(suite.backupOps[set].Results.BackupID), + "--"+flags.BackupFN, string(suite.backupOps[set].Results.BackupID), timeFilter, "smarf") cli.BuildCommandTree(cmd) @@ -198,13 +198,13 @@ func (suite *RestoreExchangeE2ESuite) TestExchangeRestoreCmd_badBoolFlags() { var timeFilter string switch set { case events: - timeFilter = "--" + utils.EventRecursFN + timeFilter = "--" + flags.EventRecursFN } cmd := tester.StubRootCmd( "restore", "exchange", "--config-file", suite.cfgFP, - "--"+utils.BackupFN, string(suite.backupOps[set].Results.BackupID), + "--"+flags.BackupFN, string(suite.backupOps[set].Results.BackupID), timeFilter, "wingbat") cli.BuildCommandTree(cmd) diff --git a/src/cli/restore/exchange_test.go b/src/cli/restore/exchange_test.go index 4f5915c21..955df3267 100644 --- a/src/cli/restore/exchange_test.go +++ b/src/cli/restore/exchange_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/cli/utils/testdata" "github.com/alcionai/corso/src/internal/tester" @@ -43,7 +44,7 @@ func (suite *ExchangeUnitSuite) TestAddExchangeCommands() { // normally a persistent flag from the root. // required to ensure a dry run. - utils.AddRunModeFlag(cmd, true) + flags.AddRunModeFlag(cmd, true) c := addExchangeCommands(cmd) require.NotNil(t, c) @@ -59,27 +60,24 @@ func (suite *ExchangeUnitSuite) TestAddExchangeCommands() { // Test arg parsing for few args cmd.SetArgs([]string{ "exchange", - "--" + utils.RunModeFN, utils.RunModeFlagTest, - "--" + utils.BackupFN, testdata.BackupInput, - - "--" + utils.ContactFN, testdata.FlgInputs(testdata.ContactInput), - "--" + utils.ContactFolderFN, testdata.FlgInputs(testdata.ContactFldInput), - "--" + utils.ContactNameFN, testdata.ContactNameInput, - - "--" + utils.EmailFN, testdata.FlgInputs(testdata.EmailInput), - "--" + utils.EmailFolderFN, testdata.FlgInputs(testdata.EmailFldInput), - "--" + utils.EmailReceivedAfterFN, testdata.EmailReceivedAfterInput, - "--" + utils.EmailReceivedBeforeFN, testdata.EmailReceivedBeforeInput, - "--" + utils.EmailSenderFN, testdata.EmailSenderInput, - "--" + utils.EmailSubjectFN, testdata.EmailSubjectInput, - - "--" + utils.EventFN, testdata.FlgInputs(testdata.EventInput), - "--" + utils.EventCalendarFN, testdata.FlgInputs(testdata.EventCalInput), - "--" + utils.EventOrganizerFN, testdata.EventOrganizerInput, - "--" + utils.EventRecursFN, testdata.EventRecursInput, - "--" + utils.EventStartsAfterFN, testdata.EventStartsAfterInput, - "--" + utils.EventStartsBeforeFN, testdata.EventStartsBeforeInput, - "--" + utils.EventSubjectFN, testdata.EventSubjectInput, + "--" + flags.RunModeFN, flags.RunModeFlagTest, + "--" + flags.BackupFN, testdata.BackupInput, + "--" + flags.ContactFN, testdata.FlgInputs(testdata.ContactInput), + "--" + flags.ContactFolderFN, testdata.FlgInputs(testdata.ContactFldInput), + "--" + flags.ContactNameFN, testdata.ContactNameInput, + "--" + flags.EmailFN, testdata.FlgInputs(testdata.EmailInput), + "--" + flags.EmailFolderFN, testdata.FlgInputs(testdata.EmailFldInput), + "--" + flags.EmailReceivedAfterFN, testdata.EmailReceivedAfterInput, + "--" + flags.EmailReceivedBeforeFN, testdata.EmailReceivedBeforeInput, + "--" + flags.EmailSenderFN, testdata.EmailSenderInput, + "--" + flags.EmailSubjectFN, testdata.EmailSubjectInput, + "--" + flags.EventFN, testdata.FlgInputs(testdata.EventInput), + "--" + flags.EventCalendarFN, testdata.FlgInputs(testdata.EventCalInput), + "--" + flags.EventOrganizerFN, testdata.EventOrganizerInput, + "--" + flags.EventRecursFN, testdata.EventRecursInput, + "--" + flags.EventStartsAfterFN, testdata.EventStartsAfterInput, + "--" + flags.EventStartsBeforeFN, testdata.EventStartsBeforeInput, + "--" + flags.EventSubjectFN, testdata.EventSubjectInput, }) cmd.SetOut(new(bytes.Buffer)) // drop output @@ -88,7 +86,7 @@ func (suite *ExchangeUnitSuite) TestAddExchangeCommands() { assert.NoError(t, err, clues.ToCore(err)) opts := utils.MakeExchangeOpts(cmd) - assert.Equal(t, testdata.BackupInput, utils.BackupIDFV) + assert.Equal(t, testdata.BackupInput, flags.BackupIDFV) assert.ElementsMatch(t, testdata.ContactInput, opts.Contact) assert.ElementsMatch(t, testdata.ContactFldInput, opts.ContactFolder) diff --git a/src/cli/restore/onedrive.go b/src/cli/restore/onedrive.go index 85b159370..ad3ac36d0 100644 --- a/src/cli/restore/onedrive.go +++ b/src/cli/restore/onedrive.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" - "github.com/alcionai/corso/src/cli/options" + "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/dttm" @@ -31,12 +31,10 @@ func addOneDriveCommands(cmd *cobra.Command) *cobra.Command { // More generic (ex: --user) and more frequently used flags take precedence. fs.SortFlags = false - utils.AddBackupIDFlag(c, true) - utils.AddOneDriveDetailsAndRestoreFlags(c) - - // restore permissions - options.AddRestorePermissionsFlag(c) - options.AddFailFastFlag(c) + flags.AddBackupIDFlag(c, true) + flags.AddOneDriveDetailsAndRestoreFlags(c) + flags.AddRestorePermissionsFlag(c) + flags.AddFailFastFlag(c) } return c @@ -82,11 +80,11 @@ func restoreOneDriveCmd(cmd *cobra.Command, args []string) error { opts := utils.MakeOneDriveOpts(cmd) - if utils.RunModeFV == utils.RunModeFlagTest { + if flags.RunModeFV == flags.RunModeFlagTest { return nil } - if err := utils.ValidateOneDriveRestoreFlags(utils.BackupIDFV, opts); err != nil { + if err := utils.ValidateOneDriveRestoreFlags(flags.BackupIDFV, opts); err != nil { return err } @@ -103,7 +101,7 @@ func restoreOneDriveCmd(cmd *cobra.Command, args []string) error { sel := utils.IncludeOneDriveRestoreDataSelectors(opts) utils.FilterOneDriveRestoreInfoSelectors(sel, opts) - ro, err := r.NewRestore(ctx, utils.BackupIDFV, sel.Selector, restoreCfg) + ro, err := r.NewRestore(ctx, flags.BackupIDFV, sel.Selector, restoreCfg) if err != nil { return Only(ctx, clues.Wrap(err, "Failed to initialize OneDrive restore")) } @@ -111,7 +109,7 @@ func restoreOneDriveCmd(cmd *cobra.Command, args []string) error { ds, err := ro.Run(ctx) if err != nil { if errors.Is(err, data.ErrNotFound) { - return Only(ctx, clues.New("Backup or backup details missing for id "+utils.BackupIDFV)) + return Only(ctx, clues.New("Backup or backup details missing for id "+flags.BackupIDFV)) } return Only(ctx, clues.Wrap(err, "Failed to run OneDrive restore")) diff --git a/src/cli/restore/onedrive_test.go b/src/cli/restore/onedrive_test.go index cf119c328..922698c55 100644 --- a/src/cli/restore/onedrive_test.go +++ b/src/cli/restore/onedrive_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/cli/utils/testdata" "github.com/alcionai/corso/src/internal/tester" @@ -43,7 +44,7 @@ func (suite *OneDriveUnitSuite) TestAddOneDriveCommands() { // normally a persistent flag from the root. // required to ensure a dry run. - utils.AddRunModeFlag(cmd, true) + flags.AddRunModeFlag(cmd, true) c := addOneDriveCommands(cmd) require.NotNil(t, c) @@ -58,15 +59,14 @@ func (suite *OneDriveUnitSuite) TestAddOneDriveCommands() { cmd.SetArgs([]string{ "onedrive", - "--" + utils.RunModeFN, utils.RunModeFlagTest, - "--" + utils.BackupFN, testdata.BackupInput, - - "--" + utils.FileFN, testdata.FlgInputs(testdata.FileNameInput), - "--" + utils.FolderFN, testdata.FlgInputs(testdata.FolderPathInput), - "--" + utils.FileCreatedAfterFN, testdata.FileCreatedAfterInput, - "--" + utils.FileCreatedBeforeFN, testdata.FileCreatedBeforeInput, - "--" + utils.FileModifiedAfterFN, testdata.FileModifiedAfterInput, - "--" + utils.FileModifiedBeforeFN, testdata.FileModifiedBeforeInput, + "--" + flags.RunModeFN, flags.RunModeFlagTest, + "--" + flags.BackupFN, testdata.BackupInput, + "--" + flags.FileFN, testdata.FlgInputs(testdata.FileNameInput), + "--" + flags.FolderFN, testdata.FlgInputs(testdata.FolderPathInput), + "--" + flags.FileCreatedAfterFN, testdata.FileCreatedAfterInput, + "--" + flags.FileCreatedBeforeFN, testdata.FileCreatedBeforeInput, + "--" + flags.FileModifiedAfterFN, testdata.FileModifiedAfterInput, + "--" + flags.FileModifiedBeforeFN, testdata.FileModifiedBeforeInput, }) cmd.SetOut(new(bytes.Buffer)) // drop output @@ -75,7 +75,7 @@ func (suite *OneDriveUnitSuite) TestAddOneDriveCommands() { assert.NoError(t, err, clues.ToCore(err)) opts := utils.MakeOneDriveOpts(cmd) - assert.Equal(t, testdata.BackupInput, utils.BackupIDFV) + assert.Equal(t, testdata.BackupInput, flags.BackupIDFV) assert.ElementsMatch(t, testdata.FileNameInput, opts.FileName) assert.ElementsMatch(t, testdata.FolderPathInput, opts.FolderPath) diff --git a/src/cli/restore/sharepoint.go b/src/cli/restore/sharepoint.go index a52f5bb2a..8ab849996 100644 --- a/src/cli/restore/sharepoint.go +++ b/src/cli/restore/sharepoint.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" - "github.com/alcionai/corso/src/cli/options" + "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/dttm" @@ -31,11 +31,10 @@ func addSharePointCommands(cmd *cobra.Command) *cobra.Command { // More generic (ex: --site) and more frequently used flags take precedence. fs.SortFlags = false - utils.AddBackupIDFlag(c, true) - utils.AddSharePointDetailsAndRestoreFlags(c) - - options.AddRestorePermissionsFlag(c) - options.AddFailFastFlag(c) + flags.AddBackupIDFlag(c, true) + flags.AddSharePointDetailsAndRestoreFlags(c) + flags.AddRestorePermissionsFlag(c) + flags.AddFailFastFlag(c) } return c @@ -87,11 +86,11 @@ func restoreSharePointCmd(cmd *cobra.Command, args []string) error { opts := utils.MakeSharePointOpts(cmd) - if utils.RunModeFV == utils.RunModeFlagTest { + if flags.RunModeFV == flags.RunModeFlagTest { return nil } - if err := utils.ValidateSharePointRestoreFlags(utils.BackupIDFV, opts); err != nil { + if err := utils.ValidateSharePointRestoreFlags(flags.BackupIDFV, opts); err != nil { return err } @@ -108,7 +107,7 @@ func restoreSharePointCmd(cmd *cobra.Command, args []string) error { sel := utils.IncludeSharePointRestoreDataSelectors(ctx, opts) utils.FilterSharePointRestoreInfoSelectors(sel, opts) - ro, err := r.NewRestore(ctx, utils.BackupIDFV, sel.Selector, restoreCfg) + ro, err := r.NewRestore(ctx, flags.BackupIDFV, sel.Selector, restoreCfg) if err != nil { return Only(ctx, clues.Wrap(err, "Failed to initialize SharePoint restore")) } @@ -116,7 +115,7 @@ func restoreSharePointCmd(cmd *cobra.Command, args []string) error { ds, err := ro.Run(ctx) if err != nil { if errors.Is(err, data.ErrNotFound) { - return Only(ctx, clues.New("Backup or backup details missing for id "+utils.BackupIDFV)) + return Only(ctx, clues.New("Backup or backup details missing for id "+flags.BackupIDFV)) } return Only(ctx, clues.Wrap(err, "Failed to run SharePoint restore")) diff --git a/src/cli/restore/sharepoint_test.go b/src/cli/restore/sharepoint_test.go index ce7e3c73d..09b056975 100644 --- a/src/cli/restore/sharepoint_test.go +++ b/src/cli/restore/sharepoint_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/cli/utils/testdata" "github.com/alcionai/corso/src/internal/tester" @@ -43,7 +44,7 @@ func (suite *SharePointUnitSuite) TestAddSharePointCommands() { // normally a persistent flag from the root. // required to ensure a dry run. - utils.AddRunModeFlag(cmd, true) + flags.AddRunModeFlag(cmd, true) c := addSharePointCommands(cmd) require.NotNil(t, c) @@ -58,22 +59,19 @@ func (suite *SharePointUnitSuite) TestAddSharePointCommands() { cmd.SetArgs([]string{ "sharepoint", - "--" + utils.RunModeFN, utils.RunModeFlagTest, - "--" + utils.BackupFN, testdata.BackupInput, - - "--" + utils.LibraryFN, testdata.LibraryInput, - "--" + utils.FileFN, testdata.FlgInputs(testdata.FileNameInput), - "--" + utils.FolderFN, testdata.FlgInputs(testdata.FolderPathInput), - "--" + utils.FileCreatedAfterFN, testdata.FileCreatedAfterInput, - "--" + utils.FileCreatedBeforeFN, testdata.FileCreatedBeforeInput, - "--" + utils.FileModifiedAfterFN, testdata.FileModifiedAfterInput, - "--" + utils.FileModifiedBeforeFN, testdata.FileModifiedBeforeInput, - - "--" + utils.ListItemFN, testdata.FlgInputs(testdata.ListItemInput), - "--" + utils.ListFolderFN, testdata.FlgInputs(testdata.ListFolderInput), - - "--" + utils.PageFN, testdata.FlgInputs(testdata.PageInput), - "--" + utils.PageFolderFN, testdata.FlgInputs(testdata.PageFolderInput), + "--" + flags.RunModeFN, flags.RunModeFlagTest, + "--" + flags.BackupFN, testdata.BackupInput, + "--" + flags.LibraryFN, testdata.LibraryInput, + "--" + flags.FileFN, testdata.FlgInputs(testdata.FileNameInput), + "--" + flags.FolderFN, testdata.FlgInputs(testdata.FolderPathInput), + "--" + flags.FileCreatedAfterFN, testdata.FileCreatedAfterInput, + "--" + flags.FileCreatedBeforeFN, testdata.FileCreatedBeforeInput, + "--" + flags.FileModifiedAfterFN, testdata.FileModifiedAfterInput, + "--" + flags.FileModifiedBeforeFN, testdata.FileModifiedBeforeInput, + "--" + flags.ListItemFN, testdata.FlgInputs(testdata.ListItemInput), + "--" + flags.ListFolderFN, testdata.FlgInputs(testdata.ListFolderInput), + "--" + flags.PageFN, testdata.FlgInputs(testdata.PageInput), + "--" + flags.PageFolderFN, testdata.FlgInputs(testdata.PageFolderInput), }) cmd.SetOut(new(bytes.Buffer)) // drop output @@ -82,7 +80,7 @@ func (suite *SharePointUnitSuite) TestAddSharePointCommands() { assert.NoError(t, err, clues.ToCore(err)) opts := utils.MakeSharePointOpts(cmd) - assert.Equal(t, testdata.BackupInput, utils.BackupIDFV) + assert.Equal(t, testdata.BackupInput, flags.BackupIDFV) assert.Equal(t, testdata.LibraryInput, opts.Library) assert.ElementsMatch(t, testdata.FileNameInput, opts.FileName) diff --git a/src/cli/utils/exchange.go b/src/cli/utils/exchange.go index d167c710e..7051d5904 100644 --- a/src/cli/utils/exchange.go +++ b/src/cli/utils/exchange.go @@ -4,53 +4,10 @@ import ( "github.com/alcionai/clues" "github.com/spf13/cobra" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/pkg/selectors" ) -// flag names (id: FN) -const ( - ContactFN = "contact" - ContactFolderFN = "contact-folder" - ContactNameFN = "contact-name" - - EmailFN = "email" - EmailFolderFN = "email-folder" - EmailReceivedAfterFN = "email-received-after" - EmailReceivedBeforeFN = "email-received-before" - EmailSenderFN = "email-sender" - EmailSubjectFN = "email-subject" - - EventFN = "event" - EventCalendarFN = "event-calendar" - EventOrganizerFN = "event-organizer" - EventRecursFN = "event-recurs" - EventStartsAfterFN = "event-starts-after" - EventStartsBeforeFN = "event-starts-before" - EventSubjectFN = "event-subject" -) - -// flag values (ie: FV) -var ( - ContactFV []string - ContactFolderFV []string - ContactNameFV string - - EmailFV []string - EmailFolderFV []string - EmailReceivedAfterFV string - EmailReceivedBeforeFV string - EmailSenderFV string - EmailSubjectFV string - - EventFV []string - EventCalendarFV []string - EventOrganizerFV string - EventRecursFV string - EventStartsAfterFV string - EventStartsBeforeFV string - EventSubjectFV string -) - type ExchangeOpts struct { Users []string @@ -73,113 +30,37 @@ type ExchangeOpts struct { EventStartsBefore string EventSubject string - Populated PopulatedFlags + Populated flags.PopulatedFlags } // populates an ExchangeOpts struct with the command's current flags. func MakeExchangeOpts(cmd *cobra.Command) ExchangeOpts { return ExchangeOpts{ - Users: UserFV, + Users: flags.UserFV, - Contact: ContactFV, - ContactFolder: ContactFolderFV, - ContactName: ContactNameFV, + Contact: flags.ContactFV, + ContactFolder: flags.ContactFolderFV, + ContactName: flags.ContactNameFV, - Email: EmailFV, - EmailFolder: EmailFolderFV, - EmailReceivedAfter: EmailReceivedAfterFV, - EmailReceivedBefore: EmailReceivedBeforeFV, - EmailSender: EmailSenderFV, - EmailSubject: EmailSubjectFV, + Email: flags.EmailFV, + EmailFolder: flags.EmailFolderFV, + EmailReceivedAfter: flags.EmailReceivedAfterFV, + EmailReceivedBefore: flags.EmailReceivedBeforeFV, + EmailSender: flags.EmailSenderFV, + EmailSubject: flags.EmailSubjectFV, - Event: EventFV, - EventCalendar: EventCalendarFV, - EventOrganizer: EventOrganizerFV, - EventRecurs: EventRecursFV, - EventStartsAfter: EventStartsAfterFV, - EventStartsBefore: EventStartsBeforeFV, - EventSubject: EventSubjectFV, + Event: flags.EventFV, + EventCalendar: flags.EventCalendarFV, + EventOrganizer: flags.EventOrganizerFV, + EventRecurs: flags.EventRecursFV, + EventStartsAfter: flags.EventStartsAfterFV, + EventStartsBefore: flags.EventStartsBeforeFV, + EventSubject: flags.EventSubjectFV, - Populated: GetPopulatedFlags(cmd), + Populated: flags.GetPopulatedFlags(cmd), } } -// AddExchangeDetailsAndRestoreFlags adds flags that are common to both the -// details and restore commands. -func AddExchangeDetailsAndRestoreFlags(cmd *cobra.Command) { - fs := cmd.Flags() - - // email flags - fs.StringSliceVar( - &EmailFV, - EmailFN, nil, - "Select email messages by ID; accepts '"+Wildcard+"' to select all emails.") - fs.StringSliceVar( - &EmailFolderFV, - EmailFolderFN, nil, - "Select emails within a folder; accepts '"+Wildcard+"' to select all email folders.") - fs.StringVar( - &EmailSubjectFV, - EmailSubjectFN, "", - "Select emails with a subject containing this value.") - fs.StringVar( - &EmailSenderFV, - EmailSenderFN, "", - "Select emails from a specific sender.") - fs.StringVar( - &EmailReceivedAfterFV, - EmailReceivedAfterFN, "", - "Select emails received after this datetime.") - fs.StringVar( - &EmailReceivedBeforeFV, - EmailReceivedBeforeFN, "", - "Select emails received before this datetime.") - - // event flags - fs.StringSliceVar( - &EventFV, - EventFN, nil, - "Select events by event ID; accepts '"+Wildcard+"' to select all events.") - fs.StringSliceVar( - &EventCalendarFV, - EventCalendarFN, nil, - "Select events under a calendar; accepts '"+Wildcard+"' to select all events.") - fs.StringVar( - &EventSubjectFV, - EventSubjectFN, "", - "Select events with a subject containing this value.") - fs.StringVar( - &EventOrganizerFV, - EventOrganizerFN, "", - "Select events from a specific organizer.") - fs.StringVar( - &EventRecursFV, - EventRecursFN, "", - "Select recurring events. Use `--event-recurs false` to select non-recurring events.") - fs.StringVar( - &EventStartsAfterFV, - EventStartsAfterFN, "", - "Select events starting after this datetime.") - fs.StringVar( - &EventStartsBeforeFV, - EventStartsBeforeFN, "", - "Select events starting before this datetime.") - - // contact flags - fs.StringSliceVar( - &ContactFV, - ContactFN, nil, - "Select contacts by contact ID; accepts '"+Wildcard+"' to select all contacts.") - fs.StringSliceVar( - &ContactFolderFV, - ContactFolderFN, nil, - "Select contacts within a folder; accepts '"+Wildcard+"' to select all contact folders.") - fs.StringVar( - &ContactNameFV, - ContactNameFN, "", - "Select contacts whose contact name contains this value.") -} - // AddExchangeInclude adds the scope of the provided values to the selector's // inclusion set. Any unpopulated slice will be replaced with selectors.Any() // to act as a wildcard. @@ -231,23 +112,23 @@ func ValidateExchangeRestoreFlags(backupID string, opts ExchangeOpts) error { return clues.New("a backup ID is required") } - if _, ok := opts.Populated[EmailReceivedAfterFN]; ok && !IsValidTimeFormat(opts.EmailReceivedAfter) { + if _, ok := opts.Populated[flags.EmailReceivedAfterFN]; ok && !IsValidTimeFormat(opts.EmailReceivedAfter) { return clues.New("invalid time format for email-received-after") } - if _, ok := opts.Populated[EmailReceivedBeforeFN]; ok && !IsValidTimeFormat(opts.EmailReceivedBefore) { + if _, ok := opts.Populated[flags.EmailReceivedBeforeFN]; ok && !IsValidTimeFormat(opts.EmailReceivedBefore) { return clues.New("invalid time format for email-received-before") } - if _, ok := opts.Populated[EventStartsAfterFN]; ok && !IsValidTimeFormat(opts.EventStartsAfter) { + if _, ok := opts.Populated[flags.EventStartsAfterFN]; ok && !IsValidTimeFormat(opts.EventStartsAfter) { return clues.New("invalid time format for event-starts-after") } - if _, ok := opts.Populated[EventStartsBeforeFN]; ok && !IsValidTimeFormat(opts.EventStartsBefore) { + if _, ok := opts.Populated[flags.EventStartsBeforeFN]; ok && !IsValidTimeFormat(opts.EventStartsBefore) { return clues.New("invalid time format for event-starts-before") } - if _, ok := opts.Populated[EventRecursFN]; ok && !IsValidBool(opts.EventRecurs) { + if _, ok := opts.Populated[flags.EventRecursFN]; ok && !IsValidBool(opts.EventRecurs) { return clues.New("invalid format for event-recurs") } diff --git a/src/cli/utils/exchange_test.go b/src/cli/utils/exchange_test.go index c61e8da77..662532743 100644 --- a/src/cli/utils/exchange_test.go +++ b/src/cli/utils/exchange_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/dttm" "github.com/alcionai/corso/src/internal/tester" @@ -62,7 +63,7 @@ func (suite *ExchangeUtilsSuite) TestValidateRestoreFlags() { func (suite *ExchangeUtilsSuite) TestIncludeExchangeRestoreDataSelectors() { stub := []string{"id-stub"} many := []string{"fnord", "smarf"} - a := []string{utils.Wildcard} + a := []string{flags.Wildcard} table := []struct { name string diff --git a/src/cli/utils/flags.go b/src/cli/utils/flags.go index ab1503034..66d3e4fcd 100644 --- a/src/cli/utils/flags.go +++ b/src/cli/utils/flags.go @@ -1,233 +1,13 @@ package utils import ( - "fmt" "strconv" - "strings" - - "github.com/spf13/cobra" - "github.com/spf13/pflag" "github.com/alcionai/corso/src/internal/common/dttm" - "github.com/alcionai/corso/src/pkg/control/repository" "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/selectors" ) -// common flag vars (eg: FV) -var ( - // RunMode describes the type of run, such as: - // flagtest, dry, run. Should default to 'run'. - RunModeFV string - - BackupIDFV string - - FolderPathFV []string - FileNameFV []string - - FileCreatedAfterFV string - FileCreatedBeforeFV string - FileModifiedAfterFV string - FileModifiedBeforeFV string - - LibraryFV string - SiteIDFV []string - WebURLFV []string - - UserFV []string - - // for selection of data by category. eg: `--data email,contacts` - CategoryDataFV []string - - MaintenanceModeFV string - ForceMaintenanceFV bool -) - -// common flag names (eg: FN) -const ( - RunModeFN = "run-mode" - - BackupFN = "backup" - CategoryDataFN = "data" - - SiteFN = "site" // site only accepts WebURL values - SiteIDFN = "site-id" // site-id accepts actual site ids - UserFN = "user" - MailBoxFN = "mailbox" - - LibraryFN = "library" - FileFN = "file" - FolderFN = "folder" - - FileCreatedAfterFN = "file-created-after" - FileCreatedBeforeFN = "file-created-before" - FileModifiedAfterFN = "file-modified-after" - FileModifiedBeforeFN = "file-modified-before" - - // Maintenance stuff. - MaintenanceModeFN = "mode" - ForceMaintenanceFN = "force" -) - -// well-known flag values -const ( - RunModeFlagTest = "flag-test" - RunModeRun = "run" -) - -// AddBackupIDFlag adds the --backup flag. -func AddBackupIDFlag(cmd *cobra.Command, require bool) { - cmd.Flags().StringVar(&BackupIDFV, BackupFN, "", "ID of the backup to retrieve.") - - if require { - cobra.CheckErr(cmd.MarkFlagRequired(BackupFN)) - } -} - -func AddDataFlag(cmd *cobra.Command, allowed []string, hide bool) { - var ( - allowedMsg string - fs = cmd.Flags() - ) - - switch len(allowed) { - case 0: - return - case 1: - allowedMsg = allowed[0] - case 2: - allowedMsg = fmt.Sprintf("%s or %s", allowed[0], allowed[1]) - default: - allowedMsg = fmt.Sprintf( - "%s or %s", - strings.Join(allowed[:len(allowed)-1], ", "), - allowed[len(allowed)-1]) - } - - fs.StringSliceVar( - &CategoryDataFV, - CategoryDataFN, nil, - "Select one or more types of data to backup: "+allowedMsg+".") - - if hide { - cobra.CheckErr(fs.MarkHidden(CategoryDataFN)) - } -} - -// AddRunModeFlag adds the hidden --run-mode flag. -func AddRunModeFlag(cmd *cobra.Command, persistent bool) { - fs := cmd.Flags() - if persistent { - fs = cmd.PersistentFlags() - } - - fs.StringVar(&RunModeFV, RunModeFN, "run", "What mode to run: dry, test, run. Defaults to run.") - cobra.CheckErr(fs.MarkHidden(RunModeFN)) -} - -// AddUserFlag adds the --user flag. -func AddUserFlag(cmd *cobra.Command) { - cmd.Flags().StringSliceVar( - &UserFV, - UserFN, nil, - "Backup a specific user's data; accepts '"+Wildcard+"' to select all users.") - cobra.CheckErr(cmd.MarkFlagRequired(UserFN)) -} - -// AddMailBoxFlag adds the --user and --mailbox flag. -func AddMailBoxFlag(cmd *cobra.Command) { - flags := cmd.Flags() - - flags.StringSliceVar( - &UserFV, - UserFN, nil, - "Backup a specific user's data; accepts '"+Wildcard+"' to select all users.") - - cobra.CheckErr(flags.MarkDeprecated(UserFN, fmt.Sprintf("use --%s instead", MailBoxFN))) - - flags.StringSliceVar( - &UserFV, - MailBoxFN, nil, - "Backup a specific mailbox's data; accepts '"+Wildcard+"' to select all mailbox.") -} - -// AddSiteIDFlag adds the --site-id flag, which accepts site ID values. -// This flag is hidden, since we expect users to prefer the --site url -// and do not want to encourage confusion. -func AddSiteIDFlag(cmd *cobra.Command) { - fs := cmd.Flags() - - // note string ARRAY var. IDs naturally contain commas, so we cannot accept - // duplicate values within a flag declaration. ie: --site-id a,b,c does not - // work. Users must call --site-id a --site-id b --site-id c. - fs.StringArrayVar( - &SiteIDFV, - SiteIDFN, nil, - //nolint:lll - "Backup data by site ID; accepts '"+Wildcard+"' to select all sites. Args cannot be comma-delimited and must use multiple flags.") - cobra.CheckErr(fs.MarkHidden(SiteIDFN)) -} - -// AddSiteFlag adds the --site flag, which accepts webURL values. -func AddSiteFlag(cmd *cobra.Command) { - cmd.Flags().StringSliceVar( - &WebURLFV, - SiteFN, nil, - "Backup data by site URL; accepts '"+Wildcard+"' to select all sites.") -} - -func AddMaintenanceModeFlag(cmd *cobra.Command) { - fs := cmd.Flags() - fs.StringVar( - &MaintenanceModeFV, - MaintenanceModeFN, - repository.CompleteMaintenance.String(), - "Type of maintenance operation to run. Pass '"+ - repository.MetadataMaintenance.String()+"' to run a faster maintenance "+ - "that does minimal clean-up and optimization. Pass '"+ - repository.CompleteMaintenance.String()+"' to fully compact existing "+ - "data and delete unused data.") - cobra.CheckErr(fs.MarkHidden(MaintenanceModeFN)) -} - -func AddForceMaintenanceFlag(cmd *cobra.Command) { - fs := cmd.Flags() - fs.BoolVar( - &ForceMaintenanceFV, - ForceMaintenanceFN, - false, - "Force maintenance. Caution: user must ensure this is not run concurrently on a single repo") - cobra.CheckErr(fs.MarkHidden(ForceMaintenanceFN)) -} - -type PopulatedFlags map[string]struct{} - -func (fs PopulatedFlags) populate(pf *pflag.Flag) { - if pf == nil { - return - } - - if pf.Changed { - fs[pf.Name] = struct{}{} - } -} - -// GetPopulatedFlags returns a map of flags that have been -// populated by the user. Entry keys match the flag's long -// name. Values are empty. -func GetPopulatedFlags(cmd *cobra.Command) PopulatedFlags { - pop := PopulatedFlags{} - - fs := cmd.Flags() - if fs == nil { - return pop - } - - fs.VisitAll(pop.populate) - - return pop -} - // IsValidTimeFormat returns true if the input is recognized as a // supported format by the common time parser. func IsValidTimeFormat(in string) bool { diff --git a/src/cli/utils/onedrive.go b/src/cli/utils/onedrive.go index ff75951c3..16c9c8d8f 100644 --- a/src/cli/utils/onedrive.go +++ b/src/cli/utils/onedrive.go @@ -4,6 +4,7 @@ import ( "github.com/alcionai/clues" "github.com/spf13/cobra" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/pkg/selectors" ) @@ -17,78 +18,43 @@ type OneDriveOpts struct { FileModifiedAfter string FileModifiedBefore string - Populated PopulatedFlags + Populated flags.PopulatedFlags } func MakeOneDriveOpts(cmd *cobra.Command) OneDriveOpts { return OneDriveOpts{ - Users: UserFV, + Users: flags.UserFV, - FileName: FileNameFV, - FolderPath: FolderPathFV, - FileCreatedAfter: FileCreatedAfterFV, - FileCreatedBefore: FileCreatedBeforeFV, - FileModifiedAfter: FileModifiedAfterFV, - FileModifiedBefore: FileModifiedBeforeFV, + FileName: flags.FileNameFV, + FolderPath: flags.FolderPathFV, + FileCreatedAfter: flags.FileCreatedAfterFV, + FileCreatedBefore: flags.FileCreatedBeforeFV, + FileModifiedAfter: flags.FileModifiedAfterFV, + FileModifiedBefore: flags.FileModifiedBeforeFV, - Populated: GetPopulatedFlags(cmd), + Populated: flags.GetPopulatedFlags(cmd), } } -// AddOneDriveDetailsAndRestoreFlags adds flags that are common to both the -// details and restore commands. -func AddOneDriveDetailsAndRestoreFlags(cmd *cobra.Command) { - fs := cmd.Flags() - - fs.StringSliceVar( - &FolderPathFV, - FolderFN, nil, - "Select files by OneDrive folder; defaults to root.") - - fs.StringSliceVar( - &FileNameFV, - FileFN, nil, - "Select files by name.") - - fs.StringVar( - &FileCreatedAfterFV, - FileCreatedAfterFN, "", - "Select files created after this datetime.") - fs.StringVar( - &FileCreatedBeforeFV, - FileCreatedBeforeFN, "", - "Select files created before this datetime.") - - fs.StringVar( - &FileModifiedAfterFV, - FileModifiedAfterFN, "", - "Select files modified after this datetime.") - - fs.StringVar( - &FileModifiedBeforeFV, - FileModifiedBeforeFN, "", - "Select files modified before this datetime.") -} - // ValidateOneDriveRestoreFlags checks common flags for correctness and interdependencies func ValidateOneDriveRestoreFlags(backupID string, opts OneDriveOpts) error { if len(backupID) == 0 { return clues.New("a backup ID is required") } - if _, ok := opts.Populated[FileCreatedAfterFN]; ok && !IsValidTimeFormat(opts.FileCreatedAfter) { + if _, ok := opts.Populated[flags.FileCreatedAfterFN]; ok && !IsValidTimeFormat(opts.FileCreatedAfter) { return clues.New("invalid time format for created-after") } - if _, ok := opts.Populated[FileCreatedBeforeFN]; ok && !IsValidTimeFormat(opts.FileCreatedBefore) { + if _, ok := opts.Populated[flags.FileCreatedBeforeFN]; ok && !IsValidTimeFormat(opts.FileCreatedBefore) { return clues.New("invalid time format for created-before") } - if _, ok := opts.Populated[FileModifiedAfterFN]; ok && !IsValidTimeFormat(opts.FileModifiedAfter) { + if _, ok := opts.Populated[flags.FileModifiedAfterFN]; ok && !IsValidTimeFormat(opts.FileModifiedAfter) { return clues.New("invalid time format for modified-after") } - if _, ok := opts.Populated[FileModifiedBeforeFN]; ok && !IsValidTimeFormat(opts.FileModifiedBefore) { + if _, ok := opts.Populated[flags.FileModifiedBeforeFN]; ok && !IsValidTimeFormat(opts.FileModifiedBefore) { return clues.New("invalid time format for modified-before") } diff --git a/src/cli/utils/options.go b/src/cli/utils/options.go new file mode 100644 index 000000000..72bfac0a1 --- /dev/null +++ b/src/cli/utils/options.go @@ -0,0 +1,26 @@ +package utils + +import ( + "github.com/alcionai/corso/src/cli/flags" + "github.com/alcionai/corso/src/pkg/control" +) + +// Control produces the control options based on the user's flags. +func Control() control.Options { + opt := control.Defaults() + + if flags.FailFastFV { + opt.FailureHandling = control.FailFast + } + + opt.DisableMetrics = flags.NoStatsFV + opt.RestorePermissions = flags.RestorePermissionsFV + opt.SkipReduce = flags.SkipReduceFV + opt.ToggleFeatures.DisableIncrementals = flags.DisableIncrementalsFV + opt.ToggleFeatures.DisableDelta = flags.DisableDeltaFV + opt.ToggleFeatures.ExchangeImmutableIDs = flags.EnableImmutableIDFV + opt.ToggleFeatures.DisableConcurrencyLimiter = flags.DisableConcurrencyLimiterFV + opt.Parallelism.ItemFetch = flags.FetchParallelismFV + + return opt +} diff --git a/src/cli/utils/options_test.go b/src/cli/utils/options_test.go new file mode 100644 index 000000000..746558aa1 --- /dev/null +++ b/src/cli/utils/options_test.go @@ -0,0 +1,67 @@ +package utils + +import ( + "testing" + + "github.com/alcionai/clues" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/cli/flags" + "github.com/alcionai/corso/src/internal/tester" +) + +type OptionsUnitSuite struct { + tester.Suite +} + +func TestOptionsUnitSuite(t *testing.T) { + suite.Run(t, &OptionsUnitSuite{Suite: tester.NewUnitSuite(t)}) +} + +func (suite *OptionsUnitSuite) TestAddExchangeCommands() { + t := suite.T() + + cmd := &cobra.Command{ + Use: "test", + Run: func(cmd *cobra.Command, args []string) { + assert.True(t, flags.FailFastFV, flags.FailFastFN) + assert.True(t, flags.DisableIncrementalsFV, flags.DisableIncrementalsFN) + assert.True(t, flags.DisableDeltaFV, flags.DisableDeltaFN) + assert.True(t, flags.NoStatsFV, flags.NoStatsFN) + assert.True(t, flags.RestorePermissionsFV, flags.RestorePermissionsFN) + assert.True(t, flags.SkipReduceFV, flags.SkipReduceFN) + assert.Equal(t, 2, flags.FetchParallelismFV, flags.FetchParallelismFN) + assert.True(t, flags.DisableConcurrencyLimiterFV, flags.DisableConcurrencyLimiterFN) + }, + } + + // adds no-stats + flags.AddGlobalOperationFlags(cmd) + + flags.AddFailFastFlag(cmd) + flags.AddDisableIncrementalsFlag(cmd) + flags.AddDisableDeltaFlag(cmd) + flags.AddRestorePermissionsFlag(cmd) + flags.AddSkipReduceFlag(cmd) + flags.AddFetchParallelismFlag(cmd) + flags.AddDisableConcurrencyLimiterFlag(cmd) + + // Test arg parsing for few args + cmd.SetArgs([]string{ + "test", + "--" + flags.FailFastFN, + "--" + flags.DisableIncrementalsFN, + "--" + flags.DisableDeltaFN, + "--" + flags.NoStatsFN, + "--" + flags.RestorePermissionsFN, + "--" + flags.SkipReduceFN, + "--" + flags.FetchParallelismFN, "2", + "--" + flags.DisableConcurrencyLimiterFN, + }) + + err := cmd.Execute() + require.NoError(t, err, clues.ToCore(err)) +} diff --git a/src/cli/utils/sharepoint.go b/src/cli/utils/sharepoint.go index bb37eb532..78b6e947f 100644 --- a/src/cli/utils/sharepoint.go +++ b/src/cli/utils/sharepoint.go @@ -8,25 +8,11 @@ import ( "github.com/alcionai/clues" "github.com/spf13/cobra" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/selectors" ) -const ( - ListFolderFN = "list" - ListItemFN = "list-item" - PageFolderFN = "page-folder" - PageFN = "page" -) - -// flag population variables -var ( - ListFolder []string - ListItem []string - PageFolder []string - Page []string -) - type SharePointOpts struct { SiteID []string WebURL []string @@ -45,95 +31,32 @@ type SharePointOpts struct { PageFolder []string Page []string - Populated PopulatedFlags + Populated flags.PopulatedFlags } func MakeSharePointOpts(cmd *cobra.Command) SharePointOpts { return SharePointOpts{ - SiteID: SiteIDFV, - WebURL: WebURLFV, + SiteID: flags.SiteIDFV, + WebURL: flags.WebURLFV, - Library: LibraryFV, - FileName: FileNameFV, - FolderPath: FolderPathFV, - FileCreatedAfter: FileCreatedAfterFV, - FileCreatedBefore: FileCreatedBeforeFV, - FileModifiedAfter: FileModifiedAfterFV, - FileModifiedBefore: FileModifiedBeforeFV, + Library: flags.LibraryFV, + FileName: flags.FileNameFV, + FolderPath: flags.FolderPathFV, + FileCreatedAfter: flags.FileCreatedAfterFV, + FileCreatedBefore: flags.FileCreatedBeforeFV, + FileModifiedAfter: flags.FileModifiedAfterFV, + FileModifiedBefore: flags.FileModifiedBeforeFV, - ListFolder: ListFolder, - ListItem: ListItem, + ListFolder: flags.ListFolderFV, + ListItem: flags.ListItemFV, - Page: Page, - PageFolder: PageFolder, + Page: flags.PageFV, + PageFolder: flags.PageFolderFV, - Populated: GetPopulatedFlags(cmd), + Populated: flags.GetPopulatedFlags(cmd), } } -// AddSharePointDetailsAndRestoreFlags adds flags that are common to both the -// details and restore commands. -func AddSharePointDetailsAndRestoreFlags(cmd *cobra.Command) { - fs := cmd.Flags() - - // libraries - - fs.StringVar( - &LibraryFV, - LibraryFN, "", - "Select only this library; defaults to all libraries.") - fs.StringSliceVar( - &FolderPathFV, - FolderFN, nil, - "Select by folder; defaults to root.") - fs.StringSliceVar( - &FileNameFV, - FileFN, nil, - "Select by file name.") - fs.StringVar( - &FileCreatedAfterFV, - FileCreatedAfterFN, "", - "Select files created after this datetime.") - fs.StringVar( - &FileCreatedBeforeFV, - FileCreatedBeforeFN, "", - "Select files created before this datetime.") - fs.StringVar( - &FileModifiedAfterFV, - FileModifiedAfterFN, "", - "Select files modified after this datetime.") - fs.StringVar( - &FileModifiedBeforeFV, - FileModifiedBeforeFN, "", - "Select files modified before this datetime.") - - // lists - - fs.StringSliceVar( - &ListFolder, - ListFolderFN, nil, - "Select lists by name; accepts '"+Wildcard+"' to select all lists.") - cobra.CheckErr(fs.MarkHidden(ListFolderFN)) - fs.StringSliceVar( - &ListItem, - ListItemFN, nil, - "Select lists by item name; accepts '"+Wildcard+"' to select all lists.") - cobra.CheckErr(fs.MarkHidden(ListItemFN)) - - // pages - - fs.StringSliceVar( - &PageFolder, - PageFolderFN, nil, - "Select pages by folder name; accepts '"+Wildcard+"' to select all pages.") - cobra.CheckErr(fs.MarkHidden(PageFolderFN)) - fs.StringSliceVar( - &Page, - PageFN, nil, - "Select pages by item name; accepts '"+Wildcard+"' to select all pages.") - cobra.CheckErr(fs.MarkHidden(PageFN)) -} - // ValidateSharePointRestoreFlags checks common flags for correctness and interdependencies func ValidateSharePointRestoreFlags(backupID string, opts SharePointOpts) error { if len(backupID) == 0 { @@ -141,7 +64,7 @@ func ValidateSharePointRestoreFlags(backupID string, opts SharePointOpts) error } // ensure url can parse all weburls provided by --site. - if _, ok := opts.Populated[SiteFN]; ok { + if _, ok := opts.Populated[flags.SiteFN]; ok { for _, wu := range opts.WebURL { if _, err := url.Parse(wu); err != nil { return clues.New("invalid site url: " + wu) @@ -149,20 +72,20 @@ func ValidateSharePointRestoreFlags(backupID string, opts SharePointOpts) error } } - if _, ok := opts.Populated[FileCreatedAfterFN]; ok && !IsValidTimeFormat(opts.FileCreatedAfter) { - return clues.New("invalid time format for " + FileCreatedAfterFN) + if _, ok := opts.Populated[flags.FileCreatedAfterFN]; ok && !IsValidTimeFormat(opts.FileCreatedAfter) { + return clues.New("invalid time format for " + flags.FileCreatedAfterFN) } - if _, ok := opts.Populated[FileCreatedBeforeFN]; ok && !IsValidTimeFormat(opts.FileCreatedBefore) { - return clues.New("invalid time format for " + FileCreatedBeforeFN) + if _, ok := opts.Populated[flags.FileCreatedBeforeFN]; ok && !IsValidTimeFormat(opts.FileCreatedBefore) { + return clues.New("invalid time format for " + flags.FileCreatedBeforeFN) } - if _, ok := opts.Populated[FileModifiedAfterFN]; ok && !IsValidTimeFormat(opts.FileModifiedAfter) { - return clues.New("invalid time format for " + FileModifiedAfterFN) + if _, ok := opts.Populated[flags.FileModifiedAfterFN]; ok && !IsValidTimeFormat(opts.FileModifiedAfter) { + return clues.New("invalid time format for " + flags.FileModifiedAfterFN) } - if _, ok := opts.Populated[FileModifiedBeforeFN]; ok && !IsValidTimeFormat(opts.FileModifiedBefore) { - return clues.New("invalid time format for " + FileModifiedBeforeFN) + if _, ok := opts.Populated[flags.FileModifiedBeforeFN]; ok && !IsValidTimeFormat(opts.FileModifiedBefore) { + return clues.New("invalid time format for " + flags.FileModifiedBeforeFN) } return nil diff --git a/src/cli/utils/sharepoint_test.go b/src/cli/utils/sharepoint_test.go index 5714adc2a..0e5cbd411 100644 --- a/src/cli/utils/sharepoint_test.go +++ b/src/cli/utils/sharepoint_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/dttm" "github.com/alcionai/corso/src/internal/tester" @@ -297,12 +298,12 @@ func (suite *SharePointUtilsSuite) TestValidateSharePointRestoreFlags() { FileCreatedBefore: dttm.Now(), FileModifiedAfter: dttm.Now(), FileModifiedBefore: dttm.Now(), - Populated: utils.PopulatedFlags{ - utils.SiteFN: {}, - utils.FileCreatedAfterFN: {}, - utils.FileCreatedBeforeFN: {}, - utils.FileModifiedAfterFN: {}, - utils.FileModifiedBeforeFN: {}, + Populated: flags.PopulatedFlags{ + flags.SiteFN: struct{}{}, + flags.FileCreatedAfterFN: struct{}{}, + flags.FileCreatedBeforeFN: struct{}{}, + flags.FileModifiedAfterFN: struct{}{}, + flags.FileModifiedBeforeFN: struct{}{}, }, }, expect: assert.NoError, @@ -318,8 +319,8 @@ func (suite *SharePointUtilsSuite) TestValidateSharePointRestoreFlags() { backupID: "id", opts: utils.SharePointOpts{ WebURL: []string{"slander://:vree.garbles/:"}, - Populated: utils.PopulatedFlags{ - utils.SiteFN: {}, + Populated: flags.PopulatedFlags{ + flags.SiteFN: struct{}{}, }, }, expect: assert.Error, @@ -329,8 +330,8 @@ func (suite *SharePointUtilsSuite) TestValidateSharePointRestoreFlags() { backupID: "id", opts: utils.SharePointOpts{ FileCreatedAfter: "1235", - Populated: utils.PopulatedFlags{ - utils.FileCreatedAfterFN: {}, + Populated: flags.PopulatedFlags{ + flags.FileCreatedAfterFN: struct{}{}, }, }, expect: assert.Error, @@ -340,8 +341,8 @@ func (suite *SharePointUtilsSuite) TestValidateSharePointRestoreFlags() { backupID: "id", opts: utils.SharePointOpts{ FileCreatedBefore: "1235", - Populated: utils.PopulatedFlags{ - utils.FileCreatedBeforeFN: {}, + Populated: flags.PopulatedFlags{ + flags.FileCreatedBeforeFN: struct{}{}, }, }, expect: assert.Error, @@ -351,8 +352,8 @@ func (suite *SharePointUtilsSuite) TestValidateSharePointRestoreFlags() { backupID: "id", opts: utils.SharePointOpts{ FileModifiedAfter: "1235", - Populated: utils.PopulatedFlags{ - utils.FileModifiedAfterFN: {}, + Populated: flags.PopulatedFlags{ + flags.FileModifiedAfterFN: struct{}{}, }, }, expect: assert.Error, @@ -362,8 +363,8 @@ func (suite *SharePointUtilsSuite) TestValidateSharePointRestoreFlags() { backupID: "id", opts: utils.SharePointOpts{ FileModifiedBefore: "1235", - Populated: utils.PopulatedFlags{ - utils.FileModifiedBeforeFN: {}, + Populated: flags.PopulatedFlags{ + flags.FileModifiedBeforeFN: struct{}{}, }, }, expect: assert.Error, diff --git a/src/cli/utils/testdata/opts.go b/src/cli/utils/testdata/opts.go index 614434c11..bb42f856a 100644 --- a/src/cli/utils/testdata/opts.go +++ b/src/cli/utils/testdata/opts.go @@ -7,6 +7,7 @@ import ( "github.com/alcionai/clues" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/dttm" "github.com/alcionai/corso/src/pkg/backup" @@ -37,8 +38,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.ExchangeOpts { return utils.ExchangeOpts{ EmailReceivedAfter: "foo", - Populated: utils.PopulatedFlags{ - utils.EmailReceivedAfterFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.EmailReceivedAfterFN: struct{}{}, }, } }, @@ -48,8 +49,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.ExchangeOpts { return utils.ExchangeOpts{ EmailReceivedAfter: "", - Populated: utils.PopulatedFlags{ - utils.EmailReceivedAfterFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.EmailReceivedAfterFN: struct{}{}, }, } }, @@ -59,8 +60,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.ExchangeOpts { return utils.ExchangeOpts{ EmailReceivedBefore: "foo", - Populated: utils.PopulatedFlags{ - utils.EmailReceivedBeforeFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.EmailReceivedBeforeFN: struct{}{}, }, } }, @@ -70,8 +71,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.ExchangeOpts { return utils.ExchangeOpts{ EmailReceivedBefore: "", - Populated: utils.PopulatedFlags{ - utils.EmailReceivedBeforeFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.EmailReceivedBeforeFN: struct{}{}, }, } }, @@ -81,8 +82,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.ExchangeOpts { return utils.ExchangeOpts{ EventRecurs: "foo", - Populated: utils.PopulatedFlags{ - utils.EventRecursFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.EventRecursFN: struct{}{}, }, } }, @@ -92,8 +93,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.ExchangeOpts { return utils.ExchangeOpts{ EventRecurs: "", - Populated: utils.PopulatedFlags{ - utils.EventRecursFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.EventRecursFN: struct{}{}, }, } }, @@ -103,8 +104,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.ExchangeOpts { return utils.ExchangeOpts{ EventStartsAfter: "foo", - Populated: utils.PopulatedFlags{ - utils.EventStartsAfterFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.EventStartsAfterFN: struct{}{}, }, } }, @@ -114,8 +115,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.ExchangeOpts { return utils.ExchangeOpts{ EventStartsAfter: "", - Populated: utils.PopulatedFlags{ - utils.EventStartsAfterFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.EventStartsAfterFN: struct{}{}, }, } }, @@ -125,8 +126,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.ExchangeOpts { return utils.ExchangeOpts{ EventStartsBefore: "foo", - Populated: utils.PopulatedFlags{ - utils.EventStartsBeforeFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.EventStartsBeforeFN: struct{}{}, }, } }, @@ -136,8 +137,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.ExchangeOpts { return utils.ExchangeOpts{ EventStartsBefore: "", - Populated: utils.PopulatedFlags{ - utils.EventStartsBeforeFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.EventStartsBeforeFN: struct{}{}, }, } }, @@ -441,8 +442,8 @@ var ( return utils.OneDriveOpts{ Users: selectors.Any(), FileCreatedAfter: "foo", - Populated: utils.PopulatedFlags{ - utils.FileCreatedAfterFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.FileCreatedAfterFN: struct{}{}, }, } }, @@ -452,8 +453,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.OneDriveOpts { return utils.OneDriveOpts{ FileCreatedAfter: "", - Populated: utils.PopulatedFlags{ - utils.FileCreatedAfterFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.FileCreatedAfterFN: struct{}{}, }, } }, @@ -463,8 +464,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.OneDriveOpts { return utils.OneDriveOpts{ FileCreatedBefore: "foo", - Populated: utils.PopulatedFlags{ - utils.FileCreatedBeforeFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.FileCreatedBeforeFN: struct{}{}, }, } }, @@ -474,8 +475,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.OneDriveOpts { return utils.OneDriveOpts{ FileCreatedBefore: "", - Populated: utils.PopulatedFlags{ - utils.FileCreatedBeforeFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.FileCreatedBeforeFN: struct{}{}, }, } }, @@ -485,8 +486,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.OneDriveOpts { return utils.OneDriveOpts{ FileModifiedAfter: "foo", - Populated: utils.PopulatedFlags{ - utils.FileModifiedAfterFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.FileModifiedAfterFN: struct{}{}, }, } }, @@ -496,8 +497,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.OneDriveOpts { return utils.OneDriveOpts{ FileModifiedAfter: "", - Populated: utils.PopulatedFlags{ - utils.FileModifiedAfterFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.FileModifiedAfterFN: struct{}{}, }, } }, @@ -507,8 +508,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.OneDriveOpts { return utils.OneDriveOpts{ FileModifiedBefore: "foo", - Populated: utils.PopulatedFlags{ - utils.FileModifiedBeforeFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.FileModifiedBeforeFN: struct{}{}, }, } }, @@ -518,8 +519,8 @@ var ( Opts: func(t *testing.T, wantedVersion int) utils.OneDriveOpts { return utils.OneDriveOpts{ FileModifiedBefore: "", - Populated: utils.PopulatedFlags{ - utils.FileModifiedBeforeFN: struct{}{}, + Populated: flags.PopulatedFlags{ + flags.FileModifiedBeforeFN: struct{}{}, }, } }, @@ -751,8 +752,8 @@ var ( // Opts: func(t *testing.T, wantedVersion int) utils.OneDriveOpts { // return utils.SharePointOpts{ // FileCreatedBefore: "foo", - // Populated: utils.PopulatedFlags{ - // utils.FileCreatedBeforeFN: struct{}{}, + // Populated: flags.PopulatedFlags{ + // flags.FileCreatedBeforeFN: struct{}{}, // }, // } // }, @@ -762,8 +763,8 @@ var ( // Opts: func(t *testing.T, wantedVersion int) utils.OneDriveOpts { // return utils.SharePointOpts{ // FileCreatedBefore: "", - // Populated: utils.PopulatedFlags{ - // utils.FileCreatedBeforeFN: struct{}{}, + // Populated: flags.PopulatedFlags{ + // flags.FileCreatedBeforeFN: struct{}{}, // }, // } // }, diff --git a/src/cli/utils/utils.go b/src/cli/utils/utils.go index e0b4c5276..277f11c5c 100644 --- a/src/cli/utils/utils.go +++ b/src/cli/utils/utils.go @@ -9,7 +9,6 @@ import ( "github.com/spf13/pflag" "github.com/alcionai/corso/src/cli/config" - "github.com/alcionai/corso/src/cli/options" "github.com/alcionai/corso/src/internal/events" "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/control" @@ -20,10 +19,6 @@ import ( "github.com/alcionai/corso/src/pkg/storage" ) -const ( - Wildcard = "*" -) - func GetAccountAndConnect(ctx context.Context) (repository.Repository, *storage.Storage, *account.Account, error) { cfg, err := config.GetConfigRepoDetails(ctx, true, nil) if err != nil { @@ -35,7 +30,7 @@ func GetAccountAndConnect(ctx context.Context) (repository.Repository, *storage. repoID = events.RepoIDNotFound } - r, err := repository.Connect(ctx, cfg.Account, cfg.Storage, repoID, options.Control()) + r, err := repository.Connect(ctx, cfg.Account, cfg.Storage, repoID, Control()) if err != nil { return nil, nil, nil, clues.Wrap(err, "connecting to the "+cfg.Storage.Provider.String()+" repository") } diff --git a/src/cmd/factory/impl/exchange.go b/src/cmd/factory/impl/exchange.go index dd304e2e9..bf8450da2 100644 --- a/src/cmd/factory/impl/exchange.go +++ b/src/cmd/factory/impl/exchange.go @@ -114,7 +114,10 @@ func handleExchangeCalendarEventFactory(cmd *cobra.Command, args []string) error func(id, now, subject, body string) []byte { return exchMock.EventWith( User, subject, body, body, - now, now, exchMock.NoRecurrence, exchMock.NoAttendees, false) + exchMock.NoOriginalStartDate, now, now, + exchMock.NoRecurrence, exchMock.NoAttendees, + exchMock.NoAttachments, exchMock.NoCancelledOccurrences, + exchMock.NoExceptionOccurrences) }, control.Defaults(), errs) diff --git a/src/go.mod b/src/go.mod index f16f280c5..bae9e2071 100644 --- a/src/go.mod +++ b/src/go.mod @@ -8,7 +8,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.0 github.com/alcionai/clues v0.0.0-20230613181047-258ea4f19225 github.com/armon/go-metrics v0.4.1 - github.com/aws/aws-sdk-go v1.44.283 + github.com/aws/aws-sdk-go v1.44.287 github.com/aws/aws-xray-sdk-go v1.8.1 github.com/cenkalti/backoff/v4 v4.2.1 github.com/google/uuid v1.3.0 diff --git a/src/go.sum b/src/go.sum index 8178b821a..fad888a58 100644 --- a/src/go.sum +++ b/src/go.sum @@ -66,8 +66,8 @@ github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/armon/go-metrics v0.4.1 h1:hR91U9KYmb6bLBYLQjyM+3j+rcd/UhE+G78SFnF8gJA= github.com/armon/go-metrics v0.4.1/go.mod h1:E6amYzXo6aW1tqzoZGT755KkbgrJsSdpwZ+3JqfkOG4= -github.com/aws/aws-sdk-go v1.44.283 h1:ObMaIvdhHJM2sIrbcljd7muHBaFb+Kp/QsX6iflGDg4= -github.com/aws/aws-sdk-go v1.44.283/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= +github.com/aws/aws-sdk-go v1.44.287 h1:CUq2/h0gZ2LOCF61AgQSEMPMfas4gTiQfHBO88gGET0= +github.com/aws/aws-sdk-go v1.44.287/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-xray-sdk-go v1.8.1 h1:O4pXV+hnCskaamGsZnFpzHyAmgPGusBMN6i7nnsy0Fo= github.com/aws/aws-xray-sdk-go v1.8.1/go.mod h1:wMmVYzej3sykAttNBkXQHK/+clAPWTOrPiajEk7Cp3A= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= diff --git a/src/internal/m365/discovery/discovery.go b/src/internal/m365/discovery/discovery.go index cba4a25a7..fd6d073e1 100644 --- a/src/internal/m365/discovery/discovery.go +++ b/src/internal/m365/discovery/discovery.go @@ -29,6 +29,10 @@ type getWithInfoer interface { GetInfoer } +type GetDefaultDriver interface { + GetDefaultDrive(ctx context.Context, userID string) (models.Driveable, error) +} + type getAller interface { GetAll(ctx context.Context, errs *fault.Bus) ([]models.Userable, error) } diff --git a/src/internal/m365/exchange/attachment.go b/src/internal/m365/exchange/attachment.go index d09124523..ba8153a63 100644 --- a/src/internal/m365/exchange/attachment.go +++ b/src/internal/m365/exchange/attachment.go @@ -9,6 +9,7 @@ import ( "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/pkg/logger" + "github.com/alcionai/corso/src/pkg/services/m365/api" ) type attachmentPoster interface { @@ -20,15 +21,14 @@ type attachmentPoster interface { PostLargeAttachment( ctx context.Context, userID, containerID, itemID, name string, - size int64, - body models.Attachmentable, - ) (models.UploadSessionable, error) + content []byte, + ) (string, error) } const ( // Use large attachment logic for attachments > 3MB // https://learn.microsoft.com/en-us/graph/outlook-large-attachments - largeAttachmentSize = int32(3 * 1024 * 1024) + largeAttachmentSize = 3 * 1024 * 1024 fileAttachmentOdataValue = "#microsoft.graph.fileAttachment" itemAttachmentOdataValue = "#microsoft.graph.itemAttachment" referenceAttachmentOdataValue = "#microsoft.graph.referenceAttachment" @@ -53,7 +53,7 @@ func attachmentType(attachment models.Attachmentable) models.AttachmentType { // uploadAttachment will upload the specified message attachment to M365 func uploadAttachment( ctx context.Context, - cli attachmentPoster, + ap attachmentPoster, userID, containerID, parentItemID string, attachment models.Attachmentable, ) error { @@ -95,12 +95,20 @@ func uploadAttachment( // for file attachments sized >= 3MB if attachmentType == models.FILE_ATTACHMENTTYPE && size >= largeAttachmentSize { - _, err := cli.PostLargeAttachment(ctx, userID, containerID, parentItemID, name, int64(size), attachment) + // We expect the entire attachment to fit in memory. + // Max attachment size is 150MB. + content, err := api.GetAttachmentContent(attachment) + if err != nil { + return clues.Wrap(err, "serializing attachment content").WithClues(ctx) + } + + _, err = ap.PostLargeAttachment(ctx, userID, containerID, parentItemID, name, content) + return err } // for all other attachments - return cli.PostSmallAttachment(ctx, userID, containerID, parentItemID, attachment) + return ap.PostSmallAttachment(ctx, userID, containerID, parentItemID, attachment) } func getOutlookOdataType(query models.Attachmentable) string { diff --git a/src/internal/m365/exchange/backup_test.go b/src/internal/m365/exchange/backup_test.go index 0b900945d..0af604c90 100644 --- a/src/internal/m365/exchange/backup_test.go +++ b/src/internal/m365/exchange/backup_test.go @@ -382,7 +382,7 @@ func newStatusUpdater(t *testing.T, wg *sync.WaitGroup) func(status *support.Con return updater } -type DataCollectionsIntegrationSuite struct { +type BackupIntgSuite struct { tester.Suite user string site string @@ -390,16 +390,15 @@ type DataCollectionsIntegrationSuite struct { ac api.Client } -func TestDataCollectionsIntegrationSuite(t *testing.T) { - suite.Run(t, &DataCollectionsIntegrationSuite{ +func TestBackupIntgSuite(t *testing.T) { + suite.Run(t, &BackupIntgSuite{ Suite: tester.NewIntegrationSuite( t, - [][]string{tester.M365AcctCredEnvs}, - ), + [][]string{tester.M365AcctCredEnvs}), }) } -func (suite *DataCollectionsIntegrationSuite) SetupSuite() { +func (suite *BackupIntgSuite) SetupSuite() { suite.user = tester.M365UserID(suite.T()) suite.site = tester.M365SiteID(suite.T()) @@ -415,7 +414,7 @@ func (suite *DataCollectionsIntegrationSuite) SetupSuite() { tester.LogTimeOfTest(suite.T()) } -func (suite *DataCollectionsIntegrationSuite) TestMailFetch() { +func (suite *BackupIntgSuite) TestMailFetch() { var ( userID = tester.M365UserID(suite.T()) users = []string{userID} @@ -499,7 +498,7 @@ func (suite *DataCollectionsIntegrationSuite) TestMailFetch() { } } -func (suite *DataCollectionsIntegrationSuite) TestDelta() { +func (suite *BackupIntgSuite) TestDelta() { var ( userID = tester.M365UserID(suite.T()) users = []string{userID} @@ -604,7 +603,7 @@ func (suite *DataCollectionsIntegrationSuite) TestDelta() { // TestMailSerializationRegression verifies that all mail data stored in the // test account can be successfully downloaded into bytes and restored into // M365 mail objects -func (suite *DataCollectionsIntegrationSuite) TestMailSerializationRegression() { +func (suite *BackupIntgSuite) TestMailSerializationRegression() { t := suite.T() ctx, flush := tester.NewContext(t) @@ -668,7 +667,7 @@ func (suite *DataCollectionsIntegrationSuite) TestMailSerializationRegression() // TestContactSerializationRegression verifies ability to query contact items // and to store contact within Collection. Downloaded contacts are run through // a regression test to ensure that downloaded items can be uploaded. -func (suite *DataCollectionsIntegrationSuite) TestContactSerializationRegression() { +func (suite *BackupIntgSuite) TestContactSerializationRegression() { var ( users = []string{suite.user} handlers = BackupHandlers(suite.ac) @@ -756,7 +755,7 @@ func (suite *DataCollectionsIntegrationSuite) TestContactSerializationRegression // TestEventsSerializationRegression ensures functionality of createCollections // to be able to successfully query, download and restore event objects -func (suite *DataCollectionsIntegrationSuite) TestEventsSerializationRegression() { +func (suite *BackupIntgSuite) TestEventsSerializationRegression() { t := suite.T() ctx, flush := tester.NewContext(t) diff --git a/src/internal/m365/exchange/contacts_restore.go b/src/internal/m365/exchange/contacts_restore.go index 82ff1364a..076cd4a22 100644 --- a/src/internal/m365/exchange/contacts_restore.go +++ b/src/internal/m365/exchange/contacts_restore.go @@ -9,7 +9,9 @@ import ( "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/pkg/backup/details" + "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/fault" + "github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" ) @@ -18,7 +20,6 @@ var _ itemRestorer = &contactRestoreHandler{} type contactRestoreHandler struct { ac api.Contacts - ip itemPoster[models.Contactable] } func newContactRestoreHandler( @@ -26,7 +27,6 @@ func newContactRestoreHandler( ) contactRestoreHandler { return contactRestoreHandler{ ac: ac.Contacts(), - ip: ac.Contacts(), } } @@ -65,6 +65,27 @@ func (h contactRestoreHandler) restore( ctx context.Context, body []byte, userID, destinationID string, + collisionKeyToItemID map[string]string, + collisionPolicy control.CollisionPolicy, + errs *fault.Bus, +) (*details.ExchangeInfo, error) { + return restoreContact( + ctx, + h.ac, + body, + userID, destinationID, + collisionKeyToItemID, + collisionPolicy, + errs) +} + +func restoreContact( + ctx context.Context, + pi postItemer[models.Contactable], + body []byte, + userID, destinationID string, + collisionKeyToItemID map[string]string, + collisionPolicy control.CollisionPolicy, errs *fault.Bus, ) (*details.ExchangeInfo, error) { contact, err := api.BytesToContactable(body) @@ -73,8 +94,20 @@ func (h contactRestoreHandler) restore( } ctx = clues.Add(ctx, "item_id", ptr.Val(contact.GetId())) + collisionKey := api.ContactCollisionKey(contact) - item, err := h.ip.PostItem(ctx, userID, destinationID, contact) + if _, ok := collisionKeyToItemID[collisionKey]; ok { + log := logger.Ctx(ctx).With("collision_key", clues.Hide(collisionKey)) + log.Debug("item collision") + + // TODO(rkeepers): Replace probably shouldn't no-op. Just a starting point. + if collisionPolicy == control.Skip || collisionPolicy == control.Replace { + log.Debug("skipping item with collision") + return nil, graph.ErrItemAlreadyExistsConflict + } + } + + item, err := pi.PostItem(ctx, userID, destinationID, contact) if err != nil { return nil, graph.Wrap(ctx, err, "restoring mail message") } @@ -84,3 +117,15 @@ func (h contactRestoreHandler) restore( return info, nil } + +func (h contactRestoreHandler) getItemsInContainerByCollisionKey( + ctx context.Context, + userID, containerID string, +) (map[string]string, error) { + m, err := h.ac.GetItemsInContainerByCollisionKey(ctx, userID, containerID) + if err != nil { + return nil, err + } + + return m, nil +} diff --git a/src/internal/m365/exchange/contacts_restore_test.go b/src/internal/m365/exchange/contacts_restore_test.go index de53f59e2..6ed6a3c5e 100644 --- a/src/internal/m365/exchange/contacts_restore_test.go +++ b/src/internal/m365/exchange/contacts_restore_test.go @@ -1,24 +1,46 @@ package exchange import ( + "context" "testing" "github.com/alcionai/clues" + "github.com/microsoftgraph/msgraph-sdk-go/models" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/internal/m365/exchange/mock" + "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/internal/tester" - "github.com/alcionai/corso/src/pkg/account" + "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/control/testdata" + "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" ) +var _ postItemer[models.Contactable] = &mockContactRestorer{} + +type mockContactRestorer struct { + postItemErr error +} + +func (m mockContactRestorer) PostItem( + ctx context.Context, + userID, containerID string, + body models.Contactable, +) (models.Contactable, error) { + return models.NewContact(), m.postItemErr +} + +// --------------------------------------------------------------------------- +// tests +// --------------------------------------------------------------------------- + type ContactsRestoreIntgSuite struct { tester.Suite - creds account.M365Config - ac api.Client - userID string + its intgTesterSetup } func TestContactsRestoreIntgSuite(t *testing.T) { @@ -30,29 +52,110 @@ func TestContactsRestoreIntgSuite(t *testing.T) { } func (suite *ContactsRestoreIntgSuite) SetupSuite() { - t := suite.T() - - a := tester.NewM365Account(t) - creds, err := a.M365Config() - require.NoError(t, err, clues.ToCore(err)) - - suite.creds = creds - - suite.ac, err = api.NewClient(creds) - require.NoError(t, err, clues.ToCore(err)) - - suite.userID = tester.M365UserID(t) + suite.its = newIntegrationTesterSetup(suite.T()) } // Testing to ensure that cache system works for in multiple different environments func (suite *ContactsRestoreIntgSuite) TestCreateContainerDestination() { runCreateDestinationTest( suite.T(), - newMailRestoreHandler(suite.ac), - path.EmailCategory, - suite.creds.AzureTenantID, - suite.userID, + newContactRestoreHandler(suite.its.ac), + path.ContactsCategory, + suite.its.creds.AzureTenantID, + suite.its.userID, testdata.DefaultRestoreConfig("").Location, []string{"Hufflepuff"}, []string{"Ravenclaw"}) } + +func (suite *ContactsRestoreIntgSuite) TestRestoreContact() { + body := mock.ContactBytes("middlename") + + stub, err := api.BytesToContactable(body) + require.NoError(suite.T(), err, clues.ToCore(err)) + + collisionKey := api.ContactCollisionKey(stub) + + table := []struct { + name string + apiMock postItemer[models.Contactable] + collisionMap map[string]string + onCollision control.CollisionPolicy + expectErr func(*testing.T, error) + }{ + { + name: "no collision: skip", + apiMock: mockContactRestorer{}, + collisionMap: map[string]string{}, + onCollision: control.Copy, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "no collision: copy", + apiMock: mockContactRestorer{}, + collisionMap: map[string]string{}, + onCollision: control.Skip, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "no collision: replace", + apiMock: mockContactRestorer{}, + collisionMap: map[string]string{}, + onCollision: control.Replace, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "collision: skip", + apiMock: mockContactRestorer{}, + collisionMap: map[string]string{collisionKey: "smarf"}, + onCollision: control.Skip, + expectErr: func(t *testing.T, err error) { + assert.ErrorIs(t, err, graph.ErrItemAlreadyExistsConflict, clues.ToCore(err)) + }, + }, + { + name: "collision: copy", + apiMock: mockContactRestorer{}, + collisionMap: map[string]string{collisionKey: "smarf"}, + onCollision: control.Copy, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "collision: replace", + apiMock: mockContactRestorer{}, + collisionMap: map[string]string{collisionKey: "smarf"}, + onCollision: control.Replace, + expectErr: func(t *testing.T, err error) { + assert.ErrorIs(t, err, graph.ErrItemAlreadyExistsConflict, clues.ToCore(err)) + }, + }, + } + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + _, err := restoreContact( + ctx, + test.apiMock, + body, + suite.its.userID, + "destination", + test.collisionMap, + test.onCollision, + fault.New(true)) + + test.expectErr(t, err) + }) + } +} diff --git a/src/internal/m365/exchange/events_restore.go b/src/internal/m365/exchange/events_restore.go index 18540ecaf..8ccb1232c 100644 --- a/src/internal/m365/exchange/events_restore.go +++ b/src/internal/m365/exchange/events_restore.go @@ -1,15 +1,23 @@ package exchange import ( + "bytes" "context" + "fmt" + "strings" + "time" "github.com/alcionai/clues" "github.com/microsoftgraph/msgraph-sdk-go/models" + "github.com/alcionai/corso/src/internal/common/dttm" "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/pkg/backup/details" + "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/fault" + "github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" ) @@ -18,17 +26,13 @@ var _ itemRestorer = &eventRestoreHandler{} type eventRestoreHandler struct { ac api.Events - ip itemPoster[models.Eventable] } func newEventRestoreHandler( ac api.Client, ) eventRestoreHandler { - ace := ac.Events() - return eventRestoreHandler{ - ac: ace, - ip: ace, + ac: ac.Events(), } } @@ -67,6 +71,32 @@ func (h eventRestoreHandler) restore( ctx context.Context, body []byte, userID, destinationID string, + collisionKeyToItemID map[string]string, + collisionPolicy control.CollisionPolicy, + errs *fault.Bus, +) (*details.ExchangeInfo, error) { + return restoreEvent( + ctx, + h.ac, + body, + userID, destinationID, + collisionKeyToItemID, + collisionPolicy, + errs) +} + +type eventRestorer interface { + postItemer[models.Eventable] + eventInstanceAndAttachmenter +} + +func restoreEvent( + ctx context.Context, + er eventRestorer, + body []byte, + userID, destinationID string, + collisionKeyToItemID map[string]string, + collisionPolicy control.CollisionPolicy, errs *fault.Bus, ) (*details.ExchangeInfo, error) { event, err := api.BytesToEventable(body) @@ -75,6 +105,18 @@ func (h eventRestoreHandler) restore( } ctx = clues.Add(ctx, "item_id", ptr.Val(event.GetId())) + collisionKey := api.EventCollisionKey(event) + + if _, ok := collisionKeyToItemID[collisionKey]; ok { + log := logger.Ctx(ctx).With("collision_key", clues.Hide(collisionKey)) + log.Debug("item collision") + + // TODO(rkeepers): Replace probably shouldn't no-op. Just a starting point. + if collisionPolicy == control.Skip || collisionPolicy == control.Replace { + log.Debug("skipping item with collision") + return nil, graph.ErrItemAlreadyExistsConflict + } + } event = toEventSimplified(event) @@ -82,17 +124,19 @@ func (h eventRestoreHandler) restore( if ptr.Val(event.GetHasAttachments()) { attachments = event.GetAttachments() - event.SetAttachments([]models.Attachmentable{}) + // We cannot use `[]models.Attbachmentable{}` instead of nil + // for beta endpoint. + event.SetAttachments(nil) } - item, err := h.ip.PostItem(ctx, userID, destinationID, event) + item, err := er.PostItem(ctx, userID, destinationID, event) if err != nil { - return nil, graph.Wrap(ctx, err, "restoring mail message") + return nil, graph.Wrap(ctx, err, "restoring calendar item") } err = uploadAttachments( ctx, - h.ac, + er, attachments, userID, destinationID, @@ -102,8 +146,359 @@ func (h eventRestoreHandler) restore( return nil, clues.Stack(err) } + // Have to parse event again as we modified the original event and + // removed cancelled and exceptions events form it + event, err = api.BytesToEventable(body) + if err != nil { + return nil, clues.Wrap(err, "creating event from bytes").WithClues(ctx) + } + + // Fix up event instances in case we have a recurring event + err = updateRecurringEvents( + ctx, + er, + userID, + destinationID, + ptr.Val(item.GetId()), + event, + errs, + ) + if err != nil { + return nil, clues.Stack(err) + } + info := api.EventInfo(event) info.Size = int64(len(body)) return info, nil } + +func updateRecurringEvents( + ctx context.Context, + eiaa eventInstanceAndAttachmenter, + userID, containerID, itemID string, + event models.Eventable, + errs *fault.Bus, +) error { + if event.GetRecurrence() == nil { + return nil + } + + // Cancellations and exceptions are currently in additional data + // but will get their own fields once the beta API lands and + // should be moved then + cancelledOccurrences := event.GetAdditionalData()["cancelledOccurrences"] + exceptionOccurrences := event.GetAdditionalData()["exceptionOccurrences"] + + err := updateCancelledOccurrences(ctx, eiaa, userID, itemID, cancelledOccurrences) + if err != nil { + return clues.Wrap(err, "update cancelled occurrences") + } + + err = updateExceptionOccurrences(ctx, eiaa, userID, containerID, itemID, exceptionOccurrences, errs) + if err != nil { + return clues.Wrap(err, "update exception occurrences") + } + + return nil +} + +type eventInstanceAndAttachmenter interface { + attachmentGetDeletePoster + DeleteItem( + ctx context.Context, + userID, itemID string, + ) error + GetItemInstances( + ctx context.Context, + userID, itemID string, + startDate, endDate string, + ) ([]models.Eventable, error) + PatchItem( + ctx context.Context, + userID, eventID string, + body models.Eventable, + ) (models.Eventable, error) +} + +// updateExceptionOccurrences take events that have exceptions, uses +// the originalStart date to find the instance and modify it to match +// the backup by updating the instance to match the backed up one +func updateExceptionOccurrences( + ctx context.Context, + eiaa eventInstanceAndAttachmenter, + userID string, + containerID string, + itemID string, + exceptionOccurrences any, + errs *fault.Bus, +) error { + if exceptionOccurrences == nil { + return nil + } + + eo, ok := exceptionOccurrences.([]any) + if !ok { + return clues.New("converting exceptionOccurrences to []any"). + With("type", fmt.Sprintf("%T", exceptionOccurrences)) + } + + for _, instance := range eo { + instance, ok := instance.(map[string]any) + if !ok { + return clues.New("converting instance to map[string]any"). + With("type", fmt.Sprintf("%T", instance)) + } + + evt, err := api.EventFromMap(instance) + if err != nil { + return clues.Wrap(err, "parsing exception event") + } + + start := ptr.Val(evt.GetOriginalStart()) + startStr := dttm.FormatTo(start, dttm.DateOnly) + endStr := dttm.FormatTo(start.Add(24*time.Hour), dttm.DateOnly) + + ictx := clues.Add(ctx, "event_instance_id", ptr.Val(evt.GetId()), "event_instance_date", start) + + // Get all instances on the day of the instance which should + // just the one we need to modify + instances, err := eiaa.GetItemInstances(ictx, userID, itemID, startStr, endStr) + if err != nil { + return clues.Wrap(err, "getting instances") + } + + // Since the min recurrence interval is 1 day and we are + // querying for only a single day worth of instances, we + // should not have more than one instance here. + if len(instances) != 1 { + return clues.New("invalid number of instances for modified"). + With("instances_count", len(instances), "search_start", startStr, "search_end", endStr) + } + + evt = toEventSimplified(evt) + + _, err = eiaa.PatchItem(ictx, userID, ptr.Val(instances[0].GetId()), evt) + if err != nil { + return clues.Wrap(err, "updating event instance") + } + + // We are creating event again from map as `toEventSimplified` + // removed the attachments and creating a clone from start of + // the event is non-trivial + evt, err = api.EventFromMap(instance) + if err != nil { + return clues.Wrap(err, "parsing event instance") + } + + err = updateAttachments( + ictx, + eiaa, + userID, + containerID, + ptr.Val(instances[0].GetId()), + evt, + errs) + if err != nil { + return clues.Wrap(err, "updating event instance attachments") + } + } + + return nil +} + +type attachmentGetDeletePoster interface { + attachmentPoster + GetAttachments( + ctx context.Context, + immutableIDs bool, + userID string, + itemID string, + ) ([]models.Attachmentable, error) + DeleteAttachment( + ctx context.Context, + userID, calendarID, eventID, attachmentID string, + ) error +} + +// updateAttachments updates the attachments of an event to match what +// is present in the backed up event. Ideally we could make use of the +// id of the series master event's attachments to see if we had +// added/removed any attachments, but as soon an event is modified, +// the id changes which makes the ids unusable. In this function, we +// use the name and content bytes to detect the changes. This function +// can be used to update the attachments of any event irrespective of +// whether they are event instances of a series master although for +// newer event, since we probably won't already have any events it +// would be better use Post[Small|Large]Attachment. +func updateAttachments( + ctx context.Context, + agdp attachmentGetDeletePoster, + userID, containerID, eventID string, + event models.Eventable, + errs *fault.Bus, +) error { + el := errs.Local() + + attachments, err := agdp.GetAttachments(ctx, false, userID, eventID) + if err != nil { + return clues.Wrap(err, "getting attachments") + } + + // Delete attachments that are not present in the backup but are + // present in the event(ones that were automatically inherited + // from series master). + for _, att := range attachments { + if el.Failure() != nil { + return el.Failure() + } + + name := ptr.Val(att.GetName()) + id := ptr.Val(att.GetId()) + + content, err := api.GetAttachmentContent(att) + if err != nil { + return clues.Wrap(err, "getting attachment").With("attachment_id", id) + } + + found := false + + for _, nAtt := range event.GetAttachments() { + nName := ptr.Val(nAtt.GetName()) + + nContent, err := api.GetAttachmentContent(nAtt) + if err != nil { + return clues.Wrap(err, "getting attachment").With("attachment_id", ptr.Val(nAtt.GetId())) + } + + if name == nName && bytes.Equal(content, nContent) { + found = true + break + } + } + + if !found { + err = agdp.DeleteAttachment(ctx, userID, containerID, eventID, id) + if err != nil { + logger.CtxErr(ctx, err).With("attachment_name", name).Info("attachment delete failed") + el.AddRecoverable(ctx, clues.Wrap(err, "deleting event attachment"). + WithClues(ctx).With("attachment_name", name)) + } + } + } + + // Upload missing(attachments that are present in the individual + // instance but not in the series master event) attachments + for _, att := range event.GetAttachments() { + name := ptr.Val(att.GetName()) + id := ptr.Val(att.GetId()) + + content, err := api.GetAttachmentContent(att) + if err != nil { + return clues.Wrap(err, "getting attachment").With("attachment_id", id) + } + + found := false + + for _, nAtt := range attachments { + nName := ptr.Val(nAtt.GetName()) + + bContent, err := api.GetAttachmentContent(nAtt) + if err != nil { + return clues.Wrap(err, "getting attachment").With("attachment_id", ptr.Val(nAtt.GetId())) + } + + // Max size allowed for an outlook attachment is 150MB + if name == nName && bytes.Equal(content, bContent) { + found = true + break + } + } + + if !found { + err = uploadAttachment(ctx, agdp, userID, containerID, eventID, att) + if err != nil { + return clues.Wrap(err, "uploading attachment"). + With("attachment_id", id) + } + } + } + + return el.Failure() +} + +// updateCancelledOccurrences get the cancelled occurrences which is a +// list of strings of the format ".", parses the date out of +// that and uses the to get the event instance at that date to delete. +func updateCancelledOccurrences( + ctx context.Context, + eiaa eventInstanceAndAttachmenter, + userID string, + itemID string, + cancelledOccurrences any, +) error { + if cancelledOccurrences == nil { + return nil + } + + co, ok := cancelledOccurrences.([]any) + if !ok { + return clues.New("converting cancelledOccurrences to []any"). + With("type", fmt.Sprintf("%T", cancelledOccurrences)) + } + + // OPTIMIZATION: We can fetch a date range instead of fetching + // instances if we have multiple cancelled events which are nearby + // and reduce the number of API calls that we have to make + for _, instance := range co { + instance, err := str.AnyToString(instance) + if err != nil { + return err + } + + splits := strings.Split(instance, ".") + + startStr := splits[len(splits)-1] + + start, err := dttm.ParseTime(startStr) + if err != nil { + return clues.Wrap(err, "parsing cancelled event date") + } + + endStr := dttm.FormatTo(start.Add(24*time.Hour), dttm.DateOnly) + + // Get all instances on the day of the instance which should + // just the one we need to modify + instances, err := eiaa.GetItemInstances(ctx, userID, itemID, startStr, endStr) + if err != nil { + return clues.Wrap(err, "getting instances") + } + + // Since the min recurrence interval is 1 day and we are + // querying for only a single day worth of instances, we + // should not have more than one instance here. + if len(instances) != 1 { + return clues.New("invalid number of instances for cancelled"). + With("instances_count", len(instances), "search_start", startStr, "search_end", endStr) + } + + err = eiaa.DeleteItem(ctx, userID, ptr.Val(instances[0].GetId())) + if err != nil { + return clues.Wrap(err, "deleting event instance") + } + } + + return nil +} + +func (h eventRestoreHandler) getItemsInContainerByCollisionKey( + ctx context.Context, + userID, containerID string, +) (map[string]string, error) { + m, err := h.ac.GetItemsInContainerByCollisionKey(ctx, userID, containerID) + if err != nil { + return nil, err + } + + return m, nil +} diff --git a/src/internal/m365/exchange/events_restore_test.go b/src/internal/m365/exchange/events_restore_test.go index 156d191d1..ddd9983b8 100644 --- a/src/internal/m365/exchange/events_restore_test.go +++ b/src/internal/m365/exchange/events_restore_test.go @@ -1,24 +1,101 @@ package exchange import ( + "context" "testing" "github.com/alcionai/clues" + "github.com/google/uuid" + "github.com/microsoftgraph/msgraph-sdk-go/models" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/internal/m365/exchange/mock" + "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/internal/tester" - "github.com/alcionai/corso/src/pkg/account" + "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/control/testdata" + "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" ) +var _ eventRestorer = &mockEventRestorer{} + +type mockEventRestorer struct { + postItemErr error + postAttachmentErr error +} + +func (m mockEventRestorer) PostItem( + ctx context.Context, + userID, containerID string, + body models.Eventable, +) (models.Eventable, error) { + return models.NewEvent(), m.postItemErr +} + +func (m mockEventRestorer) PostSmallAttachment( + _ context.Context, + _, _, _ string, + _ models.Attachmentable, +) error { + return m.postAttachmentErr +} + +func (m mockEventRestorer) PostLargeAttachment( + _ context.Context, + _, _, _, _ string, + _ []byte, +) (string, error) { + return uuid.NewString(), m.postAttachmentErr +} + +func (m mockEventRestorer) DeleteAttachment( + ctx context.Context, + userID, calendarID, eventID, attachmentID string, +) error { + return nil +} + +func (m mockEventRestorer) DeleteItem( + ctx context.Context, + userID, itemID string, +) error { + return nil +} + +func (m mockEventRestorer) GetAttachments( + _ context.Context, + _ bool, + _, _ string, +) ([]models.Attachmentable, error) { + return []models.Attachmentable{}, nil +} + +func (m mockEventRestorer) GetItemInstances( + _ context.Context, + _, _, _, _ string, +) ([]models.Eventable, error) { + return []models.Eventable{}, nil +} + +func (m mockEventRestorer) PatchItem( + _ context.Context, + _, _ string, + _ models.Eventable, +) (models.Eventable, error) { + return models.NewEvent(), nil +} + +// --------------------------------------------------------------------------- +// tests +// --------------------------------------------------------------------------- + type EventsRestoreIntgSuite struct { tester.Suite - creds account.M365Config - ac api.Client - userID string + its intgTesterSetup } func TestEventsRestoreIntgSuite(t *testing.T) { @@ -30,29 +107,110 @@ func TestEventsRestoreIntgSuite(t *testing.T) { } func (suite *EventsRestoreIntgSuite) SetupSuite() { - t := suite.T() - - a := tester.NewM365Account(t) - creds, err := a.M365Config() - require.NoError(t, err, clues.ToCore(err)) - - suite.creds = creds - - suite.ac, err = api.NewClient(creds) - require.NoError(t, err, clues.ToCore(err)) - - suite.userID = tester.M365UserID(t) + suite.its = newIntegrationTesterSetup(suite.T()) } // Testing to ensure that cache system works for in multiple different environments func (suite *EventsRestoreIntgSuite) TestCreateContainerDestination() { runCreateDestinationTest( suite.T(), - newMailRestoreHandler(suite.ac), - path.EmailCategory, - suite.creds.AzureTenantID, - suite.userID, + newEventRestoreHandler(suite.its.ac), + path.EventsCategory, + suite.its.creds.AzureTenantID, + suite.its.userID, testdata.DefaultRestoreConfig("").Location, []string{"Durmstrang"}, []string{"Beauxbatons"}) } + +func (suite *EventsRestoreIntgSuite) TestRestoreEvent() { + body := mock.EventBytes("subject") + + stub, err := api.BytesToEventable(body) + require.NoError(suite.T(), err, clues.ToCore(err)) + + collisionKey := api.EventCollisionKey(stub) + + table := []struct { + name string + apiMock eventRestorer + collisionMap map[string]string + onCollision control.CollisionPolicy + expectErr func(*testing.T, error) + }{ + { + name: "no collision: skip", + apiMock: mockEventRestorer{}, + collisionMap: map[string]string{}, + onCollision: control.Copy, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "no collision: copy", + apiMock: mockEventRestorer{}, + collisionMap: map[string]string{}, + onCollision: control.Skip, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "no collision: replace", + apiMock: mockEventRestorer{}, + collisionMap: map[string]string{}, + onCollision: control.Replace, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "collision: skip", + apiMock: mockEventRestorer{}, + collisionMap: map[string]string{collisionKey: "smarf"}, + onCollision: control.Skip, + expectErr: func(t *testing.T, err error) { + assert.ErrorIs(t, err, graph.ErrItemAlreadyExistsConflict, clues.ToCore(err)) + }, + }, + { + name: "collision: copy", + apiMock: mockEventRestorer{}, + collisionMap: map[string]string{collisionKey: "smarf"}, + onCollision: control.Copy, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "collision: replace", + apiMock: mockEventRestorer{}, + collisionMap: map[string]string{collisionKey: "smarf"}, + onCollision: control.Replace, + expectErr: func(t *testing.T, err error) { + assert.ErrorIs(t, err, graph.ErrItemAlreadyExistsConflict, clues.ToCore(err)) + }, + }, + } + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + _, err := restoreEvent( + ctx, + test.apiMock, + body, + suite.its.userID, + "destination", + test.collisionMap, + test.onCollision, + fault.New(true)) + + test.expectErr(t, err) + }) + } +} diff --git a/src/internal/m365/exchange/handlers.go b/src/internal/m365/exchange/handlers.go index 9eb7d1fe1..243ad11fd 100644 --- a/src/internal/m365/exchange/handlers.go +++ b/src/internal/m365/exchange/handlers.go @@ -7,6 +7,7 @@ import ( "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/pkg/backup/details" + "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" @@ -60,6 +61,7 @@ func BackupHandlers(ac api.Client) map[path.CategoryType]backupHandler { type restoreHandler interface { itemRestorer containerAPI + getItemsByCollisionKeyser newContainerCache(userID string) graph.ContainerResolver formatRestoreDestination( destinationContainerName string, @@ -75,19 +77,12 @@ type itemRestorer interface { ctx context.Context, body []byte, userID, destinationID string, + collisionKeyToItemID map[string]string, + collisionPolicy control.CollisionPolicy, errs *fault.Bus, ) (*details.ExchangeInfo, error) } -// runs the actual graph API post request. -type itemPoster[T any] interface { - PostItem( - ctx context.Context, - userID, dirID string, - body T, - ) (T, error) -} - // produces structs that interface with the graph/cache_container // CachedContainer interface. type containerAPI interface { @@ -129,3 +124,24 @@ func restoreHandlers( path.EventsCategory: newEventRestoreHandler(ac), } } + +type getItemsByCollisionKeyser interface { + // GetItemsInContainerByCollisionKey looks up all items currently in + // the container, and returns them in a map[collisionKey]itemID. + // The collision key is uniquely defined by each category of data. + // Collision key checks are used during restore to handle the on- + // collision restore configurations that cause the item restore to get + // skipped, replaced, or copied. + getItemsInContainerByCollisionKey( + ctx context.Context, + userID, containerID string, + ) (map[string]string, error) +} + +type postItemer[T any] interface { + PostItem( + ctx context.Context, + userID, containerID string, + body T, + ) (T, error) +} diff --git a/src/internal/m365/exchange/helper_test.go b/src/internal/m365/exchange/helper_test.go new file mode 100644 index 000000000..222179d6c --- /dev/null +++ b/src/internal/m365/exchange/helper_test.go @@ -0,0 +1,38 @@ +package exchange + +import ( + "testing" + + "github.com/alcionai/clues" + "github.com/stretchr/testify/require" + + "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/pkg/account" + "github.com/alcionai/corso/src/pkg/services/m365/api" +) + +type intgTesterSetup struct { + ac api.Client + creds account.M365Config + userID string +} + +func newIntegrationTesterSetup(t *testing.T) intgTesterSetup { + its := intgTesterSetup{} + + ctx, flush := tester.NewContext(t) + defer flush() + + a := tester.NewM365Account(t) + creds, err := a.M365Config() + require.NoError(t, err, clues.ToCore(err)) + + its.creds = creds + + its.ac, err = api.NewClient(creds) + require.NoError(t, err, clues.ToCore(err)) + + its.userID = tester.GetM365UserID(ctx) + + return its +} diff --git a/src/internal/m365/exchange/mail_restore.go b/src/internal/m365/exchange/mail_restore.go index ce0979859..1ddb3e52b 100644 --- a/src/internal/m365/exchange/mail_restore.go +++ b/src/internal/m365/exchange/mail_restore.go @@ -10,7 +10,9 @@ import ( "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/pkg/backup/details" + "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/fault" + "github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" ) @@ -19,17 +21,13 @@ var _ itemRestorer = &mailRestoreHandler{} type mailRestoreHandler struct { ac api.Mail - ip itemPoster[models.Messageable] } func newMailRestoreHandler( ac api.Client, ) mailRestoreHandler { - acm := ac.Mail() - return mailRestoreHandler{ - ac: acm, - ip: acm, + ac: ac.Mail(), } } @@ -72,6 +70,32 @@ func (h mailRestoreHandler) restore( ctx context.Context, body []byte, userID, destinationID string, + collisionKeyToItemID map[string]string, + collisionPolicy control.CollisionPolicy, + errs *fault.Bus, +) (*details.ExchangeInfo, error) { + return restoreMail( + ctx, + h.ac, + body, + userID, destinationID, + collisionKeyToItemID, + collisionPolicy, + errs) +} + +type mailRestorer interface { + postItemer[models.Messageable] + attachmentPoster +} + +func restoreMail( + ctx context.Context, + mr mailRestorer, + body []byte, + userID, destinationID string, + collisionKeyToItemID map[string]string, + collisionPolicy control.CollisionPolicy, errs *fault.Bus, ) (*details.ExchangeInfo, error) { msg, err := api.BytesToMessageable(body) @@ -80,20 +104,33 @@ func (h mailRestoreHandler) restore( } ctx = clues.Add(ctx, "item_id", ptr.Val(msg.GetId())) + collisionKey := api.MailCollisionKey(msg) + + if _, ok := collisionKeyToItemID[collisionKey]; ok { + log := logger.Ctx(ctx).With("collision_key", clues.Hide(collisionKey)) + log.Debug("item collision") + + // TODO(rkeepers): Replace probably shouldn't no-op. Just a starting point. + if collisionPolicy == control.Skip || collisionPolicy == control.Replace { + log.Debug("skipping item with collision") + return nil, graph.ErrItemAlreadyExistsConflict + } + } + msg = setMessageSVEPs(toMessage(msg)) attachments := msg.GetAttachments() // Item.Attachments --> HasAttachments doesn't always have a value populated when deserialized msg.SetAttachments([]models.Attachmentable{}) - item, err := h.ip.PostItem(ctx, userID, destinationID, msg) + item, err := mr.PostItem(ctx, userID, destinationID, msg) if err != nil { return nil, graph.Wrap(ctx, err, "restoring mail message") } err = uploadAttachments( ctx, - h.ac, + mr, attachments, userID, destinationID, @@ -138,3 +175,15 @@ func setMessageSVEPs(msg models.Messageable) models.Messageable { return msg } + +func (h mailRestoreHandler) getItemsInContainerByCollisionKey( + ctx context.Context, + userID, containerID string, +) (map[string]string, error) { + m, err := h.ac.GetItemsInContainerByCollisionKey(ctx, userID, containerID) + if err != nil { + return nil, err + } + + return m, nil +} diff --git a/src/internal/m365/exchange/mail_restore_test.go b/src/internal/m365/exchange/mail_restore_test.go index 9d71de800..fc2837c85 100644 --- a/src/internal/m365/exchange/mail_restore_test.go +++ b/src/internal/m365/exchange/mail_restore_test.go @@ -1,24 +1,64 @@ package exchange import ( + "context" "testing" "github.com/alcionai/clues" + "github.com/google/uuid" + "github.com/microsoftgraph/msgraph-sdk-go/models" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/internal/m365/exchange/mock" + "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/internal/tester" - "github.com/alcionai/corso/src/pkg/account" + "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/control/testdata" + "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" ) +var _ mailRestorer = &mockMailRestorer{} + +type mockMailRestorer struct { + postItemErr error + postAttachmentErr error +} + +func (m mockMailRestorer) PostItem( + ctx context.Context, + userID, containerID string, + body models.Messageable, +) (models.Messageable, error) { + return models.NewMessage(), m.postItemErr +} + +func (m mockMailRestorer) PostSmallAttachment( + _ context.Context, + _, _, _ string, + _ models.Attachmentable, +) error { + return m.postAttachmentErr +} + +func (m mockMailRestorer) PostLargeAttachment( + _ context.Context, + _, _, _, _ string, + _ []byte, +) (string, error) { + return uuid.NewString(), m.postAttachmentErr +} + +// --------------------------------------------------------------------------- +// tests +// --------------------------------------------------------------------------- + type MailRestoreIntgSuite struct { tester.Suite - creds account.M365Config - ac api.Client - userID string + its intgTesterSetup } func TestMailRestoreIntgSuite(t *testing.T) { @@ -30,29 +70,109 @@ func TestMailRestoreIntgSuite(t *testing.T) { } func (suite *MailRestoreIntgSuite) SetupSuite() { - t := suite.T() - - a := tester.NewM365Account(t) - creds, err := a.M365Config() - require.NoError(t, err, clues.ToCore(err)) - - suite.creds = creds - - suite.ac, err = api.NewClient(creds) - require.NoError(t, err, clues.ToCore(err)) - - suite.userID = tester.M365UserID(t) + suite.its = newIntegrationTesterSetup(suite.T()) } -// Testing to ensure that cache system works for in multiple different environments func (suite *MailRestoreIntgSuite) TestCreateContainerDestination() { runCreateDestinationTest( suite.T(), - newMailRestoreHandler(suite.ac), + newMailRestoreHandler(suite.its.ac), path.EmailCategory, - suite.creds.AzureTenantID, - suite.userID, + suite.its.creds.AzureTenantID, + suite.its.userID, testdata.DefaultRestoreConfig("").Location, []string{"Griffindor", "Croix"}, []string{"Griffindor", "Felicius"}) } + +func (suite *MailRestoreIntgSuite) TestRestoreMail() { + body := mock.MessageBytes("subject") + + stub, err := api.BytesToMessageable(body) + require.NoError(suite.T(), err, clues.ToCore(err)) + + collisionKey := api.MailCollisionKey(stub) + + table := []struct { + name string + apiMock mailRestorer + collisionMap map[string]string + onCollision control.CollisionPolicy + expectErr func(*testing.T, error) + }{ + { + name: "no collision: skip", + apiMock: mockMailRestorer{}, + collisionMap: map[string]string{}, + onCollision: control.Copy, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "no collision: copy", + apiMock: mockMailRestorer{}, + collisionMap: map[string]string{}, + onCollision: control.Skip, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "no collision: replace", + apiMock: mockMailRestorer{}, + collisionMap: map[string]string{}, + onCollision: control.Replace, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "collision: skip", + apiMock: mockMailRestorer{}, + collisionMap: map[string]string{collisionKey: "smarf"}, + onCollision: control.Skip, + expectErr: func(t *testing.T, err error) { + assert.ErrorIs(t, err, graph.ErrItemAlreadyExistsConflict, clues.ToCore(err)) + }, + }, + { + name: "collision: copy", + apiMock: mockMailRestorer{}, + collisionMap: map[string]string{collisionKey: "smarf"}, + onCollision: control.Copy, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "collision: replace", + apiMock: mockMailRestorer{}, + collisionMap: map[string]string{collisionKey: "smarf"}, + onCollision: control.Replace, + expectErr: func(t *testing.T, err error) { + assert.ErrorIs(t, err, graph.ErrItemAlreadyExistsConflict, clues.ToCore(err)) + }, + }, + } + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + _, err := restoreMail( + ctx, + test.apiMock, + body, + suite.its.userID, + "destination", + test.collisionMap, + test.onCollision, + fault.New(true)) + + test.expectErr(t, err) + }) + } +} diff --git a/src/internal/m365/exchange/mock/event.go b/src/internal/m365/exchange/mock/event.go index 7df667af6..bc7be481c 100644 --- a/src/internal/m365/exchange/mock/event.go +++ b/src/internal/m365/exchange/mock/event.go @@ -23,13 +23,8 @@ import ( // 10. attendees //nolint:lll -const ( +var ( eventTmpl = `{ - "id":"AAMkAGZmNjNlYjI3LWJlZWYtNGI4Mi04YjMyLTIxYThkNGQ4NmY1MwBGAAAAAADCNgjhM9QmQYWNcI7hCpPrBwDSEBNbUIB9RL6ePDeF3FIYAAAAAAENAADSEBNbUIB9RL6ePDeF3FIYAAAAAG76AAA=", - "calendar@odata.navigationLink":"https://graph.microsoft.com/v1.0/users('foobar@8qzvrj.onmicrosoft.com')/calendars('AAMkAGZmNjNlYjI3LWJlZWYtNGI4Mi04YjMyLTIxYThkNGQ4NmY1MwAuAAAAAADCNgjhM9QmQYWNcI7hCpPrAQDSEBNbUIB9RL6ePDeF3FIYAAAAAAENAAA=')", - "calendar@odata.associationLink":"https://graph.microsoft.com/v1.0/users('foobar@8qzvrj.onmicrosoft.com')/calendars('AAMkAGZmNjNlYjI3LWJlZWYtNGI4Mi04YjMyLTIxYThkNGQ4NmY1MwAuAAAAAADCNgjhM9QmQYWNcI7hCpPrAQDSEBNbUIB9RL6ePDeF3FIYAAAAAAENAAA=')/$ref", - "@odata.etag":"W/\"0hATW1CAfUS+njw3hdxSGAAAJIxNug==\"", - "@odata.context":"https://graph.microsoft.com/v1.0/$metadata#users('foobar%%408qzvrj.onmicrosoft.com')/events/$entity", "categories":[], "changeKey":"0hATW1CAfUS+njw3hdxSGAAAJIxNug==", "createdDateTime":"2022-03-28T03:42:03Z", @@ -46,7 +41,6 @@ const ( "timeZone":"UTC" }, "hideAttendees":false, - "iCalUId":"040000008200E00074C5B7101A82E0080000000035723BC75542D801000000000000000010000000E1E7C8F785242E4894DA13AEFB947B85", "importance":"normal", "isAllDay":false, "isCancelled":false, @@ -75,6 +69,7 @@ const ( "name":"Anu Pierson" } }, + %s "originalEndTimeZone":"UTC", "originalStartTimeZone":"UTC", "reminderMinutesBeforeStart":15, @@ -90,19 +85,23 @@ const ( "timeZone":"UTC" }, "subject":"%s", - "type":"singleInstance", + "type":"%s", "hasAttachments":%v, %s "webLink":"https://outlook.office365.com/owa/?itemid=AAMkAGZmNjNlYjI3LWJlZWYtNGI4Mi04YjMyLTIxYThkNGQ4NmY1MwBGAAAAAADCNgjhM9QmQYWNcI7hCpPrBwDSEBNbUIB9RL6ePDeF3FIYAAAAAAENAADSEBNbUIB9RL6ePDeF3FIYAAAAAG76AAA%%3D&exvsurl=1&path=/calendar/item", "recurrence":%s, + %s + %s "attendees":%s }` defaultEventBody = "This meeting is to review the latest Tailspin Toys project proposal.
\\r\\nBut why not eat some sushi while we’re at it? :)" defaultEventBodyPreview = "This meeting is to review the latest Tailspin Toys project proposal.\\r\\nBut why not eat some sushi while we’re at it? :)" defaultEventOrganizer = "foobar@8qzvrj.onmicrosoft.com" - eventAttachment = "\"attachments\":[{\"id\":\"AAMkAGZmNjNlYjI3LWJlZWYtNGI4Mi04YjMyLTIxYThkNGQ4NmY1MwBGAAAAAADCNgjhM9QmQYWNcI7hCpPrBwDSEBNbUIB9RL6ePDeF3FIYAAAAAAENAADSEBNbUIB9RL6ePDeF3FIYAACLjfLQAAABEgAQAHoI0xBbBBVEh6bFMU78ZUo=\",\"@odata.type\":\"#microsoft.graph.fileAttachment\"," + - "\"@odata.mediaContentType\":\"application/octet-stream\",\"contentType\":\"application/octet-stream\",\"isInline\":false,\"lastModifiedDateTime\":\"2022-10-26T15:19:42Z\",\"name\":\"database.db\",\"size\":11418," + + + NoAttachments = "" + eventAttachmentFormat = "{\"id\":\"AAMkAGZmNjNlYjI3LWJlZWYtNGI4Mi04YjMyLTIxYThkNGQ4NmY1MwBGAAAAAADCNgjhM9QmQYWNcI7hCpPrBwDSEBNbUIB9RL6ePDeF3FIYAAAAAAENAADSEBNbUIB9RL6ePDeF3FIYAACLjfLQAAABEgAQAHoI0xBbBBVEh6bFMU78ZUo=\",\"@odata.type\":\"#microsoft.graph.fileAttachment\"," + + "\"@odata.mediaContentType\":\"application/octet-stream\",\"contentType\":\"application/octet-stream\",\"isInline\":false,\"lastModifiedDateTime\":\"2022-10-26T15:19:42Z\",\"name\":\"%s\",\"size\":11418," + "\"contentBytes\":\"U1FMaXRlIGZvcm1hdCAzAAQAAQEAQCAgAAAATQAAAAsAAAAEAAAACAAAAAsAAAAEAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABNAC3mBw0DZwACAg8AAxUCDwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACCAwMHFxUVAYNpdGFibGVkYXRhZGF0YQJDUkVBVEUgVEFCTEUgZGF0YSAoCiAgICAgICAgIGlkIGludGVnZXIgcHJpbWFyeSBrZXkgYXV0b2luY3JlbWVudCwKICAgICAgICAgbWVhbiB0ZXh0IG5vdCBudWxsLAogICAgICAgICBtYXggdGV4dCBub3QgbnVsbCwKICAgICAgICAgbWluIHRleHQgbm90IG51bGwsCiAgICAgICAgIGRhdGEgdGV" + @@ -149,15 +148,19 @@ const ( "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + - "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"}]," + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"}" + defaultEventAttachments = "\"attachments\":[" + fmt.Sprintf(eventAttachmentFormat, "database.db") + "]," + + originalStartDateFormat = `"originalStart": "%s",` + NoOriginalStartDate = `` NoRecurrence = `null` recurrenceTmpl = `{ "pattern": { "type": "absoluteYearly", "interval": 1, - "month": 1, - "dayOfMonth": 1, + "month": %s, + "dayOfMonth": %s, "firstDayOfWeek": "sunday", "index": "first" }, @@ -170,6 +173,13 @@ const ( } }` + cancelledOccurrencesFormat = `"cancelledOccurrences": [%s],` + cancelledOccurrenceInstanceFormat = `"OID.AAMkAGJiZmE2NGU4LTQ4YjktNDI1Mi1iMWQzLTQ1MmMxODJkZmQyNABGAAAAAABFdiK7oifWRb4ADuqgSRcnBwBBFDg0JJk7TY1fmsJrh7tNAAAAAAENAABBFDg0JJk7TY1fmsJrh7tNAADHGTZoAAA=.%s"` + NoCancelledOccurrences = "" + + exceptionOccurrencesFormat = `"exceptionOccurrences": [%s],` + NoExceptionOccurrences = "" + NoAttendees = `[]` attendeesTmpl = `[{ "emailAddress": { @@ -219,38 +229,48 @@ func EventBytes(subject string) []byte { } func EventWithSubjectBytes(subject string) []byte { - tomorrow := time.Now().UTC().AddDate(0, 0, 1) - at := time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) - atTime := dttm.Format(at) - endTime := dttm.Format(at.Add(30 * time.Minute)) + var ( + tomorrow = time.Now().UTC().AddDate(0, 0, 1) + at = time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) + atTime = dttm.Format(at) + endTime = dttm.Format(at.Add(30 * time.Minute)) + ) return EventWith( defaultEventOrganizer, subject, defaultEventBody, defaultEventBodyPreview, - atTime, endTime, NoRecurrence, NoAttendees, false, + NoOriginalStartDate, atTime, endTime, NoRecurrence, NoAttendees, + NoAttachments, NoCancelledOccurrences, NoExceptionOccurrences, ) } func EventWithAttachment(subject string) []byte { - tomorrow := time.Now().UTC().AddDate(0, 0, 1) - at := time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) - atTime := dttm.Format(at) + var ( + tomorrow = time.Now().UTC().AddDate(0, 0, 1) + at = time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) + atTime = dttm.Format(at) + ) return EventWith( defaultEventOrganizer, subject, defaultEventBody, defaultEventBodyPreview, - atTime, atTime, NoRecurrence, NoAttendees, true, + NoOriginalStartDate, atTime, atTime, NoRecurrence, NoAttendees, + defaultEventAttachments, NoCancelledOccurrences, NoExceptionOccurrences, ) } func EventWithRecurrenceBytes(subject, recurrenceTimeZone string) []byte { - tomorrow := time.Now().UTC().AddDate(0, 0, 1) - at := time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) - atTime := dttm.Format(at) - timeSlice := strings.Split(atTime, "T") + var ( + tomorrow = time.Now().UTC().AddDate(0, 0, 1) + at = time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) + atTime = dttm.Format(at) + timeSlice = strings.Split(atTime, "T") + ) recurrence := string(fmt.Sprintf( recurrenceTmpl, + strconv.Itoa(int(at.Month())), + strconv.Itoa(at.Day()), timeSlice[0], recurrenceTimeZone, )) @@ -258,19 +278,125 @@ func EventWithRecurrenceBytes(subject, recurrenceTimeZone string) []byte { return EventWith( defaultEventOrganizer, subject, defaultEventBody, defaultEventBodyPreview, - atTime, atTime, recurrence, attendeesTmpl, true, + NoOriginalStartDate, atTime, atTime, recurrence, attendeesTmpl, + NoAttachments, NoCancelledOccurrences, NoExceptionOccurrences, ) } -func EventWithAttendeesBytes(subject string) []byte { - tomorrow := time.Now().UTC().AddDate(0, 0, 1) - at := time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) - atTime := dttm.Format(at) +func EventWithRecurrenceAndCancellationBytes(subject string) []byte { + var ( + tomorrow = time.Now().UTC().AddDate(0, 0, 1) + at = time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) + atTime = dttm.Format(at) + timeSlice = strings.Split(atTime, "T") + nextYear = tomorrow.AddDate(1, 0, 0) + ) + + recurrence := string(fmt.Sprintf( + recurrenceTmpl, + strconv.Itoa(int(at.Month())), + strconv.Itoa(at.Day()), + timeSlice[0], + `"UTC"`, + )) + + cancelledInstances := []string{fmt.Sprintf(cancelledOccurrenceInstanceFormat, dttm.FormatTo(nextYear, dttm.DateOnly))} + cancelledOccurrences := fmt.Sprintf(cancelledOccurrencesFormat, strings.Join(cancelledInstances, ",")) return EventWith( defaultEventOrganizer, subject, defaultEventBody, defaultEventBodyPreview, - atTime, atTime, NoRecurrence, attendeesTmpl, true, + NoOriginalStartDate, atTime, atTime, recurrence, attendeesTmpl, + defaultEventAttachments, cancelledOccurrences, NoExceptionOccurrences, + ) +} + +func EventWithRecurrenceAndExceptionBytes(subject string) []byte { + var ( + tomorrow = time.Now().UTC().AddDate(0, 0, 1) + at = time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) + atTime = dttm.Format(at) + timeSlice = strings.Split(atTime, "T") + newTime = dttm.Format(tomorrow.AddDate(0, 0, 1)) + originalStartDate = dttm.FormatTo(at, dttm.TabularOutput) + ) + + recurrence := string(fmt.Sprintf( + recurrenceTmpl, + strconv.Itoa(int(at.Month())), + strconv.Itoa(at.Day()), + timeSlice[0], + `"UTC"`, + )) + + exceptionEvent := EventWith( + defaultEventOrganizer, subject+"(modified)", + defaultEventBody, defaultEventBodyPreview, + fmt.Sprintf(originalStartDateFormat, originalStartDate), + newTime, newTime, NoRecurrence, attendeesTmpl, + NoAttachments, NoCancelledOccurrences, NoExceptionOccurrences, + ) + exceptionOccurrences := fmt.Sprintf(exceptionOccurrencesFormat, exceptionEvent) + + return EventWith( + defaultEventOrganizer, subject, + defaultEventBody, defaultEventBodyPreview, + NoOriginalStartDate, atTime, atTime, recurrence, attendeesTmpl, + defaultEventAttachments, NoCancelledOccurrences, exceptionOccurrences, + ) +} + +func EventWithRecurrenceAndExceptionAndAttachmentBytes(subject string) []byte { + var ( + tomorrow = time.Now().UTC().AddDate(0, 0, 1) + at = time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) + atTime = dttm.Format(at) + timeSlice = strings.Split(atTime, "T") + newTime = dttm.Format(tomorrow.AddDate(0, 0, 1)) + originalStartDate = dttm.FormatTo(at, dttm.TabularOutput) + ) + + recurrence := string(fmt.Sprintf( + recurrenceTmpl, + strconv.Itoa(int(at.Month())), + strconv.Itoa(at.Day()), + timeSlice[0], + `"UTC"`, + )) + + exceptionEvent := EventWith( + defaultEventOrganizer, subject+"(modified)", + defaultEventBody, defaultEventBodyPreview, + fmt.Sprintf(originalStartDateFormat, originalStartDate), + newTime, newTime, NoRecurrence, attendeesTmpl, + "\"attachments\":["+fmt.Sprintf(eventAttachmentFormat, "exception-database.db")+"],", + NoCancelledOccurrences, NoExceptionOccurrences, + ) + exceptionOccurrences := fmt.Sprintf( + exceptionOccurrencesFormat, + strings.Join([]string{string(exceptionEvent)}, ","), + ) + + return EventWith( + defaultEventOrganizer, subject, + defaultEventBody, defaultEventBodyPreview, + NoOriginalStartDate, atTime, atTime, recurrence, attendeesTmpl, + defaultEventAttachments, NoCancelledOccurrences, exceptionOccurrences, + ) +} + +func EventWithAttendeesBytes(subject string) []byte { + var ( + tomorrow = time.Now().UTC().AddDate(0, 0, 1) + at = time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), tomorrow.Hour(), 0, 0, 0, time.UTC) + atTime = dttm.Format(at) + ) + + return EventWith( + defaultEventOrganizer, subject, + defaultEventBody, defaultEventBodyPreview, + NoOriginalStartDate, atTime, atTime, NoRecurrence, attendeesTmpl, + defaultEventAttachments, NoCancelledOccurrences, NoExceptionOccurrences, ) } @@ -281,14 +407,10 @@ func EventWithAttendeesBytes(subject string) []byte { // Body must contain a well-formatted string, consumable in a json payload. IE: no unescaped newlines. func EventWith( organizer, subject, body, bodyPreview, - startDateTime, endDateTime, recurrence, attendees string, - hasAttachments bool, + originalStartDate, startDateTime, endDateTime, recurrence, attendees string, + attachments string, cancelledOccurrences, exceptionOccurrences string, ) []byte { - var attachments string - if hasAttachments { - attachments = eventAttachment - } - + hasAttachments := len(attachments) > 0 startDateTime = strings.TrimSuffix(startDateTime, "Z") endDateTime = strings.TrimSuffix(endDateTime, "Z") @@ -300,17 +422,26 @@ func EventWith( endDateTime += ".0000000" } + eventType := "singleInstance" + if recurrence != "null" { + eventType = "seriesMaster" + } + return []byte(fmt.Sprintf( eventTmpl, body, bodyPreview, endDateTime, organizer, + originalStartDate, startDateTime, subject, + eventType, hasAttachments, attachments, recurrence, + cancelledOccurrences, + exceptionOccurrences, attendees, )) } diff --git a/src/internal/m365/exchange/restore.go b/src/internal/m365/exchange/restore.go index 59d6b167e..6d519d4c3 100644 --- a/src/internal/m365/exchange/restore.go +++ b/src/internal/m365/exchange/restore.go @@ -41,9 +41,7 @@ func ConsumeRestoreCollections( directoryCache = make(map[path.CategoryType]graph.ContainerResolver) handlers = restoreHandlers(ac) metrics support.CollectionMetrics - // TODO policy to be updated from external source after completion of refactoring - policy = control.Copy - el = errs.Local() + el = errs.Local() ) ctx = clues.Add(ctx, "resource_owner", clues.Hide(userID)) @@ -87,16 +85,22 @@ func ConsumeRestoreCollections( } directoryCache[category] = gcc - ictx = clues.Add(ictx, "restore_destination_id", containerID) + collisionKeyToItemID, err := handler.getItemsInContainerByCollisionKey(ctx, userID, containerID) + if err != nil { + el.AddRecoverable(ctx, clues.Wrap(err, "building item collision cache")) + continue + } + temp, err := restoreCollection( ictx, handler, dc, userID, containerID, - policy, + collisionKeyToItemID, + restoreCfg.OnCollision, deets, errs) @@ -127,7 +131,8 @@ func restoreCollection( ir itemRestorer, dc data.RestoreCollection, userID, destinationID string, - policy control.CollisionPolicy, + collisionKeyToItemID map[string]string, + collisionPolicy control.CollisionPolicy, deets *details.Builder, errs *fault.Bus, ) (support.CollectionMetrics, error) { @@ -172,9 +177,19 @@ func restoreCollection( body := buf.Bytes() - info, err := ir.restore(ictx, body, userID, destinationID, errs) + info, err := ir.restore( + ictx, + body, + userID, + destinationID, + collisionKeyToItemID, + collisionPolicy, + errs) if err != nil { - el.AddRecoverable(ictx, err) + if !graph.IsErrItemAlreadyExistsConflict(err) { + el.AddRecoverable(ictx, err) + } + continue } diff --git a/src/internal/m365/exchange/restore_test.go b/src/internal/m365/exchange/restore_test.go index 6fa083b5a..ad1f06f8a 100644 --- a/src/internal/m365/exchange/restore_test.go +++ b/src/internal/m365/exchange/restore_test.go @@ -13,6 +13,7 @@ import ( exchMock "github.com/alcionai/corso/src/internal/m365/exchange/mock" "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/pkg/account" + "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/control/testdata" "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/path" @@ -74,6 +75,8 @@ func (suite *RestoreIntgSuite) TestRestoreContact() { ctx, exchMock.ContactBytes("Corso TestContact"), userID, folderID, + nil, + control.Copy, fault.New(true)) assert.NoError(t, err, clues.ToCore(err)) assert.NotNil(t, info, "contact item info") @@ -116,9 +119,26 @@ func (suite *RestoreIntgSuite) TestRestoreEvent() { name: "Test recurrenceTimeZone: Empty", bytes: exchMock.EventWithRecurrenceBytes(subject, `""`), }, + { + name: "Test cancelledOccurrences", + bytes: exchMock.EventWithRecurrenceAndCancellationBytes(subject), + }, + { + name: "Test exceptionOccurrences", + bytes: exchMock.EventWithRecurrenceAndExceptionBytes(subject), + }, + { + name: "Test exceptionOccurrences with different attachments", + bytes: exchMock.EventWithRecurrenceAndExceptionAndAttachmentBytes(subject), + }, } for _, test := range tests { + // Skip till https://github.com/alcionai/corso/issues/3675 is fixed + if test.name == "Test exceptionOccurrences" { + t.Skip("Bug 3675") + } + suite.Run(test.name, func() { t := suite.T() @@ -129,6 +149,8 @@ func (suite *RestoreIntgSuite) TestRestoreEvent() { ctx, test.bytes, userID, calendarID, + nil, + control.Copy, fault.New(true)) assert.NoError(t, err, clues.ToCore(err)) assert.NotNil(t, info, "event item info") @@ -357,9 +379,82 @@ func (suite *RestoreIntgSuite) TestRestoreExchangeObject() { ctx, test.bytes, userID, destination, + nil, + control.Copy, fault.New(true)) assert.NoError(t, err, clues.ToCore(err)) assert.NotNil(t, info, "item info was not populated") }) } } + +func (suite *RestoreIntgSuite) TestRestoreAndBackupEvent_recurringInstancesWithAttachments() { + t := suite.T() + + t.Skip("Bug 3675") + + ctx, flush := tester.NewContext(t) + defer flush() + + var ( + userID = tester.M365UserID(t) + subject = testdata.DefaultRestoreConfig("event").Location + handler = newEventRestoreHandler(suite.ac) + ) + + calendar, err := handler.ac.CreateContainer(ctx, userID, subject, "") + require.NoError(t, err, clues.ToCore(err)) + + calendarID := ptr.Val(calendar.GetId()) + + bytes := exchMock.EventWithRecurrenceAndExceptionAndAttachmentBytes("Reoccurring event restore and backup test") + info, err := handler.restore( + ctx, + bytes, + userID, calendarID, + nil, + control.Copy, + fault.New(true)) + require.NoError(t, err, clues.ToCore(err)) + assert.NotNil(t, info, "event item info") + + ec, err := handler.ac.Stable. + Client(). + Users(). + ByUserId(userID). + Calendars(). + ByCalendarId(calendarID). + Events(). + Get(ctx, nil) + require.NoError(t, err, clues.ToCore(err)) + + evts := ec.GetValue() + assert.Len(t, evts, 1, "count of events") + + sp, info, err := suite.ac.Events().GetItem(ctx, userID, ptr.Val(evts[0].GetId()), false, fault.New(true)) + require.NoError(t, err, clues.ToCore(err)) + assert.NotNil(t, info, "event item info") + + body, err := suite.ac.Events().Serialize(ctx, sp, userID, ptr.Val(evts[0].GetId())) + require.NoError(t, err, clues.ToCore(err)) + + event, err := api.BytesToEventable(body) + require.NoError(t, err, clues.ToCore(err)) + + assert.NotNil(t, event.GetRecurrence(), "recurrence") + eo := event.GetAdditionalData()["exceptionOccurrences"] + assert.NotNil(t, eo, "exceptionOccurrences") + + assert.NotEqual( + t, + ptr.Val(event.GetSubject()), + ptr.Val(eo.([]any)[0].(map[string]any)["subject"].(*string)), + "name equal") + + atts := eo.([]any)[0].(map[string]any)["attachments"] + assert.NotEqual( + t, + ptr.Val(event.GetAttachments()[0].GetName()), + ptr.Val(atts.([]any)[0].(map[string]any)["name"].(*string)), + "attachment name equal") +} diff --git a/src/internal/m365/exchange/transform.go b/src/internal/m365/exchange/transform.go index 425030601..7a6132b0c 100644 --- a/src/internal/m365/exchange/transform.go +++ b/src/internal/m365/exchange/transform.go @@ -70,7 +70,6 @@ func toEventSimplified(orig models.Eventable) models.Eventable { newContent := insertStringToBody(origBody, attendees) newBody := models.NewItemBody() newBody.SetContentType(origBody.GetContentType()) - newBody.SetAdditionalData(origBody.GetAdditionalData()) newBody.SetOdataType(origBody.GetOdataType()) newBody.SetContent(&newContent) orig.SetBody(newBody) @@ -89,6 +88,14 @@ func toEventSimplified(orig models.Eventable) models.Eventable { } } + // Remove exceptions for recurring events + // These will be present in objects once we start using the API + // that is currently in beta + additionalData := orig.GetAdditionalData() + delete(additionalData, "cancelledOccurrences") + delete(additionalData, "exceptionOccurrences") + orig.SetAdditionalData(additionalData) + return orig } diff --git a/src/internal/m365/exchange/transform_test.go b/src/internal/m365/exchange/transform_test.go index 4e3ce4278..823a65f2c 100644 --- a/src/internal/m365/exchange/transform_test.go +++ b/src/internal/m365/exchange/transform_test.go @@ -121,6 +121,32 @@ func (suite *TransformUnitTest) TestToEventSimplified_recurrence() { return ptr.Val(e.GetRecurrence().GetRange().GetRecurrenceTimeZone()) == "Pacific Standard Time" }, }, + { + name: "Test cancelledOccurrences", + event: func() models.Eventable { + bytes := exchMock.EventWithRecurrenceAndCancellationBytes(subject) + event, err := api.BytesToEventable(bytes) + require.NoError(t, err, clues.ToCore(err)) + return event + }, + + validateOutput: func(e models.Eventable) bool { + return e.GetAdditionalData()["cancelledOccurrences"] == nil + }, + }, + { + name: "Test exceptionOccurrences", + event: func() models.Eventable { + bytes := exchMock.EventWithRecurrenceAndExceptionBytes(subject) + event, err := api.BytesToEventable(bytes) + require.NoError(t, err, clues.ToCore(err)) + return event + }, + + validateOutput: func(e models.Eventable) bool { + return e.GetAdditionalData()["exceptionOccurrences"] == nil + }, + }, } for _, test := range tests { diff --git a/src/internal/m365/graph/concurrency_middleware.go b/src/internal/m365/graph/concurrency_middleware.go index c70988f65..7da508217 100644 --- a/src/internal/m365/graph/concurrency_middleware.go +++ b/src/internal/m365/graph/concurrency_middleware.go @@ -73,17 +73,13 @@ func (cl *concurrencyLimiter) Intercept( const ( // Default goal is to keep calls below the 10k-per-10-minute threshold. - // 14 tokens every second nets 840 per minute. That's 8400 every 10 minutes, + // 16 tokens every second nets 960 per minute. That's 9600 every 10 minutes, // which is a bit below the mark. - // But suppose we have a minute-long dry spell followed by a 10 minute tsunami. - // We'll have built up 750 tokens in reserve, so the first 750 calls go through - // immediately. Over the next 10 minutes, we'll partition out the other calls - // at a rate of 840-per-minute, ending at a total of 9150. Theoretically, if - // the volume keeps up after that, we'll always stay between 8400 and 9150 out - // of 10k. Worst case scenario, we have an extra minute of padding to allow - // up to 9990. - defaultPerSecond = 14 // 14 * 60 = 840 - defaultMaxCap = 750 // real cap is 10k-per-10-minutes + // If the bucket is full, we can push out 200 calls immediately, which brings + // the total in the first 10 minutes to 9800. We can toe that line if we want, + // but doing so risks timeouts. It's better to give the limits breathing room. + defaultPerSecond = 16 // 16 * 60 * 10 = 9600 + defaultMaxCap = 200 // real cap is 10k-per-10-minutes // since drive runs on a per-minute, rather than per-10-minute bucket, we have // to keep the max cap equal to the per-second cap. A large maxCap pool (say, // 1200, similar to the per-minute cap) would allow us to make a flood of 2400 diff --git a/src/internal/m365/graph/errors.go b/src/internal/m365/graph/errors.go index 3f56914f4..414f73036 100644 --- a/src/internal/m365/graph/errors.go +++ b/src/internal/m365/graph/errors.go @@ -37,7 +37,7 @@ const ( // @microsoft.graph.conflictBehavior=fail finds a conflicting file. nameAlreadyExists errorCode = "nameAlreadyExists" quotaExceeded errorCode = "ErrorQuotaExceeded" - requestResourceNotFound errorCode = "Request_ResourceNotFound" + RequestResourceNotFound errorCode = "Request_ResourceNotFound" resourceNotFound errorCode = "ResourceNotFound" resyncRequired errorCode = "ResyncRequired" // alt: resyncRequired syncFolderNotFound errorCode = "ErrorSyncFolderNotFound" @@ -56,17 +56,16 @@ const ( type errorMessage string const ( - IOErrDuringRead errorMessage = "IO error during request payload read" + IOErrDuringRead errorMessage = "IO error during request payload read" + MysiteURLNotFound errorMessage = "unable to retrieve user's mysite url" + MysiteNotFound errorMessage = "user's mysite not found" + NoSPLicense errorMessage = "Tenant does not have a SPO license" ) const ( - mysiteURLNotFound = "unable to retrieve user's mysite url" - mysiteNotFound = "user's mysite not found" -) - -const ( - LabelsMalware = "malware_detected" - LabelsMysiteNotFound = "mysite_not_found" + LabelsMalware = "malware_detected" + LabelsMysiteNotFound = "mysite_not_found" + LabelsNoSharePointLicense = "no_sharepoint_license" // LabelsSkippable is used to determine if an error is skippable LabelsSkippable = "skippable_errors" @@ -132,7 +131,7 @@ func IsErrExchangeMailFolderNotFound(err error) bool { } func IsErrUserNotFound(err error) bool { - return hasErrorCode(err, requestResourceNotFound) + return hasErrorCode(err, RequestResourceNotFound) } func IsErrResourceNotFound(err error) bool { @@ -297,11 +296,17 @@ func setLabels(err *clues.Err, msg string) *clues.Err { return nil } - ml := strings.ToLower(msg) - if strings.Contains(ml, mysiteNotFound) || strings.Contains(ml, mysiteURLNotFound) { + f := filters.Contains([]string{msg}) + + if f.Compare(string(MysiteNotFound)) || + f.Compare(string(MysiteURLNotFound)) { err = err.Label(LabelsMysiteNotFound) } + if f.Compare(string(NoSPLicense)) { + err = err.Label(LabelsNoSharePointLicense) + } + if IsMalware(err) { err = err.Label(LabelsMalware) } diff --git a/src/internal/m365/graph/errors_test.go b/src/internal/m365/graph/errors_test.go index 714677179..7a415c8c4 100644 --- a/src/internal/m365/graph/errors_test.go +++ b/src/internal/m365/graph/errors_test.go @@ -33,6 +33,16 @@ func odErr(code string) *odataerrors.ODataError { return odErr } +func odErrMsg(code, message string) *odataerrors.ODataError { + odErr := odataerrors.NewODataError() + merr := odataerrors.NewMainError() + merr.SetCode(&code) + merr.SetMessage(&message) + odErr.SetError(merr) + + return odErr +} + func (suite *GraphErrorsUnitSuite) TestIsErrConnectionReset() { table := []struct { name string @@ -223,7 +233,7 @@ func (suite *GraphErrorsUnitSuite) TestIsErrUserNotFound() { }, { name: "request resource not found oDataErr", - err: odErr(string(requestResourceNotFound)), + err: odErr(string(RequestResourceNotFound)), expect: assert.True, }, } @@ -423,3 +433,56 @@ func (suite *GraphErrorsUnitSuite) TestIsErrCannotOpenFileAttachment() { }) } } + +func (suite *GraphErrorsUnitSuite) TestGraphStack_labels() { + table := []struct { + name string + err error + expect []string + }{ + { + name: "nil", + err: nil, + expect: []string{}, + }, + { + name: "not-odata", + err: assert.AnError, + expect: []string{}, + }, + { + name: "oDataErr matches no labels", + err: odErr("code"), + expect: []string{}, + }, + { + name: "mysite not found", + err: odErrMsg("code", string(MysiteNotFound)), + expect: []string{}, + }, + { + name: "mysite url not found", + err: odErrMsg("code", string(MysiteURLNotFound)), + expect: []string{}, + }, + { + name: "no sp license", + err: odErrMsg("code", string(NoSPLicense)), + expect: []string{}, + }, + } + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + result := Stack(ctx, test.err) + + for _, e := range test.expect { + assert.True(t, clues.HasLabel(result, e), clues.ToCore(result)) + } + }) + } +} diff --git a/src/internal/m365/graph/uploadsession.go b/src/internal/m365/graph/uploadsession.go index 77fefd5c8..74a696373 100644 --- a/src/internal/m365/graph/uploadsession.go +++ b/src/internal/m365/graph/uploadsession.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net/http" + "strings" "github.com/alcionai/clues" @@ -21,8 +22,11 @@ const ( // Writer implements an io.Writer for a M365 // UploadSession URL type largeItemWriter struct { + // ID is the id of the item created. + // Will be available after the upload is complete + ID string // Identifier - id string + parentID string // Upload URL for this item url string // Tracks how much data will be written @@ -32,8 +36,13 @@ type largeItemWriter struct { client httpWrapper } -func NewLargeItemWriter(id, url string, size int64) *largeItemWriter { - return &largeItemWriter{id: id, url: url, contentLength: size, client: *NewNoTimeoutHTTPWrapper()} +func NewLargeItemWriter(parentID, url string, size int64) *largeItemWriter { + return &largeItemWriter{ + parentID: parentID, + url: url, + contentLength: size, + client: *NewNoTimeoutHTTPWrapper(), + } } // Write will upload the provided data to M365. It sets the `Content-Length` and `Content-Range` headers based on @@ -44,7 +53,7 @@ func (iw *largeItemWriter) Write(p []byte) (int, error) { logger.Ctx(ctx). Debugf("WRITE for %s. Size:%d, Offset: %d, TotalSize: %d", - iw.id, rangeLength, iw.lastWrittenOffset, iw.contentLength) + iw.parentID, rangeLength, iw.lastWrittenOffset, iw.contentLength) endOffset := iw.lastWrittenOffset + int64(rangeLength) @@ -58,7 +67,7 @@ func (iw *largeItemWriter) Write(p []byte) (int, error) { iw.contentLength) headers[contentLengthHeaderKey] = fmt.Sprintf("%d", rangeLength) - _, err := iw.client.Request( + resp, err := iw.client.Request( ctx, http.MethodPut, iw.url, @@ -66,7 +75,7 @@ func (iw *largeItemWriter) Write(p []byte) (int, error) { headers) if err != nil { return 0, clues.Wrap(err, "uploading item").With( - "upload_id", iw.id, + "upload_id", iw.parentID, "upload_chunk_size", rangeLength, "upload_offset", iw.lastWrittenOffset, "upload_size", iw.contentLength) @@ -75,5 +84,22 @@ func (iw *largeItemWriter) Write(p []byte) (int, error) { // Update last offset iw.lastWrittenOffset = endOffset + // Once the upload is complete, we get a Location header in the + // below format from which we can get the id of the uploaded + // item. This will only be available after we have uploaded the + // entire content(based on the size in the req header). + // https://outlook.office.com/api/v2.0/Users('')/Messages('')/Attachments('') + // Ref: https://learn.microsoft.com/en-us/graph/outlook-large-attachments?tabs=http + loc := resp.Header.Get("Location") + if loc != "" { + splits := strings.Split(loc, "'") + if len(splits) != 7 || splits[4] != ")/Attachments(" || len(splits[5]) == 0 { + return 0, clues.New("invalid format for upload completion url"). + With("location", loc) + } + + iw.ID = splits[5] + } + return rangeLength, nil } diff --git a/src/internal/m365/onedrive/collection.go b/src/internal/m365/onedrive/collection.go index afeb0bcb0..197bee01f 100644 --- a/src/internal/m365/onedrive/collection.go +++ b/src/internal/m365/onedrive/collection.go @@ -84,6 +84,8 @@ type Collection struct { // should only be true if the old delta token expired doNotMergeItems bool + + urlCache getItemPropertyer } func pathToLocation(p path.Path) (*path.Builder, error) { @@ -109,6 +111,7 @@ func NewCollection( ctrlOpts control.Options, colScope collectionScope, doNotMergeItems bool, + urlCache getItemPropertyer, ) (*Collection, error) { // TODO(ashmrtn): If OneDrive switches to using folder IDs then this will need // to be changed as we won't be able to extract path information from the @@ -132,7 +135,8 @@ func NewCollection( statusUpdater, ctrlOpts, colScope, - doNotMergeItems) + doNotMergeItems, + urlCache) c.locPath = locPath c.prevLocPath = prevLocPath @@ -149,6 +153,7 @@ func newColl( ctrlOpts control.Options, colScope collectionScope, doNotMergeItems bool, + urlCache getItemPropertyer, ) *Collection { c := &Collection{ handler: handler, @@ -162,6 +167,7 @@ func newColl( state: data.StateOf(prevPath, currPath), scope: colScope, doNotMergeItems: doNotMergeItems, + urlCache: urlCache, } return c @@ -267,7 +273,7 @@ func (oc *Collection) getDriveItemContent( el = errs.Local() ) - itemData, err := downloadContent(ctx, oc.handler, item, oc.driveID) + itemData, err := downloadContent(ctx, oc.handler, oc.urlCache, item, oc.driveID) if err != nil { if clues.HasLabel(err, graph.LabelsMalware) || (item != nil && item.GetMalware() != nil) { logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipMalware).Info("item flagged as malware") @@ -320,9 +326,13 @@ type itemAndAPIGetter interface { func downloadContent( ctx context.Context, iaag itemAndAPIGetter, + uc getItemPropertyer, item models.DriveItemable, driveID string, ) (io.ReadCloser, error) { + itemID := ptr.Val(item.GetId()) + ctx = clues.Add(ctx, "item_id", itemID) + content, err := downloadItem(ctx, iaag, item) if err == nil { return content, nil @@ -332,8 +342,19 @@ func downloadContent( // Assume unauthorized requests are a sign of an expired jwt // token, and that we've overrun the available window to - // download the actual file. Re-downloading the item will - // refresh that download url. + // download the file. Get a fresh url from the cache and attempt to + // download again. + content, err = readItemContents(ctx, iaag, uc, itemID) + if err == nil { + logger.Ctx(ctx).Debug("found item in url cache") + return content, nil + } + + // Consider cache errors(including deleted items) as cache misses. This is + // to preserve existing behavior. Fallback to refetching the item using the + // API. + logger.CtxErr(ctx, err).Info("url cache miss: refetching from API") + di, err := iaag.GetItem(ctx, driveID, ptr.Val(item.GetId())) if err != nil { return nil, clues.Wrap(err, "retrieving expired item") @@ -347,6 +368,41 @@ func downloadContent( return content, nil } +// readItemContents fetches latest download URL from the cache and attempts to +// download the file using the new URL. +func readItemContents( + ctx context.Context, + iaag itemAndAPIGetter, + uc getItemPropertyer, + itemID string, +) (io.ReadCloser, error) { + if uc == nil { + return nil, clues.New("nil url cache") + } + + props, err := uc.getItemProperties(ctx, itemID) + if err != nil { + return nil, err + } + + // Handle newly deleted items + if props.isDeleted { + logger.Ctx(ctx).Info("item deleted in cache") + return nil, graph.ErrDeletedInFlight + } + + rc, err := downloadFile(ctx, iaag, props.downloadURL) + if graph.IsErrUnauthorized(err) { + logger.CtxErr(ctx, err).Info("stale item in cache") + } + + if err != nil { + return nil, err + } + + return rc, nil +} + // populateItems iterates through items added to the collection // and uses the collection `itemReader` to read the item func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) { diff --git a/src/internal/m365/onedrive/collection_test.go b/src/internal/m365/onedrive/collection_test.go index 2cfb65cae..4b8cc53b6 100644 --- a/src/internal/m365/onedrive/collection_test.go +++ b/src/internal/m365/onedrive/collection_test.go @@ -2,6 +2,7 @@ package onedrive import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -204,7 +205,8 @@ func (suite *CollectionUnitTestSuite) TestCollection() { suite.testStatusUpdater(&wg, &collStatus), control.Options{ToggleFeatures: control.Toggles{}}, CollectionScopeFolder, - true) + true, + nil) require.NoError(t, err, clues.ToCore(err)) require.NotNil(t, coll) assert.Equal(t, folderPath, coll.FullPath()) @@ -312,7 +314,8 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadError() { suite.testStatusUpdater(&wg, &collStatus), control.Options{ToggleFeatures: control.Toggles{}}, CollectionScopeFolder, - true) + true, + nil) require.NoError(t, err, clues.ToCore(err)) stubItem := odTD.NewStubDriveItem( @@ -388,7 +391,8 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry() suite.testStatusUpdater(&wg, &collStatus), control.Options{ToggleFeatures: control.Toggles{}}, CollectionScopeFolder, - true) + true, + nil) require.NoError(t, err, clues.ToCore(err)) coll.Add(stubItem) @@ -442,7 +446,8 @@ func (suite *CollectionUnitTestSuite) TestCollectionPermissionBackupLatestModTim suite.testStatusUpdater(&wg, &collStatus), control.Options{ToggleFeatures: control.Toggles{}}, CollectionScopeFolder, - true) + true, + nil) require.NoError(t, err, clues.ToCore(err)) mtime := time.Now().AddDate(0, -1, 0) @@ -600,6 +605,19 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItem_error() { } } +var _ getItemPropertyer = &mockURLCache{} + +type mockURLCache struct { + Get func(ctx context.Context, itemID string) (itemProps, error) +} + +func (muc *mockURLCache) getItemProperties( + ctx context.Context, + itemID string, +) (itemProps, error) { + return muc.Get(ctx, itemID) +} + func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { var ( driveID string @@ -611,6 +629,12 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { itemWID.SetId(ptr.To("brainhooldy")) + m := &mockURLCache{ + Get: func(ctx context.Context, itemID string) (itemProps, error) { + return itemProps{}, clues.Stack(assert.AnError) + }, + } + table := []struct { name string mgi mock.GetsItem @@ -619,6 +643,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { getErr []error expectErr require.ErrorAssertionFunc expect require.ValueAssertionFunc + muc *mockURLCache }{ { name: "good", @@ -627,6 +652,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { getErr: []error{nil}, expectErr: require.NoError, expect: require.NotNil, + muc: m, }, { name: "expired url redownloads", @@ -636,6 +662,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { getErr: []error{errUnauth, nil}, expectErr: require.NoError, expect: require.NotNil, + muc: m, }, { name: "immediate error", @@ -643,6 +670,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { getErr: []error{assert.AnError}, expectErr: require.Error, expect: require.Nil, + muc: m, }, { name: "re-fetching the item fails", @@ -651,6 +679,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { mgi: mock.GetsItem{Item: nil, Err: assert.AnError}, expectErr: require.Error, expect: require.Nil, + muc: m, }, { name: "expired url fails redownload", @@ -660,6 +689,57 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { getErr: []error{errUnauth, assert.AnError}, expectErr: require.Error, expect: require.Nil, + muc: m, + }, + { + name: "url refreshed from cache", + mgi: mock.GetsItem{Item: itemWID, Err: nil}, + itemInfo: details.ItemInfo{}, + respBody: []io.ReadCloser{nil, iorc}, + getErr: []error{errUnauth, nil}, + expectErr: require.NoError, + expect: require.NotNil, + muc: &mockURLCache{ + Get: func(ctx context.Context, itemID string) (itemProps, error) { + return itemProps{ + downloadURL: "http://example.com", + isDeleted: false, + }, + nil + }, + }, + }, + { + name: "url refreshed from cache but item deleted", + mgi: mock.GetsItem{Item: itemWID, Err: graph.ErrDeletedInFlight}, + itemInfo: details.ItemInfo{}, + respBody: []io.ReadCloser{nil, nil, nil}, + getErr: []error{errUnauth, graph.ErrDeletedInFlight, graph.ErrDeletedInFlight}, + expectErr: require.Error, + expect: require.Nil, + muc: &mockURLCache{ + Get: func(ctx context.Context, itemID string) (itemProps, error) { + return itemProps{ + downloadURL: "http://example.com", + isDeleted: true, + }, + nil + }, + }, + }, + { + name: "fallback to item fetch on any cache error", + mgi: mock.GetsItem{Item: itemWID, Err: nil}, + itemInfo: details.ItemInfo{}, + respBody: []io.ReadCloser{nil, iorc}, + getErr: []error{errUnauth, nil}, + expectErr: require.NoError, + expect: require.NotNil, + muc: &mockURLCache{ + Get: func(ctx context.Context, itemID string) (itemProps, error) { + return itemProps{}, assert.AnError + }, + }, }, } for _, test := range table { @@ -685,7 +765,7 @@ func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { mbh.GetResps = resps mbh.GetErrs = test.getErr - r, err := downloadContent(ctx, mbh, item, driveID) + r, err := downloadContent(ctx, mbh, test.muc, item, driveID) test.expect(t, r) test.expectErr(t, err, clues.ToCore(err)) }) diff --git a/src/internal/m365/onedrive/collections.go b/src/internal/m365/onedrive/collections.go index 7122a2361..0e4a2549c 100644 --- a/src/internal/m365/onedrive/collections.go +++ b/src/internal/m365/onedrive/collections.go @@ -255,7 +255,8 @@ func (c *Collections) Get( // Drive ID -> delta URL for drive deltaURLs = map[string]string{} // Drive ID -> folder ID -> folder path - folderPaths = map[string]map[string]string{} + folderPaths = map[string]map[string]string{} + numPrevItems = 0 ) for _, d := range drives { @@ -322,6 +323,23 @@ func (c *Collections) Get( "num_deltas_entries", numDeltas, "delta_reset", delta.Reset) + numDriveItems := c.NumItems - numPrevItems + numPrevItems = c.NumItems + + // Attach an url cache + if numDriveItems < urlCacheDriveItemThreshold { + logger.Ctx(ictx).Info("adding url cache for drive") + + err = c.addURLCacheToDriveCollections( + ictx, + driveID, + prevDelta, + errs) + if err != nil { + return nil, false, err + } + } + // For both cases we don't need to do set difference on folder map if the // delta token was valid because we should see all the changes. if !delta.Reset { @@ -370,7 +388,8 @@ func (c *Collections) Get( c.statusUpdater, c.ctrl, CollectionScopeUnknown, - true) + true, + nil) if err != nil { return nil, false, clues.Wrap(err, "making collection").WithClues(ictx) } @@ -405,7 +424,8 @@ func (c *Collections) Get( c.statusUpdater, c.ctrl, CollectionScopeUnknown, - true) + true, + nil) if err != nil { return nil, false, clues.Wrap(err, "making drive tombstone").WithClues(ctx) } @@ -438,6 +458,33 @@ func (c *Collections) Get( return collections, canUsePreviousBackup, nil } +// addURLCacheToDriveCollections adds an URL cache to all collections belonging to +// a drive. +func (c *Collections) addURLCacheToDriveCollections( + ctx context.Context, + driveID, prevDelta string, + errs *fault.Bus, +) error { + uc, err := newURLCache( + driveID, + prevDelta, + urlCacheRefreshInterval, + c.handler.NewItemPager(driveID, "", api.DriveItemSelectDefault()), + errs) + if err != nil { + return err + } + + // Set the URL cache for all collections in this drive + for _, driveColls := range c.CollectionMap { + for _, coll := range driveColls { + coll.urlCache = uc + } + } + + return nil +} + func updateCollectionPaths( driveID, itemID string, cmap map[string]map[string]*Collection, @@ -557,7 +604,8 @@ func (c *Collections) handleDelete( c.ctrl, CollectionScopeUnknown, // DoNotMerge is not checked for deleted items. - false) + false, + nil) if err != nil { return clues.Wrap(err, "making collection").With( "drive_id", driveID, @@ -740,7 +788,8 @@ func (c *Collections) UpdateCollections( c.statusUpdater, c.ctrl, colScope, - invalidPrevDelta) + invalidPrevDelta, + nil) if err != nil { return clues.Stack(err).WithClues(ictx) } diff --git a/src/internal/m365/onedrive/collections_test.go b/src/internal/m365/onedrive/collections_test.go index bc64875f4..93f76e147 100644 --- a/src/internal/m365/onedrive/collections_test.go +++ b/src/internal/m365/onedrive/collections_test.go @@ -2,6 +2,7 @@ package onedrive import ( "context" + "strconv" "testing" "github.com/alcionai/clues" @@ -2678,3 +2679,86 @@ func (suite *OneDriveCollectionsUnitSuite) TestCollectItems() { }) } } + +func (suite *OneDriveCollectionsUnitSuite) TestAddURLCacheToDriveCollections() { + driveID := "test-drive" + collCount := 3 + anyFolder := (&selectors.OneDriveBackup{}).Folders(selectors.Any())[0] + + table := []struct { + name string + items []deltaPagerResult + deltaURL string + prevDeltaSuccess bool + prevDelta string + err error + }{ + { + name: "cache is attached", + }, + } + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + itemPagers := map[string]api.DriveItemEnumerator{} + itemPagers[driveID] = &mockItemPager{} + + mbh := mock.DefaultOneDriveBH() + mbh.ItemPagerV = itemPagers + + c := NewCollections( + mbh, + "test-tenant", + "test-user", + nil, + control.Options{ToggleFeatures: control.Toggles{}}) + + if _, ok := c.CollectionMap[driveID]; !ok { + c.CollectionMap[driveID] = map[string]*Collection{} + } + + // Add a few collections + for i := 0; i < collCount; i++ { + coll, err := NewCollection( + &itemBackupHandler{api.Drives{}, anyFolder}, + nil, + nil, + driveID, + nil, + control.Options{ToggleFeatures: control.Toggles{}}, + CollectionScopeFolder, + true, + nil) + require.NoError(t, err, clues.ToCore(err)) + + c.CollectionMap[driveID][strconv.Itoa(i)] = coll + require.Equal(t, nil, coll.urlCache, "cache not nil") + } + + err := c.addURLCacheToDriveCollections( + ctx, + driveID, + "", + fault.New(true)) + require.NoError(t, err, clues.ToCore(err)) + + // Check that all collections have the same cache instance attached + // to them + var uc *urlCache + for _, driveColls := range c.CollectionMap { + for _, coll := range driveColls { + require.NotNil(t, coll.urlCache, "cache is nil") + if uc == nil { + uc = coll.urlCache.(*urlCache) + } else { + require.Equal(t, uc, coll.urlCache, "cache not equal") + } + } + } + }) + } +} diff --git a/src/internal/m365/onedrive/item.go b/src/internal/m365/onedrive/item.go index c6215e9ae..bf12e91c4 100644 --- a/src/internal/m365/onedrive/item.go +++ b/src/internal/m365/onedrive/item.go @@ -77,7 +77,7 @@ func downloadFile( return nil, clues.New("malware detected").Label(graph.LabelsMalware) } - if (resp.StatusCode / 100) != 2 { + if resp != nil && (resp.StatusCode/100) != 2 { // upstream error checks can compare the status with // clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode)) return nil, clues. diff --git a/src/internal/m365/onedrive/mock/handlers.go b/src/internal/m365/onedrive/mock/handlers.go index 23ef8a4d5..83da2dee9 100644 --- a/src/internal/m365/onedrive/mock/handlers.go +++ b/src/internal/m365/onedrive/mock/handlers.go @@ -228,7 +228,7 @@ func (m GetsItemPermission) GetItemPermission( // --------------------------------------------------------------------------- // Restore Handler -// --------------------------------------------------------------------------- +// -------------------------------------------------------------------------- type RestoreHandler struct { ItemInfo details.ItemInfo diff --git a/src/internal/m365/onedrive/url_cache.go b/src/internal/m365/onedrive/url_cache.go index bb5e61b94..1c7f2d93c 100644 --- a/src/internal/m365/onedrive/url_cache.go +++ b/src/internal/m365/onedrive/url_cache.go @@ -15,14 +15,29 @@ import ( "github.com/alcionai/corso/src/pkg/services/m365/api" ) +const ( + urlCacheDriveItemThreshold = 300 * 1000 + urlCacheRefreshInterval = 1 * time.Hour +) + +type getItemPropertyer interface { + getItemProperties( + ctx context.Context, + itemID string, + ) (itemProps, error) +} + type itemProps struct { downloadURL string isDeleted bool } +var _ getItemPropertyer = &urlCache{} + // urlCache caches download URLs for drive items type urlCache struct { driveID string + prevDelta string idToProps map[string]itemProps lastRefreshTime time.Time refreshInterval time.Duration @@ -39,7 +54,7 @@ type urlCache struct { // newURLache creates a new URL cache for the specified drive ID func newURLCache( - driveID string, + driveID, prevDelta string, refreshInterval time.Duration, itemPager api.DriveItemEnumerator, errs *fault.Bus, @@ -56,6 +71,7 @@ func newURLCache( idToProps: make(map[string]itemProps), lastRefreshTime: time.Time{}, driveID: driveID, + prevDelta: prevDelta, refreshInterval: refreshInterval, itemPager: itemPager, errs: errs, @@ -165,6 +181,8 @@ func (uc *urlCache) deltaQuery( ctx context.Context, ) error { logger.Ctx(ctx).Debug("starting delta query") + // Reset item pager to remove any previous state + uc.itemPager.Reset() _, _, _, err := collectItems( ctx, @@ -173,7 +191,7 @@ func (uc *urlCache) deltaQuery( "", uc.updateCache, map[string]string{}, - "", + uc.prevDelta, uc.errs) if err != nil { return clues.Wrap(err, "delta query") diff --git a/src/internal/m365/onedrive/url_cache_test.go b/src/internal/m365/onedrive/url_cache_test.go index 6e5da998c..4b2e4e96a 100644 --- a/src/internal/m365/onedrive/url_cache_test.go +++ b/src/internal/m365/onedrive/url_cache_test.go @@ -1,6 +1,7 @@ package onedrive import ( + "context" "errors" "math/rand" "net/http" @@ -89,10 +90,38 @@ func (suite *URLCacheIntegrationSuite) TestURLCacheBasic() { nfid := ptr.Val(newFolder.GetId()) + collectorFunc := func( + context.Context, + string, + string, + []models.DriveItemable, + map[string]string, + map[string]string, + map[string]struct{}, + map[string]map[string]string, + bool, + *fault.Bus, + ) error { + return nil + } + + // Get the previous delta to feed into url cache + prevDelta, _, _, err := collectItems( + ctx, + suite.ac.Drives().NewItemPager(driveID, "", api.DriveItemSelectDefault()), + suite.driveID, + "drive-name", + collectorFunc, + map[string]string{}, + "", + fault.New(true)) + require.NoError(t, err, clues.ToCore(err)) + require.NotNil(t, prevDelta.URL) + // Create a bunch of files in the new folder var items []models.DriveItemable - for i := 0; i < 10; i++ { + for i := 0; i < 5; i++ { newItemName := "test_url_cache_basic_" + dttm.FormatNow(dttm.SafeForTesting) item, err := ac.Drives().PostItemInContainer( @@ -110,15 +139,12 @@ func (suite *URLCacheIntegrationSuite) TestURLCacheBasic() { } // Create a new URL cache with a long TTL - cache, err := newURLCache( + uc, err := newURLCache( suite.driveID, + prevDelta.URL, 1*time.Hour, driveItemPager, fault.New(true)) - - require.NoError(t, err, clues.ToCore(err)) - - err = cache.refreshCache(ctx) require.NoError(t, err, clues.ToCore(err)) // Launch parallel requests to the cache, one per item @@ -130,11 +156,11 @@ func (suite *URLCacheIntegrationSuite) TestURLCacheBasic() { defer wg.Done() // Read item from URL cache - props, err := cache.getItemProperties( + props, err := uc.getItemProperties( ctx, ptr.Val(items[i].GetId())) - require.NoError(t, err, clues.ToCore(err)) + require.NotNil(t, props) require.NotEmpty(t, props.downloadURL) require.Equal(t, false, props.isDeleted) @@ -148,15 +174,14 @@ func (suite *URLCacheIntegrationSuite) TestURLCacheBasic() { props.downloadURL, nil, nil) - require.NoError(t, err, clues.ToCore(err)) require.Equal(t, http.StatusOK, resp.StatusCode) }(i) } wg.Wait() - // Validate that <= 1 delta queries were made - require.LessOrEqual(t, cache.deltaQueryCount, 1) + // Validate that <= 1 delta queries were made by url cache + require.LessOrEqual(t, uc.deltaQueryCount, 1) } type URLCacheUnitSuite struct { @@ -407,6 +432,7 @@ func (suite *URLCacheUnitSuite) TestGetItemProperties() { cache, err := newURLCache( driveID, + "", 1*time.Hour, itemPager, fault.New(true)) @@ -449,6 +475,7 @@ func (suite *URLCacheUnitSuite) TestNeedsRefresh() { cache, err := newURLCache( driveID, + "", refreshInterval, &mockItemPager{}, fault.New(true)) @@ -522,6 +549,7 @@ func (suite *URLCacheUnitSuite) TestNewURLCache() { t := suite.T() _, err := newURLCache( test.driveID, + "", test.refreshInt, test.itemPager, test.errors) diff --git a/src/internal/operations/backup_integration_test.go b/src/internal/operations/backup_integration_test.go index 36ec0cfa3..5c1871ded 100644 --- a/src/internal/operations/backup_integration_test.go +++ b/src/internal/operations/backup_integration_test.go @@ -813,7 +813,10 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont eventDBF := func(id, timeStamp, subject, body string) []byte { return exchMock.EventWith( suite.user, subject, body, body, - now, now, exchMock.NoRecurrence, exchMock.NoAttendees, false) + exchMock.NoOriginalStartDate, now, now, + exchMock.NoRecurrence, exchMock.NoAttendees, + exchMock.NoAttachments, exchMock.NoCancelledOccurrences, + exchMock.NoExceptionOccurrences) } // test data set @@ -961,7 +964,8 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont table := []struct { name string // performs the incremental update required for the test. - updateUserData func(t *testing.T) + //revive:disable-next-line:context-as-argument + updateUserData func(t *testing.T, ctx context.Context) deltaItemsRead int deltaItemsWritten int nonDeltaItemsRead int @@ -970,7 +974,7 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont }{ { name: "clean, no changes", - updateUserData: func(t *testing.T) {}, + updateUserData: func(t *testing.T, ctx context.Context) {}, deltaItemsRead: 0, deltaItemsWritten: 0, nonDeltaItemsRead: 8, @@ -979,7 +983,7 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont }, { name: "move an email folder to a subfolder", - updateUserData: func(t *testing.T) { + updateUserData: func(t *testing.T, ctx context.Context) { cat := path.EmailCategory // contacts and events cannot be sufoldered; this is an email-only change @@ -1003,7 +1007,7 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont }, { name: "delete a folder", - updateUserData: func(t *testing.T) { + updateUserData: func(t *testing.T, ctx context.Context) { for category, d := range dataset { containerID := d.dests[container2].containerID @@ -1030,7 +1034,7 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont }, { name: "add a new folder", - updateUserData: func(t *testing.T) { + updateUserData: func(t *testing.T, ctx context.Context) { for category, gen := range dataset { deets := generateContainerOfItems( t, @@ -1075,7 +1079,7 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont }, { name: "rename a folder", - updateUserData: func(t *testing.T) { + updateUserData: func(t *testing.T, ctx context.Context) { for category, d := range dataset { containerID := d.dests[container3].containerID newLoc := containerRename @@ -1131,7 +1135,7 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont }, { name: "add a new item", - updateUserData: func(t *testing.T) { + updateUserData: func(t *testing.T, ctx context.Context) { for category, d := range dataset { containerID := d.dests[container1].containerID @@ -1185,7 +1189,7 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont }, { name: "delete an existing item", - updateUserData: func(t *testing.T) { + updateUserData: func(t *testing.T, ctx context.Context) { for category, d := range dataset { containerID := d.dests[container1].containerID @@ -1244,11 +1248,22 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont var ( t = suite.T() incMB = evmock.NewBus() - incBO = newTestBackupOp(t, ctx, kw, ms, ctrl, acct, sels, incMB, toggles, closer) atid = creds.AzureTenantID ) - test.updateUserData(t) + ctx, flush := tester.WithContext(t, ctx) + defer flush() + + incBO := newTestBackupOp(t, ctx, kw, ms, ctrl, acct, sels, incMB, toggles, closer) + + suite.Run("PreTestSetup", func() { + t := suite.T() + + ctx, flush := tester.WithContext(t, ctx) + defer flush() + + test.updateUserData(t, ctx) + }) err := incBO.Run(ctx) require.NoError(t, err, clues.ToCore(err)) @@ -1259,16 +1274,21 @@ func testExchangeContinuousBackups(suite *BackupOpIntegrationSuite, toggles cont checkMetadataFilesExist(t, ctx, bupID, kw, ms, atid, uidn.ID(), service, categories) deeTD.CheckBackupDetails(t, ctx, bupID, whatSet, ms, ss, expectDeets, true) + // FIXME: commented tests are flaky due to interference with other tests + // we need to find a better way to make good assertions here. + // The addition of the deeTD package gives us enough coverage to comment + // out the tests for now and look to their improvemeng later. + // do some additional checks to ensure the incremental dealt with fewer items. // +4 on read/writes to account for metadata: 1 delta and 1 path for each type. - if !toggles.DisableDelta { - assert.Equal(t, test.deltaItemsRead+4, incBO.Results.ItemsRead, "incremental items read") - assert.Equal(t, test.deltaItemsWritten+4, incBO.Results.ItemsWritten, "incremental items written") - } else { - assert.Equal(t, test.nonDeltaItemsRead+4, incBO.Results.ItemsRead, "non delta items read") - assert.Equal(t, test.nonDeltaItemsWritten+4, incBO.Results.ItemsWritten, "non delta items written") - } - assert.Equal(t, test.nonMetaItemsWritten, incBO.Results.ItemsWritten, "non meta incremental items write") + // if !toggles.DisableDelta { + // assert.Equal(t, test.deltaItemsRead+4, incBO.Results.ItemsRead, "incremental items read") + // assert.Equal(t, test.deltaItemsWritten+4, incBO.Results.ItemsWritten, "incremental items written") + // } else { + // assert.Equal(t, test.nonDeltaItemsRead+4, incBO.Results.ItemsRead, "non delta items read") + // assert.Equal(t, test.nonDeltaItemsWritten+4, incBO.Results.ItemsWritten, "non delta items written") + // } + // assert.Equal(t, test.nonMetaItemsWritten, incBO.Results.ItemsWritten, "non meta incremental items write") assert.NoError(t, incBO.Errors.Failure(), "incremental non-recoverable error", clues.ToCore(incBO.Errors.Failure())) assert.Empty(t, incBO.Errors.Recovered(), "incremental recoverable/iteration errors") assert.Equal(t, 1, incMB.TimesCalled[events.BackupStart], "incremental backup-start events") @@ -1542,20 +1562,21 @@ func runDriveIncrementalTest( table := []struct { name string // performs the incremental update required for the test. - updateFiles func(t *testing.T) + //revive:disable-next-line:context-as-argument + updateFiles func(t *testing.T, ctx context.Context) itemsRead int itemsWritten int nonMetaItemsWritten int }{ { name: "clean incremental, no changes", - updateFiles: func(t *testing.T) {}, + updateFiles: func(t *testing.T, ctx context.Context) {}, itemsRead: 0, itemsWritten: 0, }, { name: "create a new file", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { targetContainer := containerIDs[container1] driveItem := models.NewDriveItem() driveItem.SetName(&newFileName) @@ -1578,7 +1599,7 @@ func runDriveIncrementalTest( }, { name: "add permission to new file", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { err = onedrive.UpdatePermissions( ctx, rh, @@ -1596,7 +1617,7 @@ func runDriveIncrementalTest( }, { name: "remove permission from new file", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { err = onedrive.UpdatePermissions( ctx, rh, @@ -1614,7 +1635,7 @@ func runDriveIncrementalTest( }, { name: "add permission to container", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { targetContainer := containerIDs[container1] err = onedrive.UpdatePermissions( ctx, @@ -1633,7 +1654,7 @@ func runDriveIncrementalTest( }, { name: "remove permission from container", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { targetContainer := containerIDs[container1] err = onedrive.UpdatePermissions( ctx, @@ -1652,7 +1673,7 @@ func runDriveIncrementalTest( }, { name: "update contents of a file", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { err := suite.ac.Drives().PutItemContent( ctx, driveID, @@ -1667,7 +1688,7 @@ func runDriveIncrementalTest( }, { name: "rename a file", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { container := containerIDs[container1] driveItem := models.NewDriveItem() @@ -1691,7 +1712,7 @@ func runDriveIncrementalTest( }, { name: "move a file between folders", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { dest := containerIDs[container2] driveItem := models.NewDriveItem() @@ -1719,7 +1740,7 @@ func runDriveIncrementalTest( }, { name: "delete file", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { err := suite.ac.Drives().DeleteItem( ctx, driveID, @@ -1734,7 +1755,7 @@ func runDriveIncrementalTest( }, { name: "move a folder to a subfolder", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { parent := containerIDs[container1] child := containerIDs[container2] @@ -1762,7 +1783,7 @@ func runDriveIncrementalTest( }, { name: "rename a folder", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { parent := containerIDs[container1] child := containerIDs[container2] @@ -1792,7 +1813,7 @@ func runDriveIncrementalTest( }, { name: "delete a folder", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { container := containerIDs[containerRename] err := suite.ac.Drives().DeleteItem( ctx, @@ -1808,7 +1829,7 @@ func runDriveIncrementalTest( }, { name: "add a new folder", - updateFiles: func(t *testing.T) { + updateFiles: func(t *testing.T, ctx context.Context) { generateContainerOfItems( t, ctx, @@ -1850,9 +1871,17 @@ func runDriveIncrementalTest( incBO = newTestBackupOp(t, ctx, kw, ms, cleanCtrl, acct, sel, incMB, ffs, closer) ) - tester.LogTimeOfTest(suite.T()) + ctx, flush := tester.WithContext(t, ctx) + defer flush() - test.updateFiles(t) + suite.Run("PreTestSetup", func() { + t := suite.T() + + ctx, flush := tester.WithContext(t, ctx) + defer flush() + + test.updateFiles(t, ctx) + }) err = incBO.Run(ctx) require.NoError(t, err, clues.ToCore(err)) diff --git a/src/pkg/logger/logger.go b/src/pkg/logger/logger.go index fb4d37e4b..2cf8bb33b 100644 --- a/src/pkg/logger/logger.go +++ b/src/pkg/logger/logger.go @@ -8,6 +8,7 @@ import ( "time" "github.com/alcionai/clues" + "github.com/kopia/kopia/repo/logging" "github.com/spf13/cobra" "github.com/spf13/pflag" "go.uber.org/zap" @@ -60,6 +61,7 @@ const ( LogLevelFN = "log-level" ReadableLogsFN = "readable-logs" MaskSensitiveDataFN = "mask-sensitive-data" + logStorageFN = "log-storage" ) // flag values @@ -70,6 +72,7 @@ var ( LogLevelFV string ReadableLogsFV bool MaskSensitiveDataFV bool + logStorageFV bool ResolvedLogFile string // logFileFV after processing piiHandling string // piiHandling after MaskSensitiveDataFV processing @@ -131,6 +134,13 @@ func addFlags(fs *pflag.FlagSet, defaultFile string) { MaskSensitiveDataFN, false, "anonymize personal data in log output") + + fs.BoolVar( + &logStorageFV, + logStorageFN, + false, + "include logs produced by the downstream storage systems. Uses the same log level as the corso logger") + cobra.CheckErr(fs.MarkHidden(logStorageFN)) } // Due to races between the lazy evaluation of flags in cobra and the @@ -197,6 +207,18 @@ func PreloadLoggingFlags(args []string) Settings { set.PIIHandling = PIIHash } + // retrieve the user's preferred settings for storage engine logging in the + // corso log. + // defaults to not logging it. + storageLog, err := fs.GetBool(logStorageFN) + if err != nil { + return set + } + + if storageLog { + set.LogStorage = storageLog + } + return set } @@ -241,6 +263,7 @@ type Settings struct { Format logFormat // whether to format as text (console) or json (cloud) Level logLevel // what level to log at PIIHandling piiAlg // how to obscure pii + LogStorage bool // Whether kopia logs should be added to the corso log. } // EnsureDefaults sets any non-populated settings to their default value. @@ -390,7 +413,7 @@ const ctxKey loggingKey = "corsoLogger" // a seeded context prior to cobra evaluating flags. func Seed(ctx context.Context, set Settings) (context.Context, *zap.SugaredLogger) { zsl := singleton(set) - return Set(ctx, zsl), zsl + return SetWithSettings(ctx, zsl, set), zsl } func setCluesSecretsHash(alg piiAlg) { @@ -412,7 +435,7 @@ func CtxOrSeed(ctx context.Context, set Settings) (context.Context, *zap.Sugared l := ctx.Value(ctxKey) if l == nil { zsl := singleton(set) - return Set(ctx, zsl), zsl + return SetWithSettings(ctx, zsl, set), zsl } return ctx, l.(*zap.SugaredLogger) @@ -420,10 +443,31 @@ func CtxOrSeed(ctx context.Context, set Settings) (context.Context, *zap.Sugared // Set allows users to embed their own zap.SugaredLogger within the context. func Set(ctx context.Context, logger *zap.SugaredLogger) context.Context { + set := Settings{}.EnsureDefaults() + + return SetWithSettings(ctx, logger, set) +} + +// SetWithSettings allows users to embed their own zap.SugaredLogger within the +// context and with the given logger settings. +func SetWithSettings( + ctx context.Context, + logger *zap.SugaredLogger, + set Settings, +) context.Context { if logger == nil { return ctx } + // Add the kopia logger as well. Unfortunately we need to do this here instead + // of a kopia-specific package because we want it to be in the context that's + // used for the rest of execution. + if set.LogStorage { + ctx = logging.WithLogger(ctx, func(module string) logging.Logger { + return logger.Named("kopia-lib/" + module) + }) + } + return context.WithValue(ctx, ctxKey, logger) } diff --git a/src/pkg/services/m365/api/config.go b/src/pkg/services/m365/api/config.go index 1e3a1ce04..3f02505db 100644 --- a/src/pkg/services/m365/api/config.go +++ b/src/pkg/services/m365/api/config.go @@ -17,8 +17,15 @@ const ( // get easily misspelled. // eg: we don't need a const for "id" const ( - parentFolderID = "parentFolderId" + attendees = "attendees" + bccRecipients = "bccRecipients" + ccRecipients = "ccRecipients" + createdDateTime = "createdDateTime" displayName = "displayName" + givenName = "givenName" + parentFolderID = "parentFolderId" + surname = "surname" + toRecipients = "toRecipients" userPrincipalName = "userPrincipalName" ) diff --git a/src/pkg/services/m365/api/contacts.go b/src/pkg/services/m365/api/contacts.go index c253212cd..80f4d583e 100644 --- a/src/pkg/services/m365/api/contacts.go +++ b/src/pkg/services/m365/api/contacts.go @@ -265,3 +265,17 @@ func ContactInfo(contact models.Contactable) *details.ExchangeInfo { Modified: ptr.OrNow(contact.GetLastModifiedDateTime()), } } + +func contactCollisionKeyProps() []string { + return idAnd(givenName) +} + +// ContactCollisionKey constructs a key from the contactable's creation time and either displayName or given+surname. +// collision keys are used to identify duplicate item conflicts for handling advanced restoration config. +func ContactCollisionKey(item models.Contactable) string { + if item == nil { + return "" + } + + return ptr.Val(item.GetId()) +} diff --git a/src/pkg/services/m365/api/contacts_pager.go b/src/pkg/services/m365/api/contacts_pager.go index da79b3ce9..e2082af7c 100644 --- a/src/pkg/services/m365/api/contacts_pager.go +++ b/src/pkg/services/m365/api/contacts_pager.go @@ -90,22 +90,98 @@ func (c Contacts) EnumerateContainers( // item pager // --------------------------------------------------------------------------- -var _ itemPager = &contactPager{} +var _ itemPager[models.Contactable] = &contactsPageCtrl{} -type contactPager struct { +type contactsPageCtrl struct { gs graph.Servicer builder *users.ItemContactFoldersItemContactsRequestBuilder options *users.ItemContactFoldersItemContactsRequestBuilderGetRequestConfiguration } -func (c Contacts) NewContactPager( +func (c Contacts) NewContactsPager( + userID, containerID string, + selectProps ...string, +) itemPager[models.Contactable] { + options := &users.ItemContactFoldersItemContactsRequestBuilderGetRequestConfiguration{ + Headers: newPreferHeaders(preferPageSize(maxNonDeltaPageSize)), + QueryParameters: &users.ItemContactFoldersItemContactsRequestBuilderGetQueryParameters{ + Top: ptr.To[int32](maxNonDeltaPageSize), + }, + } + + if len(selectProps) > 0 { + options.QueryParameters.Select = selectProps + } + + builder := c.Stable. + Client(). + Users(). + ByUserId(userID). + ContactFolders(). + ByContactFolderId(containerID). + Contacts() + + return &contactsPageCtrl{c.Stable, builder, options} +} + +//lint:ignore U1000 False Positive +func (p *contactsPageCtrl) getPage(ctx context.Context) (PageLinkValuer[models.Contactable], error) { + resp, err := p.builder.Get(ctx, p.options) + if err != nil { + return nil, graph.Stack(ctx, err) + } + + return EmptyDeltaLinker[models.Contactable]{PageLinkValuer: resp}, nil +} + +//lint:ignore U1000 False Positive +func (p *contactsPageCtrl) setNext(nextLink string) { + p.builder = users.NewItemContactFoldersItemContactsRequestBuilder(nextLink, p.gs.Adapter()) +} + +//lint:ignore U1000 False Positive +func (c Contacts) GetItemsInContainerByCollisionKey( + ctx context.Context, + userID, containerID string, +) (map[string]string, error) { + ctx = clues.Add(ctx, "container_id", containerID) + pager := c.NewContactsPager(userID, containerID, contactCollisionKeyProps()...) + + items, err := enumerateItems(ctx, pager) + if err != nil { + return nil, graph.Wrap(ctx, err, "enumerating contacts") + } + + m := map[string]string{} + + for _, item := range items { + m[ContactCollisionKey(item)] = ptr.Val(item.GetId()) + } + + return m, nil +} + +// --------------------------------------------------------------------------- +// item ID pager +// --------------------------------------------------------------------------- + +var _ itemIDPager = &contactIDPager{} + +type contactIDPager struct { + gs graph.Servicer + builder *users.ItemContactFoldersItemContactsRequestBuilder + options *users.ItemContactFoldersItemContactsRequestBuilderGetRequestConfiguration +} + +func (c Contacts) NewContactIDsPager( ctx context.Context, userID, containerID string, immutableIDs bool, -) itemPager { +) itemIDPager { config := &users.ItemContactFoldersItemContactsRequestBuilderGetRequestConfiguration{ QueryParameters: &users.ItemContactFoldersItemContactsRequestBuilderGetQueryParameters{ Select: idAnd(parentFolderID), + Top: ptr.To[int32](maxNonDeltaPageSize), }, Headers: newPreferHeaders(preferPageSize(maxNonDeltaPageSize), preferImmutableIDs(immutableIDs)), } @@ -118,10 +194,10 @@ func (c Contacts) NewContactPager( ByContactFolderId(containerID). Contacts() - return &contactPager{c.Stable, builder, config} + return &contactIDPager{c.Stable, builder, config} } -func (p *contactPager) getPage(ctx context.Context) (DeltaPageLinker, error) { +func (p *contactIDPager) getPage(ctx context.Context) (DeltaPageLinker, error) { resp, err := p.builder.Get(ctx, p.options) if err != nil { return nil, graph.Stack(ctx, err) @@ -130,24 +206,24 @@ func (p *contactPager) getPage(ctx context.Context) (DeltaPageLinker, error) { return EmptyDeltaLinker[models.Contactable]{PageLinkValuer: resp}, nil } -func (p *contactPager) setNext(nextLink string) { +func (p *contactIDPager) setNext(nextLink string) { p.builder = users.NewItemContactFoldersItemContactsRequestBuilder(nextLink, p.gs.Adapter()) } // non delta pagers don't need reset -func (p *contactPager) reset(context.Context) {} +func (p *contactIDPager) reset(context.Context) {} -func (p *contactPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { +func (p *contactIDPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { return toValues[models.Contactable](pl) } // --------------------------------------------------------------------------- -// delta item pager +// delta item ID pager // --------------------------------------------------------------------------- -var _ itemPager = &contactDeltaPager{} +var _ itemIDPager = &contactDeltaIDPager{} -type contactDeltaPager struct { +type contactDeltaIDPager struct { gs graph.Servicer userID string containerID string @@ -165,14 +241,15 @@ func getContactDeltaBuilder( return builder } -func (c Contacts) NewContactDeltaPager( +func (c Contacts) NewContactDeltaIDsPager( ctx context.Context, userID, containerID, oldDelta string, immutableIDs bool, -) itemPager { +) itemIDPager { options := &users.ItemContactFoldersItemContactsDeltaRequestBuilderGetRequestConfiguration{ QueryParameters: &users.ItemContactFoldersItemContactsDeltaRequestBuilderGetQueryParameters{ Select: idAnd(parentFolderID), + // TOP is not allowed }, Headers: newPreferHeaders(preferPageSize(maxDeltaPageSize), preferImmutableIDs(immutableIDs)), } @@ -184,10 +261,10 @@ func (c Contacts) NewContactDeltaPager( builder = getContactDeltaBuilder(ctx, c.Stable, userID, containerID, options) } - return &contactDeltaPager{c.Stable, userID, containerID, builder, options} + return &contactDeltaIDPager{c.Stable, userID, containerID, builder, options} } -func (p *contactDeltaPager) getPage(ctx context.Context) (DeltaPageLinker, error) { +func (p *contactDeltaIDPager) getPage(ctx context.Context) (DeltaPageLinker, error) { resp, err := p.builder.Get(ctx, p.options) if err != nil { return nil, graph.Stack(ctx, err) @@ -196,15 +273,15 @@ func (p *contactDeltaPager) getPage(ctx context.Context) (DeltaPageLinker, error return resp, nil } -func (p *contactDeltaPager) setNext(nextLink string) { +func (p *contactDeltaIDPager) setNext(nextLink string) { p.builder = users.NewItemContactFoldersItemContactsDeltaRequestBuilder(nextLink, p.gs.Adapter()) } -func (p *contactDeltaPager) reset(ctx context.Context) { +func (p *contactDeltaIDPager) reset(ctx context.Context) { p.builder = getContactDeltaBuilder(ctx, p.gs, p.userID, p.containerID, p.options) } -func (p *contactDeltaPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { +func (p *contactDeltaIDPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { return toValues[models.Contactable](pl) } @@ -219,8 +296,8 @@ func (c Contacts) GetAddedAndRemovedItemIDs( "category", selectors.ExchangeContact, "container_id", containerID) - pager := c.NewContactPager(ctx, userID, containerID, immutableIDs) - deltaPager := c.NewContactDeltaPager(ctx, userID, containerID, oldDelta, immutableIDs) + pager := c.NewContactIDsPager(ctx, userID, containerID, immutableIDs) + deltaPager := c.NewContactDeltaIDsPager(ctx, userID, containerID, oldDelta, immutableIDs) return getAddedAndRemovedItemIDs(ctx, c.Stable, pager, deltaPager, oldDelta, canMakeDeltaQueries) } diff --git a/src/pkg/services/m365/api/contacts_pager_test.go b/src/pkg/services/m365/api/contacts_pager_test.go new file mode 100644 index 000000000..d29be16c9 --- /dev/null +++ b/src/pkg/services/m365/api/contacts_pager_test.go @@ -0,0 +1,73 @@ +package api_test + +import ( + "testing" + + "github.com/alcionai/clues" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/pkg/services/m365/api" +) + +type ContactsPagerIntgSuite struct { + tester.Suite + cts clientTesterSetup +} + +func TestContactsPagerIntgSuite(t *testing.T) { + suite.Run(t, &ContactsPagerIntgSuite{ + Suite: tester.NewIntegrationSuite( + t, + [][]string{tester.M365AcctCredEnvs}), + }) +} + +func (suite *ContactsPagerIntgSuite) SetupSuite() { + suite.cts = newClientTesterSetup(suite.T()) +} + +func (suite *ContactsPagerIntgSuite) TestGetItemsInContainerByCollisionKey() { + t := suite.T() + ac := suite.cts.ac.Contacts() + + ctx, flush := tester.NewContext(t) + defer flush() + + container, err := ac.GetContainerByID(ctx, suite.cts.userID, "contacts") + require.NoError(t, err, clues.ToCore(err)) + + conts, err := ac.Stable. + Client(). + Users(). + ByUserId(suite.cts.userID). + ContactFolders(). + ByContactFolderId(ptr.Val(container.GetId())). + Contacts(). + Get(ctx, nil) + require.NoError(t, err, clues.ToCore(err)) + + cs := conts.GetValue() + expect := make([]string, 0, len(cs)) + + for _, c := range cs { + expect = append(expect, api.ContactCollisionKey(c)) + } + + results, err := ac.GetItemsInContainerByCollisionKey(ctx, suite.cts.userID, "contacts") + require.NoError(t, err, clues.ToCore(err)) + require.Less(t, 0, len(results), "requires at least one result") + + for k, v := range results { + assert.NotEmpty(t, k, "all keys should be populated") + assert.NotEmpty(t, v, "all values should be populated") + } + + for _, e := range expect { + _, ok := results[e] + assert.Truef(t, ok, "expected results to contain collision key: %s", e) + } +} diff --git a/src/pkg/services/m365/api/drive_pager.go b/src/pkg/services/m365/api/drive_pager.go index 8199dc8e8..7d06be397 100644 --- a/src/pkg/services/m365/api/drive_pager.go +++ b/src/pkg/services/m365/api/drive_pager.go @@ -292,8 +292,8 @@ func GetAllDrives( for i := 0; i <= maxRetryCount; i++ { page, err = pager.GetPage(ctx) if err != nil { - if clues.HasLabel(err, graph.LabelsMysiteNotFound) { - logger.Ctx(ctx).Infof("resource owner does not have a drive") + if clues.HasLabel(err, graph.LabelsMysiteNotFound) || clues.HasLabel(err, graph.LabelsNoSharePointLicense) { + logger.CtxErr(ctx, err).Infof("resource owner does not have a drive") return make([]models.Driveable, 0), nil // no license or drives. } diff --git a/src/pkg/services/m365/api/events.go b/src/pkg/services/m365/api/events.go index 574e2de21..4ed1b83b8 100644 --- a/src/pkg/services/m365/api/events.go +++ b/src/pkg/services/m365/api/events.go @@ -3,8 +3,10 @@ package api import ( "bytes" "context" + "encoding/json" "fmt" "io" + "strings" "time" "github.com/alcionai/clues" @@ -15,6 +17,7 @@ import ( "github.com/alcionai/corso/src/internal/common/dttm" "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/pkg/backup/details" "github.com/alcionai/corso/src/pkg/fault" @@ -189,6 +192,14 @@ func (c Events) PatchCalendar( return nil } +const ( + // Beta version cannot have /calendars/%s for get and Patch + // https://stackoverflow.com/questions/50492177/microsoft-graph-get-user-calendar-event-with-beta-version + eventExceptionsBetaURLTemplate = "https://graph.microsoft.com/beta/users/%s/events/%s?$expand=exceptionOccurrences" + eventPostBetaURLTemplate = "https://graph.microsoft.com/beta/users/%s/calendars/%s/events" + eventPatchBetaURLTemplate = "https://graph.microsoft.com/beta/users/%s/events/%s" +) + // --------------------------------------------------------------------------- // items // --------------------------------------------------------------------------- @@ -208,41 +219,233 @@ func (c Events) GetItem( } ) - event, err = c.Stable. + // Beta endpoint helps us fetch the event exceptions, but since we + // don't use the beta SDK, the exceptionOccurrences and + // cancelledOccurrences end up in AdditionalData + // https://learn.microsoft.com/en-us/graph/api/resources/event?view=graph-rest-beta#properties + rawURL := fmt.Sprintf(eventExceptionsBetaURLTemplate, userID, itemID) + builder := users.NewItemEventsEventItemRequestBuilder(rawURL, c.Stable.Adapter()) + + event, err = builder.Get(ctx, config) + if err != nil { + return nil, nil, graph.Stack(ctx, err) + } + + err = validateCancelledOccurrences(event) + if err != nil { + return nil, nil, clues.Wrap(err, "verify cancelled occurrences") + } + + err = fixupExceptionOccurrences(ctx, c, event, immutableIDs, userID) + if err != nil { + return nil, nil, clues.Wrap(err, "fixup exception occurrences") + } + + var attachments []models.Attachmentable + if ptr.Val(event.GetHasAttachments()) || HasAttachments(event.GetBody()) { + attachments, err = c.GetAttachments(ctx, immutableIDs, userID, itemID) + if err != nil { + return nil, nil, err + } + } + + event.SetAttachments(attachments) + + return event, EventInfo(event), nil +} + +// fixupExceptionOccurrences gets attachments and converts the data +// into a format that gets serialized when storing to kopia +func fixupExceptionOccurrences( + ctx context.Context, + client Events, + event models.Eventable, + immutableIDs bool, + userID string, +) error { + // Fetch attachments for exceptions + exceptionOccurrences := event.GetAdditionalData()["exceptionOccurrences"] + if exceptionOccurrences == nil { + return nil + } + + eo, ok := exceptionOccurrences.([]any) + if !ok { + return clues.New("converting exceptionOccurrences to []any"). + With("type", fmt.Sprintf("%T", exceptionOccurrences)) + } + + for _, instance := range eo { + instance, ok := instance.(map[string]any) + if !ok { + return clues.New("converting instance to map[string]any"). + With("type", fmt.Sprintf("%T", instance)) + } + + evt, err := EventFromMap(instance) + if err != nil { + return clues.Wrap(err, "parsing exception event") + } + + // OPTIMIZATION: We don't have to store any of the + // attachments that carry over from the original + + var attachments []models.Attachmentable + if ptr.Val(event.GetHasAttachments()) || HasAttachments(event.GetBody()) { + attachments, err = client.GetAttachments(ctx, immutableIDs, userID, ptr.Val(evt.GetId())) + if err != nil { + return clues.Wrap(err, "getting event instance attachments"). + With("event_instance_id", ptr.Val(evt.GetId())) + } + } + + // This odd roundabout way of doing this is required as + // the json serialization at the end does not serialize if + // you just pass in a models.Attachmentable + convertedAttachments := []map[string]interface{}{} + + for _, attachment := range attachments { + am, err := parseableToMap(attachment) + if err != nil { + return clues.Wrap(err, "converting attachment") + } + + convertedAttachments = append(convertedAttachments, am) + } + + instance["attachments"] = convertedAttachments + } + + return nil +} + +// Adding checks to ensure that the data is in the format that we expect M365 to return +func validateCancelledOccurrences(event models.Eventable) error { + cancelledOccurrences := event.GetAdditionalData()["cancelledOccurrences"] + if cancelledOccurrences != nil { + co, ok := cancelledOccurrences.([]any) + if !ok { + return clues.New("converting cancelledOccurrences to []any"). + With("type", fmt.Sprintf("%T", cancelledOccurrences)) + } + + for _, instance := range co { + instance, err := str.AnyToString(instance) + if err != nil { + return err + } + + // There might be multiple `.` in the ID and hence >2 + splits := strings.Split(instance, ".") + if len(splits) < 2 { + return clues.New("unexpected cancelled event format"). + With("instance", instance) + } + + startStr := splits[len(splits)-1] + + _, err = dttm.ParseTime(startStr) + if err != nil { + return clues.Wrap(err, "parsing cancelled event date") + } + } + } + + return nil +} + +func parseableToMap(att serialization.Parsable) (map[string]any, error) { + var item map[string]any + + writer := kjson.NewJsonSerializationWriter() + defer writer.Close() + + if err := writer.WriteObjectValue("", att); err != nil { + return nil, err + } + + ats, err := writer.GetSerializedContent() + if err != nil { + return nil, err + } + + err = json.Unmarshal(ats, &item) + if err != nil { + return nil, clues.Wrap(err, "unmarshalling serialized attachment") + } + + return item, nil +} + +func (c Events) GetAttachments( + ctx context.Context, + immutableIDs bool, + userID, itemID string, +) ([]models.Attachmentable, error) { + config := &users.ItemEventsItemAttachmentsRequestBuilderGetRequestConfiguration{ + QueryParameters: &users.ItemEventsItemAttachmentsRequestBuilderGetQueryParameters{ + Expand: []string{"microsoft.graph.itemattachment/item"}, + }, + Headers: newPreferHeaders(preferPageSize(maxNonDeltaPageSize), preferImmutableIDs(immutableIDs)), + } + + attached, err := c.LargeItem. Client(). Users(). ByUserId(userID). Events(). ByEventId(itemID). + Attachments(). Get(ctx, config) if err != nil { - return nil, nil, graph.Stack(ctx, err) + return nil, graph.Wrap(ctx, err, "event attachment download") } - if ptr.Val(event.GetHasAttachments()) || HasAttachments(event.GetBody()) { - config := &users.ItemEventsItemAttachmentsRequestBuilderGetRequestConfiguration{ - QueryParameters: &users.ItemEventsItemAttachmentsRequestBuilderGetQueryParameters{ - Expand: []string{"microsoft.graph.itemattachment/item"}, - }, - Headers: newPreferHeaders(preferPageSize(maxNonDeltaPageSize), preferImmutableIDs(immutableIDs)), - } + return attached.GetValue(), nil +} - attached, err := c.LargeItem. - Client(). - Users(). - ByUserId(userID). - Events(). - ByEventId(itemID). - Attachments(). - Get(ctx, config) - if err != nil { - return nil, nil, graph.Wrap(ctx, err, "event attachment download") - } +func (c Events) DeleteAttachment( + ctx context.Context, + userID, calendarID, eventID, attachmentID string, +) error { + return c.Stable. + Client(). + Users(). + ByUserId(userID). + Calendars(). + ByCalendarId(calendarID). + Events(). + ByEventId(eventID). + Attachments(). + ByAttachmentId(attachmentID). + Delete(ctx, nil) +} - event.SetAttachments(attached.GetValue()) +func (c Events) GetItemInstances( + ctx context.Context, + userID, itemID, startDate, endDate string, +) ([]models.Eventable, error) { + config := &users.ItemEventsItemInstancesRequestBuilderGetRequestConfiguration{ + QueryParameters: &users.ItemEventsItemInstancesRequestBuilderGetQueryParameters{ + Select: []string{"id"}, + StartDateTime: ptr.To(startDate), + EndDateTime: ptr.To(endDate), + }, } - return event, EventInfo(event), nil + events, err := c.Stable. + Client(). + Users(). + ByUserId(userID). + Events(). + ByEventId(itemID). + Instances(). + Get(ctx, config) + if err != nil { + return nil, graph.Stack(ctx, err) + } + + return events.GetValue(), nil } func (c Events) PostItem( @@ -250,14 +453,10 @@ func (c Events) PostItem( userID, containerID string, body models.Eventable, ) (models.Eventable, error) { - itm, err := c.Stable. - Client(). - Users(). - ByUserId(userID). - Calendars(). - ByCalendarId(containerID). - Events(). - Post(ctx, body, nil) + rawURL := fmt.Sprintf(eventPostBetaURLTemplate, userID, containerID) + builder := users.NewItemCalendarsItemEventsRequestBuilder(rawURL, c.Stable.Adapter()) + + itm, err := builder.Post(ctx, body, nil) if err != nil { return nil, graph.Wrap(ctx, err, "creating calendar event") } @@ -265,6 +464,22 @@ func (c Events) PostItem( return itm, nil } +func (c Events) PatchItem( + ctx context.Context, + userID, eventID string, + body models.Eventable, +) (models.Eventable, error) { + rawURL := fmt.Sprintf(eventPatchBetaURLTemplate, userID, eventID) + builder := users.NewItemCalendarsItemEventsEventItemRequestBuilder(rawURL, c.Stable.Adapter()) + + itm, err := builder.Patch(ctx, body, nil) + if err != nil { + return nil, graph.Wrap(ctx, err, "updating calendar event") + } + + return itm, nil +} + func (c Events) DeleteItem( ctx context.Context, userID, itemID string, @@ -315,14 +530,9 @@ func (c Events) PostSmallAttachment( func (c Events) PostLargeAttachment( ctx context.Context, userID, containerID, parentItemID, itemName string, - size int64, - body models.Attachmentable, -) (models.UploadSessionable, error) { - bs, err := GetAttachmentContent(body) - if err != nil { - return nil, clues.Wrap(err, "serializing attachment content").WithClues(ctx) - } - + content []byte, +) (string, error) { + size := int64(len(content)) session := users.NewItemCalendarEventsItemAttachmentsCreateUploadSessionPostRequestBody() session.SetAttachmentItem(makeSessionAttachment(itemName, size)) @@ -338,19 +548,19 @@ func (c Events) PostLargeAttachment( CreateUploadSession(). Post(ctx, session, nil) if err != nil { - return nil, graph.Wrap(ctx, err, "uploading large event attachment") + return "", graph.Wrap(ctx, err, "uploading large event attachment") } url := ptr.Val(us.GetUploadUrl()) w := graph.NewLargeItemWriter(parentItemID, url, size) copyBuffer := make([]byte, graph.AttachmentChunkSize) - _, err = io.CopyBuffer(w, bytes.NewReader(bs), copyBuffer) + _, err = io.CopyBuffer(w, bytes.NewReader(content), copyBuffer) if err != nil { - return nil, clues.Wrap(err, "buffering large attachment content").WithClues(ctx) + return "", clues.Wrap(err, "buffering large attachment content").WithClues(ctx) } - return us, nil + return w.ID, nil } // --------------------------------------------------------------------------- @@ -472,3 +682,31 @@ func EventInfo(evt models.Eventable) *details.ExchangeInfo { Modified: ptr.OrNow(evt.GetLastModifiedDateTime()), } } + +func EventFromMap(ev map[string]any) (models.Eventable, error) { + instBytes, err := json.Marshal(ev) + if err != nil { + return nil, clues.Wrap(err, "marshaling event exception instance") + } + + body, err := BytesToEventable(instBytes) + if err != nil { + return nil, clues.Wrap(err, "converting exception event bytes to Eventable") + } + + return body, nil +} + +func eventCollisionKeyProps() []string { + return idAnd("subject") +} + +// EventCollisionKey constructs a key from the eventable's creation time, subject, and organizer. +// collision keys are used to identify duplicate item conflicts for handling advanced restoration config. +func EventCollisionKey(item models.Eventable) string { + if item == nil { + return "" + } + + return ptr.Val(item.GetSubject()) +} diff --git a/src/pkg/services/m365/api/events_pager.go b/src/pkg/services/m365/api/events_pager.go index bb390a288..782a26fc6 100644 --- a/src/pkg/services/m365/api/events_pager.go +++ b/src/pkg/services/m365/api/events_pager.go @@ -98,21 +98,27 @@ func (c Events) EnumerateContainers( // item pager // --------------------------------------------------------------------------- -var _ itemPager = &eventPager{} +var _ itemPager[models.Eventable] = &eventsPageCtrl{} -type eventPager struct { +type eventsPageCtrl struct { gs graph.Servicer builder *users.ItemCalendarsItemEventsRequestBuilder options *users.ItemCalendarsItemEventsRequestBuilderGetRequestConfiguration } -func (c Events) NewEventPager( - ctx context.Context, +func (c Events) NewEventsPager( userID, containerID string, - immutableIDs bool, -) (itemPager, error) { + selectProps ...string, +) itemPager[models.Eventable] { options := &users.ItemCalendarsItemEventsRequestBuilderGetRequestConfiguration{ - Headers: newPreferHeaders(preferPageSize(maxNonDeltaPageSize), preferImmutableIDs(immutableIDs)), + Headers: newPreferHeaders(preferPageSize(maxNonDeltaPageSize)), + QueryParameters: &users.ItemCalendarsItemEventsRequestBuilderGetQueryParameters{ + Top: ptr.To[int32](maxNonDeltaPageSize), + }, + } + + if len(selectProps) > 0 { + options.QueryParameters.Select = selectProps } builder := c.Stable. @@ -123,10 +129,82 @@ func (c Events) NewEventPager( ByCalendarId(containerID). Events() - return &eventPager{c.Stable, builder, options}, nil + return &eventsPageCtrl{c.Stable, builder, options} } -func (p *eventPager) getPage(ctx context.Context) (DeltaPageLinker, error) { +//lint:ignore U1000 False Positive +func (p *eventsPageCtrl) getPage(ctx context.Context) (PageLinkValuer[models.Eventable], error) { + resp, err := p.builder.Get(ctx, p.options) + if err != nil { + return nil, graph.Stack(ctx, err) + } + + return resp, nil +} + +//lint:ignore U1000 False Positive +func (p *eventsPageCtrl) setNext(nextLink string) { + p.builder = users.NewItemCalendarsItemEventsRequestBuilder(nextLink, p.gs.Adapter()) +} + +//lint:ignore U1000 False Positive +func (c Events) GetItemsInContainerByCollisionKey( + ctx context.Context, + userID, containerID string, +) (map[string]string, error) { + ctx = clues.Add(ctx, "container_id", containerID) + pager := c.NewEventsPager(userID, containerID, eventCollisionKeyProps()...) + + items, err := enumerateItems(ctx, pager) + if err != nil { + return nil, graph.Wrap(ctx, err, "enumerating events") + } + + m := map[string]string{} + + for _, item := range items { + m[EventCollisionKey(item)] = ptr.Val(item.GetId()) + } + + return m, nil +} + +// --------------------------------------------------------------------------- +// item ID pager +// --------------------------------------------------------------------------- + +var _ itemIDPager = &eventIDPager{} + +type eventIDPager struct { + gs graph.Servicer + builder *users.ItemCalendarsItemEventsRequestBuilder + options *users.ItemCalendarsItemEventsRequestBuilderGetRequestConfiguration +} + +func (c Events) NewEventIDsPager( + ctx context.Context, + userID, containerID string, + immutableIDs bool, +) (itemIDPager, error) { + options := &users.ItemCalendarsItemEventsRequestBuilderGetRequestConfiguration{ + Headers: newPreferHeaders(preferPageSize(maxNonDeltaPageSize), preferImmutableIDs(immutableIDs)), + QueryParameters: &users.ItemCalendarsItemEventsRequestBuilderGetQueryParameters{ + Top: ptr.To[int32](maxNonDeltaPageSize), + }, + } + + builder := c.Stable. + Client(). + Users(). + ByUserId(userID). + Calendars(). + ByCalendarId(containerID). + Events() + + return &eventIDPager{c.Stable, builder, options}, nil +} + +func (p *eventIDPager) getPage(ctx context.Context) (DeltaPageLinker, error) { resp, err := p.builder.Get(ctx, p.options) if err != nil { return nil, graph.Stack(ctx, err) @@ -135,24 +213,24 @@ func (p *eventPager) getPage(ctx context.Context) (DeltaPageLinker, error) { return EmptyDeltaLinker[models.Eventable]{PageLinkValuer: resp}, nil } -func (p *eventPager) setNext(nextLink string) { +func (p *eventIDPager) setNext(nextLink string) { p.builder = users.NewItemCalendarsItemEventsRequestBuilder(nextLink, p.gs.Adapter()) } // non delta pagers don't need reset -func (p *eventPager) reset(context.Context) {} +func (p *eventIDPager) reset(context.Context) {} -func (p *eventPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { +func (p *eventIDPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { return toValues[models.Eventable](pl) } // --------------------------------------------------------------------------- -// delta item pager +// delta item ID pager // --------------------------------------------------------------------------- -var _ itemPager = &eventDeltaPager{} +var _ itemIDPager = &eventDeltaIDPager{} -type eventDeltaPager struct { +type eventDeltaIDPager struct { gs graph.Servicer userID string containerID string @@ -160,13 +238,16 @@ type eventDeltaPager struct { options *users.ItemCalendarsItemEventsDeltaRequestBuilderGetRequestConfiguration } -func (c Events) NewEventDeltaPager( +func (c Events) NewEventDeltaIDsPager( ctx context.Context, userID, containerID, oldDelta string, immutableIDs bool, -) (itemPager, error) { +) (itemIDPager, error) { options := &users.ItemCalendarsItemEventsDeltaRequestBuilderGetRequestConfiguration{ Headers: newPreferHeaders(preferPageSize(maxDeltaPageSize), preferImmutableIDs(immutableIDs)), + QueryParameters: &users.ItemCalendarsItemEventsDeltaRequestBuilderGetQueryParameters{ + Top: ptr.To[int32](maxDeltaPageSize), + }, } var builder *users.ItemCalendarsItemEventsDeltaRequestBuilder @@ -177,7 +258,7 @@ func (c Events) NewEventDeltaPager( builder = users.NewItemCalendarsItemEventsDeltaRequestBuilder(oldDelta, c.Stable.Adapter()) } - return &eventDeltaPager{c.Stable, userID, containerID, builder, options}, nil + return &eventDeltaIDPager{c.Stable, userID, containerID, builder, options}, nil } func getEventDeltaBuilder( @@ -200,7 +281,7 @@ func getEventDeltaBuilder( return builder } -func (p *eventDeltaPager) getPage(ctx context.Context) (DeltaPageLinker, error) { +func (p *eventDeltaIDPager) getPage(ctx context.Context) (DeltaPageLinker, error) { resp, err := p.builder.Get(ctx, p.options) if err != nil { return nil, graph.Stack(ctx, err) @@ -209,15 +290,15 @@ func (p *eventDeltaPager) getPage(ctx context.Context) (DeltaPageLinker, error) return resp, nil } -func (p *eventDeltaPager) setNext(nextLink string) { +func (p *eventDeltaIDPager) setNext(nextLink string) { p.builder = users.NewItemCalendarsItemEventsDeltaRequestBuilder(nextLink, p.gs.Adapter()) } -func (p *eventDeltaPager) reset(ctx context.Context) { +func (p *eventDeltaIDPager) reset(ctx context.Context) { p.builder = getEventDeltaBuilder(ctx, p.gs, p.userID, p.containerID, p.options) } -func (p *eventDeltaPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { +func (p *eventDeltaIDPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { return toValues[models.Eventable](pl) } @@ -229,12 +310,12 @@ func (c Events) GetAddedAndRemovedItemIDs( ) ([]string, []string, DeltaUpdate, error) { ctx = clues.Add(ctx, "container_id", containerID) - pager, err := c.NewEventPager(ctx, userID, containerID, immutableIDs) + pager, err := c.NewEventIDsPager(ctx, userID, containerID, immutableIDs) if err != nil { return nil, nil, DeltaUpdate{}, graph.Wrap(ctx, err, "creating non-delta pager") } - deltaPager, err := c.NewEventDeltaPager(ctx, userID, containerID, oldDelta, immutableIDs) + deltaPager, err := c.NewEventDeltaIDsPager(ctx, userID, containerID, oldDelta, immutableIDs) if err != nil { return nil, nil, DeltaUpdate{}, graph.Wrap(ctx, err, "creating delta pager") } diff --git a/src/pkg/services/m365/api/events_pager_test.go b/src/pkg/services/m365/api/events_pager_test.go new file mode 100644 index 000000000..e95f933d8 --- /dev/null +++ b/src/pkg/services/m365/api/events_pager_test.go @@ -0,0 +1,73 @@ +package api_test + +import ( + "testing" + + "github.com/alcionai/clues" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/pkg/services/m365/api" +) + +type EventsPagerIntgSuite struct { + tester.Suite + cts clientTesterSetup +} + +func TestEventsPagerIntgSuite(t *testing.T) { + suite.Run(t, &EventsPagerIntgSuite{ + Suite: tester.NewIntegrationSuite( + t, + [][]string{tester.M365AcctCredEnvs}), + }) +} + +func (suite *EventsPagerIntgSuite) SetupSuite() { + suite.cts = newClientTesterSetup(suite.T()) +} + +func (suite *EventsPagerIntgSuite) TestGetItemsInContainerByCollisionKey() { + t := suite.T() + ac := suite.cts.ac.Events() + + ctx, flush := tester.NewContext(t) + defer flush() + + container, err := ac.GetContainerByID(ctx, suite.cts.userID, "calendar") + require.NoError(t, err, clues.ToCore(err)) + + evts, err := ac.Stable. + Client(). + Users(). + ByUserId(suite.cts.userID). + Calendars(). + ByCalendarId(ptr.Val(container.GetId())). + Events(). + Get(ctx, nil) + require.NoError(t, err, clues.ToCore(err)) + + es := evts.GetValue() + expect := make([]string, 0, len(es)) + + for _, e := range es { + expect = append(expect, api.EventCollisionKey(e)) + } + + results, err := ac.GetItemsInContainerByCollisionKey(ctx, suite.cts.userID, "calendar") + require.NoError(t, err, clues.ToCore(err)) + require.Less(t, 0, len(results), "requires at least one result") + + for k, v := range results { + assert.NotEmpty(t, k, "all keys should be populated") + assert.NotEmpty(t, v, "all values should be populated") + } + + for _, e := range expect { + _, ok := results[e] + assert.Truef(t, ok, "expected results to contain collision key: %s", e) + } +} diff --git a/src/pkg/services/m365/api/events_test.go b/src/pkg/services/m365/api/events_test.go index 2daa66454..1d4c39cc9 100644 --- a/src/pkg/services/m365/api/events_test.go +++ b/src/pkg/services/m365/api/events_test.go @@ -11,9 +11,12 @@ import ( "github.com/stretchr/testify/suite" "github.com/alcionai/corso/src/internal/common/dttm" + "github.com/alcionai/corso/src/internal/common/ptr" exchMock "github.com/alcionai/corso/src/internal/m365/exchange/mock" "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/backup/details" + "github.com/alcionai/corso/src/pkg/control/testdata" ) type EventsAPIUnitSuite struct { @@ -212,3 +215,70 @@ func (suite *EventsAPIUnitSuite) TestBytesToEventable() { }) } } + +type EventsAPIIntgSuite struct { + tester.Suite + credentials account.M365Config + ac Client +} + +func TestEventsAPIntgSuite(t *testing.T) { + suite.Run(t, &EventsAPIIntgSuite{ + Suite: tester.NewIntegrationSuite( + t, + [][]string{tester.M365AcctCredEnvs}), + }) +} + +func (suite *EventsAPIIntgSuite) SetupSuite() { + t := suite.T() + + a := tester.NewM365Account(t) + m365, err := a.M365Config() + require.NoError(t, err, clues.ToCore(err)) + + suite.credentials = m365 + suite.ac, err = NewClient(m365) + require.NoError(t, err, clues.ToCore(err)) +} + +func (suite *EventsAPIIntgSuite) TestRestoreLargeAttachment() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + userID := tester.M365UserID(suite.T()) + + folderName := testdata.DefaultRestoreConfig("eventlargeattachmenttest").Location + evts := suite.ac.Events() + calendar, err := evts.CreateContainer(ctx, userID, folderName, "") + require.NoError(t, err, clues.ToCore(err)) + + tomorrow := time.Now().Add(24 * time.Hour) + evt := models.NewEvent() + sdtz := models.NewDateTimeTimeZone() + edtz := models.NewDateTimeTimeZone() + + evt.SetSubject(ptr.To("Event with attachment")) + sdtz.SetDateTime(ptr.To(dttm.Format(tomorrow))) + sdtz.SetTimeZone(ptr.To("UTC")) + edtz.SetDateTime(ptr.To(dttm.Format(tomorrow.Add(30 * time.Minute)))) + edtz.SetTimeZone(ptr.To("UTC")) + evt.SetStart(sdtz) + evt.SetEnd(edtz) + + item, err := evts.PostItem(ctx, userID, ptr.Val(calendar.GetId()), evt) + require.NoError(t, err, clues.ToCore(err)) + + id, err := evts.PostLargeAttachment( + ctx, + userID, + ptr.Val(calendar.GetId()), + ptr.Val(item.GetId()), + "raboganm", + []byte("mangobar"), + ) + require.NoError(t, err, clues.ToCore(err)) + require.NotEmpty(t, id, "empty id for large attachment") +} diff --git a/src/pkg/services/m365/api/helper_test.go b/src/pkg/services/m365/api/helper_test.go new file mode 100644 index 000000000..0d82db8be --- /dev/null +++ b/src/pkg/services/m365/api/helper_test.go @@ -0,0 +1,34 @@ +package api_test + +import ( + "testing" + + "github.com/alcionai/clues" + "github.com/stretchr/testify/require" + + "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/pkg/services/m365/api" +) + +type clientTesterSetup struct { + ac api.Client + userID string +} + +func newClientTesterSetup(t *testing.T) clientTesterSetup { + cts := clientTesterSetup{} + + ctx, flush := tester.NewContext(t) + defer flush() + + a := tester.NewM365Account(t) + creds, err := a.M365Config() + require.NoError(t, err, clues.ToCore(err)) + + cts.ac, err = api.NewClient(creds) + require.NoError(t, err, clues.ToCore(err)) + + cts.userID = tester.GetM365UserID(ctx) + + return cts +} diff --git a/src/pkg/services/m365/api/item_pager.go b/src/pkg/services/m365/api/item_pager.go index 00a93ea13..ef54b1a3d 100644 --- a/src/pkg/services/m365/api/item_pager.go +++ b/src/pkg/services/m365/api/item_pager.go @@ -61,27 +61,48 @@ func (e EmptyDeltaLinker[T]) GetValue() []T { } // --------------------------------------------------------------------------- -// generic handler for paging item ids in a container +// generic handler for non-delta item paging in a container // --------------------------------------------------------------------------- -type itemPager interface { +type itemPager[T any] interface { // getPage get a page with the specified options from graph - getPage(context.Context) (DeltaPageLinker, error) + getPage(context.Context) (PageLinkValuer[T], error) // setNext is used to pass in the next url got from graph setNext(string) - // reset is used to clear delta url in delta pagers. When - // reset is called, we reset the state(delta url) that we - // currently have and start a new delta query without the token. - reset(context.Context) - // valuesIn gets us the values in a page - valuesIn(PageLinker) ([]getIDAndAddtler, error) } -type getIDAndAddtler interface { - GetId() *string - GetAdditionalData() map[string]any +func enumerateItems[T any]( + ctx context.Context, + pager itemPager[T], +) ([]T, error) { + var ( + result = make([]T, 0) + // stubbed initial value to ensure we enter the loop. + nextLink = "do-while" + ) + + for len(nextLink) > 0 { + // get the next page of data, check for standard errors + resp, err := pager.getPage(ctx) + if err != nil { + return nil, graph.Stack(ctx, err) + } + + result = append(result, resp.GetValue()...) + nextLink = NextLink(resp) + + pager.setNext(nextLink) + } + + logger.Ctx(ctx).Infow("completed enumeration", "count", len(result)) + + return result, nil } +// --------------------------------------------------------------------------- +// generic handler for delta-based ittem paging in a container +// --------------------------------------------------------------------------- + // uses a models interface compliant with { GetValues() []T } // to transform its results into a slice of getIDer interfaces. // Generics used here to handle the variation of msoft interfaces @@ -110,16 +131,34 @@ func toValues[T any](a any) ([]getIDAndAddtler, error) { return r, nil } +type itemIDPager interface { + // getPage get a page with the specified options from graph + getPage(context.Context) (DeltaPageLinker, error) + // setNext is used to pass in the next url got from graph + setNext(string) + // reset is used to clear delta url in delta pagers. When + // reset is called, we reset the state(delta url) that we + // currently have and start a new delta query without the token. + reset(context.Context) + // valuesIn gets us the values in a page + valuesIn(PageLinker) ([]getIDAndAddtler, error) +} + +type getIDAndAddtler interface { + GetId() *string + GetAdditionalData() map[string]any +} + func getAddedAndRemovedItemIDs( ctx context.Context, service graph.Servicer, - pager itemPager, - deltaPager itemPager, + pager itemIDPager, + deltaPager itemIDPager, oldDelta string, canMakeDeltaQueries bool, ) ([]string, []string, DeltaUpdate, error) { var ( - pgr itemPager + pgr itemIDPager resetDelta bool ) @@ -161,17 +200,16 @@ func getAddedAndRemovedItemIDs( // generic controller for retrieving all item ids in a container. func getItemsAddedAndRemovedFromContainer( ctx context.Context, - pager itemPager, + pager itemIDPager, ) ([]string, []string, string, error) { var ( addedIDs = []string{} removedIDs = []string{} deltaURL string + itemCount int + page int ) - itemCount := 0 - page := 0 - for { // get the next page of data, check for standard errors resp, err := pager.getPage(ctx) diff --git a/src/pkg/services/m365/api/item_pager_test.go b/src/pkg/services/m365/api/item_pager_test.go index 4c6dbfbeb..0ba312b51 100644 --- a/src/pkg/services/m365/api/item_pager_test.go +++ b/src/pkg/services/m365/api/item_pager_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/alcionai/clues" "github.com/microsoftgraph/msgraph-sdk-go/models/odataerrors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,6 +20,8 @@ import ( // mock impls & stubs // --------------------------------------------------------------------------- +// next and delta links + type nextLink struct { nextLink *string } @@ -36,6 +39,8 @@ func (l deltaNextLink) GetOdataDeltaLink() *string { return l.deltaLink } +// mock values + type testPagerValue struct { id string removed bool @@ -50,7 +55,11 @@ func (v testPagerValue) GetAdditionalData() map[string]any { return map[string]any{} } -type testPage struct{} +// mock page + +type testPage struct { + values []any +} func (p testPage) GetOdataNextLink() *string { // no next, just one page @@ -62,9 +71,33 @@ func (p testPage) GetOdataDeltaLink() *string { return ptr.To("") } -var _ itemPager = &testPager{} +func (p testPage) GetValue() []any { + return p.values +} + +// mock item pager + +var _ itemPager[any] = &testPager{} type testPager struct { + t *testing.T + pager testPage + pageErr error +} + +//lint:ignore U1000 False Positive +func (p *testPager) getPage(ctx context.Context) (PageLinkValuer[any], error) { + return p.pager, p.pageErr +} + +//lint:ignore U1000 False Positive +func (p *testPager) setNext(nextLink string) {} + +// mock id pager + +var _ itemIDPager = &testIDsPager{} + +type testIDsPager struct { t *testing.T added []string removed []string @@ -72,7 +105,7 @@ type testPager struct { needsReset bool } -func (p *testPager) getPage(ctx context.Context) (DeltaPageLinker, error) { +func (p *testIDsPager) getPage(ctx context.Context) (DeltaPageLinker, error) { if p.errorCode != "" { ierr := odataerrors.NewMainError() ierr.SetCode(&p.errorCode) @@ -85,8 +118,8 @@ func (p *testPager) getPage(ctx context.Context) (DeltaPageLinker, error) { return testPage{}, nil } -func (p *testPager) setNext(string) {} -func (p *testPager) reset(context.Context) { +func (p *testIDsPager) setNext(string) {} +func (p *testIDsPager) reset(context.Context) { if !p.needsReset { require.Fail(p.t, "reset should not be called") } @@ -95,7 +128,7 @@ func (p *testPager) reset(context.Context) { p.errorCode = "" } -func (p *testPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { +func (p *testIDsPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { items := []getIDAndAddtler{} for _, id := range p.added { @@ -121,11 +154,69 @@ func TestItemPagerUnitSuite(t *testing.T) { suite.Run(t, &ItemPagerUnitSuite{Suite: tester.NewUnitSuite(t)}) } +func (suite *ItemPagerUnitSuite) TestEnumerateItems() { + tests := []struct { + name string + getPager func(*testing.T, context.Context) itemPager[any] + expect []any + expectErr require.ErrorAssertionFunc + }{ + { + name: "happy path", + getPager: func( + t *testing.T, + ctx context.Context, + ) itemPager[any] { + return &testPager{ + t: t, + pager: testPage{[]any{"foo", "bar"}}, + } + }, + expect: []any{"foo", "bar"}, + expectErr: require.NoError, + }, + { + name: "next page err", + getPager: func( + t *testing.T, + ctx context.Context, + ) itemPager[any] { + return &testPager{ + t: t, + pageErr: assert.AnError, + } + }, + expect: nil, + expectErr: require.Error, + }, + } + + for _, test := range tests { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + result, err := enumerateItems(ctx, test.getPager(t, ctx)) + test.expectErr(t, err, clues.ToCore(err)) + + require.EqualValues(t, test.expect, result) + }) + } +} + func (suite *ItemPagerUnitSuite) TestGetAddedAndRemovedItemIDs() { tests := []struct { - name string - pagerGetter func(context.Context, graph.Servicer, string, string, bool) (itemPager, error) - deltaPagerGetter func(context.Context, graph.Servicer, string, string, string, bool) (itemPager, error) + name string + pagerGetter func(*testing.T, context.Context, graph.Servicer, string, string, bool) (itemIDPager, error) + deltaPagerGetter func( + *testing.T, + context.Context, + graph.Servicer, + string, string, string, + bool, + ) (itemIDPager, error) added []string removed []string deltaUpdate DeltaUpdate @@ -135,25 +226,27 @@ func (suite *ItemPagerUnitSuite) TestGetAddedAndRemovedItemIDs() { { name: "no prev delta", pagerGetter: func( + t *testing.T, ctx context.Context, gs graph.Servicer, user string, directory string, immutableIDs bool, - ) (itemPager, error) { + ) (itemIDPager, error) { // this should not be called return nil, assert.AnError }, deltaPagerGetter: func( + t *testing.T, ctx context.Context, gs graph.Servicer, user string, directory string, delta string, immutableIDs bool, - ) (itemPager, error) { - return &testPager{ - t: suite.T(), + ) (itemIDPager, error) { + return &testIDsPager{ + t: t, added: []string{"uno", "dos"}, removed: []string{"tres", "quatro"}, }, nil @@ -166,25 +259,27 @@ func (suite *ItemPagerUnitSuite) TestGetAddedAndRemovedItemIDs() { { name: "with prev delta", pagerGetter: func( + t *testing.T, ctx context.Context, gs graph.Servicer, user string, directory string, immutableIDs bool, - ) (itemPager, error) { + ) (itemIDPager, error) { // this should not be called return nil, assert.AnError }, deltaPagerGetter: func( + t *testing.T, ctx context.Context, gs graph.Servicer, user string, directory string, delta string, immutableIDs bool, - ) (itemPager, error) { - return &testPager{ - t: suite.T(), + ) (itemIDPager, error) { + return &testIDsPager{ + t: t, added: []string{"uno", "dos"}, removed: []string{"tres", "quatro"}, }, nil @@ -198,25 +293,27 @@ func (suite *ItemPagerUnitSuite) TestGetAddedAndRemovedItemIDs() { { name: "delta expired", pagerGetter: func( + t *testing.T, ctx context.Context, gs graph.Servicer, user string, directory string, immutableIDs bool, - ) (itemPager, error) { + ) (itemIDPager, error) { // this should not be called return nil, assert.AnError }, deltaPagerGetter: func( + t *testing.T, ctx context.Context, gs graph.Servicer, user string, directory string, delta string, immutableIDs bool, - ) (itemPager, error) { - return &testPager{ - t: suite.T(), + ) (itemIDPager, error) { + return &testIDsPager{ + t: t, added: []string{"uno", "dos"}, removed: []string{"tres", "quatro"}, errorCode: "SyncStateNotFound", @@ -232,27 +329,29 @@ func (suite *ItemPagerUnitSuite) TestGetAddedAndRemovedItemIDs() { { name: "quota exceeded", pagerGetter: func( + t *testing.T, ctx context.Context, gs graph.Servicer, user string, directory string, immutableIDs bool, - ) (itemPager, error) { - return &testPager{ - t: suite.T(), + ) (itemIDPager, error) { + return &testIDsPager{ + t: t, added: []string{"uno", "dos"}, removed: []string{"tres", "quatro"}, }, nil }, deltaPagerGetter: func( + t *testing.T, ctx context.Context, gs graph.Servicer, user string, directory string, delta string, immutableIDs bool, - ) (itemPager, error) { - return &testPager{errorCode: "ErrorQuotaExceeded"}, nil + ) (itemIDPager, error) { + return &testIDsPager{errorCode: "ErrorQuotaExceeded"}, nil }, added: []string{"uno", "dos"}, removed: []string{"tres", "quatro"}, @@ -268,8 +367,8 @@ func (suite *ItemPagerUnitSuite) TestGetAddedAndRemovedItemIDs() { ctx, flush := tester.NewContext(t) defer flush() - pager, _ := tt.pagerGetter(ctx, graph.Service{}, "user", "directory", false) - deltaPager, _ := tt.deltaPagerGetter(ctx, graph.Service{}, "user", "directory", tt.delta, false) + pager, _ := tt.pagerGetter(t, ctx, graph.Service{}, "user", "directory", false) + deltaPager, _ := tt.deltaPagerGetter(t, ctx, graph.Service{}, "user", "directory", tt.delta, false) added, removed, deltaUpdate, err := getAddedAndRemovedItemIDs( ctx, diff --git a/src/pkg/services/m365/api/mail.go b/src/pkg/services/m365/api/mail.go index f08cbb7c5..ab371074b 100644 --- a/src/pkg/services/m365/api/mail.go +++ b/src/pkg/services/m365/api/mail.go @@ -63,6 +63,23 @@ func (c Mail) CreateMailFolder( return mdl, nil } +func (c Mail) DeleteMailFolder( + ctx context.Context, + userID, id string, +) error { + err := c.Stable.Client(). + Users(). + ByUserId(userID). + MailFolders(). + ByMailFolderId(id). + Delete(ctx, nil) + if err != nil { + return graph.Wrap(ctx, err, "deleting mail folder") + } + + return nil +} + func (c Mail) CreateContainer( ctx context.Context, userID, containerName, parentContainerID string, @@ -407,14 +424,9 @@ func (c Mail) PostSmallAttachment( func (c Mail) PostLargeAttachment( ctx context.Context, userID, containerID, parentItemID, itemName string, - size int64, - body models.Attachmentable, -) (models.UploadSessionable, error) { - bs, err := GetAttachmentContent(body) - if err != nil { - return nil, clues.Wrap(err, "serializing attachment content").WithClues(ctx) - } - + content []byte, +) (string, error) { + size := int64(len(content)) session := users.NewItemMailFoldersItemMessagesItemAttachmentsCreateUploadSessionPostRequestBody() session.SetAttachmentItem(makeSessionAttachment(itemName, size)) @@ -430,19 +442,19 @@ func (c Mail) PostLargeAttachment( CreateUploadSession(). Post(ctx, session, nil) if err != nil { - return nil, graph.Wrap(ctx, err, "uploading large mail attachment") + return "", graph.Wrap(ctx, err, "uploading large mail attachment") } url := ptr.Val(us.GetUploadUrl()) w := graph.NewLargeItemWriter(parentItemID, url, size) copyBuffer := make([]byte, graph.AttachmentChunkSize) - _, err = io.CopyBuffer(w, bytes.NewReader(bs), copyBuffer) + _, err = io.CopyBuffer(w, bytes.NewReader(content), copyBuffer) if err != nil { - return nil, clues.Wrap(err, "buffering large attachment content").WithClues(ctx) + return "", clues.Wrap(err, "buffering large attachment content").WithClues(ctx) } - return us, nil + return w.ID, nil } // --------------------------------------------------------------------------- @@ -528,3 +540,17 @@ func UnwrapEmailAddress(contact models.Recipientable) string { return ptr.Val(contact.GetEmailAddress().GetAddress()) } + +func mailCollisionKeyProps() []string { + return idAnd("subject") +} + +// MailCollisionKey constructs a key from the messageable's subject, sender, and recipients (to, cc, bcc). +// collision keys are used to identify duplicate item conflicts for handling advanced restoration config. +func MailCollisionKey(item models.Messageable) string { + if item == nil { + return "" + } + + return ptr.Val(item.GetSubject()) +} diff --git a/src/pkg/services/m365/api/mail_pager.go b/src/pkg/services/m365/api/mail_pager.go index 71ce09663..075d02cad 100644 --- a/src/pkg/services/m365/api/mail_pager.go +++ b/src/pkg/services/m365/api/mail_pager.go @@ -121,22 +121,76 @@ func (c Mail) EnumerateContainers( // item pager // --------------------------------------------------------------------------- -var _ itemPager = &mailPager{} +var _ itemPager[models.Messageable] = &mailPageCtrl{} -type mailPager struct { +type mailPageCtrl struct { gs graph.Servicer builder *users.ItemMailFoldersItemMessagesRequestBuilder options *users.ItemMailFoldersItemMessagesRequestBuilderGetRequestConfiguration } func (c Mail) NewMailPager( + userID, containerID string, + selectProps ...string, +) itemPager[models.Messageable] { + options := &users.ItemMailFoldersItemMessagesRequestBuilderGetRequestConfiguration{ + Headers: newPreferHeaders(preferPageSize(maxNonDeltaPageSize)), + QueryParameters: &users.ItemMailFoldersItemMessagesRequestBuilderGetQueryParameters{ + Top: ptr.To[int32](maxNonDeltaPageSize), + }, + } + + if len(selectProps) > 0 { + options.QueryParameters.Select = selectProps + } + + builder := c.Stable. + Client(). + Users(). + ByUserId(userID). + MailFolders(). + ByMailFolderId(containerID). + Messages() + + return &mailPageCtrl{c.Stable, builder, options} +} + +//lint:ignore U1000 False Positive +func (p *mailPageCtrl) getPage(ctx context.Context) (PageLinkValuer[models.Messageable], error) { + page, err := p.builder.Get(ctx, p.options) + if err != nil { + return nil, graph.Stack(ctx, err) + } + + return EmptyDeltaLinker[models.Messageable]{PageLinkValuer: page}, nil +} + +//lint:ignore U1000 False Positive +func (p *mailPageCtrl) setNext(nextLink string) { + p.builder = users.NewItemMailFoldersItemMessagesRequestBuilder(nextLink, p.gs.Adapter()) +} + +// --------------------------------------------------------------------------- +// item ID pager +// --------------------------------------------------------------------------- + +var _ itemIDPager = &mailIDPager{} + +type mailIDPager struct { + gs graph.Servicer + builder *users.ItemMailFoldersItemMessagesRequestBuilder + options *users.ItemMailFoldersItemMessagesRequestBuilderGetRequestConfiguration +} + +func (c Mail) NewMailIDsPager( ctx context.Context, userID, containerID string, immutableIDs bool, -) itemPager { +) itemIDPager { config := &users.ItemMailFoldersItemMessagesRequestBuilderGetRequestConfiguration{ QueryParameters: &users.ItemMailFoldersItemMessagesRequestBuilderGetQueryParameters{ Select: idAnd("isRead"), + Top: ptr.To[int32](maxNonDeltaPageSize), }, Headers: newPreferHeaders(preferPageSize(maxNonDeltaPageSize), preferImmutableIDs(immutableIDs)), } @@ -149,10 +203,10 @@ func (c Mail) NewMailPager( ByMailFolderId(containerID). Messages() - return &mailPager{c.Stable, builder, config} + return &mailIDPager{c.Stable, builder, config} } -func (p *mailPager) getPage(ctx context.Context) (DeltaPageLinker, error) { +func (p *mailIDPager) getPage(ctx context.Context) (DeltaPageLinker, error) { page, err := p.builder.Get(ctx, p.options) if err != nil { return nil, graph.Stack(ctx, err) @@ -161,24 +215,45 @@ func (p *mailPager) getPage(ctx context.Context) (DeltaPageLinker, error) { return EmptyDeltaLinker[models.Messageable]{PageLinkValuer: page}, nil } -func (p *mailPager) setNext(nextLink string) { +func (p *mailIDPager) setNext(nextLink string) { p.builder = users.NewItemMailFoldersItemMessagesRequestBuilder(nextLink, p.gs.Adapter()) } // non delta pagers don't have reset -func (p *mailPager) reset(context.Context) {} +func (p *mailIDPager) reset(context.Context) {} -func (p *mailPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { +func (p *mailIDPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { return toValues[models.Messageable](pl) } +func (c Mail) GetItemsInContainerByCollisionKey( + ctx context.Context, + userID, containerID string, +) (map[string]string, error) { + ctx = clues.Add(ctx, "container_id", containerID) + pager := c.NewMailPager(userID, containerID, mailCollisionKeyProps()...) + + items, err := enumerateItems(ctx, pager) + if err != nil { + return nil, graph.Wrap(ctx, err, "enumerating mail") + } + + m := map[string]string{} + + for _, item := range items { + m[MailCollisionKey(item)] = ptr.Val(item.GetId()) + } + + return m, nil +} + // --------------------------------------------------------------------------- -// delta item pager +// delta item ID pager // --------------------------------------------------------------------------- -var _ itemPager = &mailDeltaPager{} +var _ itemIDPager = &mailDeltaIDPager{} -type mailDeltaPager struct { +type mailDeltaIDPager struct { gs graph.Servicer userID string containerID string @@ -204,14 +279,15 @@ func getMailDeltaBuilder( return builder } -func (c Mail) NewMailDeltaPager( +func (c Mail) NewMailDeltaIDsPager( ctx context.Context, userID, containerID, oldDelta string, immutableIDs bool, -) itemPager { +) itemIDPager { config := &users.ItemMailFoldersItemMessagesDeltaRequestBuilderGetRequestConfiguration{ QueryParameters: &users.ItemMailFoldersItemMessagesDeltaRequestBuilderGetQueryParameters{ Select: idAnd("isRead"), + Top: ptr.To[int32](maxDeltaPageSize), }, Headers: newPreferHeaders(preferPageSize(maxDeltaPageSize), preferImmutableIDs(immutableIDs)), } @@ -224,10 +300,10 @@ func (c Mail) NewMailDeltaPager( builder = getMailDeltaBuilder(ctx, c.Stable, userID, containerID, config) } - return &mailDeltaPager{c.Stable, userID, containerID, builder, config} + return &mailDeltaIDPager{c.Stable, userID, containerID, builder, config} } -func (p *mailDeltaPager) getPage(ctx context.Context) (DeltaPageLinker, error) { +func (p *mailDeltaIDPager) getPage(ctx context.Context) (DeltaPageLinker, error) { page, err := p.builder.Get(ctx, p.options) if err != nil { return nil, graph.Stack(ctx, err) @@ -236,11 +312,11 @@ func (p *mailDeltaPager) getPage(ctx context.Context) (DeltaPageLinker, error) { return page, nil } -func (p *mailDeltaPager) setNext(nextLink string) { +func (p *mailDeltaIDPager) setNext(nextLink string) { p.builder = users.NewItemMailFoldersItemMessagesDeltaRequestBuilder(nextLink, p.gs.Adapter()) } -func (p *mailDeltaPager) reset(ctx context.Context) { +func (p *mailDeltaIDPager) reset(ctx context.Context) { p.builder = p.gs. Client(). Users(). @@ -251,7 +327,7 @@ func (p *mailDeltaPager) reset(ctx context.Context) { Delta() } -func (p *mailDeltaPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { +func (p *mailDeltaIDPager) valuesIn(pl PageLinker) ([]getIDAndAddtler, error) { return toValues[models.Messageable](pl) } @@ -266,8 +342,8 @@ func (c Mail) GetAddedAndRemovedItemIDs( "category", selectors.ExchangeMail, "container_id", containerID) - pager := c.NewMailPager(ctx, userID, containerID, immutableIDs) - deltaPager := c.NewMailDeltaPager(ctx, userID, containerID, oldDelta, immutableIDs) + pager := c.NewMailIDsPager(ctx, userID, containerID, immutableIDs) + deltaPager := c.NewMailDeltaIDsPager(ctx, userID, containerID, oldDelta, immutableIDs) return getAddedAndRemovedItemIDs(ctx, c.Stable, pager, deltaPager, oldDelta, canMakeDeltaQueries) } diff --git a/src/pkg/services/m365/api/mail_pager_test.go b/src/pkg/services/m365/api/mail_pager_test.go new file mode 100644 index 000000000..0fde70163 --- /dev/null +++ b/src/pkg/services/m365/api/mail_pager_test.go @@ -0,0 +1,73 @@ +package api_test + +import ( + "testing" + + "github.com/alcionai/clues" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/pkg/services/m365/api" +) + +type MailPagerIntgSuite struct { + tester.Suite + cts clientTesterSetup +} + +func TestMailPagerIntgSuite(t *testing.T) { + suite.Run(t, &MailPagerIntgSuite{ + Suite: tester.NewIntegrationSuite( + t, + [][]string{tester.M365AcctCredEnvs}), + }) +} + +func (suite *MailPagerIntgSuite) SetupSuite() { + suite.cts = newClientTesterSetup(suite.T()) +} + +func (suite *MailPagerIntgSuite) TestGetItemsInContainerByCollisionKey() { + t := suite.T() + ac := suite.cts.ac.Mail() + + ctx, flush := tester.NewContext(t) + defer flush() + + container, err := ac.GetContainerByID(ctx, suite.cts.userID, "inbox") + require.NoError(t, err, clues.ToCore(err)) + + msgs, err := ac.Stable. + Client(). + Users(). + ByUserId(suite.cts.userID). + MailFolders(). + ByMailFolderId(ptr.Val(container.GetId())). + Messages(). + Get(ctx, nil) + require.NoError(t, err, clues.ToCore(err)) + + ms := msgs.GetValue() + expect := make([]string, 0, len(ms)) + + for _, m := range ms { + expect = append(expect, api.MailCollisionKey(m)) + } + + results, err := ac.GetItemsInContainerByCollisionKey(ctx, suite.cts.userID, "inbox") + require.NoError(t, err, clues.ToCore(err)) + require.Less(t, 0, len(results), "requires at least one result") + + for k, v := range results { + assert.NotEmpty(t, k, "all keys should be populated") + assert.NotEmpty(t, v, "all values should be populated") + } + + for _, e := range expect { + _, ok := results[e] + assert.Truef(t, ok, "expected results to contain collision key: %s", e) + } +} diff --git a/src/pkg/services/m365/api/mail_test.go b/src/pkg/services/m365/api/mail_test.go index 236bc9b4c..a328de9c1 100644 --- a/src/pkg/services/m365/api/mail_test.go +++ b/src/pkg/services/m365/api/mail_test.go @@ -19,6 +19,7 @@ import ( "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/backup/details" + "github.com/alcionai/corso/src/pkg/control/testdata" "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/services/m365/api" "github.com/alcionai/corso/src/pkg/services/m365/api/mock" @@ -202,8 +203,7 @@ func TestMailAPIIntgSuite(t *testing.T) { suite.Run(t, &MailAPIIntgSuite{ Suite: tester.NewIntegrationSuite( t, - [][]string{tester.M365AcctCredEnvs}, - ), + [][]string{tester.M365AcctCredEnvs}), }) } @@ -218,7 +218,7 @@ func (suite *MailAPIIntgSuite) SetupSuite() { suite.ac, err = mock.NewClient(m365) require.NoError(t, err, clues.ToCore(err)) - suite.user = tester.M365UserID(suite.T()) + suite.user = tester.M365UserID(t) } func getJSONObject(t *testing.T, thing serialization.Parsable) map[string]interface{} { @@ -410,3 +410,34 @@ func (suite *MailAPIIntgSuite) TestHugeAttachmentListDownload() { }) } } + +func (suite *MailAPIIntgSuite) TestRestoreLargeAttachment() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + userID := tester.M365UserID(suite.T()) + + folderName := testdata.DefaultRestoreConfig("maillargeattachmenttest").Location + msgs := suite.ac.Mail() + mailfolder, err := msgs.CreateMailFolder(ctx, userID, folderName) + require.NoError(t, err, clues.ToCore(err)) + + msg := models.NewMessage() + msg.SetSubject(ptr.To("Mail with attachment")) + + item, err := msgs.PostItem(ctx, userID, ptr.Val(mailfolder.GetId()), msg) + require.NoError(t, err, clues.ToCore(err)) + + id, err := msgs.PostLargeAttachment( + ctx, + userID, + ptr.Val(mailfolder.GetId()), + ptr.Val(item.GetId()), + "raboganm", + []byte("mangobar"), + ) + require.NoError(t, err, clues.ToCore(err)) + require.NotEmpty(t, id, "empty id for large attachment") +} diff --git a/src/pkg/services/m365/api/users.go b/src/pkg/services/m365/api/users.go index 07d2430ac..286cc52ff 100644 --- a/src/pkg/services/m365/api/users.go +++ b/src/pkg/services/m365/api/users.go @@ -183,7 +183,7 @@ func (c Users) GetInfo(ctx context.Context, userID string) (*UserInfo, error) { // check whether the user is able to access their onedrive drive. // if they cannot, we can assume they are ineligible for onedrive backups. if _, err := c.GetDefaultDrive(ctx, userID); err != nil { - if !clues.HasLabel(err, graph.LabelsMysiteNotFound) { + if !clues.HasLabel(err, graph.LabelsMysiteNotFound) || clues.HasLabel(err, graph.LabelsNoSharePointLicense) { logger.CtxErr(ctx, err).Error("getting user's drive") return nil, graph.Wrap(ctx, err, "getting user's drive") } diff --git a/src/pkg/services/m365/m365.go b/src/pkg/services/m365/m365.go index 305a10bbf..ffffd3625 100644 --- a/src/pkg/services/m365/m365.go +++ b/src/pkg/services/m365/m365.go @@ -73,12 +73,12 @@ func UsersCompatNoInfo(ctx context.Context, acct account.Account) ([]*UserNoInfo // UserHasMailbox returns true if the user has an exchange mailbox enabled // false otherwise, and a nil pointer and an error in case of error func UserHasMailbox(ctx context.Context, acct account.Account, userID string) (bool, error) { - uapi, err := makeUserAPI(acct) + ac, err := makeAC(acct) if err != nil { - return false, clues.Wrap(err, "getting mailbox").WithClues(ctx) + return false, clues.Stack(err).WithClues(ctx) } - _, err = uapi.GetMailInbox(ctx, userID) + _, err = ac.Users().GetMailInbox(ctx, userID) if err != nil { // we consider this a non-error case, since it // answers the question the caller is asking. @@ -103,16 +103,20 @@ func UserHasMailbox(ctx context.Context, acct account.Account, userID string) (b // UserHasDrives returns true if the user has any drives // false otherwise, and a nil pointer and an error in case of error func UserHasDrives(ctx context.Context, acct account.Account, userID string) (bool, error) { - uapi, err := makeUserAPI(acct) + ac, err := makeAC(acct) if err != nil { - return false, clues.Wrap(err, "getting drives").WithClues(ctx) + return false, clues.Stack(err).WithClues(ctx) } - _, err = uapi.GetDefaultDrive(ctx, userID) + return checkUserHasDrives(ctx, ac.Users(), userID) +} + +func checkUserHasDrives(ctx context.Context, dgdd discovery.GetDefaultDriver, userID string) (bool, error) { + _, err := dgdd.GetDefaultDrive(ctx, userID) if err != nil { // we consider this a non-error case, since it // answers the question the caller is asking. - if clues.HasLabel(err, graph.LabelsMysiteNotFound) { + if clues.HasLabel(err, graph.LabelsMysiteNotFound) || clues.HasLabel(err, graph.LabelsNoSharePointLicense) { return false, nil } @@ -130,12 +134,12 @@ func UserHasDrives(ctx context.Context, acct account.Account, userID string) (bo // TODO: Remove this once we remove `Info` from `Users` and instead rely on the `GetUserInfo` API // to get user information func usersNoInfo(ctx context.Context, acct account.Account, errs *fault.Bus) ([]*UserNoInfo, error) { - uapi, err := makeUserAPI(acct) + ac, err := makeAC(acct) if err != nil { - return nil, clues.Wrap(err, "getting users").WithClues(ctx) + return nil, clues.Stack(err).WithClues(ctx) } - us, err := discovery.Users(ctx, uapi, errs) + us, err := discovery.Users(ctx, ac.Users(), errs) if err != nil { return nil, err } @@ -162,12 +166,12 @@ func usersNoInfo(ctx context.Context, acct account.Account, errs *fault.Bus) ([] // Users returns a list of users in the specified M365 tenant func Users(ctx context.Context, acct account.Account, errs *fault.Bus) ([]*User, error) { - uapi, err := makeUserAPI(acct) + ac, err := makeAC(acct) if err != nil { - return nil, clues.Wrap(err, "getting users").WithClues(ctx) + return nil, clues.Stack(err).WithClues(ctx) } - us, err := discovery.Users(ctx, uapi, errs) + us, err := discovery.Users(ctx, ac.Users(), errs) if err != nil { return nil, err } @@ -197,7 +201,7 @@ func Users(ctx context.Context, acct account.Account, errs *fault.Bus) ([]*User, func parseUser(item models.Userable) (*User, error) { if item.GetUserPrincipalName() == nil { return nil, clues.New("user missing principal name"). - With("user_id", *item.GetId()) // TODO: pii + With("user_id", ptr.Val(item.GetId())) } u := &User{ @@ -215,12 +219,12 @@ func GetUserInfo( acct account.Account, userID string, ) (*api.UserInfo, error) { - uapi, err := makeUserAPI(acct) + ac, err := makeAC(acct) if err != nil { - return nil, clues.Wrap(err, "getting user info").WithClues(ctx) + return nil, clues.Stack(err).WithClues(ctx) } - ui, err := discovery.UserInfo(ctx, uapi, userID) + ui, err := discovery.UserInfo(ctx, ac.Users(), userID) if err != nil { return nil, err } @@ -249,9 +253,26 @@ type Site struct { // Sites returns a list of Sites in a specified M365 tenant func Sites(ctx context.Context, acct account.Account, errs *fault.Bus) ([]*Site, error) { - sites, err := discovery.Sites(ctx, acct, errs) + ac, err := makeAC(acct) if err != nil { - return nil, clues.Wrap(err, "initializing M365 api connection") + return nil, clues.Stack(err).WithClues(ctx) + } + + return getAllSites(ctx, ac.Sites()) +} + +type getAllSiteser interface { + GetAll(ctx context.Context, errs *fault.Bus) ([]models.Siteable, error) +} + +func getAllSites(ctx context.Context, gas getAllSiteser) ([]*Site, error) { + sites, err := gas.GetAll(ctx, fault.New(true)) + if err != nil { + if clues.HasLabel(err, graph.LabelsNoSharePointLicense) { + return nil, clues.Stack(graph.ErrServiceNotEnabled, err) + } + + return nil, clues.Wrap(err, "retrieving sites") } ret := make([]*Site, 0, len(sites)) @@ -304,16 +325,16 @@ func SitesMap( // helpers // --------------------------------------------------------------------------- -func makeUserAPI(acct account.Account) (api.Users, error) { +func makeAC(acct account.Account) (api.Client, error) { creds, err := acct.M365Config() if err != nil { - return api.Users{}, clues.Wrap(err, "getting m365 account creds") + return api.Client{}, clues.Wrap(err, "getting m365 account creds") } cli, err := api.NewClient(creds) if err != nil { - return api.Users{}, clues.Wrap(err, "constructing api client") + return api.Client{}, clues.Wrap(err, "constructing api client") } - return cli.Users(), nil + return cli, nil } diff --git a/src/pkg/services/m365/m365_test.go b/src/pkg/services/m365/m365_test.go index 27a385d5d..e77338dea 100644 --- a/src/pkg/services/m365/m365_test.go +++ b/src/pkg/services/m365/m365_test.go @@ -1,17 +1,22 @@ -package m365_test +package m365 import ( + "context" "testing" "github.com/alcionai/clues" + "github.com/microsoftgraph/msgraph-sdk-go/models" + "github.com/microsoftgraph/msgraph-sdk-go/models/odataerrors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/internal/m365/discovery" + "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/path" - "github.com/alcionai/corso/src/pkg/services/m365" ) type M365IntegrationSuite struct { @@ -22,8 +27,7 @@ func TestM365IntegrationSuite(t *testing.T) { suite.Run(t, &M365IntegrationSuite{ Suite: tester.NewIntegrationSuite( t, - [][]string{tester.M365AcctCredEnvs}, - ), + [][]string{tester.M365AcctCredEnvs}), }) } @@ -35,7 +39,7 @@ func (suite *M365IntegrationSuite) TestUsers() { acct := tester.NewM365Account(suite.T()) - users, err := m365.Users(ctx, acct, fault.New(true)) + users, err := Users(ctx, acct, fault.New(true)) assert.NoError(t, err, clues.ToCore(err)) assert.NotEmpty(t, users) @@ -59,7 +63,7 @@ func (suite *M365IntegrationSuite) TestUsersCompat_HasNoInfo() { acct := tester.NewM365Account(suite.T()) - users, err := m365.UsersCompatNoInfo(ctx, acct) + users, err := UsersCompatNoInfo(ctx, acct) assert.NoError(t, err, clues.ToCore(err)) assert.NotEmpty(t, users) @@ -85,7 +89,7 @@ func (suite *M365IntegrationSuite) TestGetUserInfo() { uid = tester.M365UserID(t) ) - info, err := m365.GetUserInfo(ctx, acct, uid) + info, err := GetUserInfo(ctx, acct, uid) require.NoError(t, err, clues.ToCore(err)) require.NotNil(t, info) require.NotEmpty(t, info) @@ -112,7 +116,7 @@ func (suite *M365IntegrationSuite) TestUserHasMailbox() { uid = tester.M365UserID(t) ) - enabled, err := m365.UserHasMailbox(ctx, acct, uid) + enabled, err := UserHasMailbox(ctx, acct, uid) require.NoError(t, err, clues.ToCore(err)) assert.True(t, enabled) } @@ -128,7 +132,7 @@ func (suite *M365IntegrationSuite) TestUserHasDrive() { uid = tester.M365UserID(t) ) - enabled, err := m365.UserHasDrives(ctx, acct, uid) + enabled, err := UserHasDrives(ctx, acct, uid) require.NoError(t, err, clues.ToCore(err)) assert.True(t, enabled) } @@ -139,14 +143,14 @@ func (suite *M365IntegrationSuite) TestSites() { ctx, flush := tester.NewContext(t) defer flush() - acct := tester.NewM365Account(suite.T()) + acct := tester.NewM365Account(t) - sites, err := m365.Sites(ctx, acct, fault.New(true)) + sites, err := Sites(ctx, acct, fault.New(true)) assert.NoError(t, err, clues.ToCore(err)) assert.NotEmpty(t, sites) for _, s := range sites { - suite.Run("site", func() { + suite.Run("site_"+s.ID, func() { t := suite.T() assert.NotEmpty(t, s.WebURL) assert.NotEmpty(t, s.ID) @@ -154,3 +158,204 @@ func (suite *M365IntegrationSuite) TestSites() { }) } } + +type m365UnitSuite struct { + tester.Suite +} + +func TestM365UnitSuite(t *testing.T) { + suite.Run(t, &m365UnitSuite{Suite: tester.NewUnitSuite(t)}) +} + +type mockDGDD struct { + response models.Driveable + err error +} + +func (m mockDGDD) GetDefaultDrive(context.Context, string) (models.Driveable, error) { + return m.response, m.err +} + +func (suite *m365UnitSuite) TestCheckUserHasDrives() { + table := []struct { + name string + mock func(context.Context) discovery.GetDefaultDriver + expect assert.BoolAssertionFunc + expectErr func(*testing.T, error) + }{ + { + name: "ok", + mock: func(ctx context.Context) discovery.GetDefaultDriver { + return mockDGDD{models.NewDrive(), nil} + }, + expect: assert.True, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "mysite not found", + mock: func(ctx context.Context) discovery.GetDefaultDriver { + odErr := odataerrors.NewODataError() + merr := odataerrors.NewMainError() + merr.SetCode(ptr.To("code")) + merr.SetMessage(ptr.To(string(graph.MysiteNotFound))) + odErr.SetError(merr) + + return mockDGDD{nil, graph.Stack(ctx, odErr)} + }, + expect: assert.False, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "mysite URL not found", + mock: func(ctx context.Context) discovery.GetDefaultDriver { + odErr := odataerrors.NewODataError() + merr := odataerrors.NewMainError() + merr.SetCode(ptr.To("code")) + merr.SetMessage(ptr.To(string(graph.MysiteURLNotFound))) + odErr.SetError(merr) + + return mockDGDD{nil, graph.Stack(ctx, odErr)} + }, + expect: assert.False, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "no sharepoint license", + mock: func(ctx context.Context) discovery.GetDefaultDriver { + odErr := odataerrors.NewODataError() + merr := odataerrors.NewMainError() + merr.SetCode(ptr.To("code")) + merr.SetMessage(ptr.To(string(graph.NoSPLicense))) + odErr.SetError(merr) + + return mockDGDD{nil, graph.Stack(ctx, odErr)} + }, + expect: assert.False, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "user not found", + mock: func(ctx context.Context) discovery.GetDefaultDriver { + odErr := odataerrors.NewODataError() + merr := odataerrors.NewMainError() + merr.SetCode(ptr.To(string(graph.RequestResourceNotFound))) + merr.SetMessage(ptr.To("message")) + odErr.SetError(merr) + + return mockDGDD{nil, graph.Stack(ctx, odErr)} + }, + expect: assert.False, + expectErr: func(t *testing.T, err error) { + assert.Error(t, err, clues.ToCore(err)) + }, + }, + { + name: "arbitrary error", + mock: func(ctx context.Context) discovery.GetDefaultDriver { + odErr := odataerrors.NewODataError() + merr := odataerrors.NewMainError() + merr.SetCode(ptr.To("code")) + merr.SetMessage(ptr.To("message")) + odErr.SetError(merr) + + return mockDGDD{nil, graph.Stack(ctx, odErr)} + }, + expect: assert.False, + expectErr: func(t *testing.T, err error) { + assert.Error(t, err, clues.ToCore(err)) + }, + }, + } + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + dgdd := test.mock(ctx) + + ok, err := checkUserHasDrives(ctx, dgdd, "foo") + test.expect(t, ok, "has drives flag") + test.expectErr(t, err) + }) + } +} + +type mockGAS struct { + response []models.Siteable + err error +} + +func (m mockGAS) GetAll(context.Context, *fault.Bus) ([]models.Siteable, error) { + return m.response, m.err +} + +func (suite *m365UnitSuite) TestGetAllSites() { + table := []struct { + name string + mock func(context.Context) getAllSiteser + expectErr func(*testing.T, error) + }{ + { + name: "ok", + mock: func(ctx context.Context) getAllSiteser { + return mockGAS{[]models.Siteable{}, nil} + }, + expectErr: func(t *testing.T, err error) { + assert.NoError(t, err, clues.ToCore(err)) + }, + }, + { + name: "no sharepoint license", + mock: func(ctx context.Context) getAllSiteser { + odErr := odataerrors.NewODataError() + merr := odataerrors.NewMainError() + merr.SetCode(ptr.To("code")) + merr.SetMessage(ptr.To(string(graph.NoSPLicense))) + odErr.SetError(merr) + + return mockGAS{nil, graph.Stack(ctx, odErr)} + }, + expectErr: func(t *testing.T, err error) { + assert.ErrorIs(t, err, graph.ErrServiceNotEnabled, clues.ToCore(err)) + }, + }, + { + name: "arbitrary error", + mock: func(ctx context.Context) getAllSiteser { + odErr := odataerrors.NewODataError() + merr := odataerrors.NewMainError() + merr.SetCode(ptr.To("code")) + merr.SetMessage(ptr.To("message")) + odErr.SetError(merr) + + return mockGAS{nil, graph.Stack(ctx, odErr)} + }, + expectErr: func(t *testing.T, err error) { + assert.Error(t, err, clues.ToCore(err)) + }, + }, + } + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + gas := test.mock(ctx) + + _, err := getAllSites(ctx, gas) + test.expectErr(t, err) + }) + } +}