diff --git a/src/cli/config/config_test.go b/src/cli/config/config_test.go index 444f6108d..a49eb95ae 100644 --- a/src/cli/config/config_test.go +++ b/src/cli/config/config_test.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "strconv" "testing" "github.com/alcionai/clues" @@ -12,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/cli/flags" + "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/credentials" @@ -26,12 +29,14 @@ const ( ` + StorageProviderTypeKey + ` = 'S3' ` + AccountProviderTypeKey + ` = 'M365' ` + AzureTenantIDKey + ` = '%s' -` + DisableTLSKey + ` = 'false' -` + DisableTLSVerificationKey + ` = 'false' ` + AccessKey + ` = '%s' ` + SecretAccessKey + ` = '%s' ` + SessionToken + ` = '%s' ` + CorsoPassphrase + ` = '%s' +` + AzureClientID + ` = '%s' +` + AzureSecret + ` = '%s' +` + DisableTLSKey + ` = '%s' +` + DisableTLSVerificationKey + ` = '%s' ` ) @@ -71,16 +76,23 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() { ) const ( - b = "read-repo-config-basic-bucket" - tID = "6f34ac30-8196-469b-bf8f-d83deadbbbba" - accKey = "aws-test-access-key" - secret = "aws-test-secret-key" - token = "aws-test-session-token" - passphrase = "passphrase-test" + b = "read-repo-config-basic-bucket" + tID = "6f34ac30-8196-469b-bf8f-d83deadbbbba" + accKey = "aws-test-access-key" + secret = "aws-test-secret-key" + token = "aws-test-session-token" + passphrase = "passphrase-test" + azureClientID = "azure-client-id-test" + azureSecret = "azure-secret-test" + endpoint = "s3-test" + disableTLS = "true" + disableTLSVerification = "true" ) // Generate test config file - testConfigData := fmt.Sprintf(configFileTemplate, b, tID, accKey, secret, token, passphrase) + testConfigData := fmt.Sprintf(configFileTemplate, b, tID, accKey, secret, + token, passphrase, azureClientID, azureSecret, + disableTLS, disableTLSVerification) testConfigFilePath := filepath.Join(t.TempDir(), "corso.toml") err := os.WriteFile(testConfigFilePath, []byte(testConfigData), 0o700) require.NoError(t, err, clues.ToCore(err)) @@ -95,6 +107,9 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() { s3Cfg, err := s3ConfigsFromViper(vpr) require.NoError(t, err, clues.ToCore(err)) assert.Equal(t, b, s3Cfg.Bucket) + assert.Equal(t, "test-prefix/", s3Cfg.Prefix) + assert.Equal(t, disableTLS, strconv.FormatBool(s3Cfg.DoNotUseTLS)) + assert.Equal(t, disableTLSVerification, strconv.FormatBool(s3Cfg.DoNotVerifyTLS)) s3Cfg, err = s3CredsFromViper(vpr, s3Cfg) require.NoError(t, err, clues.ToCore(err)) @@ -104,6 +119,8 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() { m365, err := m365ConfigsFromViper(vpr) require.NoError(t, err, clues.ToCore(err)) + assert.Equal(t, azureClientID, m365.AzureClientID) + assert.Equal(t, azureSecret, m365.AzureClientSecret) assert.Equal(t, tID, m365.AzureTenantID) } @@ -223,6 +240,106 @@ func (suite *ConfigSuite) TestMustMatchConfig() { } } +func (suite *ConfigSuite) TestReadFromFlags() { + var ( + t = suite.T() + vpr = viper.New() + ) + + const ( + b = "read-repo-config-basic-bucket" + tID = "6f34ac30-8196-469b-bf8f-d83deadbbbba" + accKey = "aws-test-access-key" + secret = "aws-test-secret-key" + token = "aws-test-session-token" + passphrase = "passphrase-test" + azureClientID = "azure-client-id-test" + azureSecret = "azure-secret-test" + prefix = "prefix-test" + disableTLS = "true" + disableTLSVerification = "true" + ) + + t.Cleanup(func() { + // reset values + flags.AzureClientTenantFV = "" + flags.AzureClientIDFV = "" + flags.AzureClientSecretFV = "" + + flags.AWSAccessKeyFV = "" + flags.AWSSecretAccessKeyFV = "" + flags.AWSSessionTokenFV = "" + + flags.CorsoPassphraseFV = "" + }) + + // Generate test config file + testConfigData := fmt.Sprintf(configFileTemplate, b, tID, accKey, secret, token, + passphrase, azureClientID, azureSecret, + disableTLS, disableTLSVerification) + + testConfigFilePath := filepath.Join(t.TempDir(), "corso.toml") + err := os.WriteFile(testConfigFilePath, []byte(testConfigData), 0o700) + require.NoError(t, err, clues.ToCore(err)) + + // Configure viper to read test config file + vpr.SetConfigFile(testConfigFilePath) + + // Read and validate config + err = vpr.ReadInConfig() + require.NoError(t, err, "reading repo config", clues.ToCore(err)) + + overrides := map[string]string{} + flags.AzureClientTenantFV = "6f34ac30-8196-469b-bf8f-d83deadbbbba" + flags.AzureClientIDFV = "azure-id-flag-value" + flags.AzureClientSecretFV = "azure-secret-flag-value" + + flags.AWSAccessKeyFV = "aws-access-key" + flags.AWSSecretAccessKeyFV = "aws-access-secret-flag-value" + flags.AWSSessionTokenFV = "aws-access-session-flag-value" + + overrides[storage.Bucket] = "flag-bucket" + overrides[storage.Endpoint] = "flag-endpoint" + overrides[storage.Prefix] = "flag-prefix" + overrides[storage.DoNotUseTLS] = "true" + overrides[storage.DoNotVerifyTLS] = "true" + overrides[credentials.AWSAccessKeyID] = flags.AWSAccessKeyFV + overrides[credentials.AWSSecretAccessKey] = flags.AWSSecretAccessKeyFV + overrides[credentials.AWSSessionToken] = flags.AWSSessionTokenFV + + flags.CorsoPassphraseFV = "passphrase-flags" + + repoDetails, err := getStorageAndAccountWithViper( + vpr, + true, + false, + overrides, + ) + + m365Config, _ := repoDetails.Account.M365Config() + s3Cfg, _ := repoDetails.Storage.S3Config() + commonConfig, _ := repoDetails.Storage.CommonConfig() + pass := commonConfig.Corso.CorsoPassphrase + + require.NoError(t, err, "reading repo config", clues.ToCore(err)) + + assert.Equal(t, flags.AWSAccessKeyFV, s3Cfg.AWS.AccessKey) + assert.Equal(t, flags.AWSSecretAccessKeyFV, s3Cfg.AWS.SecretKey) + assert.Equal(t, flags.AWSSessionTokenFV, s3Cfg.AWS.SessionToken) + + assert.Equal(t, overrides[storage.Bucket], s3Cfg.Bucket) + assert.Equal(t, overrides[storage.Endpoint], s3Cfg.Endpoint) + assert.Equal(t, overrides[storage.Prefix], s3Cfg.Prefix) + assert.Equal(t, str.ParseBool(overrides[storage.DoNotUseTLS]), s3Cfg.DoNotUseTLS) + assert.Equal(t, str.ParseBool(overrides[storage.DoNotVerifyTLS]), s3Cfg.DoNotVerifyTLS) + + assert.Equal(t, flags.AzureClientIDFV, m365Config.AzureClientID) + assert.Equal(t, flags.AzureClientSecretFV, m365Config.AzureClientSecret) + assert.Equal(t, flags.AzureClientTenantFV, m365Config.AzureTenantID) + + assert.Equal(t, flags.CorsoPassphraseFV, pass) +} + // ------------------------------------------------------------ // integration tests // ------------------------------------------------------------