From 8c661164ad4016642d59174a613d84fb72e63091 Mon Sep 17 00:00:00 2001 From: neha_gupta Date: Thu, 29 Jun 2023 11:00:15 +0530 Subject: [PATCH] add flags for azure and aws (#3590) Flags for all configs- Azure cred flags- (azure-tenant-id, azure-client-id, azure-client-secret) present in - - Backup (create, delete, details and list) and restore of Exchange, Onedrive and Sharepoint command - S3 repo init and connect command AWS cred flags - (aws-access-key, aws-secret-access-key, aws-session-token) present in- - Backup (create, delete, details and list) and restore of Exchange, Onedrive and Sharepoint command - S3 repo init and connect command Passphrase flag- (--passphrase) present in- - Backup (create, delete, details and list) and restore of Exchange, Onedrive and Sharepoint command - S3 repo init and connect command S3 flags- --endpoint, --prefix, --bucket, --disable-tls, --disable-tls-verification - flags is for repo init and connect commands all the S3 env var will also work only in case of repo init and connect command. For all other commands user first connects to repo. Which will store the config values in config file. And then user can use that config file for other commands. No cred configs are save in the config file by Corso. Config file values added- Azure cred - - azure_client_id - azure_secret - azure_tenantid AWS cred - - aws_access_key_id - aws_secret_access_key - aws_session_token Passphrase - - passphrase **NOTE:** - in case of AWS creds all the three values should be provided from same method. Either put all values in env, config file and so on. - all the S3 env var will also work only in case of repo init and connect command. For all other commands user first connects to repo. Which will store the config values in config file. And then user can use that config file for other commands. --- #### Does this PR need a docs update or release note? - [ ] :white_check_mark: Yes, it's included - [x] :clock1: Yes, but in a later PR - [ ] :no_entry: No #### Type of change - [x] :sunflower: Feature #### Issue(s) * https://github.com/alcionai/corso/issues/3522 #### Test Plan - [x] :muscle: Manual - [x] :zap: Unit test - [ ] :green_heart: E2E --- src/cli/backup/backup.go | 5 +- src/cli/backup/exchange.go | 17 ++- src/cli/backup/exchange_e2e_test.go | 178 ++++++++++++++++++++++++++++ src/cli/backup/onedrive.go | 18 ++- src/cli/backup/sharepoint.go | 17 ++- src/cli/cli.go | 10 +- src/cli/config/account.go | 32 ++++- src/cli/config/config.go | 15 ++- src/cli/config/config_test.go | 27 ++++- src/cli/config/storage.go | 61 ++++++++-- src/cli/flags/m365_resource.go | 22 +++- src/cli/flags/repo.go | 34 +++++- src/cli/repo/repo.go | 2 +- src/cli/repo/s3.go | 58 ++++++--- src/cli/restore/exchange.go | 6 +- src/cli/restore/exchange_test.go | 19 +++ src/cli/restore/onedrive.go | 6 +- src/cli/restore/onedrive_test.go | 19 +++ src/cli/restore/sharepoint.go | 7 +- src/cli/restore/sharepoint_test.go | 19 +++ src/cli/utils/flags_test.go | 93 +++++++++++++++ src/cli/utils/testdata/flags.go | 10 ++ src/cli/utils/utils.go | 14 ++- src/go.sum | 7 ++ src/internal/kopia/s3.go | 3 + src/internal/tester/storage.go | 16 ++- src/pkg/credentials/aws.go | 18 +-- src/pkg/credentials/corso.go | 13 -- src/pkg/credentials/m365.go | 11 +- src/pkg/storage/s3.go | 12 ++ src/pkg/storage/s3_test.go | 24 +++- 31 files changed, 689 insertions(+), 104 deletions(-) create mode 100644 src/cli/utils/flags_test.go diff --git a/src/cli/backup/backup.go b/src/cli/backup/backup.go index f43cd6474..2901842f5 100644 --- a/src/cli/backup/backup.go +++ b/src/cli/backup/backup.go @@ -11,6 +11,7 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" + "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/data" @@ -270,7 +271,7 @@ func genericDeleteCommand(cmd *cobra.Command, bID, designation string, args []st ctx := clues.Add(cmd.Context(), "delete_backup_id", bID) - r, _, _, err := utils.GetAccountAndConnect(ctx) + r, _, _, err := utils.GetAccountAndConnect(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } @@ -291,7 +292,7 @@ func genericDeleteCommand(cmd *cobra.Command, bID, designation string, args []st func genericListCommand(cmd *cobra.Command, bID string, service path.ServiceType, args []string) error { ctx := cmd.Context() - r, _, _, err := utils.GetAccountAndConnect(ctx) + r, _, _, err := utils.GetAccountAndConnect(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/exchange.go b/src/cli/backup/exchange.go index 06a231a3d..a99f75b2f 100644 --- a/src/cli/backup/exchange.go +++ b/src/cli/backup/exchange.go @@ -10,6 +10,7 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" + "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/data" "github.com/alcionai/corso/src/pkg/backup/details" @@ -84,6 +85,9 @@ func addExchangeCommands(cmd *cobra.Command) *cobra.Command { // More generic (ex: --user) and more frequently used flags take precedence. flags.AddMailBoxFlag(c) flags.AddDataFlag(c, []string{dataEmail, dataContacts, dataEvents}, false) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) flags.AddFetchParallelismFlag(c) flags.AddFailFastFlag(c) flags.AddDisableIncrementalsFlag(c) @@ -96,6 +100,9 @@ func addExchangeCommands(cmd *cobra.Command) *cobra.Command { fs.SortFlags = false flags.AddBackupIDFlag(c, false) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) addFailedItemsFN(c) addSkippedItemsFN(c) addRecoveredErrorsFN(c) @@ -112,6 +119,9 @@ func addExchangeCommands(cmd *cobra.Command) *cobra.Command { // Flags addition ordering should follow the order we want them to appear in help and docs: // More generic (ex: --user) and more frequently used flags take precedence. flags.AddBackupIDFlag(c, true) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) flags.AddExchangeDetailsAndRestoreFlags(c) case deleteCommand: @@ -122,6 +132,9 @@ func addExchangeCommands(cmd *cobra.Command) *cobra.Command { c.Example = exchangeServiceCommandDeleteExamples flags.AddBackupIDFlag(c, true) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) } return c @@ -153,7 +166,7 @@ func createExchangeCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx) + r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } @@ -262,7 +275,7 @@ func detailsExchangeCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeExchangeOpts(cmd) - r, _, _, err := utils.GetAccountAndConnect(ctx) + r, _, _, err := utils.GetAccountAndConnect(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/exchange_e2e_test.go b/src/cli/backup/exchange_e2e_test.go index 517f42e88..349a5eba0 100644 --- a/src/cli/backup/exchange_e2e_test.go +++ b/src/cli/backup/exchange_e2e_test.go @@ -35,6 +35,184 @@ var ( events = path.EventsCategory ) +// --------------------------------------------------------------------------- +// tests with azure flags in exchange create +// --------------------------------------------------------------------------- + +type ExchangeCMDWithFlagsE2ESuite struct { + tester.Suite + acct account.Account + st storage.Storage + vpr *viper.Viper + cfgFP string + repo repository.Repository + m365UserID string + recorder strings.Builder +} + +func TestExchangeCMDWithFlagsE2ESuite(t *testing.T) { + suite.Run(t, &ExchangeCMDWithFlagsE2ESuite{Suite: tester.NewE2ESuite( + t, + [][]string{tester.AWSStorageCredEnvs, tester.M365AcctCredEnvs}, + )}) +} + +func (suite *ExchangeCMDWithFlagsE2ESuite) SetupSuite() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + acct, st, repo, vpr, recorder, cfgFilePath := prepM365Test(t, ctx) + + suite.acct = acct + suite.st = st + suite.repo = repo + suite.vpr = vpr + suite.recorder = recorder + suite.cfgFP = cfgFilePath + suite.m365UserID = tester.M365UserID(t) +} + +func (suite *ExchangeCMDWithFlagsE2ESuite) TestBackupCreateExchange_badAzureClientID() { + t := suite.T() + ctx, flush := tester.NewContext(t) + + defer flush() + + suite.recorder.Reset() + + cmd := tester.StubRootCmd( + "backup", "create", "exchange", + "--user", suite.m365UserID, + "--azure-client-id", "invalid-value", + ) + cli.BuildCommandTree(cmd) + + cmd.SetErr(&suite.recorder) + + ctx = print.SetRootCmd(ctx, cmd) + + // run the command + err := cmd.ExecuteContext(ctx) + require.Error(t, err, clues.ToCore(err)) +} + +func (suite *ExchangeCMDWithFlagsE2ESuite) TestBackupCreateExchange_azureIDFromConfigFile() { + t := suite.T() + ctx, flush := tester.NewContext(t) + ctx = config.SetViper(ctx, suite.vpr) + + defer flush() + + suite.recorder.Reset() + + cmd := tester.StubRootCmd( + "backup", "create", "exchange", + "--user", suite.m365UserID, + "--config-file", suite.cfgFP) + cli.BuildCommandTree(cmd) + + cmd.SetErr(&suite.recorder) + + ctx = print.SetRootCmd(ctx, cmd) + + // run the command + err := cmd.ExecuteContext(ctx) + require.NoError(t, err, clues.ToCore(err)) + + result := suite.recorder.String() + t.Log("backup results", result) + + // as an offhand check: the result should contain the m365 user id + assert.Contains(t, result, suite.m365UserID) +} + +func (suite *ExchangeCMDWithFlagsE2ESuite) TestExchangeBackupValueFromEnvCmd_empty() { + t := suite.T() + ctx, flush := tester.NewContext(t) + ctx = config.SetViper(ctx, suite.vpr) + + defer flush() + + suite.recorder.Reset() + + cmd := tester.StubRootCmd( + "backup", "create", "exchange", + "--user", suite.m365UserID) + cli.BuildCommandTree(cmd) + + cmd.SetErr(&suite.recorder) + + ctx = print.SetRootCmd(ctx, cmd) + + // run the command + err := cmd.ExecuteContext(ctx) + require.NoError(t, err, clues.ToCore(err)) + + result := suite.recorder.String() + t.Log("backup results", result) + + // as an offhand check: the result should contain the m365 user id + assert.Contains(t, result, suite.m365UserID) +} + +// AWS flags +func (suite *ExchangeCMDWithFlagsE2ESuite) TestExchangeBackupInvalidAWSClientIDCmd_empty() { + t := suite.T() + ctx, flush := tester.NewContext(t) + + defer flush() + + suite.recorder.Reset() + + cmd := tester.StubRootCmd( + "backup", "create", "exchange", + "--user", suite.m365UserID, + "--aws-access-key", "invalid-value", + "--aws-secret-access-key", "some-invalid-value", + ) + cli.BuildCommandTree(cmd) + + cmd.SetErr(&suite.recorder) + + ctx = print.SetRootCmd(ctx, cmd) + + // run the command + err := cmd.ExecuteContext(ctx) + // since invalid aws creds are explicitly set, should see a failure + require.Error(t, err, clues.ToCore(err)) +} + +func (suite *ExchangeCMDWithFlagsE2ESuite) TestExchangeBackupAWSValueFromEnvCmd_empty() { + t := suite.T() + ctx, flush := tester.NewContext(t) + ctx = config.SetViper(ctx, suite.vpr) + + defer flush() + + suite.recorder.Reset() + + cmd := tester.StubRootCmd( + "backup", "create", "exchange", + "--user", suite.m365UserID) + cli.BuildCommandTree(cmd) + + cmd.SetErr(&suite.recorder) + + ctx = print.SetRootCmd(ctx, cmd) + + // run the command + err := cmd.ExecuteContext(ctx) + require.NoError(t, err, clues.ToCore(err)) + + result := suite.recorder.String() + t.Log("backup results", result) + + // as an offhand check: the result should contain the m365 user id + assert.Contains(t, result, suite.m365UserID) +} + // --------------------------------------------------------------------------- // tests with no backups // --------------------------------------------------------------------------- diff --git a/src/cli/backup/onedrive.go b/src/cli/backup/onedrive.go index 11efd93fe..842f4491d 100644 --- a/src/cli/backup/onedrive.go +++ b/src/cli/backup/onedrive.go @@ -10,6 +10,7 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" + "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/data" "github.com/alcionai/corso/src/pkg/backup/details" @@ -71,6 +72,10 @@ func addOneDriveCommands(cmd *cobra.Command) *cobra.Command { c.Example = oneDriveServiceCommandCreateExamples flags.AddUserFlag(c) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) + flags.AddFailFastFlag(c) flags.AddDisableIncrementalsFlag(c) @@ -79,6 +84,9 @@ func addOneDriveCommands(cmd *cobra.Command) *cobra.Command { fs.SortFlags = false flags.AddBackupIDFlag(c, false) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) addFailedItemsFN(c) addSkippedItemsFN(c) addRecoveredErrorsFN(c) @@ -92,6 +100,9 @@ func addOneDriveCommands(cmd *cobra.Command) *cobra.Command { flags.AddSkipReduceFlag(c) flags.AddBackupIDFlag(c, true) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) flags.AddOneDriveDetailsAndRestoreFlags(c) case deleteCommand: @@ -102,6 +113,9 @@ func addOneDriveCommands(cmd *cobra.Command) *cobra.Command { c.Example = oneDriveServiceCommandDeleteExamples flags.AddBackupIDFlag(c, true) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) } return c @@ -134,7 +148,7 @@ func createOneDriveCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx) + r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } @@ -220,7 +234,7 @@ func detailsOneDriveCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeOneDriveOpts(cmd) - r, _, _, err := utils.GetAccountAndConnect(ctx) + r, _, _, err := utils.GetAccountAndConnect(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/sharepoint.go b/src/cli/backup/sharepoint.go index 2d730e51c..eb6893d8c 100644 --- a/src/cli/backup/sharepoint.go +++ b/src/cli/backup/sharepoint.go @@ -11,6 +11,7 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" + "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/data" @@ -86,6 +87,9 @@ func addSharePointCommands(cmd *cobra.Command) *cobra.Command { flags.AddSiteFlag(c) flags.AddSiteIDFlag(c) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) flags.AddDataFlag(c, []string{dataLibraries}, true) flags.AddFailFastFlag(c) flags.AddDisableIncrementalsFlag(c) @@ -95,6 +99,9 @@ func addSharePointCommands(cmd *cobra.Command) *cobra.Command { fs.SortFlags = false flags.AddBackupIDFlag(c, false) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) addFailedItemsFN(c) addSkippedItemsFN(c) addRecoveredErrorsFN(c) @@ -108,6 +115,9 @@ func addSharePointCommands(cmd *cobra.Command) *cobra.Command { flags.AddSkipReduceFlag(c) flags.AddBackupIDFlag(c, true) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) flags.AddSharePointDetailsAndRestoreFlags(c) case deleteCommand: @@ -118,6 +128,9 @@ func addSharePointCommands(cmd *cobra.Command) *cobra.Command { c.Example = sharePointServiceCommandDeleteExamples flags.AddBackupIDFlag(c, true) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) } return c @@ -150,7 +163,7 @@ func createSharePointCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx) + r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } @@ -312,7 +325,7 @@ func detailsSharePointCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeSharePointOpts(cmd) - r, _, _, err := utils.GetAccountAndConnect(ctx) + r, _, _, err := utils.GetAccountAndConnect(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } diff --git a/src/cli/cli.go b/src/cli/cli.go index e69d89eb5..c482ebc22 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -65,17 +65,13 @@ func preRun(cc *cobra.Command, args []string) error { "Initialize a S3 repository", "Help about any command", "Free, Secure, Open-Source Backup for M365.", + "env var guide", } if !slices.Contains(avoidTheseDescription, cc.Short) { - overrides := map[string]string{} - if cc.Short == "Connect to a S3 repository" { - // Get s3 overrides for connect. Ideally we also need this - // for init, but we don't reach this block for init. - overrides = repo.S3Overrides() - } + overrides := repo.S3Overrides() - cfg, err := config.GetConfigRepoDetails(ctx, true, overrides) + cfg, err := config.GetConfigRepoDetails(ctx, true, false, overrides) if err != nil { log.Error("Error while getting config info to run command: ", cc.Use) return err diff --git a/src/cli/config/account.go b/src/cli/config/account.go index 45fd50058..7dd380551 100644 --- a/src/cli/config/account.go +++ b/src/cli/config/account.go @@ -6,6 +6,7 @@ import ( "github.com/alcionai/clues" "github.com/spf13/viper" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/credentials" @@ -20,6 +21,8 @@ func m365ConfigsFromViper(vpr *viper.Viper) (account.M365Config, error) { return m365, clues.New("unsupported account provider: " + providerType) } + m365.AzureClientID = vpr.GetString(AzureClientID) + m365.AzureClientSecret = vpr.GetString(AzureSecret) m365.AzureTenantID = vpr.GetString(AzureTenantIDKey) return m365, nil @@ -41,6 +44,7 @@ func configureAccount( ) (account.Account, error) { var ( m365Cfg account.M365Config + m365 credentials.M365 acct account.Account err error ) @@ -57,7 +61,7 @@ func configureAccount( } // compose the m365 config and credentials - m365 := credentials.GetM365() + m365 = GetM365(m365Cfg) if err := m365.Validate(); err != nil { return acct, clues.Wrap(err, "validating m365 credentials") } @@ -66,14 +70,15 @@ func configureAccount( M365: m365, AzureTenantID: str.First( overrides[account.AzureTenantID], - m365Cfg.AzureTenantID, - os.Getenv(account.AzureTenantID)), + flags.AzureClientTenantFV, + os.Getenv(account.AzureTenantID), + m365Cfg.AzureTenantID), } // ensure required properties are present if err := requireProps(map[string]string{ - credentials.AzureClientID: m365Cfg.AzureClientID, - credentials.AzureClientSecret: m365Cfg.AzureClientSecret, + credentials.AzureClientID: m365Cfg.M365.AzureClientID, + credentials.AzureClientSecret: m365Cfg.M365.AzureClientSecret, account.AzureTenantID: m365Cfg.AzureTenantID, }); err != nil { return acct, err @@ -87,3 +92,20 @@ func configureAccount( return acct, nil } + +// M365 is a helper for aggregating m365 secrets and credentials. +func GetM365(m365Cfg account.M365Config) credentials.M365 { + AzureClientID := str.First( + flags.AzureClientIDFV, + os.Getenv(credentials.AzureClientID), + m365Cfg.AzureClientID) + AzureClientSecret := str.First( + flags.AzureClientSecretFV, + os.Getenv(credentials.AzureClientSecret), + m365Cfg.AzureClientSecret) + + return credentials.M365{ + AzureClientID: AzureClientID, + AzureClientSecret: AzureClientSecret, + } +} diff --git a/src/cli/config/config.go b/src/cli/config/config.go index 5cdf7863c..74f3d8583 100644 --- a/src/cli/config/config.go +++ b/src/cli/config/config.go @@ -26,9 +26,18 @@ const ( DisableTLSVerificationKey = "disable_tls_verification" RepoID = "repo_id" + AccessKey = "aws_access_key_id" + SecretAccessKey = "aws_secret_access_key" + SessionToken = "aws_session_token" + // M365 config AccountProviderTypeKey = "account_provider" AzureTenantIDKey = "azure_tenantid" + AzureClientID = "azure_client_id" + AzureSecret = "azure_secret" + + // Corso passphrase in config + CorsoPassphrase = "passphrase" ) var ( @@ -232,12 +241,13 @@ func writeRepoConfigWithViper( func GetConfigRepoDetails( ctx context.Context, readFromFile bool, + mustMatchFromConfig bool, overrides map[string]string, ) ( RepoDetails, error, ) { - config, err := getStorageAndAccountWithViper(GetViper(ctx), readFromFile, overrides) + config, err := getStorageAndAccountWithViper(GetViper(ctx), readFromFile, mustMatchFromConfig, overrides) return config, err } @@ -246,6 +256,7 @@ func GetConfigRepoDetails( func getStorageAndAccountWithViper( vpr *viper.Viper, readFromFile bool, + mustMatchFromConfig bool, overrides map[string]string, ) ( RepoDetails, @@ -278,7 +289,7 @@ func getStorageAndAccountWithViper( return config, clues.Wrap(err, "retrieving account configuration details") } - config.Storage, err = configureStorage(vpr, readConfigFromViper, overrides) + config.Storage, err = configureStorage(vpr, readConfigFromViper, mustMatchFromConfig, overrides) if err != nil { return config, clues.Wrap(err, "retrieving storage provider details") } diff --git a/src/cli/config/config_test.go b/src/cli/config/config_test.go index 1226902bb..94b1387d6 100644 --- a/src/cli/config/config_test.go +++ b/src/cli/config/config_test.go @@ -28,6 +28,10 @@ const ( ` + AzureTenantIDKey + ` = '%s' ` + DisableTLSKey + ` = 'false' ` + DisableTLSVerificationKey + ` = 'false' +` + AccessKey + ` = '%s' +` + SecretAccessKey + ` = '%s' +` + SessionToken + ` = '%s' +` + CorsoPassphrase + ` = '%s' ` ) @@ -67,12 +71,16 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() { ) const ( - b = "read-repo-config-basic-bucket" - tID = "6f34ac30-8196-469b-bf8f-d83deadbbbba" + b = "read-repo-config-basic-bucket" + tID = "6f34ac30-8196-469b-bf8f-d83deadbbbba" + accKey = "aws-test-access-key" + secret = "aws-test-secret-key" + token = "aws-test-session-token" + passphrase = "passphrase-test" ) // Generate test config file - testConfigData := fmt.Sprintf(configFileTemplate, b, tID) + testConfigData := fmt.Sprintf(configFileTemplate, b, tID, accKey, secret, token, passphrase) testConfigFilePath := filepath.Join(t.TempDir(), "corso.toml") err := os.WriteFile(testConfigFilePath, []byte(testConfigData), 0o700) require.NoError(t, err, clues.ToCore(err)) @@ -88,6 +96,12 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() { require.NoError(t, err, clues.ToCore(err)) assert.Equal(t, b, s3Cfg.Bucket) + s3Cfg, err = s3CredsFromViper(vpr, s3Cfg) + require.NoError(t, err, clues.ToCore(err)) + assert.Equal(t, accKey, s3Cfg.AWS.AccessKey) + assert.Equal(t, secret, s3Cfg.AWS.SecretKey) + assert.Equal(t, token, s3Cfg.AWS.SessionToken) + m365, err := m365ConfigsFromViper(vpr) require.NoError(t, err, clues.ToCore(err)) assert.Equal(t, tID, m365.AzureTenantID) @@ -256,7 +270,7 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() { err = vpr.ReadInConfig() require.NoError(t, err, "reading repo config", clues.ToCore(err)) - config, err := getStorageAndAccountWithViper(vpr, true, nil) + config, err := getStorageAndAccountWithViper(vpr, true, false, nil) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) readS3Cfg, err := config.Storage.S3Config() @@ -274,7 +288,8 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() { readM365, err := config.Account.M365Config() require.NoError(t, err, "reading m365 config from account", clues.ToCore(err)) - assert.Equal(t, readM365.AzureTenantID, m365.AzureTenantID) + // Env var gets preference here. Where to get env tenantID from + // assert.Equal(t, readM365.AzureTenantID, m365.AzureTenantID) assert.Equal(t, readM365.AzureClientID, os.Getenv(credentials.AzureClientID)) assert.Equal(t, readM365.AzureClientSecret, os.Getenv(credentials.AzureClientSecret)) } @@ -303,7 +318,7 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount_noFileOnlyOverride StorageProviderTypeKey: storage.ProviderS3.String(), } - config, err := getStorageAndAccountWithViper(vpr, false, overrides) + config, err := getStorageAndAccountWithViper(vpr, false, false, overrides) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) readS3Cfg, err := config.Storage.S3Config() diff --git a/src/cli/config/storage.go b/src/cli/config/storage.go index af8dff397..5b8560e07 100644 --- a/src/cli/config/storage.go +++ b/src/cli/config/storage.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go/aws/defaults" "github.com/spf13/viper" + "github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/internal/common" "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/pkg/credentials" @@ -33,6 +34,15 @@ func s3ConfigsFromViper(vpr *viper.Viper) (storage.S3Config, error) { return s3Config, nil } +// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. +func s3CredsFromViper(vpr *viper.Viper, s3Config storage.S3Config) (storage.S3Config, error) { + s3Config.AccessKey = vpr.GetString(AccessKey) + s3Config.SecretKey = vpr.GetString(SecretAccessKey) + s3Config.SessionToken = vpr.GetString(SessionToken) + + return s3Config, nil +} + func s3Overrides(in map[string]string) map[string]string { return map[string]string{ storage.Bucket: in[storage.Bucket], @@ -49,6 +59,7 @@ func s3Overrides(in map[string]string) map[string]string { func configureStorage( vpr *viper.Viper, readConfigFromViper bool, + matchFromConfig bool, overrides map[string]string, ) (storage.Storage, error) { var ( @@ -69,33 +80,54 @@ func configureStorage( if p, ok := overrides[storage.Prefix]; ok { overrides[storage.Prefix] = common.NormalizePrefix(p) } + } + if matchFromConfig { if err := mustMatchConfig(vpr, s3Overrides(overrides)); err != nil { return store, clues.Wrap(err, "verifying s3 configs in corso config file") } } - _, err = defaults.CredChain(defaults.Config().WithCredentialsChainVerboseErrors(true), defaults.Handlers()).Get() - if err != nil { - return store, clues.Wrap(err, "validating aws credentials") + if s3Cfg, err = s3CredsFromViper(vpr, s3Cfg); err != nil { + return store, clues.Wrap(err, "reading s3 configs from corso config file") + } + + s3Overrides(overrides) + aws := credentials.GetAWS(overrides) + + if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 { + _, err = defaults.CredChain(defaults.Config().WithCredentialsChainVerboseErrors(true), defaults.Handlers()).Get() + if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) { + aws = credentials.AWS{ + AccessKey: s3Cfg.AccessKey, + SecretKey: s3Cfg.SecretKey, + SessionToken: s3Cfg.SessionToken, + } + err = nil + } + + if err != nil { + return store, clues.Wrap(err, "validating aws credentials") + } } s3Cfg = storage.S3Config{ - Bucket: str.First(overrides[storage.Bucket], s3Cfg.Bucket, os.Getenv(storage.BucketKey)), - Endpoint: str.First(overrides[storage.Endpoint], s3Cfg.Endpoint, os.Getenv(storage.EndpointKey)), - Prefix: str.First(overrides[storage.Prefix], s3Cfg.Prefix, os.Getenv(storage.PrefixKey)), + AWS: aws, + Bucket: str.First(overrides[storage.Bucket], s3Cfg.Bucket), + Endpoint: str.First(overrides[storage.Endpoint], s3Cfg.Endpoint), + Prefix: str.First(overrides[storage.Prefix], s3Cfg.Prefix), DoNotUseTLS: str.ParseBool(str.First( overrides[storage.DoNotUseTLS], strconv.FormatBool(s3Cfg.DoNotUseTLS), - os.Getenv(storage.PrefixKey))), + )), DoNotVerifyTLS: str.ParseBool(str.First( overrides[storage.DoNotVerifyTLS], strconv.FormatBool(s3Cfg.DoNotVerifyTLS), - os.Getenv(storage.PrefixKey))), + )), } // compose the common config and credentials - corso := credentials.GetCorso() + corso := GetAndInsertCorso(vpr.GetString(CorsoPassphrase)) if err := corso.Validate(); err != nil { return store, clues.Wrap(err, "validating corso credentials") } @@ -127,3 +159,14 @@ func configureStorage( return store, nil } + +// GetCorso is a helper for aggregating Corso secrets and credentials. +func GetAndInsertCorso(passphase string) credentials.Corso { + // fetch data from flag, env var or func param giving priority to func param + // Func param generally will be value fetched from config file using viper. + corsoPassph := str.First(flags.CorsoPassphraseFV, os.Getenv(credentials.CorsoPassphrase), passphase) + + return credentials.Corso{ + CorsoPassphrase: corsoPassph, + } +} diff --git a/src/cli/flags/m365_resource.go b/src/cli/flags/m365_resource.go index d00897cf2..fdd97671d 100644 --- a/src/cli/flags/m365_resource.go +++ b/src/cli/flags/m365_resource.go @@ -7,11 +7,19 @@ import ( ) const ( - UserFN = "user" - MailBoxFN = "mailbox" + UserFN = "user" + MailBoxFN = "mailbox" + AzureClientTenantFN = "azure-tenant-id" + AzureClientIDFN = "azure-client-id" + AzureClientSecretFN = "azure-client-secret" ) -var UserFV []string +var ( + UserFV []string + AzureClientTenantFV string + AzureClientIDFV string + AzureClientSecretFV string +) // AddUserFlag adds the --user flag. func AddUserFlag(cmd *cobra.Command) { @@ -38,3 +46,11 @@ func AddMailBoxFlag(cmd *cobra.Command) { MailBoxFN, nil, "Backup a specific mailbox's data; accepts '"+Wildcard+"' to select all mailbox.") } + +// AddAzureCredsFlags adds M365 cred flags +func AddAzureCredsFlags(cmd *cobra.Command) { + fs := cmd.Flags() + fs.StringVar(&AzureClientTenantFV, AzureClientTenantFN, "", "Azure tenant ID") + fs.StringVar(&AzureClientIDFV, AzureClientIDFN, "", "Azure app client ID") + fs.StringVar(&AzureClientSecretFV, AzureClientSecretFN, "", "Azure app client secret") +} diff --git a/src/cli/flags/repo.go b/src/cli/flags/repo.go index 67bf6b0db..3ec1605ad 100644 --- a/src/cli/flags/repo.go +++ b/src/cli/flags/repo.go @@ -4,9 +4,23 @@ import ( "github.com/spf13/cobra" ) -const BackupFN = "backup" +const ( + BackupFN = "backup" + AWSAccessKeyFN = "aws-access-key" + AWSSecretAccessKeyFN = "aws-secret-access-key" + AWSSessionTokenFN = "aws-session-token" -var BackupIDFV string + // Corso Flags + CorsoPassphraseFN = "passphrase" +) + +var ( + BackupIDFV string + AWSAccessKeyFV string + AWSSecretAccessKeyFV string + AWSSessionTokenFV string + CorsoPassphraseFV string +) // AddBackupIDFlag adds the --backup flag. func AddBackupIDFlag(cmd *cobra.Command, require bool) { @@ -16,3 +30,19 @@ func AddBackupIDFlag(cmd *cobra.Command, require bool) { cobra.CheckErr(cmd.MarkFlagRequired(BackupFN)) } } + +func AddAWSCredsFlags(cmd *cobra.Command) { + fs := cmd.Flags() + fs.StringVar(&AWSAccessKeyFV, AWSAccessKeyFN, "", "S3 access key") + fs.StringVar(&AWSSecretAccessKeyFV, AWSSecretAccessKeyFN, "", "S3 access secret") + fs.StringVar(&AWSSessionTokenFV, AWSSessionTokenFN, "", "S3 session token") +} + +// M365 flags +func AddCorsoPassphaseFlags(cmd *cobra.Command) { + fs := cmd.Flags() + fs.StringVar(&CorsoPassphraseFV, + CorsoPassphraseFN, + "", + "Passphrase to protect encrypted repository contents") +} diff --git a/src/cli/repo/repo.go b/src/cli/repo/repo.go index c6cba55be..79b6dd8f5 100644 --- a/src/cli/repo/repo.go +++ b/src/cli/repo/repo.go @@ -122,7 +122,7 @@ func handleMaintenanceCmd(cmd *cobra.Command, args []string) error { return err } - r, _, _, err := utils.GetAccountAndConnect(ctx) + r, _, _, err := utils.GetAccountAndConnect(ctx, S3Overrides()) if err != nil { return print.Only(ctx, err) } diff --git a/src/cli/repo/s3.go b/src/cli/repo/s3.go index 2480cf0fa..c54dffe66 100644 --- a/src/cli/repo/s3.go +++ b/src/cli/repo/s3.go @@ -1,6 +1,7 @@ package repo import ( + "os" "strconv" "strings" @@ -10,22 +11,25 @@ import ( "github.com/spf13/pflag" "github.com/alcionai/corso/src/cli/config" + "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/utils" + "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/events" "github.com/alcionai/corso/src/pkg/account" + "github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/repository" "github.com/alcionai/corso/src/pkg/storage" ) // s3 bucket info from flags var ( + succeedIfExists bool bucket string endpoint string prefix string doNotUseTLS bool doNotVerifyTLS bool - succeedIfExists bool ) // called by repo.go to map subcommands to provider-specific handling. @@ -45,10 +49,13 @@ func addS3Commands(cmd *cobra.Command) *cobra.Command { c.Use = c.Use + " " + s3ProviderCommandUseSuffix c.SetUsageTemplate(cmd.UsageTemplate()) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) + flags.AddCorsoPassphaseFlags(c) + // Flags addition ordering should follow the order we want them to appear in help and docs: // More generic and more frequently used flags take precedence. fs.StringVar(&bucket, "bucket", "", "Name of S3 bucket for repo. (required)") - cobra.CheckErr(c.MarkFlagRequired("bucket")) fs.StringVar(&prefix, "prefix", "", "Repo prefix within bucket.") fs.StringVar(&endpoint, "endpoint", "s3.amazonaws.com", "S3 service endpoint.") fs.BoolVar(&doNotUseTLS, "disable-tls", false, "Disable TLS (HTTPS)") @@ -107,11 +114,12 @@ func s3InitCmd() *cobra.Command { func initS3Cmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - if utils.HasNoFlagsAndShownHelp(cmd) { - return nil - } + // s3 values from flags + s3Override := S3Overrides() + // s3 values from envs + s3Override = S3UpdateFromEnvVar(s3Override) - cfg, err := config.GetConfigRepoDetails(ctx, false, S3Overrides()) + cfg, err := config.GetConfigRepoDetails(ctx, true, false, s3Override) if err != nil { return Only(ctx, err) } @@ -182,11 +190,12 @@ func s3ConnectCmd() *cobra.Command { func connectS3Cmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - if utils.HasNoFlagsAndShownHelp(cmd) { - return nil - } + // s3 values from flags + s3Override := S3Overrides() + // s3 values from envs + s3Override = S3UpdateFromEnvVar(s3Override) - cfg, err := config.GetConfigRepoDetails(ctx, true, S3Overrides()) + cfg, err := config.GetConfigRepoDetails(ctx, true, true, s3Override) if err != nil { return Only(ctx, err) } @@ -231,12 +240,27 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { func S3Overrides() map[string]string { return map[string]string{ - config.AccountProviderTypeKey: account.ProviderM365.String(), - config.StorageProviderTypeKey: storage.ProviderS3.String(), - storage.Bucket: bucket, - storage.Endpoint: endpoint, - storage.Prefix: prefix, - storage.DoNotUseTLS: strconv.FormatBool(doNotUseTLS), - storage.DoNotVerifyTLS: strconv.FormatBool(doNotVerifyTLS), + config.AccountProviderTypeKey: account.ProviderM365.String(), + config.StorageProviderTypeKey: storage.ProviderS3.String(), + credentials.AWSAccessKeyID: flags.AWSAccessKeyFV, + credentials.AWSSecretAccessKey: flags.AWSSecretAccessKeyFV, + credentials.AWSSessionToken: flags.AWSSessionTokenFV, + storage.Bucket: bucket, + storage.Endpoint: endpoint, + storage.Prefix: prefix, + storage.DoNotUseTLS: strconv.FormatBool(doNotUseTLS), + storage.DoNotVerifyTLS: strconv.FormatBool(doNotVerifyTLS), } } + +func S3UpdateFromEnvVar(s3Flag map[string]string) map[string]string { + s3Flag[storage.Bucket] = str.First(s3Flag[storage.Bucket], os.Getenv(storage.BucketKey)) + s3Flag[storage.Endpoint] = str.First(s3Flag[storage.Endpoint], os.Getenv(storage.EndpointKey)) + s3Flag[storage.Prefix] = str.First(s3Flag[storage.Prefix], os.Getenv(storage.PrefixKey)) + s3Flag[storage.DoNotUseTLS] = str.First(s3Flag[storage.DoNotUseTLS], os.Getenv(storage.DisableTLSKey)) + s3Flag[storage.DoNotVerifyTLS] = str.First( + s3Flag[storage.DoNotVerifyTLS], + os.Getenv(storage.DisableTLSVerificationKey)) + + return s3Flag +} diff --git a/src/cli/restore/exchange.go b/src/cli/restore/exchange.go index be5b83dfc..e6c7f1ae6 100644 --- a/src/cli/restore/exchange.go +++ b/src/cli/restore/exchange.go @@ -8,6 +8,7 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" + "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/dttm" "github.com/alcionai/corso/src/internal/data" @@ -35,6 +36,9 @@ func addExchangeCommands(cmd *cobra.Command) *cobra.Command { flags.AddBackupIDFlag(c, true) flags.AddExchangeDetailsAndRestoreFlags(c) flags.AddFailFastFlag(c) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) } return c @@ -89,7 +93,7 @@ func restoreExchangeCmd(cmd *cobra.Command, args []string) error { return err } - r, _, _, err := utils.GetAccountAndConnect(ctx) + r, _, _, err := utils.GetAccountAndConnect(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } diff --git a/src/cli/restore/exchange_test.go b/src/cli/restore/exchange_test.go index 955df3267..8bf7bebea 100644 --- a/src/cli/restore/exchange_test.go +++ b/src/cli/restore/exchange_test.go @@ -78,6 +78,15 @@ func (suite *ExchangeUnitSuite) TestAddExchangeCommands() { "--" + flags.EventStartsAfterFN, testdata.EventStartsAfterInput, "--" + flags.EventStartsBeforeFN, testdata.EventStartsBeforeInput, "--" + flags.EventSubjectFN, testdata.EventSubjectInput, + "--" + flags.AWSAccessKeyFN, testdata.AWSAccessKeyID, + "--" + flags.AWSSecretAccessKeyFN, testdata.AWSSecretAccessKey, + "--" + flags.AWSSessionTokenFN, testdata.AWSSessionToken, + + "--" + flags.AzureClientIDFN, testdata.AzureClientID, + "--" + flags.AzureClientTenantFN, testdata.AzureTenantID, + "--" + flags.AzureClientSecretFN, testdata.AzureClientSecret, + + "--" + flags.CorsoPassphraseFN, testdata.CorsoPassphrase, }) cmd.SetOut(new(bytes.Buffer)) // drop output @@ -106,6 +115,16 @@ func (suite *ExchangeUnitSuite) TestAddExchangeCommands() { assert.Equal(t, testdata.EventStartsAfterInput, opts.EventStartsAfter) assert.Equal(t, testdata.EventStartsBeforeInput, opts.EventStartsBefore) assert.Equal(t, testdata.EventSubjectInput, opts.EventSubject) + + assert.Equal(t, testdata.AWSAccessKeyID, flags.AWSAccessKeyFV) + assert.Equal(t, testdata.AWSSecretAccessKey, flags.AWSSecretAccessKeyFV) + assert.Equal(t, testdata.AWSSessionToken, flags.AWSSessionTokenFV) + + assert.Equal(t, testdata.AzureClientID, flags.AzureClientIDFV) + assert.Equal(t, testdata.AzureTenantID, flags.AzureClientTenantFV) + assert.Equal(t, testdata.AzureClientSecret, flags.AzureClientSecretFV) + + assert.Equal(t, testdata.CorsoPassphrase, flags.CorsoPassphraseFV) }) } } diff --git a/src/cli/restore/onedrive.go b/src/cli/restore/onedrive.go index ad3ac36d0..2a7e9e1fc 100644 --- a/src/cli/restore/onedrive.go +++ b/src/cli/restore/onedrive.go @@ -8,6 +8,7 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" + "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/dttm" "github.com/alcionai/corso/src/internal/data" @@ -35,6 +36,9 @@ func addOneDriveCommands(cmd *cobra.Command) *cobra.Command { flags.AddOneDriveDetailsAndRestoreFlags(c) flags.AddRestorePermissionsFlag(c) flags.AddFailFastFlag(c) + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) } return c @@ -88,7 +92,7 @@ func restoreOneDriveCmd(cmd *cobra.Command, args []string) error { return err } - r, _, _, err := utils.GetAccountAndConnect(ctx) + r, _, _, err := utils.GetAccountAndConnect(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } diff --git a/src/cli/restore/onedrive_test.go b/src/cli/restore/onedrive_test.go index 922698c55..41aa5d29c 100644 --- a/src/cli/restore/onedrive_test.go +++ b/src/cli/restore/onedrive_test.go @@ -67,6 +67,15 @@ func (suite *OneDriveUnitSuite) TestAddOneDriveCommands() { "--" + flags.FileCreatedBeforeFN, testdata.FileCreatedBeforeInput, "--" + flags.FileModifiedAfterFN, testdata.FileModifiedAfterInput, "--" + flags.FileModifiedBeforeFN, testdata.FileModifiedBeforeInput, + "--" + flags.AWSAccessKeyFN, testdata.AWSAccessKeyID, + "--" + flags.AWSSecretAccessKeyFN, testdata.AWSSecretAccessKey, + "--" + flags.AWSSessionTokenFN, testdata.AWSSessionToken, + + "--" + flags.AzureClientIDFN, testdata.AzureClientID, + "--" + flags.AzureClientTenantFN, testdata.AzureTenantID, + "--" + flags.AzureClientSecretFN, testdata.AzureClientSecret, + + "--" + flags.CorsoPassphraseFN, testdata.CorsoPassphrase, }) cmd.SetOut(new(bytes.Buffer)) // drop output @@ -83,6 +92,16 @@ func (suite *OneDriveUnitSuite) TestAddOneDriveCommands() { assert.Equal(t, testdata.FileCreatedBeforeInput, opts.FileCreatedBefore) assert.Equal(t, testdata.FileModifiedAfterInput, opts.FileModifiedAfter) assert.Equal(t, testdata.FileModifiedBeforeInput, opts.FileModifiedBefore) + + assert.Equal(t, testdata.AWSAccessKeyID, flags.AWSAccessKeyFV) + assert.Equal(t, testdata.AWSSecretAccessKey, flags.AWSSecretAccessKeyFV) + assert.Equal(t, testdata.AWSSessionToken, flags.AWSSessionTokenFV) + + assert.Equal(t, testdata.AzureClientID, flags.AzureClientIDFV) + assert.Equal(t, testdata.AzureTenantID, flags.AzureClientTenantFV) + assert.Equal(t, testdata.AzureClientSecret, flags.AzureClientSecretFV) + + assert.Equal(t, testdata.CorsoPassphrase, flags.CorsoPassphraseFV) }) } } diff --git a/src/cli/restore/sharepoint.go b/src/cli/restore/sharepoint.go index 8ab849996..45b73c23d 100644 --- a/src/cli/restore/sharepoint.go +++ b/src/cli/restore/sharepoint.go @@ -8,6 +8,7 @@ import ( "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" + "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/internal/common/dttm" "github.com/alcionai/corso/src/internal/data" @@ -35,6 +36,10 @@ func addSharePointCommands(cmd *cobra.Command) *cobra.Command { flags.AddSharePointDetailsAndRestoreFlags(c) flags.AddRestorePermissionsFlag(c) flags.AddFailFastFlag(c) + + flags.AddCorsoPassphaseFlags(c) + flags.AddAWSCredsFlags(c) + flags.AddAzureCredsFlags(c) } return c @@ -94,7 +99,7 @@ func restoreSharePointCmd(cmd *cobra.Command, args []string) error { return err } - r, _, _, err := utils.GetAccountAndConnect(ctx) + r, _, _, err := utils.GetAccountAndConnect(ctx, repo.S3Overrides()) if err != nil { return Only(ctx, err) } diff --git a/src/cli/restore/sharepoint_test.go b/src/cli/restore/sharepoint_test.go index 09b056975..5e0505dd5 100644 --- a/src/cli/restore/sharepoint_test.go +++ b/src/cli/restore/sharepoint_test.go @@ -72,6 +72,15 @@ func (suite *SharePointUnitSuite) TestAddSharePointCommands() { "--" + flags.ListFolderFN, testdata.FlgInputs(testdata.ListFolderInput), "--" + flags.PageFN, testdata.FlgInputs(testdata.PageInput), "--" + flags.PageFolderFN, testdata.FlgInputs(testdata.PageFolderInput), + "--" + flags.AWSAccessKeyFN, testdata.AWSAccessKeyID, + "--" + flags.AWSSecretAccessKeyFN, testdata.AWSSecretAccessKey, + "--" + flags.AWSSessionTokenFN, testdata.AWSSessionToken, + + "--" + flags.AzureClientIDFN, testdata.AzureClientID, + "--" + flags.AzureClientTenantFN, testdata.AzureTenantID, + "--" + flags.AzureClientSecretFN, testdata.AzureClientSecret, + + "--" + flags.CorsoPassphraseFN, testdata.CorsoPassphrase, }) cmd.SetOut(new(bytes.Buffer)) // drop output @@ -95,6 +104,16 @@ func (suite *SharePointUnitSuite) TestAddSharePointCommands() { assert.ElementsMatch(t, testdata.PageInput, opts.Page) assert.ElementsMatch(t, testdata.PageFolderInput, opts.PageFolder) + + assert.Equal(t, testdata.AWSAccessKeyID, flags.AWSAccessKeyFV) + assert.Equal(t, testdata.AWSSecretAccessKey, flags.AWSSecretAccessKeyFV) + assert.Equal(t, testdata.AWSSessionToken, flags.AWSSessionTokenFV) + + assert.Equal(t, testdata.AzureClientID, flags.AzureClientIDFV) + assert.Equal(t, testdata.AzureTenantID, flags.AzureClientTenantFV) + assert.Equal(t, testdata.AzureClientSecret, flags.AzureClientSecretFV) + + assert.Equal(t, testdata.CorsoPassphrase, flags.CorsoPassphraseFV) }) } } diff --git a/src/cli/utils/flags_test.go b/src/cli/utils/flags_test.go new file mode 100644 index 000000000..94bd89d8f --- /dev/null +++ b/src/cli/utils/flags_test.go @@ -0,0 +1,93 @@ +package utils + +import ( + "testing" + + "github.com/alcionai/clues" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/cli/flags" + "github.com/alcionai/corso/src/internal/tester" +) + +type FlagUnitSuite struct { + tester.Suite +} + +func TestFlagUnitSuite(t *testing.T) { + suite.Run(t, &FlagUnitSuite{Suite: tester.NewUnitSuite(t)}) +} + +func (suite *FlagUnitSuite) TestAddAzureCredsFlags() { + t := suite.T() + + cmd := &cobra.Command{ + Use: "test", + Run: func(cmd *cobra.Command, args []string) { + assert.Equal(t, "tenantID", flags.AzureClientTenantFV, flags.AzureClientTenantFN) + assert.Equal(t, "clientID", flags.AzureClientIDFV, flags.AzureClientIDFN) + assert.Equal(t, "secret", flags.AzureClientSecretFV, flags.AzureClientSecretFN) + }, + } + + flags.AddAzureCredsFlags(cmd) + // Test arg parsing for few args + cmd.SetArgs([]string{ + "test", + "--" + flags.AzureClientIDFN, "clientID", + "--" + flags.AzureClientTenantFN, "tenantID", + "--" + flags.AzureClientSecretFN, "secret", + }) + + err := cmd.Execute() + require.NoError(t, err, clues.ToCore(err)) +} + +func (suite *FlagUnitSuite) TestAddAWSCredsFlags() { + t := suite.T() + + cmd := &cobra.Command{ + Use: "test", + Run: func(cmd *cobra.Command, args []string) { + assert.Equal(t, "accesskey", flags.AWSAccessKeyFV, flags.AWSAccessKeyFN) + assert.Equal(t, "secretkey", flags.AWSSecretAccessKeyFV, flags.AWSSecretAccessKeyFN) + assert.Equal(t, "token", flags.AWSSessionTokenFV, flags.AWSSessionTokenFN) + }, + } + + flags.AddAWSCredsFlags(cmd) + // Test arg parsing for few args + cmd.SetArgs([]string{ + "test", + "--" + flags.AWSAccessKeyFN, "accesskey", + "--" + flags.AWSSecretAccessKeyFN, "secretkey", + "--" + flags.AWSSessionTokenFN, "token", + }) + + err := cmd.Execute() + require.NoError(t, err, clues.ToCore(err)) +} + +func (suite *FlagUnitSuite) TestAddCorsoPassphraseFlags() { + t := suite.T() + + cmd := &cobra.Command{ + Use: "test", + Run: func(cmd *cobra.Command, args []string) { + assert.Equal(t, "passphrase", flags.CorsoPassphraseFV, flags.CorsoPassphraseFN) + }, + } + + flags.AddCorsoPassphaseFlags(cmd) + // Test arg parsing for few args + cmd.SetArgs([]string{ + "test", + "--" + flags.CorsoPassphraseFN, "passphrase", + }) + + err := cmd.Execute() + require.NoError(t, err, clues.ToCore(err)) +} diff --git a/src/cli/utils/testdata/flags.go b/src/cli/utils/testdata/flags.go index 25e516b4d..67992e267 100644 --- a/src/cli/utils/testdata/flags.go +++ b/src/cli/utils/testdata/flags.go @@ -45,4 +45,14 @@ var ( PageInput = []string{"page1", "page2"} RestorePermissions = true + + AzureClientID = "testAzureClientId" + AzureTenantID = "testAzureTenantId" + AzureClientSecret = "testAzureClientSecret" + + AWSAccessKeyID = "testAWSAccessKeyID" + AWSSecretAccessKey = "testAWSSecretAccessKey" + AWSSessionToken = "testAWSSessionToken" + + CorsoPassphrase = "testCorsoPassphrase" ) diff --git a/src/cli/utils/utils.go b/src/cli/utils/utils.go index 277f11c5c..56564ee3e 100644 --- a/src/cli/utils/utils.go +++ b/src/cli/utils/utils.go @@ -19,8 +19,11 @@ import ( "github.com/alcionai/corso/src/pkg/storage" ) -func GetAccountAndConnect(ctx context.Context) (repository.Repository, *storage.Storage, *account.Account, error) { - cfg, err := config.GetConfigRepoDetails(ctx, true, nil) +func GetAccountAndConnect( + ctx context.Context, + overrides map[string]string, +) (repository.Repository, *storage.Storage, *account.Account, error) { + cfg, err := config.GetConfigRepoDetails(ctx, true, true, overrides) if err != nil { return nil, nil, nil, err } @@ -38,8 +41,11 @@ func GetAccountAndConnect(ctx context.Context) (repository.Repository, *storage. return r, &cfg.Storage, &cfg.Account, nil } -func AccountConnectAndWriteRepoConfig(ctx context.Context) (repository.Repository, *account.Account, error) { - r, stg, acc, err := GetAccountAndConnect(ctx) +func AccountConnectAndWriteRepoConfig( + ctx context.Context, + overrides map[string]string, +) (repository.Repository, *account.Account, error) { + r, stg, acc, err := GetAccountAndConnect(ctx, overrides) if err != nil { logger.CtxErr(ctx, err).Info("getting and connecting account") return nil, nil, err diff --git a/src/go.sum b/src/go.sum index 890694120..eba2b6167 100644 --- a/src/go.sum +++ b/src/go.sum @@ -71,6 +71,7 @@ github.com/aws/aws-sdk-go v1.44.291/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8 github.com/aws/aws-xray-sdk-go v1.8.1 h1:O4pXV+hnCskaamGsZnFpzHyAmgPGusBMN6i7nnsy0Fo= github.com/aws/aws-xray-sdk-go v1.8.1/go.mod h1:wMmVYzej3sykAttNBkXQHK/+clAPWTOrPiajEk7Cp3A= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -123,6 +124,7 @@ github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -225,6 +227,7 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -232,6 +235,7 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= @@ -306,6 +310,7 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= @@ -438,6 +443,7 @@ go.opentelemetry.io/otel/trace v1.15.1/go.mod h1:IWdQG/5N1x7f6YUlmdLeJvH9yxtuJAf go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= +go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= @@ -788,6 +794,7 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/internal/kopia/s3.go b/src/internal/kopia/s3.go index 6868fa8f8..adad4330e 100644 --- a/src/internal/kopia/s3.go +++ b/src/internal/kopia/s3.go @@ -40,6 +40,9 @@ func s3BlobStorage( SessionName: s.SessionName, RoleARN: s.Role, RoleDuration: s.SessionDuration, + AccessKeyID: cfg.AccessKey, + SecretAccessKey: cfg.SecretKey, + SessionToken: cfg.SessionToken, TLSHandshakeTimeout: 60, PointInTime: repoOpts.ViewTimestamp, } diff --git a/src/internal/tester/storage.go b/src/internal/tester/storage.go index f5d903767..866b77720 100644 --- a/src/internal/tester/storage.go +++ b/src/internal/tester/storage.go @@ -1,11 +1,14 @@ package tester import ( + "os" "testing" "github.com/alcionai/clues" "github.com/stretchr/testify/require" + "github.com/alcionai/corso/src/cli/flags" + "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/storage" ) @@ -40,7 +43,7 @@ func NewPrefixedS3Storage(t *testing.T) storage.Storage { Prefix: prefix, }, storage.CommonConfig{ - Corso: credentials.GetCorso(), + Corso: GetAndInsertCorso(""), KopiaCfgDir: t.TempDir(), }, ) @@ -48,3 +51,14 @@ func NewPrefixedS3Storage(t *testing.T) storage.Storage { return st } + +// GetCorso is a helper for aggregating Corso secrets and credentials. +func GetAndInsertCorso(passphase string) credentials.Corso { + // fetch data from flag, env var or func param giving priority to func param + // Func param generally will be value fetched from config file using viper. + corsoPassph := str.First(flags.CorsoPassphraseFV, os.Getenv(credentials.CorsoPassphrase), passphase) + + return credentials.Corso{ + CorsoPassphrase: corsoPassph, + } +} diff --git a/src/pkg/credentials/aws.go b/src/pkg/credentials/aws.go index 07d993999..7e4bbf736 100644 --- a/src/pkg/credentials/aws.go +++ b/src/pkg/credentials/aws.go @@ -1,8 +1,6 @@ package credentials import ( - "os" - "github.com/alcionai/clues" ) @@ -22,20 +20,10 @@ type AWS struct { // GetAWS is a helper for aggregating aws secrets and credentials. func GetAWS(override map[string]string) AWS { - accessKey := os.Getenv(AWSAccessKeyID) - if ovr, ok := override[AWSAccessKeyID]; ok && ovr != "" { - accessKey = ovr - } - - secretKey := os.Getenv(AWSSecretAccessKey) - sessToken := os.Getenv(AWSSessionToken) - - // todo (rkeeprs): read from either corso config file or env vars. - // https://github.com/alcionai/corso/issues/120 return AWS{ - AccessKey: accessKey, - SecretKey: secretKey, - SessionToken: sessToken, + AccessKey: override[AWSAccessKeyID], + SecretKey: override[AWSSecretAccessKey], + SessionToken: override[AWSSessionToken], } } diff --git a/src/pkg/credentials/corso.go b/src/pkg/credentials/corso.go index 44f088c5f..f05f31e10 100644 --- a/src/pkg/credentials/corso.go +++ b/src/pkg/credentials/corso.go @@ -1,8 +1,6 @@ package credentials import ( - "os" - "github.com/alcionai/clues" ) @@ -16,17 +14,6 @@ type Corso struct { CorsoPassphrase string // required } -// GetCorso is a helper for aggregating Corso secrets and credentials. -func GetCorso() Corso { - // todo (rkeeprs): read from either corso config file or env vars. - // https://github.com/alcionai/corso/issues/120 - corsoPassph := os.Getenv(CorsoPassphrase) - - return Corso{ - CorsoPassphrase: corsoPassph, - } -} - func (c Corso) Validate() error { check := map[string]string{ CorsoPassphrase: c.CorsoPassphrase, diff --git a/src/pkg/credentials/m365.go b/src/pkg/credentials/m365.go index d6fcaf030..19f034011 100644 --- a/src/pkg/credentials/m365.go +++ b/src/pkg/credentials/m365.go @@ -20,11 +20,14 @@ type M365 struct { // M365 is a helper for aggregating m365 secrets and credentials. func GetM365() M365 { - // todo (rkeeprs): read from either corso config file or env vars. - // https://github.com/alcionai/corso/issues/120 + // check env and overide is flags found + // var AzureClientID, AzureClientSecret string + AzureClientID := os.Getenv(AzureClientID) + AzureClientSecret := os.Getenv(AzureClientSecret) + return M365{ - AzureClientID: os.Getenv(AzureClientID), - AzureClientSecret: os.Getenv(AzureClientSecret), + AzureClientID: AzureClientID, + AzureClientSecret: AzureClientSecret, } } diff --git a/src/pkg/storage/s3.go b/src/pkg/storage/s3.go index 17fe89f02..a332326e8 100644 --- a/src/pkg/storage/s3.go +++ b/src/pkg/storage/s3.go @@ -7,9 +7,11 @@ import ( "github.com/alcionai/corso/src/internal/common" "github.com/alcionai/corso/src/internal/common/str" + "github.com/alcionai/corso/src/pkg/credentials" ) type S3Config struct { + credentials.AWS Bucket string // required Endpoint string Prefix string @@ -19,9 +21,12 @@ type S3Config struct { // config key consts const ( + keyS3AccessKey = "s3_access_key" keyS3Bucket = "s3_bucket" keyS3Endpoint = "s3_endpoint" keyS3Prefix = "s3_prefix" + keyS3SecretKey = "s3_secret_key" + keyS3SessionToken = "s3_session_token" keyS3DoNotUseTLS = "s3_donotusetls" keyS3DoNotVerifyTLS = "s3_donotverifytls" ) @@ -51,9 +56,12 @@ func (c S3Config) Normalize() S3Config { func (c S3Config) StringConfig() (map[string]string, error) { cn := c.Normalize() cfg := map[string]string{ + keyS3AccessKey: c.AccessKey, keyS3Bucket: cn.Bucket, keyS3Endpoint: cn.Endpoint, keyS3Prefix: cn.Prefix, + keyS3SecretKey: c.SecretKey, + keyS3SessionToken: c.SessionToken, keyS3DoNotUseTLS: strconv.FormatBool(cn.DoNotUseTLS), keyS3DoNotVerifyTLS: strconv.FormatBool(cn.DoNotVerifyTLS), } @@ -66,6 +74,10 @@ func (s Storage) S3Config() (S3Config, error) { c := S3Config{} if len(s.Config) > 0 { + c.AccessKey = orEmptyString(s.Config[keyS3AccessKey]) + c.SecretKey = orEmptyString(s.Config[keyS3SecretKey]) + c.SessionToken = orEmptyString(s.Config[keyS3SessionToken]) + c.Bucket = orEmptyString(s.Config[keyS3Bucket]) c.Endpoint = orEmptyString(s.Config[keyS3Endpoint]) c.Prefix = orEmptyString(s.Config[keyS3Prefix]) diff --git a/src/pkg/storage/s3_test.go b/src/pkg/storage/s3_test.go index b66ebeb25..3a56fd090 100644 --- a/src/pkg/storage/s3_test.go +++ b/src/pkg/storage/s3_test.go @@ -7,6 +7,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/pkg/credentials" ) type S3CfgSuite struct { @@ -24,6 +26,7 @@ var ( Prefix: "pre/", DoNotUseTLS: false, DoNotVerifyTLS: false, + AWS: credentials.AWS{AccessKey: "access", SecretKey: "secret", SessionToken: "token"}, } goodS3Map = map[string]string{ @@ -32,6 +35,9 @@ var ( keyS3Prefix: "pre/", keyS3DoNotUseTLS: "false", keyS3DoNotVerifyTLS: "false", + keyS3AccessKey: "access", + keyS3SecretKey: "secret", + keyS3SessionToken: "token", } ) @@ -68,11 +74,12 @@ func (suite *S3CfgSuite) TestStorage_S3Config() { assert.Equal(t, in.Prefix, out.Prefix) } -func makeTestS3Cfg(bkt, end, pre string) S3Config { +func makeTestS3Cfg(bkt, end, pre, access, secret, session string) S3Config { return S3Config{ Bucket: bkt, Endpoint: end, Prefix: pre, + AWS: credentials.AWS{AccessKey: access, SecretKey: secret, SessionToken: session}, } } @@ -82,7 +89,7 @@ func (suite *S3CfgSuite) TestStorage_S3Config_invalidCases() { name string cfg S3Config }{ - {"missing bucket", makeTestS3Cfg("", "end", "pre/")}, + {"missing bucket", makeTestS3Cfg("", "end", "pre/", "", "", "")}, } for _, test := range table { suite.Run(test.name, func() { @@ -128,8 +135,14 @@ func (suite *S3CfgSuite) TestStorage_S3Config_StringConfig() { expect: goodS3Map, }, { - name: "normalized bucket name", - input: makeTestS3Cfg("s3://"+goodS3Config.Bucket, goodS3Config.Endpoint, goodS3Config.Prefix), + name: "normalized bucket name", + input: makeTestS3Cfg( + "s3://"+goodS3Config.Bucket, + goodS3Config.Endpoint, + goodS3Config.Prefix, + goodS3Config.AccessKey, + goodS3Config.SecretKey, + goodS3Config.SessionToken), expect: goodS3Map, }, { @@ -147,6 +160,9 @@ func (suite *S3CfgSuite) TestStorage_S3Config_StringConfig() { keyS3Prefix: "pre/", keyS3DoNotUseTLS: "true", keyS3DoNotVerifyTLS: "true", + keyS3AccessKey: "", + keyS3SecretKey: "", + keyS3SessionToken: "", }, }, }