diff --git a/src/cli/cli.go b/src/cli/cli.go index d1f535e76..8f00c52a5 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/viper" "github.com/alcionai/corso/cli/backup" + "github.com/alcionai/corso/cli/config" "github.com/alcionai/corso/cli/repo" ) @@ -31,19 +32,8 @@ func init() { } func initConfig() { - if cfgFile != "" { - // Use config file from the flag. - viper.SetConfigFile(cfgFile) - } else { - // Find home directory. - home, err := os.UserHomeDir() - cobra.CheckErr(err) - - // Search config in home directory with name ".corso" (without extension). - viper.AddConfigPath(home) - viper.SetConfigType("toml") - viper.SetConfigName(".corso") - } + err := config.InitConfig(cfgFile) + cobra.CheckErr(err) if err := viper.ReadInConfig(); err == nil { fmt.Println("Using config file:", viper.ConfigFileUsed()) diff --git a/src/cli/config/config.go b/src/cli/config/config.go index 44bf1a38e..239344b74 100644 --- a/src/cli/config/config.go +++ b/src/cli/config/config.go @@ -1,6 +1,10 @@ package config import ( + "os" + "path" + "strings" + "github.com/alcionai/corso/pkg/repository" "github.com/alcionai/corso/pkg/storage" "github.com/pkg/errors" @@ -18,6 +22,40 @@ const ( TenantIDKey = "tenantid" ) +func InitConfig(configFilePath string) error { + // Configure default config file location + if configFilePath == "" { + // Find home directory. + home, err := os.UserHomeDir() + if err != nil { + return err + } + + // Search config in home directory with name ".corso" (without extension). + viper.AddConfigPath(home) + viper.SetConfigType("toml") + viper.SetConfigName(".corso") + return nil + } + // Use a custom file location + + viper.SetConfigFile(configFilePath) + // We also configure the path, type and filename + // because `viper.SafeWriteConfig` needs these set to + // work correctly (it does not use the configured file) + viper.AddConfigPath(path.Dir(configFilePath)) + + fileName := path.Base(configFilePath) + ext := path.Ext(configFilePath) + if len(ext) == 0 { + return errors.New("config file requires an extension e.g. `toml`") + } + fileName = strings.TrimSuffix(fileName, ext) + viper.SetConfigType(ext[1:]) + viper.SetConfigName(fileName) + return nil +} + // WriteRepoConfig currently just persists corso config to the config file // It does not check for conflicts or existing data. func WriteRepoConfig(s3Config storage.S3Config, account repository.Account) error { diff --git a/src/cli/config/config_test.go b/src/cli/config/config_test.go new file mode 100644 index 000000000..bdeb39950 --- /dev/null +++ b/src/cli/config/config_test.go @@ -0,0 +1,70 @@ +package config_test + +import ( + "fmt" + "io/ioutil" + "path" + "testing" + + "github.com/alcionai/corso/cli/config" + "github.com/alcionai/corso/pkg/repository" + "github.com/alcionai/corso/pkg/storage" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +const ( + configFileTemplate = ` +bucket = '%s' +endpoint = 's3.amazonaws.com' +prefix = 'test-prefix' +provider = 'S3' +tenantid = '%s' +` +) + +type ConfigSuite struct { + suite.Suite +} + +func TestConfigSuite(t *testing.T) { + suite.Run(t, new(ConfigSuite)) +} + +func (suite *ConfigSuite) TestReadRepoConfigBasic() { + // Generate test config file + b := "test-bucket" + tID := "6f34ac30-8196-469b-bf8f-d83deadbbbba" + testConfigData := fmt.Sprintf(configFileTemplate, b, tID) + testConfigFilePath := path.Join(suite.T().TempDir(), "corso.toml") + err := ioutil.WriteFile(testConfigFilePath, []byte(testConfigData), 0700) + assert.NoError(suite.T(), err) + + // Configure viper to read test config file + viper.SetConfigFile(testConfigFilePath) + + // Read and validate config + s3Cfg, account, err := config.ReadRepoConfig() + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), b, s3Cfg.Bucket) + assert.Equal(suite.T(), tID, account.TenantID) +} + +func (suite *ConfigSuite) TestWriteReadConfig() { + // Configure viper to read test config file + tempDir := suite.T().TempDir() + testConfigFilePath := path.Join(tempDir, "corso.toml") + err := config.InitConfig(testConfigFilePath) + assert.NoError(suite.T(), err) + + s3Cfg := storage.S3Config{Bucket: "bucket"} + account := repository.Account{TenantID: "6f34ac30-8196-469b-bf8f-d83deadbbbbd"} + err = config.WriteRepoConfig(s3Cfg, account) + assert.NoError(suite.T(), err) + + readS3Cfg, readAccount, err := config.ReadRepoConfig() + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), s3Cfg, readS3Cfg) + assert.Equal(suite.T(), account, readAccount) +} diff --git a/src/cli/repo/s3.go b/src/cli/repo/s3.go index b627f65be..251ac9b13 100644 --- a/src/cli/repo/s3.go +++ b/src/cli/repo/s3.go @@ -120,10 +120,9 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { // TODO: Merge/Validate any local configuration here to make sure there are no conflicts // For now - just reading/logging the local config here (a successful repo connect will overwrite) localS3Cfg, localAccount, err := config.ReadRepoConfig() - if err != nil { - return err + if err == nil { + fmt.Printf("ConfigFile - %s\n\tbucket:\t%s\n\ttenantID:\t%s\n", viper.ConfigFileUsed(), localS3Cfg.Bucket, localAccount.TenantID) } - fmt.Printf("ConfigFile - %s\n\tbucket:\t%s\n\ttenantID:\t%s\n", viper.ConfigFileUsed(), localS3Cfg.Bucket, localAccount.TenantID) a := repository.Account{ TenantID: m365.TenantID,