Unit tests for config file helpers (#160)

Adds unit tests for the config file read/write helpers.

It also uncovered a bug in how viper handles writing a config file when a config file path is set
directly. This required a workaround in our init logic when we are using a custom config file name.

This commit does the following:
- Refactors the init logic into a InitConfig helper
- Adds a unit test to validate basic ReadRepoConfig behavior
- Adds a unit test that uses WriteReadConfig to write config and ReadRepoConfig to read it
This commit is contained in:
Vaibhav Kamra 2022-06-08 11:13:25 -07:00 committed by GitHub
parent 13ca33fae0
commit 5dcb2b7579
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 113 additions and 16 deletions

View File

@ -8,6 +8,7 @@ import (
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/alcionai/corso/cli/backup" "github.com/alcionai/corso/cli/backup"
"github.com/alcionai/corso/cli/config"
"github.com/alcionai/corso/cli/repo" "github.com/alcionai/corso/cli/repo"
) )
@ -31,20 +32,9 @@ func init() {
} }
func initConfig() { func initConfig() {
if cfgFile != "" { err := config.InitConfig(cfgFile)
// Use config file from the flag.
viper.SetConfigFile(cfgFile)
} else {
// Find home directory.
home, err := os.UserHomeDir()
cobra.CheckErr(err) cobra.CheckErr(err)
// Search config in home directory with name ".corso" (without extension).
viper.AddConfigPath(home)
viper.SetConfigType("toml")
viper.SetConfigName(".corso")
}
if err := viper.ReadInConfig(); err == nil { if err := viper.ReadInConfig(); err == nil {
fmt.Println("Using config file:", viper.ConfigFileUsed()) fmt.Println("Using config file:", viper.ConfigFileUsed())
} }

View File

@ -1,6 +1,10 @@
package config package config
import ( import (
"os"
"path"
"strings"
"github.com/alcionai/corso/pkg/repository" "github.com/alcionai/corso/pkg/repository"
"github.com/alcionai/corso/pkg/storage" "github.com/alcionai/corso/pkg/storage"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -18,6 +22,40 @@ const (
TenantIDKey = "tenantid" TenantIDKey = "tenantid"
) )
func InitConfig(configFilePath string) error {
// Configure default config file location
if configFilePath == "" {
// Find home directory.
home, err := os.UserHomeDir()
if err != nil {
return err
}
// Search config in home directory with name ".corso" (without extension).
viper.AddConfigPath(home)
viper.SetConfigType("toml")
viper.SetConfigName(".corso")
return nil
}
// Use a custom file location
viper.SetConfigFile(configFilePath)
// We also configure the path, type and filename
// because `viper.SafeWriteConfig` needs these set to
// work correctly (it does not use the configured file)
viper.AddConfigPath(path.Dir(configFilePath))
fileName := path.Base(configFilePath)
ext := path.Ext(configFilePath)
if len(ext) == 0 {
return errors.New("config file requires an extension e.g. `toml`")
}
fileName = strings.TrimSuffix(fileName, ext)
viper.SetConfigType(ext[1:])
viper.SetConfigName(fileName)
return nil
}
// WriteRepoConfig currently just persists corso config to the config file // WriteRepoConfig currently just persists corso config to the config file
// It does not check for conflicts or existing data. // It does not check for conflicts or existing data.
func WriteRepoConfig(s3Config storage.S3Config, account repository.Account) error { func WriteRepoConfig(s3Config storage.S3Config, account repository.Account) error {

View File

@ -0,0 +1,70 @@
package config_test
import (
"fmt"
"io/ioutil"
"path"
"testing"
"github.com/alcionai/corso/cli/config"
"github.com/alcionai/corso/pkg/repository"
"github.com/alcionai/corso/pkg/storage"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
const (
configFileTemplate = `
bucket = '%s'
endpoint = 's3.amazonaws.com'
prefix = 'test-prefix'
provider = 'S3'
tenantid = '%s'
`
)
type ConfigSuite struct {
suite.Suite
}
func TestConfigSuite(t *testing.T) {
suite.Run(t, new(ConfigSuite))
}
func (suite *ConfigSuite) TestReadRepoConfigBasic() {
// Generate test config file
b := "test-bucket"
tID := "6f34ac30-8196-469b-bf8f-d83deadbbbba"
testConfigData := fmt.Sprintf(configFileTemplate, b, tID)
testConfigFilePath := path.Join(suite.T().TempDir(), "corso.toml")
err := ioutil.WriteFile(testConfigFilePath, []byte(testConfigData), 0700)
assert.NoError(suite.T(), err)
// Configure viper to read test config file
viper.SetConfigFile(testConfigFilePath)
// Read and validate config
s3Cfg, account, err := config.ReadRepoConfig()
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), b, s3Cfg.Bucket)
assert.Equal(suite.T(), tID, account.TenantID)
}
func (suite *ConfigSuite) TestWriteReadConfig() {
// Configure viper to read test config file
tempDir := suite.T().TempDir()
testConfigFilePath := path.Join(tempDir, "corso.toml")
err := config.InitConfig(testConfigFilePath)
assert.NoError(suite.T(), err)
s3Cfg := storage.S3Config{Bucket: "bucket"}
account := repository.Account{TenantID: "6f34ac30-8196-469b-bf8f-d83deadbbbbd"}
err = config.WriteRepoConfig(s3Cfg, account)
assert.NoError(suite.T(), err)
readS3Cfg, readAccount, err := config.ReadRepoConfig()
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), s3Cfg, readS3Cfg)
assert.Equal(suite.T(), account, readAccount)
}

View File

@ -120,10 +120,9 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
// TODO: Merge/Validate any local configuration here to make sure there are no conflicts // TODO: Merge/Validate any local configuration here to make sure there are no conflicts
// For now - just reading/logging the local config here (a successful repo connect will overwrite) // For now - just reading/logging the local config here (a successful repo connect will overwrite)
localS3Cfg, localAccount, err := config.ReadRepoConfig() localS3Cfg, localAccount, err := config.ReadRepoConfig()
if err != nil { if err == nil {
return err
}
fmt.Printf("ConfigFile - %s\n\tbucket:\t%s\n\ttenantID:\t%s\n", viper.ConfigFileUsed(), localS3Cfg.Bucket, localAccount.TenantID) fmt.Printf("ConfigFile - %s\n\tbucket:\t%s\n\ttenantID:\t%s\n", viper.ConfigFileUsed(), localS3Cfg.Bucket, localAccount.TenantID)
}
a := repository.Account{ a := repository.Account{
TenantID: m365.TenantID, TenantID: m365.TenantID,