refactor cli/config for better local testing (#245)
This commit is contained in:
parent
689c5cc1e9
commit
60eb8eec08
@ -16,12 +16,13 @@ import (
|
||||
|
||||
const (
|
||||
// S3 config
|
||||
ProviderTypeKey = "provider"
|
||||
StorageProviderTypeKey = "provider"
|
||||
BucketNameKey = "bucket"
|
||||
EndpointKey = "endpoint"
|
||||
PrefixKey = "prefix"
|
||||
|
||||
// M365 config
|
||||
AccountProviderTypeKey = "account_provider"
|
||||
TenantIDKey = "tenantid"
|
||||
)
|
||||
|
||||
@ -64,7 +65,7 @@ func InitConfig(configFilePath string) error {
|
||||
func WriteRepoConfig(s3Config storage.S3Config, account account.M365Config) error {
|
||||
// Rudimentary support for persisting repo config
|
||||
// TODO: Handle conflicts, support other config types
|
||||
viper.Set(ProviderTypeKey, storage.ProviderS3.String())
|
||||
viper.Set(StorageProviderTypeKey, storage.ProviderS3.String())
|
||||
viper.Set(BucketNameKey, s3Config.Bucket)
|
||||
viper.Set(EndpointKey, s3Config.Endpoint)
|
||||
viper.Set(PrefixKey, s3Config.Prefix)
|
||||
@ -79,40 +80,42 @@ func WriteRepoConfig(s3Config storage.S3Config, account account.M365Config) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReadRepoConfig() (storage.S3Config, account.Account, error) {
|
||||
var (
|
||||
s3Config storage.S3Config
|
||||
acct account.Account
|
||||
err error
|
||||
)
|
||||
func readRepoConfig() error {
|
||||
var err error
|
||||
|
||||
if err = viper.ReadInConfig(); err != nil {
|
||||
return s3Config, acct, errors.Wrap(err, "reading config file: "+viper.ConfigFileUsed())
|
||||
return errors.Wrap(err, "reading config file: "+viper.ConfigFileUsed())
|
||||
}
|
||||
|
||||
if providerType := viper.GetString(ProviderTypeKey); providerType != storage.ProviderS3.String() {
|
||||
return s3Config, acct, errors.New("Unsupported storage provider: " + providerType)
|
||||
return err
|
||||
}
|
||||
|
||||
// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values.
|
||||
func s3ConfigsFromViper() (storage.S3Config, error) {
|
||||
var s3Config storage.S3Config
|
||||
|
||||
providerType := viper.GetString(StorageProviderTypeKey)
|
||||
if providerType != storage.ProviderS3.String() {
|
||||
return s3Config, errors.New("unsupported storage provider: " + providerType)
|
||||
}
|
||||
|
||||
s3Config.Bucket = viper.GetString(BucketNameKey)
|
||||
s3Config.Endpoint = viper.GetString(EndpointKey)
|
||||
s3Config.Prefix = viper.GetString(PrefixKey)
|
||||
|
||||
m365Creds := credentials.GetM365()
|
||||
tenantID := os.Getenv(account.TenantID)
|
||||
cfgTenantID := viper.GetString(TenantIDKey)
|
||||
if len(cfgTenantID) > 0 {
|
||||
tenantID = cfgTenantID
|
||||
return s3Config, nil
|
||||
}
|
||||
acct, err = account.NewAccount(
|
||||
account.ProviderM365,
|
||||
account.M365Config{
|
||||
M365: m365Creds,
|
||||
TenantID: tenantID,
|
||||
},
|
||||
)
|
||||
|
||||
return s3Config, acct, err
|
||||
// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values.
|
||||
func m365ConfigsFromViper() (account.M365Config, error) {
|
||||
var m365 account.M365Config
|
||||
|
||||
providerType := viper.GetString(AccountProviderTypeKey)
|
||||
if providerType != account.ProviderM365.String() {
|
||||
return m365, errors.New("unsupported account provider: " + providerType)
|
||||
}
|
||||
|
||||
m365.TenantID = first(viper.GetString(TenantIDKey), os.Getenv(account.TenantID))
|
||||
return m365, nil
|
||||
}
|
||||
|
||||
// GetStorageAndAccount creates a storage and account instance by mediating all the possible
|
||||
@ -120,16 +123,37 @@ func ReadRepoConfig() (storage.S3Config, account.Account, error) {
|
||||
func GetStorageAndAccount(readFromFile bool, overrides map[string]string) (storage.Storage, account.Account, error) {
|
||||
var (
|
||||
s3Cfg storage.S3Config
|
||||
store storage.Storage
|
||||
m365Cfg account.M365Config
|
||||
acct account.Account
|
||||
err error
|
||||
)
|
||||
|
||||
// possibly read the prior config from a .corso file
|
||||
if readFromFile {
|
||||
s3Cfg, acct, err = ReadRepoConfig()
|
||||
if err != nil {
|
||||
return storage.Storage{}, acct, errors.Wrap(err, "reading corso config file")
|
||||
if err = readRepoConfig(); err != nil {
|
||||
return store, acct, errors.Wrap(err, "reading corso config file")
|
||||
}
|
||||
|
||||
s3Cfg, err = s3ConfigsFromViper()
|
||||
if err != nil {
|
||||
return store, acct, errors.Wrap(err, "reading s3 configs from corso config file")
|
||||
}
|
||||
|
||||
m365Cfg, err = m365ConfigsFromViper()
|
||||
if err != nil {
|
||||
return store, acct, errors.Wrap(err, "reading m365 configs from corso config file")
|
||||
}
|
||||
}
|
||||
|
||||
// compose the m365 account config and credentials
|
||||
m365Cfg = account.M365Config{
|
||||
M365: credentials.GetM365(),
|
||||
TenantID: first(overrides[account.TenantID], m365Cfg.TenantID, os.Getenv(account.TenantID)),
|
||||
}
|
||||
acct, err = account.NewAccount(account.ProviderM365, m365Cfg)
|
||||
if err != nil {
|
||||
return store, acct, errors.Wrap(err, "retrieving m365 account configuration")
|
||||
}
|
||||
|
||||
// compose the s3 storage config and credentials
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
package config_test
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -11,8 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/alcionai/corso/cli/config"
|
||||
ctesting "github.com/alcionai/corso/internal/testing"
|
||||
"github.com/alcionai/corso/pkg/account"
|
||||
"github.com/alcionai/corso/pkg/storage"
|
||||
)
|
||||
|
||||
@ -22,6 +21,7 @@ bucket = '%s'
|
||||
endpoint = 's3.amazonaws.com'
|
||||
prefix = 'test-prefix'
|
||||
provider = 'S3'
|
||||
account_provider = 'M365'
|
||||
tenantid = '%s'
|
||||
`
|
||||
)
|
||||
@ -47,11 +47,14 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() {
|
||||
viper.SetConfigFile(testConfigFilePath)
|
||||
|
||||
// Read and validate config
|
||||
s3Cfg, account, err := config.ReadRepoConfig()
|
||||
assert.NoError(suite.T(), err)
|
||||
err = readRepoConfig()
|
||||
require.NoError(suite.T(), err)
|
||||
|
||||
s3Cfg, err := s3ConfigsFromViper()
|
||||
require.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), b, s3Cfg.Bucket)
|
||||
|
||||
m365, err := account.M365Config()
|
||||
m365, err := m365ConfigsFromViper()
|
||||
require.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), tID, m365.TenantID)
|
||||
}
|
||||
@ -60,20 +63,23 @@ func (suite *ConfigSuite) TestWriteReadConfig() {
|
||||
// Configure viper to read test config file
|
||||
tempDir := suite.T().TempDir()
|
||||
testConfigFilePath := path.Join(tempDir, "corso.toml")
|
||||
err := config.InitConfig(testConfigFilePath)
|
||||
err := InitConfig(testConfigFilePath)
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
s3Cfg := storage.S3Config{Bucket: "write-read-config-bucket"}
|
||||
acct, err := ctesting.NewM365Account()
|
||||
require.NoError(suite.T(), err)
|
||||
m365, err := acct.M365Config()
|
||||
m365 := account.M365Config{TenantID: "3c0748d2-470e-444c-9064-1268e52609d5"}
|
||||
|
||||
err = WriteRepoConfig(s3Cfg, m365)
|
||||
require.NoError(suite.T(), err)
|
||||
|
||||
err = config.WriteRepoConfig(s3Cfg, m365)
|
||||
assert.NoError(suite.T(), err)
|
||||
err = readRepoConfig()
|
||||
require.NoError(suite.T(), err)
|
||||
|
||||
readS3Cfg, readAccount, err := config.ReadRepoConfig()
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), s3Cfg, readS3Cfg)
|
||||
assert.Equal(suite.T(), acct, readAccount)
|
||||
readS3Cfg, err := s3ConfigsFromViper()
|
||||
require.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), readS3Cfg.Bucket, s3Cfg.Bucket)
|
||||
|
||||
readM365, err := m365ConfigsFromViper()
|
||||
require.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), readM365.TenantID, m365.TenantID)
|
||||
}
|
||||
|
||||
24
src/pkg/account/accountprovider_string.go
Normal file
24
src/pkg/account/accountprovider_string.go
Normal file
@ -0,0 +1,24 @@
|
||||
// Code generated by "stringer -type=accountProvider -linecomment"; DO NOT EDIT.
|
||||
|
||||
package account
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[ProviderUnknown-0]
|
||||
_ = x[ProviderM365-1]
|
||||
}
|
||||
|
||||
const _accountProvider_name = "Unknown ProviderM365"
|
||||
|
||||
var _accountProvider_index = [...]uint8{0, 16, 20}
|
||||
|
||||
func (i accountProvider) String() string {
|
||||
if i < 0 || i >= accountProvider(len(_accountProvider_index)-1) {
|
||||
return "accountProvider(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _accountProvider_name[_accountProvider_index[i]:_accountProvider_index[i+1]]
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user