diff --git a/src/cli/backup/backup.go b/src/cli/backup/backup.go index 93fcce74f..610c80646 100644 --- a/src/cli/backup/backup.go +++ b/src/cli/backup/backup.go @@ -12,7 +12,6 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" - "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/data" @@ -290,7 +289,10 @@ func genericDeleteCommand( ctx := clues.Add(cmd.Context(), "delete_backup_id", bID) - r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, repo.S3Overrides(cmd)) + r, _, _, _, err := utils.GetAccountAndConnectWithOverrides( + ctx, + cmd, + pst) if err != nil { return Only(ctx, err) } @@ -316,7 +318,10 @@ func genericListCommand( ) error { ctx := cmd.Context() - r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, repo.S3Overrides(cmd)) + r, _, _, _, err := utils.GetAccountAndConnectWithOverrides( + ctx, + cmd, + service) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/exchange.go b/src/cli/backup/exchange.go index 298569da6..c01c1e530 100644 --- a/src/cli/backup/exchange.go +++ b/src/cli/backup/exchange.go @@ -10,7 +10,6 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" - "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/data" "github.com/alcionai/corso/src/pkg/backup/details" @@ -168,7 +167,10 @@ func createExchangeCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.ExchangeService, repo.S3Overrides(cmd)) + r, acct, err := utils.AccountConnectAndWriteRepoConfig( + ctx, + cmd, + path.ExchangeService) if err != nil { return Only(ctx, err) } @@ -277,7 +279,10 @@ func detailsExchangeCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeExchangeOpts(cmd) - r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.ExchangeService, repo.S3Overrides(cmd)) + r, _, _, ctrlOpts, err := utils.GetAccountAndConnectWithOverrides( + ctx, + cmd, + path.ExchangeService) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/groups.go b/src/cli/backup/groups.go index cef2bbf49..59fee09cf 100644 --- a/src/cli/backup/groups.go +++ b/src/cli/backup/groups.go @@ -12,7 +12,6 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" - "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/data" @@ -154,7 +153,10 @@ func createGroupsCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.GroupsService, repo.S3Overrides(cmd)) + r, acct, err := utils.AccountConnectAndWriteRepoConfig( + ctx, + cmd, + path.GroupsService) if err != nil { return Only(ctx, err) } @@ -226,7 +228,10 @@ func detailsGroupsCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeGroupsOpts(cmd) - r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.GroupsService, repo.S3Overrides(cmd)) + r, _, _, ctrlOpts, err := utils.GetAccountAndConnectWithOverrides( + ctx, + cmd, + path.GroupsService) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/onedrive.go b/src/cli/backup/onedrive.go index 87a8a2236..c1fb87291 100644 --- a/src/cli/backup/onedrive.go +++ b/src/cli/backup/onedrive.go @@ -10,7 +10,6 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" - "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/data" "github.com/alcionai/corso/src/pkg/backup/details" @@ -149,7 +148,10 @@ func createOneDriveCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.OneDriveService, repo.S3Overrides(cmd)) + r, acct, err := utils.AccountConnectAndWriteRepoConfig( + ctx, + cmd, + path.OneDriveService) if err != nil { return Only(ctx, err) } @@ -235,7 +237,10 @@ func detailsOneDriveCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeOneDriveOpts(cmd) - r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.OneDriveService, repo.S3Overrides(cmd)) + r, _, _, ctrlOpts, err := utils.GetAccountAndConnectWithOverrides( + ctx, + cmd, + path.OneDriveService) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/sharepoint.go b/src/cli/backup/sharepoint.go index c80076512..8f79ed6be 100644 --- a/src/cli/backup/sharepoint.go +++ b/src/cli/backup/sharepoint.go @@ -11,7 +11,6 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" - "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/data" @@ -159,7 +158,10 @@ func createSharePointCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.SharePointService, repo.S3Overrides(cmd)) + r, acct, err := utils.AccountConnectAndWriteRepoConfig( + ctx, + cmd, + path.SharePointService) if err != nil { return Only(ctx, err) } @@ -319,7 +321,10 @@ func detailsSharePointCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeSharePointOpts(cmd) - r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.SharePointService, repo.S3Overrides(cmd)) + r, _, _, ctrlOpts, err := utils.GetAccountAndConnectWithOverrides( + ctx, + cmd, + path.SharePointService) if err != nil { return Only(ctx, err) } diff --git a/src/cli/cli.go b/src/cli/cli.go index 230433479..afab6a604 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -64,15 +64,24 @@ func preRun(cc *cobra.Command, args []string) error { avoidTheseDescription := []string{ "Initialize a repository.", "Initialize a S3 repository", + "Connect to a S3 repository", "Help about any command", "Free, Secure, Open-Source Backup for M365.", "env var guide", } if !slices.Contains(avoidTheseDescription, cc.Short) { - overrides := repo.S3Overrides(cc) + provider, overrides, err := utils.GetStorageProviderAndOverrides(ctx, cc) + if err != nil { + return err + } - cfg, err := config.GetConfigRepoDetails(ctx, true, false, overrides) + cfg, err := config.GetConfigRepoDetails( + ctx, + provider, + true, + false, + overrides) if err != nil { log.Error("Error while getting config info to run command: ", cc.Use) return err diff --git a/src/cli/config/account.go b/src/cli/config/account.go index 2f630e6a5..8d87880d9 100644 --- a/src/cli/config/account.go +++ b/src/cli/config/account.go @@ -16,17 +16,17 @@ import ( func m365ConfigsFromViper(vpr *viper.Viper) (account.M365Config, error) { var m365 account.M365Config - m365.AzureClientID = vpr.GetString(AzureClientID) - m365.AzureClientSecret = vpr.GetString(AzureSecret) - m365.AzureTenantID = vpr.GetString(AzureTenantIDKey) + m365.AzureClientID = vpr.GetString(account.AzureClientID) + m365.AzureClientSecret = vpr.GetString(account.AzureSecret) + m365.AzureTenantID = vpr.GetString(account.AzureTenantIDKey) return m365, nil } func m365Overrides(in map[string]string) map[string]string { return map[string]string{ - account.AzureTenantID: in[account.AzureTenantID], - AccountProviderTypeKey: in[AccountProviderTypeKey], + account.AzureTenantID: in[account.AzureTenantID], + account.AccountProviderTypeKey: in[account.AccountProviderTypeKey], } } @@ -52,7 +52,7 @@ func configureAccount( } if matchFromConfig { - providerType := vpr.GetString(AccountProviderTypeKey) + providerType := vpr.GetString(account.AccountProviderTypeKey) if providerType != account.ProviderM365.String() { return acct, clues.New("unsupported account provider: " + providerType) } diff --git a/src/cli/config/config.go b/src/cli/config/config.go index 6cb820bd6..41e26422c 100644 --- a/src/cli/config/config.go +++ b/src/cli/config/config.go @@ -21,7 +21,6 @@ import ( const ( // S3 config - StorageProviderTypeKey = "provider" BucketNameKey = "bucket" EndpointKey = "endpoint" PrefixKey = "prefix" @@ -33,12 +32,6 @@ const ( SecretAccessKey = "aws_secret_access_key" SessionToken = "aws_session_token" - // M365 config - AccountProviderTypeKey = "account_provider" - AzureTenantIDKey = "azure_tenantid" - AzureClientID = "azure_client_id" - AzureSecret = "azure_secret" - // Corso passphrase in config CorsoPassphrase = "passphrase" CorsoUser = "corso_user" @@ -228,7 +221,7 @@ func writeRepoConfigWithViper( s3Config = s3Config.Normalize() // Rudimentary support for persisting repo config // TODO: Handle conflicts, support other config types - vpr.Set(StorageProviderTypeKey, storage.ProviderS3.String()) + vpr.Set(storage.StorageProviderTypeKey, storage.ProviderS3.String()) vpr.Set(BucketNameKey, s3Config.Bucket) vpr.Set(EndpointKey, s3Config.Endpoint) vpr.Set(PrefixKey, s3Config.Prefix) @@ -245,8 +238,8 @@ func writeRepoConfigWithViper( vpr.Set(CorsoHost, repoOpts.Host) } - vpr.Set(AccountProviderTypeKey, account.ProviderM365.String()) - vpr.Set(AzureTenantIDKey, m365Config.AzureTenantID) + vpr.Set(account.AccountProviderTypeKey, account.ProviderM365.String()) + vpr.Set(account.AzureTenantIDKey, m365Config.AzureTenantID) if err := vpr.SafeWriteConfig(); err != nil { if _, ok := err.(viper.ConfigFileAlreadyExistsError); ok { @@ -263,6 +256,7 @@ func writeRepoConfigWithViper( // data sources (config file, env vars, flag overrides) and the config file. func GetConfigRepoDetails( ctx context.Context, + provider storage.ProviderType, readFromFile bool, mustMatchFromConfig bool, overrides map[string]string, @@ -270,7 +264,13 @@ func GetConfigRepoDetails( RepoDetails, error, ) { - config, err := getStorageAndAccountWithViper(GetViper(ctx), readFromFile, mustMatchFromConfig, overrides) + config, err := getStorageAndAccountWithViper( + GetViper(ctx), + provider, + readFromFile, + mustMatchFromConfig, + overrides) + return config, err } @@ -278,6 +278,7 @@ func GetConfigRepoDetails( // struct for testing. func getStorageAndAccountWithViper( vpr *viper.Viper, + provider storage.ProviderType, readFromFile bool, mustMatchFromConfig bool, overrides map[string]string, @@ -312,7 +313,7 @@ func getStorageAndAccountWithViper( return config, clues.Wrap(err, "retrieving account configuration details") } - config.Storage, err = configureStorage(vpr, readConfigFromViper, mustMatchFromConfig, overrides) + config.Storage, err = configureStorage(vpr, provider, readConfigFromViper, mustMatchFromConfig, overrides) if err != nil { return config, clues.Wrap(err, "retrieving storage provider details") } @@ -336,12 +337,12 @@ func getUserHost(vpr *viper.Viper, readConfigFromViper bool) (string, string) { // --------------------------------------------------------------------------- var constToTomlKeyMap = map[string]string{ - account.AzureTenantID: AzureTenantIDKey, - AccountProviderTypeKey: AccountProviderTypeKey, - storage.Bucket: BucketNameKey, - storage.Endpoint: EndpointKey, - storage.Prefix: PrefixKey, - StorageProviderTypeKey: StorageProviderTypeKey, + account.AzureTenantID: account.AzureTenantIDKey, + account.AccountProviderTypeKey: account.AccountProviderTypeKey, + storage.Bucket: BucketNameKey, + storage.Endpoint: EndpointKey, + storage.Prefix: PrefixKey, + storage.StorageProviderTypeKey: storage.StorageProviderTypeKey, } // mustMatchConfig compares the values of each key to their config file value in viper. diff --git a/src/cli/config/config_test.go b/src/cli/config/config_test.go index b697f0080..d9c5152f1 100644 --- a/src/cli/config/config_test.go +++ b/src/cli/config/config_test.go @@ -29,15 +29,15 @@ const ( ` + BucketNameKey + ` = '%s' ` + EndpointKey + ` = 's3.amazonaws.com' ` + PrefixKey + ` = 'test-prefix/' -` + StorageProviderTypeKey + ` = 'S3' -` + AccountProviderTypeKey + ` = 'M365' -` + AzureTenantIDKey + ` = '%s' +` + storage.StorageProviderTypeKey + ` = 'S3' +` + account.AccountProviderTypeKey + ` = 'M365' +` + account.AzureTenantIDKey + ` = '%s' ` + AccessKey + ` = '%s' ` + SecretAccessKey + ` = '%s' ` + SessionToken + ` = '%s' ` + CorsoPassphrase + ` = '%s' -` + AzureClientID + ` = '%s' -` + AzureSecret + ` = '%s' +` + account.AzureClientID + ` = '%s' +` + account.AzureSecret + ` = '%s' ` + DisableTLSKey + ` = '%s' ` + DisableTLSVerificationKey + ` = '%s' ` @@ -326,6 +326,7 @@ func (suite *ConfigSuite) TestReadFromFlags() { repoDetails, err := getStorageAndAccountWithViper( vpr, + storage.ProviderS3, true, false, overrides) @@ -400,7 +401,7 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() { err = vpr.ReadInConfig() require.NoError(t, err, "reading repo config", clues.ToCore(err)) - cfg, err := getStorageAndAccountWithViper(vpr, true, true, nil) + cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, true, true, nil) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) readS3Cfg, err := cfg.Storage.S3Config() @@ -438,17 +439,17 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount_noFileOnlyOverride m365 := account.M365Config{AzureTenantID: tid} overrides := map[string]string{ - account.AzureTenantID: tid, - AccountProviderTypeKey: account.ProviderM365.String(), - storage.Bucket: bkt, - storage.Endpoint: end, - storage.Prefix: pfx, - storage.DoNotUseTLS: "true", - storage.DoNotVerifyTLS: "true", - StorageProviderTypeKey: storage.ProviderS3.String(), + account.AzureTenantID: tid, + account.AccountProviderTypeKey: account.ProviderM365.String(), + storage.Bucket: bkt, + storage.Endpoint: end, + storage.Prefix: pfx, + storage.DoNotUseTLS: "true", + storage.DoNotVerifyTLS: "true", + storage.StorageProviderTypeKey: storage.ProviderS3.String(), } - cfg, err := getStorageAndAccountWithViper(vpr, false, true, overrides) + cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, false, true, overrides) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) readS3Cfg, err := cfg.Storage.S3Config() diff --git a/src/cli/config/storage.go b/src/cli/config/storage.go index 964c740fd..e63a723c9 100644 --- a/src/cli/config/storage.go +++ b/src/cli/config/storage.go @@ -1,6 +1,7 @@ package config import ( + "context" "os" "path/filepath" "strconv" @@ -40,12 +41,12 @@ func s3CredsFromViper(vpr *viper.Viper, s3Config storage.S3Config) (storage.S3Co func s3Overrides(in map[string]string) map[string]string { return map[string]string{ - storage.Bucket: in[storage.Bucket], - storage.Endpoint: in[storage.Endpoint], - storage.Prefix: in[storage.Prefix], - storage.DoNotUseTLS: in[storage.DoNotUseTLS], - storage.DoNotVerifyTLS: in[storage.DoNotVerifyTLS], - StorageProviderTypeKey: in[StorageProviderTypeKey], + storage.Bucket: in[storage.Bucket], + storage.Endpoint: in[storage.Endpoint], + storage.Prefix: in[storage.Prefix], + storage.DoNotUseTLS: in[storage.DoNotUseTLS], + storage.DoNotVerifyTLS: in[storage.DoNotVerifyTLS], + storage.StorageProviderTypeKey: in[storage.StorageProviderTypeKey], } } @@ -53,6 +54,7 @@ func s3Overrides(in map[string]string) map[string]string { // viper properties and manual overrides. func configureStorage( vpr *viper.Viper, + provider storage.ProviderType, readConfigFromViper bool, matchFromConfig bool, overrides map[string]string, @@ -77,7 +79,7 @@ func configureStorage( } if matchFromConfig { - providerType := vpr.GetString(StorageProviderTypeKey) + providerType := vpr.GetString(storage.StorageProviderTypeKey) if providerType != storage.ProviderS3.String() { return store, clues.New("unsupported storage provider: " + providerType) } @@ -92,7 +94,6 @@ func configureStorage( return store, clues.Wrap(err, "reading s3 configs from corso config file") } - s3Overrides(overrides) aws := credentials.GetAWS(overrides) if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 { @@ -152,7 +153,7 @@ func configureStorage( } // build the storage - store, err = storage.NewStorage(storage.ProviderS3, s3Cfg, cCfg) + store, err = storage.NewStorage(provider, s3Cfg, cCfg) if err != nil { return store, clues.Wrap(err, "configuring repository storage") } @@ -170,3 +171,22 @@ func GetAndInsertCorso(passphase string) credentials.Corso { CorsoPassphrase: corsoPassph, } } + +// GetStorageProviderFromConfigFile reads the storage provider from the config file. +// Storage provider can only be sourced from config file with the exception of +// commands that create or connect to a repo. +func GetStorageProviderFromConfigFile(ctx context.Context) (storage.ProviderType, error) { + vpr := GetViper(ctx) + + err := vpr.ReadInConfig() + if err != nil { + return storage.ProviderUnknown, clues.Wrap(err, "reading config file") + } + + provider := vpr.GetString(storage.StorageProviderTypeKey) + if provider != storage.ProviderS3.String() { + return storage.ProviderUnknown, clues.New("unsupported storage provider: " + provider) + } + + return storage.StringToProviderType[provider], nil +} diff --git a/src/cli/export/export.go b/src/cli/export/export.go index 89a2111fc..9689ff9f7 100644 --- a/src/cli/export/export.go +++ b/src/cli/export/export.go @@ -9,7 +9,6 @@ import ( "github.com/spf13/cobra" . "github.com/alcionai/corso/src/cli/print" - "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/dttm" "github.com/alcionai/corso/src/internal/data" @@ -70,7 +69,10 @@ func runExport( sel selectors.Selector, backupID, serviceName string, ) error { - r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), repo.S3Overrides(cmd)) + r, _, _, _, err := utils.GetAccountAndConnectWithOverrides( + ctx, + cmd, + sel.PathService()) if err != nil { return Only(ctx, err) } diff --git a/src/cli/flags/s3.go b/src/cli/flags/s3.go index add56dc6c..a7e4e490e 100644 --- a/src/cli/flags/s3.go +++ b/src/cli/flags/s3.go @@ -1,6 +1,14 @@ package flags -import "github.com/spf13/cobra" +import ( + "strconv" + + "github.com/spf13/cobra" + + "github.com/alcionai/corso/src/pkg/account" + "github.com/alcionai/corso/src/pkg/credentials" + "github.com/alcionai/corso/src/pkg/storage" +) // S3 bucket flags const ( @@ -39,3 +47,49 @@ func AddS3BucketFlags(cmd *cobra.Command) { fs.BoolVar(&SucceedIfExistsFV, SucceedIfExistsFN, false, "Exit with success if the repo has already been initialized.") cobra.CheckErr(fs.MarkHidden("succeed-if-exists")) } + +func S3FlagOverrides(cmd *cobra.Command) map[string]string { + fs := GetPopulatedFlags(cmd) + return PopulateS3Flags(fs) +} + +func PopulateS3Flags(flagset PopulatedFlags) map[string]string { + s3Overrides := make(map[string]string) + // TODO(pandeyabs): Move account overrides out of s3 flags + s3Overrides[account.AccountProviderTypeKey] = account.ProviderM365.String() + s3Overrides[storage.StorageProviderTypeKey] = storage.ProviderS3.String() + + if _, ok := flagset[AWSAccessKeyFN]; ok { + s3Overrides[credentials.AWSAccessKeyID] = AWSAccessKeyFV + } + + if _, ok := flagset[AWSSecretAccessKeyFN]; ok { + s3Overrides[credentials.AWSSecretAccessKey] = AWSSecretAccessKeyFV + } + + if _, ok := flagset[AWSSessionTokenFN]; ok { + s3Overrides[credentials.AWSSessionToken] = AWSSessionTokenFV + } + + if _, ok := flagset[BucketFN]; ok { + s3Overrides[storage.Bucket] = BucketFV + } + + if _, ok := flagset[PrefixFN]; ok { + s3Overrides[storage.Prefix] = PrefixFV + } + + if _, ok := flagset[DoNotUseTLSFN]; ok { + s3Overrides[storage.DoNotUseTLS] = strconv.FormatBool(DoNotUseTLSFV) + } + + if _, ok := flagset[DoNotVerifyTLSFN]; ok { + s3Overrides[storage.DoNotVerifyTLS] = strconv.FormatBool(DoNotVerifyTLSFV) + } + + if _, ok := flagset[EndpointFN]; ok { + s3Overrides[storage.Endpoint] = EndpointFV + } + + return s3Overrides +} diff --git a/src/cli/repo/repo.go b/src/cli/repo/repo.go index dddafa406..34d538670 100644 --- a/src/cli/repo/repo.go +++ b/src/cli/repo/repo.go @@ -123,10 +123,10 @@ func handleMaintenanceCmd(cmd *cobra.Command, args []string) error { r, _, err := utils.AccountConnectAndWriteRepoConfig( ctx, + cmd, // Need to give it a valid service so it won't error out on us even though // we don't need the graph client. - path.OneDriveService, - S3Overrides(cmd)) + path.OneDriveService) if err != nil { return print.Only(ctx, err) } diff --git a/src/cli/repo/s3.go b/src/cli/repo/s3.go index af18cb65e..cbe83c951 100644 --- a/src/cli/repo/s3.go +++ b/src/cli/repo/s3.go @@ -1,7 +1,6 @@ package repo import ( - "strconv" "strings" "github.com/alcionai/clues" @@ -13,8 +12,6 @@ import ( . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/events" - "github.com/alcionai/corso/src/pkg/account" - "github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/repository" "github.com/alcionai/corso/src/pkg/storage" ) @@ -89,10 +86,12 @@ func s3InitCmd() *cobra.Command { func initS3Cmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - // s3 values from flags - s3Override := S3Overrides(cmd) - - cfg, err := config.GetConfigRepoDetails(ctx, true, false, s3Override) + cfg, err := config.GetConfigRepoDetails( + ctx, + storage.ProviderS3, + true, + false, + flags.S3FlagOverrides(cmd)) if err != nil { return Only(ctx, err) } @@ -175,10 +174,12 @@ func s3ConnectCmd() *cobra.Command { func connectS3Cmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - // s3 values from flags - s3Override := S3Overrides(cmd) - - cfg, err := config.GetConfigRepoDetails(ctx, true, true, s3Override) + cfg, err := config.GetConfigRepoDetails( + ctx, + storage.ProviderS3, + true, + true, + flags.S3FlagOverrides(cmd)) if err != nil { return Only(ctx, err) } @@ -227,48 +228,3 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { return nil } - -func S3Overrides(cmd *cobra.Command) map[string]string { - fs := flags.GetPopulatedFlags(cmd) - return PopulateS3Flags(fs) -} - -func PopulateS3Flags(flagset flags.PopulatedFlags) map[string]string { - s3Overrides := make(map[string]string) - s3Overrides[config.AccountProviderTypeKey] = account.ProviderM365.String() - s3Overrides[config.StorageProviderTypeKey] = storage.ProviderS3.String() - - if _, ok := flagset[flags.AWSAccessKeyFN]; ok { - s3Overrides[credentials.AWSAccessKeyID] = flags.AWSAccessKeyFV - } - - if _, ok := flagset[flags.AWSSecretAccessKeyFN]; ok { - s3Overrides[credentials.AWSSecretAccessKey] = flags.AWSSecretAccessKeyFV - } - - if _, ok := flagset[flags.AWSSessionTokenFN]; ok { - s3Overrides[credentials.AWSSessionToken] = flags.AWSSessionTokenFV - } - - if _, ok := flagset[flags.BucketFN]; ok { - s3Overrides[storage.Bucket] = flags.BucketFV - } - - if _, ok := flagset[flags.PrefixFN]; ok { - s3Overrides[storage.Prefix] = flags.PrefixFV - } - - if _, ok := flagset[flags.DoNotUseTLSFN]; ok { - s3Overrides[storage.DoNotUseTLS] = strconv.FormatBool(flags.DoNotUseTLSFV) - } - - if _, ok := flagset[flags.DoNotVerifyTLSFN]; ok { - s3Overrides[storage.DoNotVerifyTLS] = strconv.FormatBool(flags.DoNotVerifyTLSFV) - } - - if _, ok := flagset[flags.EndpointFN]; ok { - s3Overrides[storage.Endpoint] = flags.EndpointFV - } - - return s3Overrides -} diff --git a/src/cli/restore/restore.go b/src/cli/restore/restore.go index 3f62ab0ae..940a1f084 100644 --- a/src/cli/restore/restore.go +++ b/src/cli/restore/restore.go @@ -10,7 +10,6 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" - "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/data" "github.com/alcionai/corso/src/pkg/count" @@ -103,7 +102,10 @@ func runRestore( sel selectors.Selector, backupID, serviceName string, ) error { - r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), repo.S3Overrides(cmd)) + r, _, _, _, err := utils.GetAccountAndConnectWithOverrides( + ctx, + cmd, + sel.PathService()) if err != nil { return Only(ctx, err) } diff --git a/src/cli/utils/utils.go b/src/cli/utils/utils.go index 5a639474a..d1bc9cb0c 100644 --- a/src/cli/utils/utils.go +++ b/src/cli/utils/utils.go @@ -9,6 +9,7 @@ import ( "github.com/spf13/pflag" "github.com/alcionai/corso/src/cli/config" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/internal/events" "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/control" @@ -21,12 +22,34 @@ import ( var ErrNotYetImplemented = clues.New("not yet implemented") +// GetAccountAndConnectWithOverrides is a wrapper for GetAccountAndConnect +// that also gets the storage provider and any storage provider specific +// flag overrides from the command line. +func GetAccountAndConnectWithOverrides( + ctx context.Context, + cmd *cobra.Command, + pst path.ServiceType, +) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) { + provider, overrides, err := GetStorageProviderAndOverrides(ctx, cmd) + if err != nil { + return nil, nil, nil, nil, err + } + + return GetAccountAndConnect(ctx, pst, provider, overrides) +} + func GetAccountAndConnect( ctx context.Context, pst path.ServiceType, + provider storage.ProviderType, overrides map[string]string, ) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) { - cfg, err := config.GetConfigRepoDetails(ctx, true, true, overrides) + cfg, err := config.GetConfigRepoDetails( + ctx, + provider, + true, + true, + overrides) if err != nil { return nil, nil, nil, nil, err } @@ -54,10 +77,13 @@ func GetAccountAndConnect( func AccountConnectAndWriteRepoConfig( ctx context.Context, + cmd *cobra.Command, pst path.ServiceType, - overrides map[string]string, ) (repository.Repository, *account.Account, error) { - r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, overrides) + r, stg, acc, opts, err := GetAccountAndConnectWithOverrides( + ctx, + cmd, + pst) if err != nil { logger.CtxErr(ctx, err).Info("getting and connecting account") return nil, nil, err @@ -203,3 +229,24 @@ func SendStartCorsoEvent( bus.SetRepoID(repoID) bus.Event(ctx, events.CorsoStart, data) } + +// GetStorageProviderAndOverrides returns the storage provider type and +// any flags specified on the command line which are storage provider specific. +func GetStorageProviderAndOverrides( + ctx context.Context, + cmd *cobra.Command, +) (storage.ProviderType, map[string]string, error) { + provider, err := config.GetStorageProviderFromConfigFile(ctx) + if err != nil { + return provider, nil, clues.Stack(err) + } + + overrides := map[string]string{} + + switch provider { + case storage.ProviderS3: + overrides = flags.S3FlagOverrides(cmd) + } + + return provider, overrides, nil +} diff --git a/src/cmd/longevity_test/longevity.go b/src/cmd/longevity_test/longevity.go index b3d6f865d..5072f9683 100644 --- a/src/cmd/longevity_test/longevity.go +++ b/src/cmd/longevity_test/longevity.go @@ -16,12 +16,14 @@ import ( "github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/repository" + "github.com/alcionai/corso/src/pkg/storage" "github.com/alcionai/corso/src/pkg/store" ) // deleteBackups connects to the repository and deletes all backups for // service that are at least deletionDays old. Returns the IDs of all backups // that were deleted. +// Only supported for S3 repos currently. func deleteBackups( ctx context.Context, service path.ServiceType, @@ -29,7 +31,11 @@ func deleteBackups( ) ([]string, error) { ctx = clues.Add(ctx, "cutoff_days", deletionDays) - r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, nil) + r, _, _, _, err := utils.GetAccountAndConnect( + ctx, + service, + storage.ProviderS3, + nil) if err != nil { return nil, clues.Wrap(err, "connecting to account").WithClues(ctx) } @@ -67,6 +73,7 @@ func deleteBackups( // pitrListBackups connects to the repository at the given point in time and // lists the backups for service. It then checks the list of backups contains // the backups in backupIDs. +// Only supported for S3 repos currently. func pitrListBackups( ctx context.Context, service path.ServiceType, @@ -87,7 +94,12 @@ func pitrListBackups( // TODO(ashmrtn): This may be moved into CLI layer at some point when we add // flags for opening a repo at a point in time. - cfg, err := config.GetConfigRepoDetails(ctx, true, true, nil) + cfg, err := config.GetConfigRepoDetails( + ctx, + storage.ProviderS3, + true, + true, + nil) if err != nil { return clues.Wrap(err, "getting config info") } diff --git a/src/cmd/s3checker/s3checker.go b/src/cmd/s3checker/s3checker.go index 6413f7d83..086cb51a4 100644 --- a/src/cmd/s3checker/s3checker.go +++ b/src/cmd/s3checker/s3checker.go @@ -187,7 +187,12 @@ func handleCheckerCommand(cmd *cobra.Command, args []string, f flags) error { storage.Prefix: f.bucketPrefix, } - repoDetails, err := config.GetConfigRepoDetails(ctx, false, false, overrides) + repoDetails, err := config.GetConfigRepoDetails( + ctx, + storage.ProviderS3, + false, + false, + overrides) if err != nil { return clues.Wrap(err, "getting storage config") } diff --git a/src/internal/operations/backup_test.go b/src/internal/operations/backup_test.go index a191f9235..a06e3c3bf 100644 --- a/src/internal/operations/backup_test.go +++ b/src/internal/operations/backup_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/alcionai/corso/src/cli/config" "github.com/alcionai/corso/src/internal/common/prefixmatcher" "github.com/alcionai/corso/src/internal/data" dataMock "github.com/alcionai/corso/src/internal/data/mock" @@ -1641,7 +1640,7 @@ func makeMockItem( func (suite *AssistBackupIntegrationSuite) TestBackupTypesForFailureModes() { var ( acct = tconfig.NewM365Account(suite.T()) - tenantID = acct.Config[config.AzureTenantIDKey] + tenantID = acct.Config[account.AzureTenantIDKey] opts = control.DefaultOptions() osel = selectors.NewOneDriveBackup([]string{userID}) ) @@ -1905,7 +1904,7 @@ func selectFilesFromDeets(d details.Details) map[string]details.Entry { func (suite *AssistBackupIntegrationSuite) TestExtensionsIncrementals() { var ( acct = tconfig.NewM365Account(suite.T()) - tenantID = acct.Config[config.AzureTenantIDKey] + tenantID = acct.Config[account.AzureTenantIDKey] opts = control.DefaultOptions() osel = selectors.NewOneDriveBackup([]string{userID}) // Default policy used by SDK clients diff --git a/src/pkg/account/account.go b/src/pkg/account/account.go index 4c1591818..ea1eb3070 100644 --- a/src/pkg/account/account.go +++ b/src/pkg/account/account.go @@ -19,6 +19,14 @@ var ( errMissingRequired = clues.New("missing required storage configuration") ) +const ( + // M365 config + AccountProviderTypeKey = "account_provider" + AzureTenantIDKey = "azure_tenantid" + AzureClientID = "azure_client_id" + AzureSecret = "azure_secret" +) + // Account defines an account provider, along with any credentials // and identifiers required to set up or communicate with that provider. type Account struct { diff --git a/src/pkg/storage/storage.go b/src/pkg/storage/storage.go index d1a1067a6..3bf9d82f7 100644 --- a/src/pkg/storage/storage.go +++ b/src/pkg/storage/storage.go @@ -17,6 +17,16 @@ const ( ProviderFilesystem ProviderType = 2 // Filesystem ) +var StringToProviderType = map[string]ProviderType{ + ProviderUnknown.String(): ProviderUnknown, + ProviderS3.String(): ProviderS3, + ProviderFilesystem.String(): ProviderFilesystem, +} + +const ( + StorageProviderTypeKey = "provider" +) + // storage parsing errors var ( errMissingRequired = clues.New("missing required storage configuration")