From 8590b241999b68095492cbae3e7345209e1aa032 Mon Sep 17 00:00:00 2001 From: Abhishek Pandey Date: Mon, 18 Sep 2023 20:02:54 +0530 Subject: [PATCH] Introduce new interfaces for storage configuration (#4251) Introducing a new `Configurer` interface to abstract out storage config information(for s3, filesystem etc) from caller code. I consider this as a short term solution. We need to consolidate overall config handling in a better way, but that's out of scope for this PR chain. Testing * Most of the changes here are code movement under the hood. So relying on existing tests. * I'll address any test gaps in a later PR. --- #### Does this PR need a docs update or release note? - [ ] :white_check_mark: Yes, it's included - [ ] :clock1: Yes, but in a later PR - [x] :no_entry: No #### Type of change - [ ] :sunflower: Feature - [ ] :bug: Bugfix - [ ] :world_map: Documentation - [ ] :robot: Supportability/Tests - [ ] :computer: CI/Deployment - [x] :broom: Tech Debt/Cleanup #### Issue(s) * https://github.com/alcionai/corso/issues/1416 #### Test Plan - [x] :muscle: Manual - [x] :zap: Unit test - [ ] :green_heart: E2E --- src/cli/backup/helpers_test.go | 4 +- src/cli/config/config.go | 35 ++---- src/cli/config/config_test.go | 75 +++++++++---- src/cli/config/storage.go | 113 ++----------------- src/cli/flags/s3.go | 8 +- src/cli/repo/s3.go | 8 +- src/cli/repo/s3_e2e_test.go | 24 +++- src/cli/restore/exchange_e2e_test.go | 4 +- src/cli/utils/utils.go | 6 +- src/cmd/s3checker/s3checker.go | 4 +- src/internal/events/events_test.go | 2 +- src/internal/kopia/s3.go | 4 +- src/pkg/storage/s3.go | 157 ++++++++++++++++++++++++--- src/pkg/storage/s3_test.go | 13 ++- src/pkg/storage/storage.go | 77 +++++++++++++ src/pkg/storage/testdata/storage.go | 2 +- 16 files changed, 344 insertions(+), 192 deletions(-) diff --git a/src/cli/backup/helpers_test.go b/src/cli/backup/helpers_test.go index a54019c1c..cba71af95 100644 --- a/src/cli/backup/helpers_test.go +++ b/src/cli/backup/helpers_test.go @@ -140,9 +140,11 @@ func prepM365Test( recorder = strings.Builder{} ) - cfg, err := st.S3Config() + sc, err := st.StorageConfig() require.NoError(t, err, clues.ToCore(err)) + cfg := sc.(*storage.S3Config) + force := map[string]string{ tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgStorageProvider: storage.ProviderS3.String(), diff --git a/src/cli/config/config.go b/src/cli/config/config.go index 41e26422c..c4805fe69 100644 --- a/src/cli/config/config.go +++ b/src/cli/config/config.go @@ -20,17 +20,7 @@ import ( ) const ( - // S3 config - BucketNameKey = "bucket" - EndpointKey = "endpoint" - PrefixKey = "prefix" - DisableTLSKey = "disable_tls" - DisableTLSVerificationKey = "disable_tls_verification" - RepoID = "repo_id" - - AccessKey = "aws_access_key_id" - SecretAccessKey = "aws_secret_access_key" - SessionToken = "aws_session_token" + RepoID = "repo_id" // Corso passphrase in config CorsoPassphrase = "passphrase" @@ -196,14 +186,14 @@ func Read(ctx context.Context) error { // It does not check for conflicts or existing data. func WriteRepoConfig( ctx context.Context, - s3Config storage.S3Config, + wcs storage.WriteConfigToStorer, m365Config account.M365Config, repoOpts repository.Options, repoID string, ) error { return writeRepoConfigWithViper( GetViper(ctx), - s3Config, + wcs, m365Config, repoOpts, repoID) @@ -213,20 +203,14 @@ func WriteRepoConfig( // struct for testing. func writeRepoConfigWithViper( vpr *viper.Viper, - s3Config storage.S3Config, + wcs storage.WriteConfigToStorer, m365Config account.M365Config, repoOpts repository.Options, repoID string, ) error { - s3Config = s3Config.Normalize() - // Rudimentary support for persisting repo config - // TODO: Handle conflicts, support other config types - vpr.Set(storage.StorageProviderTypeKey, storage.ProviderS3.String()) - vpr.Set(BucketNameKey, s3Config.Bucket) - vpr.Set(EndpointKey, s3Config.Endpoint) - vpr.Set(PrefixKey, s3Config.Prefix) - vpr.Set(DisableTLSKey, s3Config.DoNotUseTLS) - vpr.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS) + // Write storage configuration to viper + wcs.WriteConfigToStore(vpr) + vpr.Set(RepoID, repoID) // Need if-checks as Viper will write empty values otherwise. @@ -339,15 +323,12 @@ func getUserHost(vpr *viper.Viper, readConfigFromViper bool) (string, string) { var constToTomlKeyMap = map[string]string{ account.AzureTenantID: account.AzureTenantIDKey, account.AccountProviderTypeKey: account.AccountProviderTypeKey, - storage.Bucket: BucketNameKey, - storage.Endpoint: EndpointKey, - storage.Prefix: PrefixKey, - storage.StorageProviderTypeKey: storage.StorageProviderTypeKey, } // mustMatchConfig compares the values of each key to their config file value in viper. // If any value differs from the viper value, an error is returned. // values in m that aren't stored in the config are ignored. +// TODO(pandeyabs): This code is currently duplicated in 2 places. func mustMatchConfig(vpr *viper.Viper, m map[string]string) error { for k, v := range m { if len(v) == 0 { diff --git a/src/cli/config/config_test.go b/src/cli/config/config_test.go index d9c5152f1..bccc79601 100644 --- a/src/cli/config/config_test.go +++ b/src/cli/config/config_test.go @@ -26,20 +26,20 @@ import ( const ( configFileTemplate = ` -` + BucketNameKey + ` = '%s' -` + EndpointKey + ` = 's3.amazonaws.com' -` + PrefixKey + ` = 'test-prefix/' +` + storage.BucketNameKey + ` = '%s' +` + storage.EndpointKey + ` = 's3.amazonaws.com' +` + storage.PrefixKey + ` = 'test-prefix/' ` + storage.StorageProviderTypeKey + ` = 'S3' ` + account.AccountProviderTypeKey + ` = 'M365' ` + account.AzureTenantIDKey + ` = '%s' -` + AccessKey + ` = '%s' -` + SecretAccessKey + ` = '%s' -` + SessionToken + ` = '%s' +` + storage.AccessKey + ` = '%s' +` + storage.SecretAccessKey + ` = '%s' +` + storage.SessionToken + ` = '%s' ` + CorsoPassphrase + ` = '%s' ` + account.AzureClientID + ` = '%s' ` + account.AzureSecret + ` = '%s' -` + DisableTLSKey + ` = '%s' -` + DisableTLSVerificationKey + ` = '%s' +` + storage.DisableTLSKey + ` = '%s' +` + storage.DisableTLSVerificationKey + ` = '%s' ` ) @@ -107,18 +107,32 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() { err = vpr.ReadInConfig() require.NoError(t, err, "reading repo config", clues.ToCore(err)) - s3Cfg, err := s3ConfigsFromViper(vpr) + sc, err := storage.NewStorageConfig(storage.ProviderS3) require.NoError(t, err, clues.ToCore(err)) + err = sc.ApplyConfigOverrides(vpr, true, true, nil) + require.NoError(t, err, clues.ToCore(err)) + + s3Cfg := sc.(*storage.S3Config) + assert.Equal(t, b, s3Cfg.Bucket) assert.Equal(t, "test-prefix/", s3Cfg.Prefix) assert.Equal(t, disableTLS, strconv.FormatBool(s3Cfg.DoNotUseTLS)) assert.Equal(t, disableTLSVerification, strconv.FormatBool(s3Cfg.DoNotVerifyTLS)) - s3Cfg, err = s3CredsFromViper(vpr, s3Cfg) - require.NoError(t, err, clues.ToCore(err)) - assert.Equal(t, accKey, s3Cfg.AWS.AccessKey) - assert.Equal(t, secret, s3Cfg.AWS.SecretKey) - assert.Equal(t, token, s3Cfg.AWS.SessionToken) + // Config file may or may not be the source of truth for below values. These may be + // overridden by env vars (and flags but not relevant for this test). + // + // Other alternatives are: + // 1) unset env vars temporarily so that we can test against config file values. But that + // may be problematic if we decide to parallelize tests in future. + // 2) assert against env var values instead of config file values. This can cause issues + // if CI/local env have different config override mechanisms. + // 3) Skip asserts for these keys. They will be validated in other tests. Choosing this + // option. + + // assert.Equal(t, accKey, s3Cfg.AWS.AccessKey) + // assert.Equal(t, secret, s3Cfg.AWS.SecretKey) + // assert.Equal(t, token, s3Cfg.AWS.SessionToken) m365, err := m365ConfigsFromViper(vpr) require.NoError(t, err, clues.ToCore(err)) @@ -146,7 +160,11 @@ func (suite *ConfigSuite) TestWriteReadConfig() { err := initWithViper(vpr, testConfigFilePath) require.NoError(t, err, "initializing repo config", clues.ToCore(err)) - s3Cfg := storage.S3Config{Bucket: bkt, DoNotUseTLS: true, DoNotVerifyTLS: true} + s3Cfg := &storage.S3Config{ + Bucket: bkt, + DoNotUseTLS: true, + DoNotVerifyTLS: true, + } m365 := account.M365Config{AzureTenantID: tid} rOpts := repository.Options{ @@ -160,8 +178,12 @@ func (suite *ConfigSuite) TestWriteReadConfig() { err = vpr.ReadInConfig() require.NoError(t, err, "reading repo config", clues.ToCore(err)) - readS3Cfg, err := s3ConfigsFromViper(vpr) + sc, err := storage.NewStorageConfig(storage.ProviderS3) require.NoError(t, err, clues.ToCore(err)) + err = sc.ApplyConfigOverrides(vpr, true, true, nil) + require.NoError(t, err, clues.ToCore(err)) + + readS3Cfg := sc.(*storage.S3Config) assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) assert.Equal(t, readS3Cfg.DoNotUseTLS, s3Cfg.DoNotUseTLS) assert.Equal(t, readS3Cfg.DoNotVerifyTLS, s3Cfg.DoNotVerifyTLS) @@ -191,7 +213,7 @@ func (suite *ConfigSuite) TestMustMatchConfig() { err := initWithViper(vpr, testConfigFilePath) require.NoError(t, err, "initializing repo config") - s3Cfg := storage.S3Config{Bucket: bkt} + s3Cfg := &storage.S3Config{Bucket: bkt} m365 := account.M365Config{AzureTenantID: tid} err = writeRepoConfigWithViper(vpr, s3Cfg, m365, repository.Options{}, "repoid") @@ -330,9 +352,14 @@ func (suite *ConfigSuite) TestReadFromFlags() { true, false, overrides) + require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) m365Config, _ := repoDetails.Account.M365Config() - s3Cfg, _ := repoDetails.Storage.S3Config() + + sc, err := repoDetails.Storage.StorageConfig() + require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err)) + + s3Cfg := sc.(*storage.S3Config) commonConfig, _ := repoDetails.Storage.CommonConfig() pass := commonConfig.Corso.CorsoPassphrase @@ -386,7 +413,7 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() { err := initWithViper(vpr, testConfigFilePath) require.NoError(t, err, "initializing repo config", clues.ToCore(err)) - s3Cfg := storage.S3Config{ + s3Cfg := &storage.S3Config{ Bucket: bkt, Endpoint: end, Prefix: pfx, @@ -404,8 +431,11 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() { cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, true, true, nil) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) - readS3Cfg, err := cfg.Storage.S3Config() + sc, err := cfg.Storage.StorageConfig() require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err)) + + readS3Cfg := sc.(*storage.S3Config) + assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) assert.Equal(t, readS3Cfg.Endpoint, s3Cfg.Endpoint) assert.Equal(t, readS3Cfg.Prefix, s3Cfg.Prefix) @@ -452,8 +482,11 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount_noFileOnlyOverride cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, false, true, overrides) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) - readS3Cfg, err := cfg.Storage.S3Config() + sc, err := cfg.Storage.StorageConfig() require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err)) + + readS3Cfg := sc.(*storage.S3Config) + assert.Equal(t, readS3Cfg.Bucket, bkt) assert.Equal(t, cfg.RepoID, "") assert.Equal(t, readS3Cfg.Endpoint, end) diff --git a/src/cli/config/storage.go b/src/cli/config/storage.go index e63a723c9..5070ea24c 100644 --- a/src/cli/config/storage.go +++ b/src/cli/config/storage.go @@ -4,52 +4,16 @@ import ( "context" "os" "path/filepath" - "strconv" "github.com/alcionai/clues" - "github.com/aws/aws-sdk-go/aws/defaults" "github.com/spf13/viper" "github.com/alcionai/corso/src/cli/flags" - "github.com/alcionai/corso/src/internal/common" "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/storage" ) -// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. -func s3ConfigsFromViper(vpr *viper.Viper) (storage.S3Config, error) { - var s3Config storage.S3Config - - s3Config.Bucket = vpr.GetString(BucketNameKey) - s3Config.Endpoint = vpr.GetString(EndpointKey) - s3Config.Prefix = vpr.GetString(PrefixKey) - s3Config.DoNotUseTLS = vpr.GetBool(DisableTLSKey) - s3Config.DoNotVerifyTLS = vpr.GetBool(DisableTLSVerificationKey) - - return s3Config, nil -} - -// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. -func s3CredsFromViper(vpr *viper.Viper, s3Config storage.S3Config) (storage.S3Config, error) { - s3Config.AccessKey = vpr.GetString(AccessKey) - s3Config.SecretKey = vpr.GetString(SecretAccessKey) - s3Config.SessionToken = vpr.GetString(SessionToken) - - return s3Config, nil -} - -func s3Overrides(in map[string]string) map[string]string { - return map[string]string{ - storage.Bucket: in[storage.Bucket], - storage.Endpoint: in[storage.Endpoint], - storage.Prefix: in[storage.Prefix], - storage.DoNotUseTLS: in[storage.DoNotUseTLS], - storage.DoNotVerifyTLS: in[storage.DoNotVerifyTLS], - storage.StorageProviderTypeKey: in[storage.StorageProviderTypeKey], - } -} - // configureStorage builds a complete storage configuration from a mix of // viper properties and manual overrides. func configureStorage( @@ -59,72 +23,20 @@ func configureStorage( matchFromConfig bool, overrides map[string]string, ) (storage.Storage, error) { - var ( - s3Cfg storage.S3Config - store storage.Storage - err error - ) + var store storage.Storage - if readConfigFromViper { - if s3Cfg, err = s3ConfigsFromViper(vpr); err != nil { - return store, clues.Wrap(err, "reading s3 configs from corso config file") - } - - if b, ok := overrides[storage.Bucket]; ok { - overrides[storage.Bucket] = common.NormalizeBucket(b) - } - - if p, ok := overrides[storage.Prefix]; ok { - overrides[storage.Prefix] = common.NormalizePrefix(p) - } - - if matchFromConfig { - providerType := vpr.GetString(storage.StorageProviderTypeKey) - if providerType != storage.ProviderS3.String() { - return store, clues.New("unsupported storage provider: " + providerType) - } - - if err := mustMatchConfig(vpr, s3Overrides(overrides)); err != nil { - return store, clues.Wrap(err, "verifying s3 configs in corso config file") - } - } + sc, err := storage.NewStorageConfig(provider) + if err != nil { + return store, clues.Stack(err) } - if s3Cfg, err = s3CredsFromViper(vpr, s3Cfg); err != nil { - return store, clues.Wrap(err, "reading s3 configs from corso config file") - } - - aws := credentials.GetAWS(overrides) - - if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 { - _, err = defaults.CredChain(defaults.Config().WithCredentialsChainVerboseErrors(true), defaults.Handlers()).Get() - if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) { - aws = credentials.AWS{ - AccessKey: s3Cfg.AccessKey, - SecretKey: s3Cfg.SecretKey, - SessionToken: s3Cfg.SessionToken, - } - err = nil - } - - if err != nil { - return store, clues.Wrap(err, "validating aws credentials") - } - } - - s3Cfg = storage.S3Config{ - AWS: aws, - Bucket: str.First(overrides[storage.Bucket], s3Cfg.Bucket), - Endpoint: str.First(overrides[storage.Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"), - Prefix: str.First(overrides[storage.Prefix], s3Cfg.Prefix), - DoNotUseTLS: str.ParseBool(str.First( - overrides[storage.DoNotUseTLS], - strconv.FormatBool(s3Cfg.DoNotUseTLS), - "false")), - DoNotVerifyTLS: str.ParseBool(str.First( - overrides[storage.DoNotVerifyTLS], - strconv.FormatBool(s3Cfg.DoNotVerifyTLS), - "false")), + err = sc.ApplyConfigOverrides( + vpr, + readConfigFromViper, + matchFromConfig, + overrides) + if err != nil { + return store, clues.Stack(err) } // compose the common config and credentials @@ -146,14 +58,13 @@ func configureStorage( // ensure required properties are present if err := requireProps(map[string]string{ - storage.Bucket: s3Cfg.Bucket, credentials.CorsoPassphrase: corso.CorsoPassphrase, }); err != nil { return storage.Storage{}, err } // build the storage - store, err = storage.NewStorage(provider, s3Cfg, cCfg) + store, err = storage.NewStorage(provider, sc, cCfg) if err != nil { return store, clues.Wrap(err, "configuring repository storage") } diff --git a/src/cli/flags/s3.go b/src/cli/flags/s3.go index a7e4e490e..5d641f544 100644 --- a/src/cli/flags/s3.go +++ b/src/cli/flags/s3.go @@ -5,7 +5,6 @@ import ( "github.com/spf13/cobra" - "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/storage" ) @@ -54,10 +53,9 @@ func S3FlagOverrides(cmd *cobra.Command) map[string]string { } func PopulateS3Flags(flagset PopulatedFlags) map[string]string { - s3Overrides := make(map[string]string) - // TODO(pandeyabs): Move account overrides out of s3 flags - s3Overrides[account.AccountProviderTypeKey] = account.ProviderM365.String() - s3Overrides[storage.StorageProviderTypeKey] = storage.ProviderS3.String() + s3Overrides := map[string]string{ + storage.StorageProviderTypeKey: storage.ProviderS3.String(), + } if _, ok := flagset[AWSAccessKeyFN]; ok { s3Overrides[credentials.AWSAccessKeyID] = AWSAccessKeyFV diff --git a/src/cli/repo/s3.go b/src/cli/repo/s3.go index cbe83c951..fa46715bd 100644 --- a/src/cli/repo/s3.go +++ b/src/cli/repo/s3.go @@ -112,11 +112,13 @@ func initS3Cmd(cmd *cobra.Command, args []string) error { cfg.Account.ID(), opt) - s3Cfg, err := cfg.Storage.S3Config() + sc, err := cfg.Storage.StorageConfig() if err != nil { return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) } + s3Cfg := sc.(*storage.S3Config) + if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { invalidEndpointErr := "endpoint doesn't support specifying protocol. " + "pass --disable-tls flag to use http:// instead of default https://" @@ -189,11 +191,13 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { repoID = events.RepoIDNotFound } - s3Cfg, err := cfg.Storage.S3Config() + sc, err := cfg.Storage.StorageConfig() if err != nil { return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) } + s3Cfg := sc.(*storage.S3Config) + m365, err := cfg.Account.M365Config() if err != nil { return Only(ctx, clues.Wrap(err, "Failed to parse m365 account config")) diff --git a/src/cli/repo/s3_e2e_test.go b/src/cli/repo/s3_e2e_test.go index 28dc7c3c7..55dbcab3a 100644 --- a/src/cli/repo/s3_e2e_test.go +++ b/src/cli/repo/s3_e2e_test.go @@ -63,8 +63,10 @@ func (suite *S3E2ESuite) TestInitS3Cmd() { defer flush() st := storeTD.NewPrefixedS3Storage(t) - cfg, err := st.S3Config() + + sc, err := st.StorageConfig() require.NoError(t, err, clues.ToCore(err)) + cfg := sc.(*storage.S3Config) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) if !test.hasConfigFile { @@ -100,9 +102,11 @@ func (suite *S3E2ESuite) TestInitMultipleTimes() { defer flush() st := storeTD.NewPrefixedS3Storage(t) - cfg, err := st.S3Config() + sc, err := st.StorageConfig() require.NoError(t, err, clues.ToCore(err)) + cfg := sc.(*storage.S3Config) + vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) ctx = config.SetViper(ctx, vpr) @@ -129,9 +133,12 @@ func (suite *S3E2ESuite) TestInitS3Cmd_missingBucket() { defer flush() st := storeTD.NewPrefixedS3Storage(t) - cfg, err := st.S3Config() + + sc, err := st.StorageConfig() require.NoError(t, err, clues.ToCore(err)) + cfg := sc.(*storage.S3Config) + force := map[string]string{ tconfig.TestCfgBucket: "", } @@ -182,8 +189,9 @@ func (suite *S3E2ESuite) TestConnectS3Cmd() { defer flush() st := storeTD.NewPrefixedS3Storage(t) - cfg, err := st.S3Config() + sc, err := st.StorageConfig() require.NoError(t, err, clues.ToCore(err)) + cfg := sc.(*storage.S3Config) force := map[string]string{ tconfig.TestCfgAccountProvider: account.ProviderM365.String(), @@ -230,9 +238,11 @@ func (suite *S3E2ESuite) TestConnectS3Cmd_BadBucket() { defer flush() st := storeTD.NewPrefixedS3Storage(t) - cfg, err := st.S3Config() + sc, err := st.StorageConfig() require.NoError(t, err, clues.ToCore(err)) + cfg := sc.(*storage.S3Config) + vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) ctx = config.SetViper(ctx, vpr) @@ -256,9 +266,11 @@ func (suite *S3E2ESuite) TestConnectS3Cmd_BadPrefix() { defer flush() st := storeTD.NewPrefixedS3Storage(t) - cfg, err := st.S3Config() + sc, err := st.StorageConfig() require.NoError(t, err, clues.ToCore(err)) + cfg := sc.(*storage.S3Config) + vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) ctx = config.SetViper(ctx, vpr) diff --git a/src/cli/restore/exchange_e2e_test.go b/src/cli/restore/exchange_e2e_test.go index 4b8b79e75..5b9f9a637 100644 --- a/src/cli/restore/exchange_e2e_test.go +++ b/src/cli/restore/exchange_e2e_test.go @@ -66,9 +66,11 @@ func (suite *RestoreExchangeE2ESuite) SetupSuite() { suite.acct = tconfig.NewM365Account(t) suite.st = storeTD.NewPrefixedS3Storage(t) - cfg, err := suite.st.S3Config() + sc, err := suite.st.StorageConfig() require.NoError(t, err, clues.ToCore(err)) + cfg := sc.(*storage.S3Config) + force := map[string]string{ tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgStorageProvider: storage.ProviderS3.String(), diff --git a/src/cli/utils/utils.go b/src/cli/utils/utils.go index d1bc9cb0c..1df944763 100644 --- a/src/cli/utils/utils.go +++ b/src/cli/utils/utils.go @@ -32,7 +32,7 @@ func GetAccountAndConnectWithOverrides( ) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) { provider, overrides, err := GetStorageProviderAndOverrides(ctx, cmd) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, clues.Stack(err) } return GetAccountAndConnect(ctx, pst, provider, overrides) @@ -89,12 +89,14 @@ func AccountConnectAndWriteRepoConfig( return nil, nil, err } - s3Config, err := stg.S3Config() + sc, err := stg.StorageConfig() if err != nil { logger.CtxErr(ctx, err).Info("getting storage configuration") return nil, nil, err } + s3Config := sc.(*storage.S3Config) + m365Config, err := acc.M365Config() if err != nil { logger.CtxErr(ctx, err).Info("getting m365 configuration") diff --git a/src/cmd/s3checker/s3checker.go b/src/cmd/s3checker/s3checker.go index 086cb51a4..0c42b8aa5 100644 --- a/src/cmd/s3checker/s3checker.go +++ b/src/cmd/s3checker/s3checker.go @@ -197,11 +197,13 @@ func handleCheckerCommand(cmd *cobra.Command, args []string, f flags) error { return clues.Wrap(err, "getting storage config") } - cfg, err := repoDetails.Storage.S3Config() + sc, err := repoDetails.Storage.StorageConfig() if err != nil { return clues.Wrap(err, "getting S3 config") } + cfg := sc.(*storage.S3Config) + endpoint := defaultS3Endpoint if len(cfg.Endpoint) > 0 { endpoint = cfg.Endpoint diff --git a/src/internal/events/events_test.go b/src/internal/events/events_test.go index 662fbeafa..77444727c 100644 --- a/src/internal/events/events_test.go +++ b/src/internal/events/events_test.go @@ -33,7 +33,7 @@ func (suite *EventsIntegrationSuite) TestNewBus() { s, err := storage.NewStorage( storage.ProviderS3, - storage.S3Config{ + &storage.S3Config{ Bucket: "bckt", Prefix: "prfx", }) diff --git a/src/internal/kopia/s3.go b/src/internal/kopia/s3.go index adad4330e..f4a379ada 100644 --- a/src/internal/kopia/s3.go +++ b/src/internal/kopia/s3.go @@ -20,11 +20,13 @@ func s3BlobStorage( repoOpts repository.Options, s storage.Storage, ) (blob.Storage, error) { - cfg, err := s.S3Config() + sc, err := s.StorageConfig() if err != nil { return nil, clues.Stack(err).WithClues(ctx) } + cfg := sc.(*storage.S3Config) + endpoint := defaultS3Endpoint if len(cfg.Endpoint) > 0 { endpoint = cfg.Endpoint diff --git a/src/pkg/storage/s3.go b/src/pkg/storage/s3.go index a332326e8..c689e77cd 100644 --- a/src/pkg/storage/s3.go +++ b/src/pkg/storage/s3.go @@ -1,9 +1,11 @@ package storage import ( + "os" "strconv" "github.com/alcionai/clues" + "github.com/spf13/cast" "github.com/alcionai/corso/src/internal/common" "github.com/alcionai/corso/src/internal/common/str" @@ -40,7 +42,27 @@ const ( DoNotVerifyTLS = "donotverifytls" ) -func (c S3Config) Normalize() S3Config { +// config file keys +const ( + BucketNameKey = "bucket" + EndpointKey = "endpoint" + PrefixKey = "prefix" + DisableTLSKey = "disable_tls" + DisableTLSVerificationKey = "disable_tls_verification" + + AccessKey = "aws_access_key_id" + SecretAccessKey = "aws_secret_access_key" + SessionToken = "aws_session_token" +) + +var s3constToTomlKeyMap = map[string]string{ + Bucket: BucketNameKey, + Endpoint: EndpointKey, + Prefix: PrefixKey, + StorageProviderTypeKey: StorageProviderTypeKey, +} + +func (c *S3Config) normalize() S3Config { return S3Config{ Bucket: common.NormalizeBucket(c.Bucket), Endpoint: c.Endpoint, @@ -53,8 +75,8 @@ func (c S3Config) Normalize() S3Config { // StringConfig transforms a s3Config struct into a plain // map[string]string. All values in the original struct which // serialize into the map are expected to be strings. -func (c S3Config) StringConfig() (map[string]string, error) { - cn := c.Normalize() +func (c *S3Config) StringConfig() (map[string]string, error) { + cn := c.normalize() cfg := map[string]string{ keyS3AccessKey: c.AccessKey, keyS3Bucket: cn.Bucket, @@ -66,23 +88,22 @@ func (c S3Config) StringConfig() (map[string]string, error) { keyS3DoNotVerifyTLS: strconv.FormatBool(cn.DoNotVerifyTLS), } - return cfg, c.validate() + return cfg, cn.validate() } -// S3Config retrieves the S3Config details from the Storage config. -func (s Storage) S3Config() (S3Config, error) { - c := S3Config{} +func buildS3ConfigFromMap(config map[string]string) (*S3Config, error) { + c := &S3Config{} - if len(s.Config) > 0 { - c.AccessKey = orEmptyString(s.Config[keyS3AccessKey]) - c.SecretKey = orEmptyString(s.Config[keyS3SecretKey]) - c.SessionToken = orEmptyString(s.Config[keyS3SessionToken]) + if len(config) > 0 { + c.AccessKey = orEmptyString(config[keyS3AccessKey]) + c.SecretKey = orEmptyString(config[keyS3SecretKey]) + c.SessionToken = orEmptyString(config[keyS3SessionToken]) - c.Bucket = orEmptyString(s.Config[keyS3Bucket]) - c.Endpoint = orEmptyString(s.Config[keyS3Endpoint]) - c.Prefix = orEmptyString(s.Config[keyS3Prefix]) - c.DoNotUseTLS = str.ParseBool(s.Config[keyS3DoNotUseTLS]) - c.DoNotVerifyTLS = str.ParseBool(s.Config[keyS3DoNotVerifyTLS]) + c.Bucket = orEmptyString(config[keyS3Bucket]) + c.Endpoint = orEmptyString(config[keyS3Endpoint]) + c.Prefix = orEmptyString(config[keyS3Prefix]) + c.DoNotUseTLS = str.ParseBool(config[keyS3DoNotUseTLS]) + c.DoNotVerifyTLS = str.ParseBool(config[keyS3DoNotVerifyTLS]) } return c, c.validate() @@ -100,3 +121,107 @@ func (c S3Config) validate() error { return nil } + +func s3Overrides(in map[string]string) map[string]string { + return map[string]string{ + Bucket: in[Bucket], + Endpoint: in[Endpoint], + Prefix: in[Prefix], + DoNotUseTLS: in[DoNotUseTLS], + DoNotVerifyTLS: in[DoNotVerifyTLS], + StorageProviderTypeKey: in[StorageProviderTypeKey], + } +} + +func (c *S3Config) s3ConfigsFromStore(kvg Getter) { + c.Bucket = cast.ToString(kvg.Get(BucketNameKey)) + c.Endpoint = cast.ToString(kvg.Get(EndpointKey)) + c.Prefix = cast.ToString(kvg.Get(PrefixKey)) + c.DoNotUseTLS = cast.ToBool(kvg.Get(DisableTLSKey)) + c.DoNotVerifyTLS = cast.ToBool(kvg.Get(DisableTLSVerificationKey)) +} + +func (c *S3Config) s3CredsFromStore(kvg Getter) { + c.AccessKey = cast.ToString(kvg.Get(AccessKey)) + c.SecretKey = cast.ToString(kvg.Get(SecretAccessKey)) + c.SessionToken = cast.ToString(kvg.Get(SessionToken)) +} + +var _ Configurer = &S3Config{} + +func (c *S3Config) ApplyConfigOverrides( + kvg Getter, + readConfigFromStore bool, + matchFromConfig bool, + overrides map[string]string, +) error { + if readConfigFromStore { + c.s3ConfigsFromStore(kvg) + + if b, ok := overrides[Bucket]; ok { + overrides[Bucket] = common.NormalizeBucket(b) + } + + if p, ok := overrides[Prefix]; ok { + overrides[Prefix] = common.NormalizePrefix(p) + } + + if matchFromConfig { + providerType := cast.ToString(kvg.Get(StorageProviderTypeKey)) + if providerType != ProviderS3.String() { + return clues.New("unsupported storage provider: " + providerType) + } + + if err := mustMatchConfig(kvg, s3constToTomlKeyMap, s3Overrides(overrides)); err != nil { + return clues.Wrap(err, "verifying s3 configs in corso config file") + } + } + } + + c.s3CredsFromStore(kvg) + + aws := credentials.AWS{ + AccessKey: str.First( + overrides[credentials.AWSAccessKeyID], + os.Getenv(credentials.AWSAccessKeyID), + c.AccessKey), + SecretKey: str.First( + overrides[credentials.AWSSecretAccessKey], + os.Getenv(credentials.AWSSecretAccessKey), + c.SecretKey), + SessionToken: str.First( + overrides[credentials.AWSSessionToken], + os.Getenv(credentials.AWSSessionToken), + c.SessionToken), + } + + c.AWS = aws + c.Bucket = str.First(overrides[Bucket], c.Bucket) + c.Endpoint = str.First(overrides[Endpoint], c.Endpoint, "s3.amazonaws.com") + c.Prefix = str.First(overrides[Prefix], c.Prefix) + c.DoNotUseTLS = str.ParseBool(str.First( + overrides[DoNotUseTLS], + strconv.FormatBool(c.DoNotUseTLS), + "false")) + c.DoNotVerifyTLS = str.ParseBool(str.First( + overrides[DoNotVerifyTLS], + strconv.FormatBool(c.DoNotVerifyTLS), + "false")) + + return c.validate() +} + +var _ WriteConfigToStorer = &S3Config{} + +func (c *S3Config) WriteConfigToStore( + kvs Setter, +) { + s3Config := c.normalize() + + kvs.Set(StorageProviderTypeKey, ProviderS3.String()) + kvs.Set(BucketNameKey, s3Config.Bucket) + kvs.Set(EndpointKey, s3Config.Endpoint) + kvs.Set(PrefixKey, s3Config.Prefix) + kvs.Set(DisableTLSKey, s3Config.DoNotUseTLS) + kvs.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS) +} diff --git a/src/pkg/storage/s3_test.go b/src/pkg/storage/s3_test.go index 3a56fd090..2a4b239f9 100644 --- a/src/pkg/storage/s3_test.go +++ b/src/pkg/storage/s3_test.go @@ -64,11 +64,12 @@ func (suite *S3CfgSuite) TestStorage_S3Config() { t := suite.T() in := goodS3Config - s, err := NewStorage(ProviderS3, in) + s, err := NewStorage(ProviderS3, &in) assert.NoError(t, err, clues.ToCore(err)) - out, err := s.S3Config() + sc, err := s.StorageConfig() assert.NoError(t, err, clues.ToCore(err)) + out := sc.(*S3Config) assert.Equal(t, in.Bucket, out.Bucket) assert.Equal(t, in.Endpoint, out.Endpoint) assert.Equal(t, in.Prefix, out.Prefix) @@ -93,7 +94,7 @@ func (suite *S3CfgSuite) TestStorage_S3Config_invalidCases() { } for _, test := range table { suite.Run(test.name, func() { - _, err := NewStorage(ProviderUnknown, test.cfg) + _, err := NewStorage(ProviderUnknown, &test.cfg) assert.Error(suite.T(), err) }) } @@ -114,10 +115,10 @@ func (suite *S3CfgSuite) TestStorage_S3Config_invalidCases() { suite.Run(test.name, func() { t := suite.T() - st, err := NewStorage(ProviderUnknown, goodS3Config) + st, err := NewStorage(ProviderUnknown, &goodS3Config) assert.NoError(t, err, clues.ToCore(err)) test.amend(st) - _, err = st.S3Config() + _, err = st.StorageConfig() assert.Error(t, err) }) } @@ -187,7 +188,7 @@ func (suite *S3CfgSuite) TestStorage_S3Config_Normalize() { Bucket: prefixedBkt, } - result := st.Normalize() + result := st.normalize() assert.Equal(suite.T(), normalBkt, result.Bucket) assert.NotEqual(suite.T(), st.Bucket, result.Bucket) } diff --git a/src/pkg/storage/storage.go b/src/pkg/storage/storage.go index 3bf9d82f7..5926734ec 100644 --- a/src/pkg/storage/storage.go +++ b/src/pkg/storage/storage.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/alcionai/clues" + "github.com/spf13/cast" "github.com/alcionai/corso/src/internal/common" ) @@ -92,3 +93,79 @@ func orEmptyString(v any) string { return v.(string) } + +func (s Storage) StorageConfig() (Configurer, error) { + switch s.Provider { + case ProviderS3: + return buildS3ConfigFromMap(s.Config) + } + + return nil, clues.New("unsupported storage provider: " + s.Provider.String()) +} + +func NewStorageConfig(provider ProviderType) (Configurer, error) { + switch provider { + case ProviderS3: + return &S3Config{}, nil + } + + return nil, clues.New("unsupported storage provider: " + provider.String()) +} + +type Getter interface { + Get(key string) any +} + +type Setter interface { + Set(key string, value any) +} + +// WriteConfigToStorer writes config key value pairs to provided store. +type WriteConfigToStorer interface { + WriteConfigToStore( + s Setter, + ) +} + +type Configurer interface { + common.StringConfigurer + + // ApplyOverrides fetches config from file, processes overrides + // from sources like environment variables and flags, and updates the + // underlying configuration accordingly. + ApplyConfigOverrides( + g Getter, + readConfigFromStore bool, + matchFromConfig bool, + overrides map[string]string, + ) error + + WriteConfigToStorer +} + +// mustMatchConfig compares the values of each key to their config file value in store. +// If any value differs from the store value, an error is returned. +// values in m that aren't stored in the config are ignored. +func mustMatchConfig( + g Getter, + tomlMap map[string]string, + m map[string]string, +) error { + for k, v := range m { + if len(v) == 0 { + continue // empty variables will get caught by configuration validators, if necessary + } + + tomlK, ok := tomlMap[k] + if !ok { + continue // m may declare values which aren't stored in the config file + } + + vv := cast.ToString(g.Get(tomlK)) + if v != vv { + return clues.New("value of " + k + " (" + v + ") does not match corso configuration value (" + vv + ")") + } + } + + return nil +} diff --git a/src/pkg/storage/testdata/storage.go b/src/pkg/storage/testdata/storage.go index 631d29b80..853aff13d 100644 --- a/src/pkg/storage/testdata/storage.go +++ b/src/pkg/storage/testdata/storage.go @@ -38,7 +38,7 @@ func NewPrefixedS3Storage(t tester.TestT) storage.Storage { st, err := storage.NewStorage( storage.ProviderS3, - storage.S3Config{ + &storage.S3Config{ Bucket: cfg[tconfig.TestCfgBucket], Prefix: prefix, },