diff --git a/src/cli/config/account.go b/src/cli/config/account.go new file mode 100644 index 000000000..ded37d6a6 --- /dev/null +++ b/src/cli/config/account.go @@ -0,0 +1,80 @@ +package config + +import ( + "os" + + "github.com/alcionai/corso/cli/utils" + "github.com/alcionai/corso/pkg/account" + "github.com/alcionai/corso/pkg/credentials" + "github.com/pkg/errors" + "github.com/spf13/viper" +) + +// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. +func m365ConfigsFromViper(vpr *viper.Viper) (account.M365Config, error) { + var m365 account.M365Config + + providerType := vpr.GetString(AccountProviderTypeKey) + if providerType != account.ProviderM365.String() { + return m365, errors.New("unsupported account provider: " + providerType) + } + + m365.TenantID = vpr.GetString(TenantIDKey) + return m365, nil +} + +func m365Overrides(in map[string]string) map[string]string { + return map[string]string{ + account.TenantID: in[account.TenantID], + AccountProviderTypeKey: in[AccountProviderTypeKey], + } +} + +// configureAccount builds a complete account configuration from a mix of +// viper properties and manual overrides. +func configureAccount(vpr *viper.Viper, readConfigFromViper bool, overrides map[string]string) (account.Account, error) { + var ( + m365Cfg account.M365Config + acct account.Account + err error + ) + + if readConfigFromViper { + m365Cfg, err = m365ConfigsFromViper(vpr) + if err != nil { + return acct, errors.Wrap(err, "reading m365 configs from corso config file") + } + + if err := mustMatchConfig(vpr, m365Overrides(overrides)); err != nil { + return acct, errors.Wrap(err, "verifying m365 configs in corso config file") + } + } + + // compose the m365 config and credentials + m365 := credentials.GetM365() + if err := m365.Validate(); err != nil { + return acct, errors.Wrap(err, "validating m365 credentials") + } + + m365Cfg = account.M365Config{ + M365: m365, + TenantID: first(overrides[account.TenantID], m365Cfg.TenantID, os.Getenv(account.TenantID)), + } + + // ensure requried properties are present + if err := utils.RequireProps(map[string]string{ + credentials.ClientID: m365Cfg.ClientID, + credentials.ClientSecret: m365Cfg.ClientSecret, + account.TenantID: m365Cfg.TenantID, + }); err != nil { + return acct, err + } + + // build the account + acct, err = account.NewAccount(account.ProviderM365, m365Cfg) + if err != nil { + return acct, errors.Wrap(err, "retrieving m365 account configuration") + } + + return acct, nil +} diff --git a/src/cli/config/config.go b/src/cli/config/config.go index 911477865..9ac65545a 100644 --- a/src/cli/config/config.go +++ b/src/cli/config/config.go @@ -8,9 +8,7 @@ import ( "github.com/pkg/errors" "github.com/spf13/viper" - "github.com/alcionai/corso/cli/utils" "github.com/alcionai/corso/pkg/account" - "github.com/alcionai/corso/pkg/credentials" "github.com/alcionai/corso/pkg/storage" ) @@ -27,6 +25,12 @@ const ( ) func InitConfig(configFilePath string) error { + return initConfigWithViper(viper.GetViper(), configFilePath) +} + +// initConfigWithViper implements InitConfig, but takes in a viper +// struct for testing. +func initConfigWithViper(vpr *viper.Viper, configFilePath string) error { // Configure default config file location if configFilePath == "" { // Find home directory. @@ -36,18 +40,17 @@ func InitConfig(configFilePath string) error { } // Search config in home directory with name ".corso" (without extension). - viper.AddConfigPath(home) - viper.SetConfigType("toml") - viper.SetConfigName(".corso") + vpr.AddConfigPath(home) + vpr.SetConfigType("toml") + vpr.SetConfigName(".corso") return nil } - // Use a custom file location - viper.SetConfigFile(configFilePath) + vpr.SetConfigFile(configFilePath) // We also configure the path, type and filename - // because `viper.SafeWriteConfig` needs these set to + // because `vpr.SafeWriteConfig` needs these set to // work correctly (it does not use the configured file) - viper.AddConfigPath(path.Dir(configFilePath)) + vpr.AddConfigPath(path.Dir(configFilePath)) fileName := path.Base(configFilePath) ext := path.Ext(configFilePath) @@ -55,147 +58,85 @@ func InitConfig(configFilePath string) error { return errors.New("config file requires an extension e.g. `toml`") } fileName = strings.TrimSuffix(fileName, ext) - viper.SetConfigType(ext[1:]) - viper.SetConfigName(fileName) + vpr.SetConfigType(ext[1:]) + vpr.SetConfigName(fileName) + return nil } // WriteRepoConfig currently just persists corso config to the config file // It does not check for conflicts or existing data. -func WriteRepoConfig(s3Config storage.S3Config, account account.M365Config) error { +func WriteRepoConfig(s3Config storage.S3Config, m365Config account.M365Config) error { + return writeRepoConfigWithViper(viper.GetViper(), s3Config, m365Config) +} + +// writeRepoConfigWithViper implements WriteRepoConfig, but takes in a viper +// struct for testing. +func writeRepoConfigWithViper(vpr *viper.Viper, s3Config storage.S3Config, m365Config account.M365Config) error { // Rudimentary support for persisting repo config // TODO: Handle conflicts, support other config types - viper.Set(StorageProviderTypeKey, storage.ProviderS3.String()) - viper.Set(BucketNameKey, s3Config.Bucket) - viper.Set(EndpointKey, s3Config.Endpoint) - viper.Set(PrefixKey, s3Config.Prefix) - viper.Set(TenantIDKey, account.TenantID) + vpr.Set(StorageProviderTypeKey, storage.ProviderS3.String()) + vpr.Set(BucketNameKey, s3Config.Bucket) + vpr.Set(EndpointKey, s3Config.Endpoint) + vpr.Set(PrefixKey, s3Config.Prefix) - if err := viper.SafeWriteConfig(); err != nil { + vpr.Set(AccountProviderTypeKey, account.ProviderM365.String()) + vpr.Set(TenantIDKey, m365Config.TenantID) + + if err := vpr.SafeWriteConfig(); err != nil { if _, ok := err.(viper.ConfigFileAlreadyExistsError); ok { - return viper.GetViper().WriteConfig() + return vpr.WriteConfig() } return err } return nil } -func readRepoConfig() error { - var err error - - if err = viper.ReadInConfig(); err != nil { - return errors.Wrap(err, "reading config file: "+viper.ConfigFileUsed()) - } - - return err -} - -// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. -func s3ConfigsFromViper() (storage.S3Config, error) { - var s3Config storage.S3Config - - providerType := viper.GetString(StorageProviderTypeKey) - if providerType != storage.ProviderS3.String() { - return s3Config, errors.New("unsupported storage provider: " + providerType) - } - - s3Config.Bucket = viper.GetString(BucketNameKey) - s3Config.Endpoint = viper.GetString(EndpointKey) - s3Config.Prefix = viper.GetString(PrefixKey) - return s3Config, nil -} - -// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. -func m365ConfigsFromViper() (account.M365Config, error) { - var m365 account.M365Config - - providerType := viper.GetString(AccountProviderTypeKey) - if providerType != account.ProviderM365.String() { - return m365, errors.New("unsupported account provider: " + providerType) - } - - m365.TenantID = first(viper.GetString(TenantIDKey), os.Getenv(account.TenantID)) - return m365, nil -} - // GetStorageAndAccount creates a storage and account instance by mediating all the possible // data sources (config file, env vars, flag overrides) and the config file. func GetStorageAndAccount(readFromFile bool, overrides map[string]string) (storage.Storage, account.Account, error) { + return getStorageAndAccountWithViper(viper.GetViper(), readFromFile, overrides) +} + +// getSorageAndAccountWithViper implements GetSorageAndAccount, but takes in a viper +// struct for testing. +func getStorageAndAccountWithViper(vpr *viper.Viper, readFromFile bool, overrides map[string]string) (storage.Storage, account.Account, error) { var ( - s3Cfg storage.S3Config - store storage.Storage - m365Cfg account.M365Config - acct account.Account - err error + store storage.Storage + acct account.Account + err error ) + readConfigFromViper := readFromFile + // possibly read the prior config from a .corso file if readFromFile { - if err = readRepoConfig(); err != nil { - return store, acct, errors.Wrap(err, "reading corso config file") - } - - s3Cfg, err = s3ConfigsFromViper() + err = vpr.ReadInConfig() if err != nil { - return store, acct, errors.Wrap(err, "reading s3 configs from corso config file") - } - - m365Cfg, err = m365ConfigsFromViper() - if err != nil { - return store, acct, errors.Wrap(err, "reading m365 configs from corso config file") + if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + return store, acct, errors.Wrap(err, "reading corso config file: "+vpr.ConfigFileUsed()) + } + readConfigFromViper = false } } - // compose the m365 account config and credentials - m365Cfg = account.M365Config{ - M365: credentials.GetM365(), - TenantID: first(overrides[account.TenantID], m365Cfg.TenantID, os.Getenv(account.TenantID)), - } - acct, err = account.NewAccount(account.ProviderM365, m365Cfg) + acct, err = configureAccount(vpr, readConfigFromViper, overrides) if err != nil { - return store, acct, errors.Wrap(err, "retrieving m365 account configuration") + return store, acct, errors.Wrap(err, "retrieving account configuration details") } - // compose the s3 storage config and credentials - aws := credentials.GetAWS(overrides) - if err := aws.Validate(); err != nil { - return storage.Storage{}, acct, errors.Wrap(err, "validating aws credentials") - } - s3Cfg = storage.S3Config{ - AWS: aws, - Bucket: first(overrides[storage.Bucket], s3Cfg.Bucket), - Endpoint: first(overrides[storage.Endpoint], s3Cfg.Endpoint), - Prefix: first(overrides[storage.Prefix], s3Cfg.Prefix), - } - - // compose the common config and credentials - corso := credentials.GetCorso() - if err := corso.Validate(); err != nil { - return storage.Storage{}, acct, errors.Wrap(err, "validating corso credentials") - } - cCfg := storage.CommonConfig{ - Corso: corso, - } - - // ensure requried properties are present - if err := utils.RequireProps(map[string]string{ - credentials.AWSAccessKeyID: aws.AccessKey, - storage.Bucket: s3Cfg.Bucket, - credentials.AWSSecretAccessKey: aws.SecretKey, - credentials.AWSSessionToken: aws.SessionToken, - credentials.CorsoPassword: corso.CorsoPassword, - }); err != nil { - return storage.Storage{}, acct, err - } - - // return a complete storage - s, err := storage.NewStorage(storage.ProviderS3, s3Cfg, cCfg) + store, err = configureStorage(vpr, readConfigFromViper, overrides) if err != nil { - return storage.Storage{}, acct, errors.Wrap(err, "configuring repository storage") + return store, acct, errors.Wrap(err, "retrieving storage provider details") } - return s, acct, nil + + return store, acct, nil } +// --------------------------------------------------------------------------- +// Helper funcs +// --------------------------------------------------------------------------- + // returns the first non-zero valued string func first(vs ...string) string { for _, v := range vs { @@ -205,3 +146,32 @@ func first(vs ...string) string { } return "" } + +var constToTomlKeyMap = map[string]string{ + account.TenantID: TenantIDKey, + AccountProviderTypeKey: AccountProviderTypeKey, + storage.Bucket: BucketNameKey, + storage.Endpoint: EndpointKey, + storage.Prefix: PrefixKey, + StorageProviderTypeKey: StorageProviderTypeKey, +} + +// mustMatchConfig compares the values of each key to their config file value in viper. +// If any value differs from the viper value, an error is returned. +// values in m that aren't stored in the config are ignored. +func mustMatchConfig(vpr *viper.Viper, m map[string]string) error { + for k, v := range m { + if len(v) == 0 { + continue // empty variables will get caught by configuration validators, if necessary + } + tomlK, ok := constToTomlKeyMap[k] + if !ok { + continue // m may declare values which aren't stored in the config file + } + vv := vpr.GetString(tomlK) + if v != vv { + return errors.New("value of " + k + " (" + v + ") does not match corso configuration value (" + vv + ")") + } + } + return nil +} diff --git a/src/cli/config/config_test.go b/src/cli/config/config_test.go index acad314b3..4700fc95a 100644 --- a/src/cli/config/config_test.go +++ b/src/cli/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "fmt" "io/ioutil" + "os" "path" "testing" @@ -11,18 +12,20 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + ctesting "github.com/alcionai/corso/internal/testing" "github.com/alcionai/corso/pkg/account" + "github.com/alcionai/corso/pkg/credentials" "github.com/alcionai/corso/pkg/storage" ) const ( configFileTemplate = ` -bucket = '%s' -endpoint = 's3.amazonaws.com' -prefix = 'test-prefix' -provider = 'S3' -account_provider = 'M365' -tenantid = '%s' +` + BucketNameKey + ` = '%s' +` + EndpointKey + ` = 's3.amazonaws.com' +` + PrefixKey + ` = 'test-prefix' +` + StorageProviderTypeKey + ` = 'S3' +` + AccountProviderTypeKey + ` = 'M365' +` + TenantIDKey + ` = '%s' ` ) @@ -35,51 +38,268 @@ func TestConfigSuite(t *testing.T) { } func (suite *ConfigSuite) TestReadRepoConfigBasic() { + var ( + t = suite.T() + vpr = viper.New() + ) + + const ( + b = "read-repo-config-basic-bucket" + tID = "6f34ac30-8196-469b-bf8f-d83deadbbbba" + ) + // Generate test config file - b := "read-repo-config-basic-bucket" - tID := "6f34ac30-8196-469b-bf8f-d83deadbbbba" testConfigData := fmt.Sprintf(configFileTemplate, b, tID) - testConfigFilePath := path.Join(suite.T().TempDir(), "corso.toml") + testConfigFilePath := path.Join(t.TempDir(), "corso.toml") err := ioutil.WriteFile(testConfigFilePath, []byte(testConfigData), 0700) - assert.NoError(suite.T(), err) + require.NoError(t, err) // Configure viper to read test config file - viper.SetConfigFile(testConfigFilePath) + vpr.SetConfigFile(testConfigFilePath) // Read and validate config - err = readRepoConfig() - require.NoError(suite.T(), err) + require.NoError(t, vpr.ReadInConfig(), "reading repo config") - s3Cfg, err := s3ConfigsFromViper() - require.NoError(suite.T(), err) - assert.Equal(suite.T(), b, s3Cfg.Bucket) + s3Cfg, err := s3ConfigsFromViper(vpr) + require.NoError(t, err) + assert.Equal(t, b, s3Cfg.Bucket) - m365, err := m365ConfigsFromViper() - require.NoError(suite.T(), err) - assert.Equal(suite.T(), tID, m365.TenantID) + m365, err := m365ConfigsFromViper(vpr) + require.NoError(t, err) + assert.Equal(t, tID, m365.TenantID) } func (suite *ConfigSuite) TestWriteReadConfig() { + var ( + t = suite.T() + vpr = viper.New() + ) + + const ( + bkt = "write-read-config-bucket" + tid = "3c0748d2-470e-444c-9064-1268e52609d5" + ) + // Configure viper to read test config file - tempDir := suite.T().TempDir() - testConfigFilePath := path.Join(tempDir, "corso.toml") - err := InitConfig(testConfigFilePath) - assert.NoError(suite.T(), err) + testConfigFilePath := path.Join(t.TempDir(), "corso.toml") + require.NoError(t, initConfigWithViper(vpr, testConfigFilePath), "initializing repo config") - s3Cfg := storage.S3Config{Bucket: "write-read-config-bucket"} - m365 := account.M365Config{TenantID: "3c0748d2-470e-444c-9064-1268e52609d5"} + s3Cfg := storage.S3Config{Bucket: bkt} + m365 := account.M365Config{TenantID: tid} - err = WriteRepoConfig(s3Cfg, m365) - require.NoError(suite.T(), err) + require.NoError(t, writeRepoConfigWithViper(vpr, s3Cfg, m365), "writing repo config") + require.NoError(t, vpr.ReadInConfig(), "reading repo config") - err = readRepoConfig() - require.NoError(suite.T(), err) + readS3Cfg, err := s3ConfigsFromViper(vpr) + require.NoError(t, err) + assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) - readS3Cfg, err := s3ConfigsFromViper() - require.NoError(suite.T(), err) - assert.Equal(suite.T(), readS3Cfg.Bucket, s3Cfg.Bucket) - - readM365, err := m365ConfigsFromViper() - require.NoError(suite.T(), err) - assert.Equal(suite.T(), readM365.TenantID, m365.TenantID) + readM365, err := m365ConfigsFromViper(vpr) + require.NoError(t, err) + assert.Equal(t, readM365.TenantID, m365.TenantID) +} + +func (suite *ConfigSuite) TestMustMatchConfig() { + var ( + t = suite.T() + vpr = viper.New() + ) + + const ( + bkt = "must-match-config-bucket" + tid = "dfb12063-7598-458b-85ab-42352c5c25e2" + ) + + // Configure viper to read test config file + testConfigFilePath := path.Join(t.TempDir(), "corso.toml") + require.NoError(t, initConfigWithViper(vpr, testConfigFilePath), "initializing repo config") + + s3Cfg := storage.S3Config{Bucket: bkt} + m365 := account.M365Config{TenantID: tid} + + require.NoError(t, writeRepoConfigWithViper(vpr, s3Cfg, m365), "writing repo config") + require.NoError(t, vpr.ReadInConfig(), "reading repo config") + + table := []struct { + name string + input map[string]string + errCheck assert.ErrorAssertionFunc + }{ + { + name: "full match", + input: map[string]string{ + storage.Bucket: bkt, + account.TenantID: tid, + }, + errCheck: assert.NoError, + }, + { + name: "empty values", + input: map[string]string{ + storage.Bucket: "", + account.TenantID: "", + }, + errCheck: assert.NoError, + }, + { + name: "no overrides", + input: map[string]string{}, + errCheck: assert.NoError, + }, + { + name: "nil map", + input: nil, + errCheck: assert.NoError, + }, + { + name: "no recognized keys", + input: map[string]string{ + "fnords": "smurfs", + "nonsense": "", + }, + errCheck: assert.NoError, + }, + { + name: "mismatch", + input: map[string]string{ + storage.Bucket: tid, + account.TenantID: bkt, + }, + errCheck: assert.Error, + }, + } + for _, test := range table { + t.Run(test.name, func(t *testing.T) { + test.errCheck(t, mustMatchConfig(vpr, test.input)) + }) + } +} + +// ------------------------------------------------------------ +// integration tests +// ------------------------------------------------------------ + +type ConfigIntegrationSuite struct { + suite.Suite +} + +func TestConfigIntegrationSuite(t *testing.T) { + if err := ctesting.RunOnAny( + ctesting.CorsoCITests, + ctesting.CorsoCLIConfigTests, + ); err != nil { + t.Skip(err) + } + suite.Run(t, new(ConfigIntegrationSuite)) +} + +func (suite *ConfigIntegrationSuite) SetupSuite() { + _, err := ctesting.GetRequiredEnvVars( + append( + ctesting.AWSStorageCredEnvs, + ctesting.M365AcctCredEnvs..., + )..., + ) + require.NoError(suite.T(), err) +} + +func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() { + t := suite.T() + vpr := viper.New() + + const ( + bkt = "get-storage-and-account-bucket" + end = "https://get-storage-and-account.com" + pfx = "get-storage-and-account-prefix" + tid = "3a2faa4e-a882-445c-9d27-f552ef189381" + ) + + // Configure viper to read test config file + testConfigFilePath := path.Join(t.TempDir(), "corso.toml") + require.NoError(t, initConfigWithViper(vpr, testConfigFilePath), "initializing repo config") + + s3Cfg := storage.S3Config{ + Bucket: bkt, + Endpoint: end, + Prefix: pfx, + } + m365 := account.M365Config{TenantID: tid} + + require.NoError(t, writeRepoConfigWithViper(vpr, s3Cfg, m365), "writing repo config") + require.NoError(t, vpr.ReadInConfig(), "reading repo config") + + st, ac, err := getStorageAndAccountWithViper(vpr, true, nil) + require.NoError(t, err, "getting storage and account from config") + + readS3Cfg, err := st.S3Config() + require.NoError(t, err, "reading s3 config from storage") + assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) + assert.Equal(t, readS3Cfg.Endpoint, s3Cfg.Endpoint) + assert.Equal(t, readS3Cfg.Prefix, s3Cfg.Prefix) + + assert.Equal(t, readS3Cfg.AccessKey, os.Getenv(credentials.AWSAccessKeyID)) + assert.Equal(t, readS3Cfg.SecretKey, os.Getenv(credentials.AWSSecretAccessKey)) + assert.Equal(t, readS3Cfg.SessionToken, os.Getenv(credentials.AWSSessionToken)) + + common, err := st.CommonConfig() + require.NoError(t, err, "reading common config from storage") + assert.Equal(t, common.CorsoPassword, os.Getenv(credentials.CorsoPassword)) + + readM365, err := ac.M365Config() + require.NoError(t, err, "reading m365 config from account") + assert.Equal(t, readM365.TenantID, m365.TenantID) + assert.Equal(t, readM365.ClientID, os.Getenv(credentials.ClientID)) + assert.Equal(t, readM365.ClientSecret, os.Getenv(credentials.ClientSecret)) +} + +func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount_noFileOnlyOverrides() { + t := suite.T() + vpr := viper.New() + + const ( + bkt = "get-storage-and-account-no-file-bucket" + end = "https://get-storage-and-account.com/no-file" + pfx = "get-storage-and-account-no-file-prefix" + tid = "88f8522b-18e4-4d0f-b514-2d7b34d4c5a1" + ) + + // Configure viper to read test config file + s3Cfg := storage.S3Config{ + Bucket: bkt, + Endpoint: end, + Prefix: pfx, + } + m365 := account.M365Config{TenantID: tid} + + overrides := map[string]string{ + account.TenantID: tid, + AccountProviderTypeKey: account.ProviderM365.String(), + storage.Bucket: bkt, + storage.Endpoint: end, + storage.Prefix: pfx, + StorageProviderTypeKey: storage.ProviderS3.String(), + } + + st, ac, err := getStorageAndAccountWithViper(vpr, false, overrides) + require.NoError(t, err, "getting storage and account from config") + + readS3Cfg, err := st.S3Config() + require.NoError(t, err, "reading s3 config from storage") + assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) + assert.Equal(t, readS3Cfg.Endpoint, s3Cfg.Endpoint) + assert.Equal(t, readS3Cfg.Prefix, s3Cfg.Prefix) + + assert.Equal(t, readS3Cfg.AccessKey, os.Getenv(credentials.AWSAccessKeyID)) + assert.Equal(t, readS3Cfg.SecretKey, os.Getenv(credentials.AWSSecretAccessKey)) + assert.Equal(t, readS3Cfg.SessionToken, os.Getenv(credentials.AWSSessionToken)) + + common, err := st.CommonConfig() + require.NoError(t, err, "reading common config from storage") + assert.Equal(t, common.CorsoPassword, os.Getenv(credentials.CorsoPassword)) + + readM365, err := ac.M365Config() + require.NoError(t, err, "reading m365 config from account") + assert.Equal(t, readM365.TenantID, m365.TenantID) + assert.Equal(t, readM365.ClientID, os.Getenv(credentials.ClientID)) + assert.Equal(t, readM365.ClientSecret, os.Getenv(credentials.ClientSecret)) } diff --git a/src/cli/config/storage.go b/src/cli/config/storage.go new file mode 100644 index 000000000..e173a83bf --- /dev/null +++ b/src/cli/config/storage.go @@ -0,0 +1,94 @@ +package config + +import ( + "github.com/alcionai/corso/cli/utils" + "github.com/alcionai/corso/pkg/credentials" + "github.com/alcionai/corso/pkg/storage" + "github.com/pkg/errors" + "github.com/spf13/viper" +) + +// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. +func s3ConfigsFromViper(vpr *viper.Viper) (storage.S3Config, error) { + var s3Config storage.S3Config + + providerType := vpr.GetString(StorageProviderTypeKey) + if providerType != storage.ProviderS3.String() { + return s3Config, errors.New("unsupported storage provider: " + providerType) + } + + s3Config.Bucket = vpr.GetString(BucketNameKey) + s3Config.Endpoint = vpr.GetString(EndpointKey) + s3Config.Prefix = vpr.GetString(PrefixKey) + return s3Config, nil +} + +func s3Overrides(in map[string]string) map[string]string { + return map[string]string{ + storage.Bucket: in[storage.Bucket], + storage.Endpoint: in[storage.Endpoint], + storage.Prefix: in[storage.Prefix], + StorageProviderTypeKey: in[StorageProviderTypeKey], + } +} + +// configureStorage builds a complete storage configuration from a mix of +// viper properties and manual overrides. +func configureStorage(vpr *viper.Viper, readConfigFromViper bool, overrides map[string]string) (storage.Storage, error) { + var ( + s3Cfg storage.S3Config + store storage.Storage + err error + ) + + if readConfigFromViper { + if s3Cfg, err = s3ConfigsFromViper(vpr); err != nil { + return store, errors.Wrap(err, "reading s3 configs from corso config file") + } + + if err := mustMatchConfig(vpr, s3Overrides(overrides)); err != nil { + return store, errors.Wrap(err, "verifying s3 configs in corso config file") + } + } + + // compose the s3 storage config and credentials + aws := credentials.GetAWS(overrides) + if err := aws.Validate(); err != nil { + return store, errors.Wrap(err, "validating aws credentials") + } + + s3Cfg = storage.S3Config{ + AWS: aws, + Bucket: first(overrides[storage.Bucket], s3Cfg.Bucket), + Endpoint: first(overrides[storage.Endpoint], s3Cfg.Endpoint), + Prefix: first(overrides[storage.Prefix], s3Cfg.Prefix), + } + + // compose the common config and credentials + corso := credentials.GetCorso() + if err := corso.Validate(); err != nil { + return store, errors.Wrap(err, "validating corso credentials") + } + cCfg := storage.CommonConfig{ + Corso: corso, + } + + // ensure requried properties are present + if err := utils.RequireProps(map[string]string{ + credentials.AWSAccessKeyID: aws.AccessKey, + storage.Bucket: s3Cfg.Bucket, + credentials.AWSSecretAccessKey: aws.SecretKey, + credentials.AWSSessionToken: aws.SessionToken, + credentials.CorsoPassword: corso.CorsoPassword, + }); err != nil { + return storage.Storage{}, err + } + + // build the storage + store, err = storage.NewStorage(storage.ProviderS3, s3Cfg, cCfg) + if err != nil { + return store, errors.Wrap(err, "configuring repository storage") + } + + return store, nil +} diff --git a/src/cli/repo/s3.go b/src/cli/repo/s3.go index 18b9f647f..d7389379d 100644 --- a/src/cli/repo/s3.go +++ b/src/cli/repo/s3.go @@ -9,6 +9,7 @@ import ( "github.com/alcionai/corso/cli/config" "github.com/alcionai/corso/cli/utils" + "github.com/alcionai/corso/pkg/account" "github.com/alcionai/corso/pkg/credentials" "github.com/alcionai/corso/pkg/logger" "github.com/alcionai/corso/pkg/repository" @@ -63,13 +64,7 @@ func initS3Cmd(cmd *cobra.Command, args []string) error { return nil } - overrides := map[string]string{ - credentials.AWSAccessKeyID: accessKey, - storage.Bucket: bucket, - storage.Endpoint: endpoint, - storage.Prefix: prefix, - } - s, a, err := config.GetStorageAndAccount(false, overrides) + s, a, err := config.GetStorageAndAccount(false, s3Overrides()) if err != nil { return err } @@ -123,13 +118,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { return nil } - overrides := map[string]string{ - credentials.AWSAccessKeyID: accessKey, - storage.Bucket: bucket, - storage.Endpoint: endpoint, - storage.Prefix: prefix, - } - s, a, err := config.GetStorageAndAccount(true, overrides) + s, a, err := config.GetStorageAndAccount(true, s3Overrides()) if err != nil { return err } @@ -163,3 +152,14 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { } return nil } + +func s3Overrides() map[string]string { + return map[string]string{ + config.AccountProviderTypeKey: account.ProviderM365.String(), + config.StorageProviderTypeKey: storage.ProviderS3.String(), + credentials.AWSAccessKeyID: accessKey, + storage.Bucket: bucket, + storage.Endpoint: endpoint, + storage.Prefix: prefix, + } +} diff --git a/src/internal/testing/integration_runners.go b/src/internal/testing/integration_runners.go index c7dc26006..bd29981cf 100644 --- a/src/internal/testing/integration_runners.go +++ b/src/internal/testing/integration_runners.go @@ -10,6 +10,7 @@ import ( const ( CorsoCITests = "CORSO_CI_TESTS" + CorsoCLIConfigTests = "CORSO_CLI_CONFIG_TESTS" CorsoGraphConnectorTests = "CORSO_GRAPH_CONNECTOR_TESTS" CorsoKopiaWrapperTests = "CORSO_KOPIA_WRAPPER_TESTS" CorsoRepositoryTests = "CORSO_REPOSITORY_TESTS"