From 60eb8eec083d2b008daebe18873f0c54cad8a982 Mon Sep 17 00:00:00 2001 From: Keepers <104464746+ryanfkeepers@users.noreply.github.com> Date: Mon, 27 Jun 2022 15:25:06 -0600 Subject: [PATCH] refactor cli/config for better local testing (#245) --- src/cli/config/config.go | 92 ++++++++++++++--------- src/cli/config/config_test.go | 38 ++++++---- src/pkg/account/accountprovider_string.go | 24 ++++++ 3 files changed, 104 insertions(+), 50 deletions(-) create mode 100644 src/pkg/account/accountprovider_string.go diff --git a/src/cli/config/config.go b/src/cli/config/config.go index 39dabad6c..911477865 100644 --- a/src/cli/config/config.go +++ b/src/cli/config/config.go @@ -16,13 +16,14 @@ import ( const ( // S3 config - ProviderTypeKey = "provider" - BucketNameKey = "bucket" - EndpointKey = "endpoint" - PrefixKey = "prefix" + StorageProviderTypeKey = "provider" + BucketNameKey = "bucket" + EndpointKey = "endpoint" + PrefixKey = "prefix" // M365 config - TenantIDKey = "tenantid" + AccountProviderTypeKey = "account_provider" + TenantIDKey = "tenantid" ) func InitConfig(configFilePath string) error { @@ -64,7 +65,7 @@ func InitConfig(configFilePath string) error { func WriteRepoConfig(s3Config storage.S3Config, account account.M365Config) error { // Rudimentary support for persisting repo config // TODO: Handle conflicts, support other config types - viper.Set(ProviderTypeKey, storage.ProviderS3.String()) + viper.Set(StorageProviderTypeKey, storage.ProviderS3.String()) viper.Set(BucketNameKey, s3Config.Bucket) viper.Set(EndpointKey, s3Config.Endpoint) viper.Set(PrefixKey, s3Config.Prefix) @@ -79,57 +80,80 @@ func WriteRepoConfig(s3Config storage.S3Config, account account.M365Config) erro return nil } -func ReadRepoConfig() (storage.S3Config, account.Account, error) { - var ( - s3Config storage.S3Config - acct account.Account - err error - ) +func readRepoConfig() error { + var err error if err = viper.ReadInConfig(); err != nil { - return s3Config, acct, errors.Wrap(err, "reading config file: "+viper.ConfigFileUsed()) + return errors.Wrap(err, "reading config file: "+viper.ConfigFileUsed()) } - if providerType := viper.GetString(ProviderTypeKey); providerType != storage.ProviderS3.String() { - return s3Config, acct, errors.New("Unsupported storage provider: " + providerType) + return err +} + +// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. +func s3ConfigsFromViper() (storage.S3Config, error) { + var s3Config storage.S3Config + + providerType := viper.GetString(StorageProviderTypeKey) + if providerType != storage.ProviderS3.String() { + return s3Config, errors.New("unsupported storage provider: " + providerType) } s3Config.Bucket = viper.GetString(BucketNameKey) s3Config.Endpoint = viper.GetString(EndpointKey) s3Config.Prefix = viper.GetString(PrefixKey) + return s3Config, nil +} - m365Creds := credentials.GetM365() - tenantID := os.Getenv(account.TenantID) - cfgTenantID := viper.GetString(TenantIDKey) - if len(cfgTenantID) > 0 { - tenantID = cfgTenantID +// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. +func m365ConfigsFromViper() (account.M365Config, error) { + var m365 account.M365Config + + providerType := viper.GetString(AccountProviderTypeKey) + if providerType != account.ProviderM365.String() { + return m365, errors.New("unsupported account provider: " + providerType) } - acct, err = account.NewAccount( - account.ProviderM365, - account.M365Config{ - M365: m365Creds, - TenantID: tenantID, - }, - ) - return s3Config, acct, err + m365.TenantID = first(viper.GetString(TenantIDKey), os.Getenv(account.TenantID)) + return m365, nil } // GetStorageAndAccount creates a storage and account instance by mediating all the possible // data sources (config file, env vars, flag overrides) and the config file. func GetStorageAndAccount(readFromFile bool, overrides map[string]string) (storage.Storage, account.Account, error) { var ( - s3Cfg storage.S3Config - acct account.Account - err error + s3Cfg storage.S3Config + store storage.Storage + m365Cfg account.M365Config + acct account.Account + err error ) // possibly read the prior config from a .corso file if readFromFile { - s3Cfg, acct, err = ReadRepoConfig() - if err != nil { - return storage.Storage{}, acct, errors.Wrap(err, "reading corso config file") + if err = readRepoConfig(); err != nil { + return store, acct, errors.Wrap(err, "reading corso config file") } + + s3Cfg, err = s3ConfigsFromViper() + if err != nil { + return store, acct, errors.Wrap(err, "reading s3 configs from corso config file") + } + + m365Cfg, err = m365ConfigsFromViper() + if err != nil { + return store, acct, errors.Wrap(err, "reading m365 configs from corso config file") + } + } + + // compose the m365 account config and credentials + m365Cfg = account.M365Config{ + M365: credentials.GetM365(), + TenantID: first(overrides[account.TenantID], m365Cfg.TenantID, os.Getenv(account.TenantID)), + } + acct, err = account.NewAccount(account.ProviderM365, m365Cfg) + if err != nil { + return store, acct, errors.Wrap(err, "retrieving m365 account configuration") } // compose the s3 storage config and credentials diff --git a/src/cli/config/config_test.go b/src/cli/config/config_test.go index 8942407e6..acad314b3 100644 --- a/src/cli/config/config_test.go +++ b/src/cli/config/config_test.go @@ -1,4 +1,4 @@ -package config_test +package config import ( "fmt" @@ -11,8 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/alcionai/corso/cli/config" - ctesting "github.com/alcionai/corso/internal/testing" + "github.com/alcionai/corso/pkg/account" "github.com/alcionai/corso/pkg/storage" ) @@ -22,6 +21,7 @@ bucket = '%s' endpoint = 's3.amazonaws.com' prefix = 'test-prefix' provider = 'S3' +account_provider = 'M365' tenantid = '%s' ` ) @@ -47,11 +47,14 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() { viper.SetConfigFile(testConfigFilePath) // Read and validate config - s3Cfg, account, err := config.ReadRepoConfig() - assert.NoError(suite.T(), err) + err = readRepoConfig() + require.NoError(suite.T(), err) + + s3Cfg, err := s3ConfigsFromViper() + require.NoError(suite.T(), err) assert.Equal(suite.T(), b, s3Cfg.Bucket) - m365, err := account.M365Config() + m365, err := m365ConfigsFromViper() require.NoError(suite.T(), err) assert.Equal(suite.T(), tID, m365.TenantID) } @@ -60,20 +63,23 @@ func (suite *ConfigSuite) TestWriteReadConfig() { // Configure viper to read test config file tempDir := suite.T().TempDir() testConfigFilePath := path.Join(tempDir, "corso.toml") - err := config.InitConfig(testConfigFilePath) + err := InitConfig(testConfigFilePath) assert.NoError(suite.T(), err) s3Cfg := storage.S3Config{Bucket: "write-read-config-bucket"} - acct, err := ctesting.NewM365Account() - require.NoError(suite.T(), err) - m365, err := acct.M365Config() + m365 := account.M365Config{TenantID: "3c0748d2-470e-444c-9064-1268e52609d5"} + + err = WriteRepoConfig(s3Cfg, m365) require.NoError(suite.T(), err) - err = config.WriteRepoConfig(s3Cfg, m365) - assert.NoError(suite.T(), err) + err = readRepoConfig() + require.NoError(suite.T(), err) - readS3Cfg, readAccount, err := config.ReadRepoConfig() - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), s3Cfg, readS3Cfg) - assert.Equal(suite.T(), acct, readAccount) + readS3Cfg, err := s3ConfigsFromViper() + require.NoError(suite.T(), err) + assert.Equal(suite.T(), readS3Cfg.Bucket, s3Cfg.Bucket) + + readM365, err := m365ConfigsFromViper() + require.NoError(suite.T(), err) + assert.Equal(suite.T(), readM365.TenantID, m365.TenantID) } diff --git a/src/pkg/account/accountprovider_string.go b/src/pkg/account/accountprovider_string.go new file mode 100644 index 000000000..f055843ea --- /dev/null +++ b/src/pkg/account/accountprovider_string.go @@ -0,0 +1,24 @@ +// Code generated by "stringer -type=accountProvider -linecomment"; DO NOT EDIT. + +package account + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[ProviderUnknown-0] + _ = x[ProviderM365-1] +} + +const _accountProvider_name = "Unknown ProviderM365" + +var _accountProvider_index = [...]uint8{0, 16, 20} + +func (i accountProvider) String() string { + if i < 0 || i >= accountProvider(len(_accountProvider_index)-1) { + return "accountProvider(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _accountProvider_name[_accountProvider_index[i]:_accountProvider_index[i+1]] +}