diff --git a/src/cli/export/export.go b/src/cli/export/export.go index 2e205ec98..0a1463d2b 100644 --- a/src/cli/export/export.go +++ b/src/cli/export/export.go @@ -61,6 +61,10 @@ func runExport( sel selectors.Selector, backupID, serviceName string, ) error { + if err := utils.ValidateExportConfigFlags(&ueco); err != nil { + return Only(ctx, err) + } + r, _, _, _, err := utils.GetAccountAndConnectWithOverrides( ctx, cmd, diff --git a/src/cli/flags/export.go b/src/cli/flags/export.go index 824662c87..5558f7d63 100644 --- a/src/cli/flags/export.go +++ b/src/cli/flags/export.go @@ -1,13 +1,7 @@ package flags import ( - "strings" - - "github.com/alcionai/clues" "github.com/spf13/cobra" - - "github.com/alcionai/corso/src/pkg/control" - "github.com/alcionai/corso/src/pkg/filters" ) const ( @@ -27,20 +21,3 @@ func AddExportConfigFlags(cmd *cobra.Command) { fs.StringVar(&FormatFV, FormatFN, "", "Specify the export file format") cobra.CheckErr(fs.MarkHidden(FormatFN)) } - -// ValidateExportConfigFlags ensures all export config flags that utilize -// enumerated values match a well-known value. -func ValidateExportConfigFlags() error { - acceptedFormatTypes := []string{ - string(control.DefaultFormat), - string(control.JSONFormat), - } - - if !filters.Equal(acceptedFormatTypes).Compare(FormatFV) { - return clues.New("unrecognized format type: " + FormatFV) - } - - FormatFV = strings.ToLower(FormatFV) - - return nil -} diff --git a/src/cli/flags/export_test.go b/src/cli/flags/export_test.go deleted file mode 100644 index 020a97590..000000000 --- a/src/cli/flags/export_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package flags - -import ( - "testing" - - "github.com/alcionai/clues" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - - "github.com/alcionai/corso/src/internal/tester" -) - -type ExportUnitSuite struct { - tester.Suite -} - -func TestExportUnitSuite(t *testing.T) { - suite.Run(t, &ExportUnitSuite{Suite: tester.NewUnitSuite(t)}) -} - -func (suite *ExportUnitSuite) TestValidateExportConfigFlags() { - t := suite.T() - - FormatFV = "" - - err := ValidateExportConfigFlags() - assert.NoError(t, err, clues.ToCore(err)) - - FormatFV = "json" - - err = ValidateExportConfigFlags() - assert.NoError(t, err, clues.ToCore(err)) - - FormatFV = "JsoN" - - err = ValidateExportConfigFlags() - assert.NoError(t, err, clues.ToCore(err)) - - FormatFV = "fnerds" - - err = ValidateExportConfigFlags() - assert.Error(t, err, clues.ToCore(err)) -} diff --git a/src/cli/restore/restore.go b/src/cli/restore/restore.go index 3bd20b4fa..d2407889b 100644 --- a/src/cli/restore/restore.go +++ b/src/cli/restore/restore.go @@ -94,6 +94,10 @@ func runRestore( sel selectors.Selector, backupID, serviceName string, ) error { + if err := utils.ValidateRestoreConfigFlags(urco); err != nil { + return Only(ctx, err) + } + r, _, _, _, err := utils.GetAccountAndConnectWithOverrides( ctx, cmd, diff --git a/src/cli/utils/exchange.go b/src/cli/utils/exchange.go index 1c2ba9005..031ce15b6 100644 --- a/src/cli/utils/exchange.go +++ b/src/cli/utils/exchange.go @@ -139,7 +139,7 @@ func ValidateExchangeRestoreFlags(backupID string, opts ExchangeOpts) error { return clues.New("invalid format for event-recurs") } - return validateRestoreConfigFlags(flags.CollisionsFV, opts.RestoreCfg) + return nil } // IncludeExchangeRestoreDataSelectors builds the common data-selector diff --git a/src/cli/utils/export_config.go b/src/cli/utils/export_config.go index 92464ebdb..708ae1d7e 100644 --- a/src/cli/utils/export_config.go +++ b/src/cli/utils/export_config.go @@ -2,11 +2,14 @@ package utils import ( "context" + "strings" + "github.com/alcionai/clues" "github.com/spf13/cobra" "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/pkg/control" + "github.com/alcionai/corso/src/pkg/filters" ) type ExportCfgOpts struct { @@ -39,3 +42,23 @@ func MakeExportConfig( return exportCfg } + +// ValidateExportConfigFlags ensures all export config flags that utilize +// enumerated values match a well-known value. +func ValidateExportConfigFlags(opts *ExportCfgOpts) error { + acceptedFormatTypes := []string{ + string(control.DefaultFormat), + string(control.JSONFormat), + } + + if _, populated := opts.Populated[flags.FormatFN]; !populated { + opts.Format = string(control.DefaultFormat) + } else if !filters.Equal(acceptedFormatTypes).Compare(opts.Format) { + opts.Format = string(control.DefaultFormat) + return clues.New("unrecognized format type: " + opts.Format) + } + + opts.Format = strings.ToLower(opts.Format) + + return nil +} diff --git a/src/cli/utils/export_config_test.go b/src/cli/utils/export_config_test.go index d25d6629b..7da782072 100644 --- a/src/cli/utils/export_config_test.go +++ b/src/cli/utils/export_config_test.go @@ -3,6 +3,7 @@ package utils import ( "testing" + "github.com/alcionai/clues" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" @@ -52,3 +53,57 @@ func (suite *ExportCfgUnitSuite) TestMakeExportConfig() { }) } } + +func (suite *ExportCfgUnitSuite) TestValidateExportConfigFlags() { + table := []struct { + name string + input ExportCfgOpts + expectErr assert.ErrorAssertionFunc + expectFormat control.FormatType + }{ + { + name: "default", + input: ExportCfgOpts{ + Format: string(control.DefaultFormat), + Populated: flags.PopulatedFlags{flags.FormatFN: struct{}{}}, + }, + expectErr: assert.NoError, + expectFormat: control.DefaultFormat, + }, + { + name: "json", + input: ExportCfgOpts{ + Format: string(control.JSONFormat), + Populated: flags.PopulatedFlags{flags.FormatFN: struct{}{}}, + }, + expectErr: assert.NoError, + expectFormat: control.JSONFormat, + }, + { + name: "bad format", + input: ExportCfgOpts{ + Format: "smurfs", + Populated: flags.PopulatedFlags{flags.FormatFN: struct{}{}}, + }, + expectErr: assert.Error, + expectFormat: control.DefaultFormat, + }, + { + name: "bad format unpopulated", + input: ExportCfgOpts{ + Format: "smurfs", + }, + expectErr: assert.NoError, + expectFormat: control.DefaultFormat, + }, + } + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + err := ValidateExportConfigFlags(&test.input) + + test.expectErr(t, err, clues.ToCore(err)) + assert.Equal(t, test.expectFormat, control.FormatType(test.input.Format)) + }) + } +} diff --git a/src/cli/utils/groups.go b/src/cli/utils/groups.go index ac08b1f32..3c7d378ec 100644 --- a/src/cli/utils/groups.go +++ b/src/cli/utils/groups.go @@ -138,7 +138,7 @@ func ValidateGroupsRestoreFlags(backupID string, opts GroupsOpts) error { return clues.New("invalid time format for " + flags.MessageLastReplyBeforeFN) } - return validateRestoreConfigFlags(flags.CollisionsFV, opts.RestoreCfg) + return nil } // AddGroupsFilter adds the scope of the provided values to the selector's diff --git a/src/cli/utils/onedrive.go b/src/cli/utils/onedrive.go index dd0a8ad1e..20bcad4b7 100644 --- a/src/cli/utils/onedrive.go +++ b/src/cli/utils/onedrive.go @@ -67,7 +67,7 @@ func ValidateOneDriveRestoreFlags(backupID string, opts OneDriveOpts) error { return clues.New("invalid time format for " + flags.FileModifiedBeforeFN) } - return validateRestoreConfigFlags(flags.CollisionsFV, opts.RestoreCfg) + return nil } // AddOneDriveFilter adds the scope of the provided values to the selector's diff --git a/src/cli/utils/restore_config.go b/src/cli/utils/restore_config.go index 6be54f1ab..6ce92f128 100644 --- a/src/cli/utils/restore_config.go +++ b/src/cli/utils/restore_config.go @@ -2,6 +2,7 @@ package utils import ( "context" + "fmt" "github.com/alcionai/clues" "github.com/spf13/cobra" @@ -40,14 +41,13 @@ func makeRestoreCfgOpts(cmd *cobra.Command) RestoreCfgOpts { } } -// validateRestoreConfigFlags checks common restore flags for -// correctness and interdependencies. -func validateRestoreConfigFlags(fv string, opts RestoreCfgOpts) error { +// ValidateRestoreConfigFlags checks common restore flags for correctness and interdependencies. +func ValidateRestoreConfigFlags(opts RestoreCfgOpts) error { _, populated := opts.Populated[flags.CollisionsFN] - _, foundInValidSet := control.ValidCollisionPolicies()[control.CollisionPolicy(fv)] + isValid := control.IsValidCollisionPolicy(control.CollisionPolicy(opts.Collisions)) - if populated && !foundInValidSet { - return clues.New("invalid entry for " + flags.CollisionsFN) + if populated && !isValid { + return clues.New(fmt.Sprintf("invalid collision policy: %s", flags.CollisionsFN)) } return nil diff --git a/src/cli/utils/restore_config_test.go b/src/cli/utils/restore_config_test.go index c3509e360..eb96ae09b 100644 --- a/src/cli/utils/restore_config_test.go +++ b/src/cli/utils/restore_config_test.go @@ -23,13 +23,11 @@ func TestRestoreCfgUnitSuite(t *testing.T) { func (suite *RestoreCfgUnitSuite) TestValidateRestoreConfigFlags() { table := []struct { name string - fv string opts RestoreCfgOpts expect assert.ErrorAssertionFunc }{ { name: "no error", - fv: string(control.Skip), opts: RestoreCfgOpts{ Collisions: string(control.Skip), Populated: flags.PopulatedFlags{ @@ -40,7 +38,6 @@ func (suite *RestoreCfgUnitSuite) TestValidateRestoreConfigFlags() { }, { name: "bad but not populated", - fv: "foo", opts: RestoreCfgOpts{ Collisions: "foo", Populated: flags.PopulatedFlags{}, @@ -49,7 +46,6 @@ func (suite *RestoreCfgUnitSuite) TestValidateRestoreConfigFlags() { }, { name: "error", - fv: "foo", opts: RestoreCfgOpts{ Collisions: "foo", Populated: flags.PopulatedFlags{ @@ -61,7 +57,7 @@ func (suite *RestoreCfgUnitSuite) TestValidateRestoreConfigFlags() { } for _, test := range table { suite.Run(test.name, func() { - err := validateRestoreConfigFlags(test.fv, test.opts) + err := ValidateRestoreConfigFlags(test.opts) test.expect(suite.T(), err, clues.ToCore(err)) }) } diff --git a/src/cli/utils/sharepoint.go b/src/cli/utils/sharepoint.go index 2ab43d90c..d106187a9 100644 --- a/src/cli/utils/sharepoint.go +++ b/src/cli/utils/sharepoint.go @@ -97,7 +97,7 @@ func ValidateSharePointRestoreFlags(backupID string, opts SharePointOpts) error return clues.New("invalid time format for " + flags.FileModifiedBeforeFN) } - return validateRestoreConfigFlags(flags.CollisionsFV, opts.RestoreCfg) + return nil } // AddSharePointInfo adds the scope of the provided values to the selector's diff --git a/src/pkg/control/restore.go b/src/pkg/control/restore.go index 01491ad91..228ffe303 100644 --- a/src/pkg/control/restore.go +++ b/src/pkg/control/restore.go @@ -7,8 +7,6 @@ import ( "strings" "github.com/alcionai/clues" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" "github.com/alcionai/corso/src/internal/common/dttm" "github.com/alcionai/corso/src/pkg/logger" @@ -29,12 +27,13 @@ const ( Replace CollisionPolicy = "replace" ) -func ValidCollisionPolicies() map[CollisionPolicy]struct{} { - return map[CollisionPolicy]struct{}{ - Skip: {}, - Copy: {}, - Replace: {}, +func IsValidCollisionPolicy(cp CollisionPolicy) bool { + switch cp { + case Skip, Copy, Replace: + return true } + + return false } const RootLocation = "/" @@ -84,7 +83,7 @@ func EnsureRestoreConfigDefaults( ctx context.Context, rc RestoreConfig, ) RestoreConfig { - if !slices.Contains(maps.Keys(ValidCollisionPolicies()), rc.OnCollision) { + if !IsValidCollisionPolicy(rc.OnCollision) { logger.Ctx(ctx). With( "bad_collision_policy", rc.OnCollision,