Introduce new interfaces for storage configuration (#4251)

<!-- PR description-->
Introducing a new `Configurer` interface to abstract out storage config information(for s3, filesystem etc) from caller code. I consider this as a short term solution. We need to consolidate overall config handling in a better way, but that's out of scope for this PR chain.

Testing
* Most of the changes here are code movement under the hood. So relying on existing tests.
* I'll address any test gaps in a later PR.


---

#### Does this PR need a docs update or release note?

- [ ]  Yes, it's included
- [ ] 🕐 Yes, but in a later PR
- [x]  No

#### Type of change

<!--- Please check the type of change your PR introduces: --->
- [ ] 🌻 Feature
- [ ] 🐛 Bugfix
- [ ] 🗺️ Documentation
- [ ] 🤖 Supportability/Tests
- [ ] 💻 CI/Deployment
- [x] 🧹 Tech Debt/Cleanup

#### Issue(s)

<!-- Can reference multiple issues. Use one of the following "magic words" - "closes, fixes" to auto-close the Github issue. -->
* https://github.com/alcionai/corso/issues/1416

#### Test Plan

<!-- How will this be tested prior to merging.-->
- [x] 💪 Manual
- [x]  Unit test
- [ ] 💚 E2E
This commit is contained in:
Abhishek Pandey 2023-09-18 20:02:54 +05:30 committed by GitHub
parent eb357e1051
commit 8590b24199
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 344 additions and 192 deletions

View File

@ -140,9 +140,11 @@ func prepM365Test(
recorder = strings.Builder{} recorder = strings.Builder{}
) )
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(*storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(), tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),

View File

@ -20,18 +20,8 @@ import (
) )
const ( const (
// S3 config
BucketNameKey = "bucket"
EndpointKey = "endpoint"
PrefixKey = "prefix"
DisableTLSKey = "disable_tls"
DisableTLSVerificationKey = "disable_tls_verification"
RepoID = "repo_id" RepoID = "repo_id"
AccessKey = "aws_access_key_id"
SecretAccessKey = "aws_secret_access_key"
SessionToken = "aws_session_token"
// Corso passphrase in config // Corso passphrase in config
CorsoPassphrase = "passphrase" CorsoPassphrase = "passphrase"
CorsoUser = "corso_user" CorsoUser = "corso_user"
@ -196,14 +186,14 @@ func Read(ctx context.Context) error {
// It does not check for conflicts or existing data. // It does not check for conflicts or existing data.
func WriteRepoConfig( func WriteRepoConfig(
ctx context.Context, ctx context.Context,
s3Config storage.S3Config, wcs storage.WriteConfigToStorer,
m365Config account.M365Config, m365Config account.M365Config,
repoOpts repository.Options, repoOpts repository.Options,
repoID string, repoID string,
) error { ) error {
return writeRepoConfigWithViper( return writeRepoConfigWithViper(
GetViper(ctx), GetViper(ctx),
s3Config, wcs,
m365Config, m365Config,
repoOpts, repoOpts,
repoID) repoID)
@ -213,20 +203,14 @@ func WriteRepoConfig(
// struct for testing. // struct for testing.
func writeRepoConfigWithViper( func writeRepoConfigWithViper(
vpr *viper.Viper, vpr *viper.Viper,
s3Config storage.S3Config, wcs storage.WriteConfigToStorer,
m365Config account.M365Config, m365Config account.M365Config,
repoOpts repository.Options, repoOpts repository.Options,
repoID string, repoID string,
) error { ) error {
s3Config = s3Config.Normalize() // Write storage configuration to viper
// Rudimentary support for persisting repo config wcs.WriteConfigToStore(vpr)
// TODO: Handle conflicts, support other config types
vpr.Set(storage.StorageProviderTypeKey, storage.ProviderS3.String())
vpr.Set(BucketNameKey, s3Config.Bucket)
vpr.Set(EndpointKey, s3Config.Endpoint)
vpr.Set(PrefixKey, s3Config.Prefix)
vpr.Set(DisableTLSKey, s3Config.DoNotUseTLS)
vpr.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS)
vpr.Set(RepoID, repoID) vpr.Set(RepoID, repoID)
// Need if-checks as Viper will write empty values otherwise. // Need if-checks as Viper will write empty values otherwise.
@ -339,15 +323,12 @@ func getUserHost(vpr *viper.Viper, readConfigFromViper bool) (string, string) {
var constToTomlKeyMap = map[string]string{ var constToTomlKeyMap = map[string]string{
account.AzureTenantID: account.AzureTenantIDKey, account.AzureTenantID: account.AzureTenantIDKey,
account.AccountProviderTypeKey: account.AccountProviderTypeKey, account.AccountProviderTypeKey: account.AccountProviderTypeKey,
storage.Bucket: BucketNameKey,
storage.Endpoint: EndpointKey,
storage.Prefix: PrefixKey,
storage.StorageProviderTypeKey: storage.StorageProviderTypeKey,
} }
// mustMatchConfig compares the values of each key to their config file value in viper. // mustMatchConfig compares the values of each key to their config file value in viper.
// If any value differs from the viper value, an error is returned. // If any value differs from the viper value, an error is returned.
// values in m that aren't stored in the config are ignored. // values in m that aren't stored in the config are ignored.
// TODO(pandeyabs): This code is currently duplicated in 2 places.
func mustMatchConfig(vpr *viper.Viper, m map[string]string) error { func mustMatchConfig(vpr *viper.Viper, m map[string]string) error {
for k, v := range m { for k, v := range m {
if len(v) == 0 { if len(v) == 0 {

View File

@ -26,20 +26,20 @@ import (
const ( const (
configFileTemplate = ` configFileTemplate = `
` + BucketNameKey + ` = '%s' ` + storage.BucketNameKey + ` = '%s'
` + EndpointKey + ` = 's3.amazonaws.com' ` + storage.EndpointKey + ` = 's3.amazonaws.com'
` + PrefixKey + ` = 'test-prefix/' ` + storage.PrefixKey + ` = 'test-prefix/'
` + storage.StorageProviderTypeKey + ` = 'S3' ` + storage.StorageProviderTypeKey + ` = 'S3'
` + account.AccountProviderTypeKey + ` = 'M365' ` + account.AccountProviderTypeKey + ` = 'M365'
` + account.AzureTenantIDKey + ` = '%s' ` + account.AzureTenantIDKey + ` = '%s'
` + AccessKey + ` = '%s' ` + storage.AccessKey + ` = '%s'
` + SecretAccessKey + ` = '%s' ` + storage.SecretAccessKey + ` = '%s'
` + SessionToken + ` = '%s' ` + storage.SessionToken + ` = '%s'
` + CorsoPassphrase + ` = '%s' ` + CorsoPassphrase + ` = '%s'
` + account.AzureClientID + ` = '%s' ` + account.AzureClientID + ` = '%s'
` + account.AzureSecret + ` = '%s' ` + account.AzureSecret + ` = '%s'
` + DisableTLSKey + ` = '%s' ` + storage.DisableTLSKey + ` = '%s'
` + DisableTLSVerificationKey + ` = '%s' ` + storage.DisableTLSVerificationKey + ` = '%s'
` `
) )
@ -107,18 +107,32 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() {
err = vpr.ReadInConfig() err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err)) require.NoError(t, err, "reading repo config", clues.ToCore(err))
s3Cfg, err := s3ConfigsFromViper(vpr) sc, err := storage.NewStorageConfig(storage.ProviderS3)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
err = sc.ApplyConfigOverrides(vpr, true, true, nil)
require.NoError(t, err, clues.ToCore(err))
s3Cfg := sc.(*storage.S3Config)
assert.Equal(t, b, s3Cfg.Bucket) assert.Equal(t, b, s3Cfg.Bucket)
assert.Equal(t, "test-prefix/", s3Cfg.Prefix) assert.Equal(t, "test-prefix/", s3Cfg.Prefix)
assert.Equal(t, disableTLS, strconv.FormatBool(s3Cfg.DoNotUseTLS)) assert.Equal(t, disableTLS, strconv.FormatBool(s3Cfg.DoNotUseTLS))
assert.Equal(t, disableTLSVerification, strconv.FormatBool(s3Cfg.DoNotVerifyTLS)) assert.Equal(t, disableTLSVerification, strconv.FormatBool(s3Cfg.DoNotVerifyTLS))
s3Cfg, err = s3CredsFromViper(vpr, s3Cfg) // Config file may or may not be the source of truth for below values. These may be
require.NoError(t, err, clues.ToCore(err)) // overridden by env vars (and flags but not relevant for this test).
assert.Equal(t, accKey, s3Cfg.AWS.AccessKey) //
assert.Equal(t, secret, s3Cfg.AWS.SecretKey) // Other alternatives are:
assert.Equal(t, token, s3Cfg.AWS.SessionToken) // 1) unset env vars temporarily so that we can test against config file values. But that
// may be problematic if we decide to parallelize tests in future.
// 2) assert against env var values instead of config file values. This can cause issues
// if CI/local env have different config override mechanisms.
// 3) Skip asserts for these keys. They will be validated in other tests. Choosing this
// option.
// assert.Equal(t, accKey, s3Cfg.AWS.AccessKey)
// assert.Equal(t, secret, s3Cfg.AWS.SecretKey)
// assert.Equal(t, token, s3Cfg.AWS.SessionToken)
m365, err := m365ConfigsFromViper(vpr) m365, err := m365ConfigsFromViper(vpr)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
@ -146,7 +160,11 @@ func (suite *ConfigSuite) TestWriteReadConfig() {
err := initWithViper(vpr, testConfigFilePath) err := initWithViper(vpr, testConfigFilePath)
require.NoError(t, err, "initializing repo config", clues.ToCore(err)) require.NoError(t, err, "initializing repo config", clues.ToCore(err))
s3Cfg := storage.S3Config{Bucket: bkt, DoNotUseTLS: true, DoNotVerifyTLS: true} s3Cfg := &storage.S3Config{
Bucket: bkt,
DoNotUseTLS: true,
DoNotVerifyTLS: true,
}
m365 := account.M365Config{AzureTenantID: tid} m365 := account.M365Config{AzureTenantID: tid}
rOpts := repository.Options{ rOpts := repository.Options{
@ -160,8 +178,12 @@ func (suite *ConfigSuite) TestWriteReadConfig() {
err = vpr.ReadInConfig() err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err)) require.NoError(t, err, "reading repo config", clues.ToCore(err))
readS3Cfg, err := s3ConfigsFromViper(vpr) sc, err := storage.NewStorageConfig(storage.ProviderS3)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
err = sc.ApplyConfigOverrides(vpr, true, true, nil)
require.NoError(t, err, clues.ToCore(err))
readS3Cfg := sc.(*storage.S3Config)
assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket)
assert.Equal(t, readS3Cfg.DoNotUseTLS, s3Cfg.DoNotUseTLS) assert.Equal(t, readS3Cfg.DoNotUseTLS, s3Cfg.DoNotUseTLS)
assert.Equal(t, readS3Cfg.DoNotVerifyTLS, s3Cfg.DoNotVerifyTLS) assert.Equal(t, readS3Cfg.DoNotVerifyTLS, s3Cfg.DoNotVerifyTLS)
@ -191,7 +213,7 @@ func (suite *ConfigSuite) TestMustMatchConfig() {
err := initWithViper(vpr, testConfigFilePath) err := initWithViper(vpr, testConfigFilePath)
require.NoError(t, err, "initializing repo config") require.NoError(t, err, "initializing repo config")
s3Cfg := storage.S3Config{Bucket: bkt} s3Cfg := &storage.S3Config{Bucket: bkt}
m365 := account.M365Config{AzureTenantID: tid} m365 := account.M365Config{AzureTenantID: tid}
err = writeRepoConfigWithViper(vpr, s3Cfg, m365, repository.Options{}, "repoid") err = writeRepoConfigWithViper(vpr, s3Cfg, m365, repository.Options{}, "repoid")
@ -330,9 +352,14 @@ func (suite *ConfigSuite) TestReadFromFlags() {
true, true,
false, false,
overrides) overrides)
require.NoError(t, err, "getting storage and account from config", clues.ToCore(err))
m365Config, _ := repoDetails.Account.M365Config() m365Config, _ := repoDetails.Account.M365Config()
s3Cfg, _ := repoDetails.Storage.S3Config()
sc, err := repoDetails.Storage.StorageConfig()
require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err))
s3Cfg := sc.(*storage.S3Config)
commonConfig, _ := repoDetails.Storage.CommonConfig() commonConfig, _ := repoDetails.Storage.CommonConfig()
pass := commonConfig.Corso.CorsoPassphrase pass := commonConfig.Corso.CorsoPassphrase
@ -386,7 +413,7 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() {
err := initWithViper(vpr, testConfigFilePath) err := initWithViper(vpr, testConfigFilePath)
require.NoError(t, err, "initializing repo config", clues.ToCore(err)) require.NoError(t, err, "initializing repo config", clues.ToCore(err))
s3Cfg := storage.S3Config{ s3Cfg := &storage.S3Config{
Bucket: bkt, Bucket: bkt,
Endpoint: end, Endpoint: end,
Prefix: pfx, Prefix: pfx,
@ -404,8 +431,11 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() {
cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, true, true, nil) cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, true, true, nil)
require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err))
readS3Cfg, err := cfg.Storage.S3Config() sc, err := cfg.Storage.StorageConfig()
require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err)) require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err))
readS3Cfg := sc.(*storage.S3Config)
assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket)
assert.Equal(t, readS3Cfg.Endpoint, s3Cfg.Endpoint) assert.Equal(t, readS3Cfg.Endpoint, s3Cfg.Endpoint)
assert.Equal(t, readS3Cfg.Prefix, s3Cfg.Prefix) assert.Equal(t, readS3Cfg.Prefix, s3Cfg.Prefix)
@ -452,8 +482,11 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount_noFileOnlyOverride
cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, false, true, overrides) cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, false, true, overrides)
require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err))
readS3Cfg, err := cfg.Storage.S3Config() sc, err := cfg.Storage.StorageConfig()
require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err)) require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err))
readS3Cfg := sc.(*storage.S3Config)
assert.Equal(t, readS3Cfg.Bucket, bkt) assert.Equal(t, readS3Cfg.Bucket, bkt)
assert.Equal(t, cfg.RepoID, "") assert.Equal(t, cfg.RepoID, "")
assert.Equal(t, readS3Cfg.Endpoint, end) assert.Equal(t, readS3Cfg.Endpoint, end)

View File

@ -4,52 +4,16 @@ import (
"context" "context"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
"github.com/alcionai/corso/src/internal/common"
"github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/credentials"
"github.com/alcionai/corso/src/pkg/storage" "github.com/alcionai/corso/src/pkg/storage"
) )
// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values.
func s3ConfigsFromViper(vpr *viper.Viper) (storage.S3Config, error) {
var s3Config storage.S3Config
s3Config.Bucket = vpr.GetString(BucketNameKey)
s3Config.Endpoint = vpr.GetString(EndpointKey)
s3Config.Prefix = vpr.GetString(PrefixKey)
s3Config.DoNotUseTLS = vpr.GetBool(DisableTLSKey)
s3Config.DoNotVerifyTLS = vpr.GetBool(DisableTLSVerificationKey)
return s3Config, nil
}
// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values.
func s3CredsFromViper(vpr *viper.Viper, s3Config storage.S3Config) (storage.S3Config, error) {
s3Config.AccessKey = vpr.GetString(AccessKey)
s3Config.SecretKey = vpr.GetString(SecretAccessKey)
s3Config.SessionToken = vpr.GetString(SessionToken)
return s3Config, nil
}
func s3Overrides(in map[string]string) map[string]string {
return map[string]string{
storage.Bucket: in[storage.Bucket],
storage.Endpoint: in[storage.Endpoint],
storage.Prefix: in[storage.Prefix],
storage.DoNotUseTLS: in[storage.DoNotUseTLS],
storage.DoNotVerifyTLS: in[storage.DoNotVerifyTLS],
storage.StorageProviderTypeKey: in[storage.StorageProviderTypeKey],
}
}
// configureStorage builds a complete storage configuration from a mix of // configureStorage builds a complete storage configuration from a mix of
// viper properties and manual overrides. // viper properties and manual overrides.
func configureStorage( func configureStorage(
@ -59,72 +23,20 @@ func configureStorage(
matchFromConfig bool, matchFromConfig bool,
overrides map[string]string, overrides map[string]string,
) (storage.Storage, error) { ) (storage.Storage, error) {
var ( var store storage.Storage
s3Cfg storage.S3Config
store storage.Storage
err error
)
if readConfigFromViper {
if s3Cfg, err = s3ConfigsFromViper(vpr); err != nil {
return store, clues.Wrap(err, "reading s3 configs from corso config file")
}
if b, ok := overrides[storage.Bucket]; ok {
overrides[storage.Bucket] = common.NormalizeBucket(b)
}
if p, ok := overrides[storage.Prefix]; ok {
overrides[storage.Prefix] = common.NormalizePrefix(p)
}
if matchFromConfig {
providerType := vpr.GetString(storage.StorageProviderTypeKey)
if providerType != storage.ProviderS3.String() {
return store, clues.New("unsupported storage provider: " + providerType)
}
if err := mustMatchConfig(vpr, s3Overrides(overrides)); err != nil {
return store, clues.Wrap(err, "verifying s3 configs in corso config file")
}
}
}
if s3Cfg, err = s3CredsFromViper(vpr, s3Cfg); err != nil {
return store, clues.Wrap(err, "reading s3 configs from corso config file")
}
aws := credentials.GetAWS(overrides)
if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 {
_, err = defaults.CredChain(defaults.Config().WithCredentialsChainVerboseErrors(true), defaults.Handlers()).Get()
if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) {
aws = credentials.AWS{
AccessKey: s3Cfg.AccessKey,
SecretKey: s3Cfg.SecretKey,
SessionToken: s3Cfg.SessionToken,
}
err = nil
}
sc, err := storage.NewStorageConfig(provider)
if err != nil { if err != nil {
return store, clues.Wrap(err, "validating aws credentials") return store, clues.Stack(err)
}
} }
s3Cfg = storage.S3Config{ err = sc.ApplyConfigOverrides(
AWS: aws, vpr,
Bucket: str.First(overrides[storage.Bucket], s3Cfg.Bucket), readConfigFromViper,
Endpoint: str.First(overrides[storage.Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"), matchFromConfig,
Prefix: str.First(overrides[storage.Prefix], s3Cfg.Prefix), overrides)
DoNotUseTLS: str.ParseBool(str.First( if err != nil {
overrides[storage.DoNotUseTLS], return store, clues.Stack(err)
strconv.FormatBool(s3Cfg.DoNotUseTLS),
"false")),
DoNotVerifyTLS: str.ParseBool(str.First(
overrides[storage.DoNotVerifyTLS],
strconv.FormatBool(s3Cfg.DoNotVerifyTLS),
"false")),
} }
// compose the common config and credentials // compose the common config and credentials
@ -146,14 +58,13 @@ func configureStorage(
// ensure required properties are present // ensure required properties are present
if err := requireProps(map[string]string{ if err := requireProps(map[string]string{
storage.Bucket: s3Cfg.Bucket,
credentials.CorsoPassphrase: corso.CorsoPassphrase, credentials.CorsoPassphrase: corso.CorsoPassphrase,
}); err != nil { }); err != nil {
return storage.Storage{}, err return storage.Storage{}, err
} }
// build the storage // build the storage
store, err = storage.NewStorage(provider, s3Cfg, cCfg) store, err = storage.NewStorage(provider, sc, cCfg)
if err != nil { if err != nil {
return store, clues.Wrap(err, "configuring repository storage") return store, clues.Wrap(err, "configuring repository storage")
} }

View File

@ -5,7 +5,6 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/alcionai/corso/src/pkg/account"
"github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/credentials"
"github.com/alcionai/corso/src/pkg/storage" "github.com/alcionai/corso/src/pkg/storage"
) )
@ -54,10 +53,9 @@ func S3FlagOverrides(cmd *cobra.Command) map[string]string {
} }
func PopulateS3Flags(flagset PopulatedFlags) map[string]string { func PopulateS3Flags(flagset PopulatedFlags) map[string]string {
s3Overrides := make(map[string]string) s3Overrides := map[string]string{
// TODO(pandeyabs): Move account overrides out of s3 flags storage.StorageProviderTypeKey: storage.ProviderS3.String(),
s3Overrides[account.AccountProviderTypeKey] = account.ProviderM365.String() }
s3Overrides[storage.StorageProviderTypeKey] = storage.ProviderS3.String()
if _, ok := flagset[AWSAccessKeyFN]; ok { if _, ok := flagset[AWSAccessKeyFN]; ok {
s3Overrides[credentials.AWSAccessKeyID] = AWSAccessKeyFV s3Overrides[credentials.AWSAccessKeyID] = AWSAccessKeyFV

View File

@ -112,11 +112,13 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
cfg.Account.ID(), cfg.Account.ID(),
opt) opt)
s3Cfg, err := cfg.Storage.S3Config() sc, err := cfg.Storage.StorageConfig()
if err != nil { if err != nil {
return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration"))
} }
s3Cfg := sc.(*storage.S3Config)
if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") {
invalidEndpointErr := "endpoint doesn't support specifying protocol. " + invalidEndpointErr := "endpoint doesn't support specifying protocol. " +
"pass --disable-tls flag to use http:// instead of default https://" "pass --disable-tls flag to use http:// instead of default https://"
@ -189,11 +191,13 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
repoID = events.RepoIDNotFound repoID = events.RepoIDNotFound
} }
s3Cfg, err := cfg.Storage.S3Config() sc, err := cfg.Storage.StorageConfig()
if err != nil { if err != nil {
return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration"))
} }
s3Cfg := sc.(*storage.S3Config)
m365, err := cfg.Account.M365Config() m365, err := cfg.Account.M365Config()
if err != nil { if err != nil {
return Only(ctx, clues.Wrap(err, "Failed to parse m365 account config")) return Only(ctx, clues.Wrap(err, "Failed to parse m365 account config"))

View File

@ -63,8 +63,10 @@ func (suite *S3E2ESuite) TestInitS3Cmd() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config()
sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(*storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
if !test.hasConfigFile { if !test.hasConfigFile {
@ -100,9 +102,11 @@ func (suite *S3E2ESuite) TestInitMultipleTimes() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(*storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr) ctx = config.SetViper(ctx, vpr)
@ -129,9 +133,12 @@ func (suite *S3E2ESuite) TestInitS3Cmd_missingBucket() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config()
sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(*storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgBucket: "", tconfig.TestCfgBucket: "",
} }
@ -182,8 +189,9 @@ func (suite *S3E2ESuite) TestConnectS3Cmd() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(*storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
@ -230,9 +238,11 @@ func (suite *S3E2ESuite) TestConnectS3Cmd_BadBucket() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(*storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr) ctx = config.SetViper(ctx, vpr)
@ -256,9 +266,11 @@ func (suite *S3E2ESuite) TestConnectS3Cmd_BadPrefix() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(*storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr) ctx = config.SetViper(ctx, vpr)

View File

@ -66,9 +66,11 @@ func (suite *RestoreExchangeE2ESuite) SetupSuite() {
suite.acct = tconfig.NewM365Account(t) suite.acct = tconfig.NewM365Account(t)
suite.st = storeTD.NewPrefixedS3Storage(t) suite.st = storeTD.NewPrefixedS3Storage(t)
cfg, err := suite.st.S3Config() sc, err := suite.st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(*storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(), tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),

View File

@ -32,7 +32,7 @@ func GetAccountAndConnectWithOverrides(
) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) { ) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) {
provider, overrides, err := GetStorageProviderAndOverrides(ctx, cmd) provider, overrides, err := GetStorageProviderAndOverrides(ctx, cmd)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, clues.Stack(err)
} }
return GetAccountAndConnect(ctx, pst, provider, overrides) return GetAccountAndConnect(ctx, pst, provider, overrides)
@ -89,12 +89,14 @@ func AccountConnectAndWriteRepoConfig(
return nil, nil, err return nil, nil, err
} }
s3Config, err := stg.S3Config() sc, err := stg.StorageConfig()
if err != nil { if err != nil {
logger.CtxErr(ctx, err).Info("getting storage configuration") logger.CtxErr(ctx, err).Info("getting storage configuration")
return nil, nil, err return nil, nil, err
} }
s3Config := sc.(*storage.S3Config)
m365Config, err := acc.M365Config() m365Config, err := acc.M365Config()
if err != nil { if err != nil {
logger.CtxErr(ctx, err).Info("getting m365 configuration") logger.CtxErr(ctx, err).Info("getting m365 configuration")

View File

@ -197,11 +197,13 @@ func handleCheckerCommand(cmd *cobra.Command, args []string, f flags) error {
return clues.Wrap(err, "getting storage config") return clues.Wrap(err, "getting storage config")
} }
cfg, err := repoDetails.Storage.S3Config() sc, err := repoDetails.Storage.StorageConfig()
if err != nil { if err != nil {
return clues.Wrap(err, "getting S3 config") return clues.Wrap(err, "getting S3 config")
} }
cfg := sc.(*storage.S3Config)
endpoint := defaultS3Endpoint endpoint := defaultS3Endpoint
if len(cfg.Endpoint) > 0 { if len(cfg.Endpoint) > 0 {
endpoint = cfg.Endpoint endpoint = cfg.Endpoint

View File

@ -33,7 +33,7 @@ func (suite *EventsIntegrationSuite) TestNewBus() {
s, err := storage.NewStorage( s, err := storage.NewStorage(
storage.ProviderS3, storage.ProviderS3,
storage.S3Config{ &storage.S3Config{
Bucket: "bckt", Bucket: "bckt",
Prefix: "prfx", Prefix: "prfx",
}) })

View File

@ -20,11 +20,13 @@ func s3BlobStorage(
repoOpts repository.Options, repoOpts repository.Options,
s storage.Storage, s storage.Storage,
) (blob.Storage, error) { ) (blob.Storage, error) {
cfg, err := s.S3Config() sc, err := s.StorageConfig()
if err != nil { if err != nil {
return nil, clues.Stack(err).WithClues(ctx) return nil, clues.Stack(err).WithClues(ctx)
} }
cfg := sc.(*storage.S3Config)
endpoint := defaultS3Endpoint endpoint := defaultS3Endpoint
if len(cfg.Endpoint) > 0 { if len(cfg.Endpoint) > 0 {
endpoint = cfg.Endpoint endpoint = cfg.Endpoint

View File

@ -1,9 +1,11 @@
package storage package storage
import ( import (
"os"
"strconv" "strconv"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/spf13/cast"
"github.com/alcionai/corso/src/internal/common" "github.com/alcionai/corso/src/internal/common"
"github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/common/str"
@ -40,7 +42,27 @@ const (
DoNotVerifyTLS = "donotverifytls" DoNotVerifyTLS = "donotverifytls"
) )
func (c S3Config) Normalize() S3Config { // config file keys
const (
BucketNameKey = "bucket"
EndpointKey = "endpoint"
PrefixKey = "prefix"
DisableTLSKey = "disable_tls"
DisableTLSVerificationKey = "disable_tls_verification"
AccessKey = "aws_access_key_id"
SecretAccessKey = "aws_secret_access_key"
SessionToken = "aws_session_token"
)
var s3constToTomlKeyMap = map[string]string{
Bucket: BucketNameKey,
Endpoint: EndpointKey,
Prefix: PrefixKey,
StorageProviderTypeKey: StorageProviderTypeKey,
}
func (c *S3Config) normalize() S3Config {
return S3Config{ return S3Config{
Bucket: common.NormalizeBucket(c.Bucket), Bucket: common.NormalizeBucket(c.Bucket),
Endpoint: c.Endpoint, Endpoint: c.Endpoint,
@ -53,8 +75,8 @@ func (c S3Config) Normalize() S3Config {
// StringConfig transforms a s3Config struct into a plain // StringConfig transforms a s3Config struct into a plain
// map[string]string. All values in the original struct which // map[string]string. All values in the original struct which
// serialize into the map are expected to be strings. // serialize into the map are expected to be strings.
func (c S3Config) StringConfig() (map[string]string, error) { func (c *S3Config) StringConfig() (map[string]string, error) {
cn := c.Normalize() cn := c.normalize()
cfg := map[string]string{ cfg := map[string]string{
keyS3AccessKey: c.AccessKey, keyS3AccessKey: c.AccessKey,
keyS3Bucket: cn.Bucket, keyS3Bucket: cn.Bucket,
@ -66,23 +88,22 @@ func (c S3Config) StringConfig() (map[string]string, error) {
keyS3DoNotVerifyTLS: strconv.FormatBool(cn.DoNotVerifyTLS), keyS3DoNotVerifyTLS: strconv.FormatBool(cn.DoNotVerifyTLS),
} }
return cfg, c.validate() return cfg, cn.validate()
} }
// S3Config retrieves the S3Config details from the Storage config. func buildS3ConfigFromMap(config map[string]string) (*S3Config, error) {
func (s Storage) S3Config() (S3Config, error) { c := &S3Config{}
c := S3Config{}
if len(s.Config) > 0 { if len(config) > 0 {
c.AccessKey = orEmptyString(s.Config[keyS3AccessKey]) c.AccessKey = orEmptyString(config[keyS3AccessKey])
c.SecretKey = orEmptyString(s.Config[keyS3SecretKey]) c.SecretKey = orEmptyString(config[keyS3SecretKey])
c.SessionToken = orEmptyString(s.Config[keyS3SessionToken]) c.SessionToken = orEmptyString(config[keyS3SessionToken])
c.Bucket = orEmptyString(s.Config[keyS3Bucket]) c.Bucket = orEmptyString(config[keyS3Bucket])
c.Endpoint = orEmptyString(s.Config[keyS3Endpoint]) c.Endpoint = orEmptyString(config[keyS3Endpoint])
c.Prefix = orEmptyString(s.Config[keyS3Prefix]) c.Prefix = orEmptyString(config[keyS3Prefix])
c.DoNotUseTLS = str.ParseBool(s.Config[keyS3DoNotUseTLS]) c.DoNotUseTLS = str.ParseBool(config[keyS3DoNotUseTLS])
c.DoNotVerifyTLS = str.ParseBool(s.Config[keyS3DoNotVerifyTLS]) c.DoNotVerifyTLS = str.ParseBool(config[keyS3DoNotVerifyTLS])
} }
return c, c.validate() return c, c.validate()
@ -100,3 +121,107 @@ func (c S3Config) validate() error {
return nil return nil
} }
func s3Overrides(in map[string]string) map[string]string {
return map[string]string{
Bucket: in[Bucket],
Endpoint: in[Endpoint],
Prefix: in[Prefix],
DoNotUseTLS: in[DoNotUseTLS],
DoNotVerifyTLS: in[DoNotVerifyTLS],
StorageProviderTypeKey: in[StorageProviderTypeKey],
}
}
func (c *S3Config) s3ConfigsFromStore(kvg Getter) {
c.Bucket = cast.ToString(kvg.Get(BucketNameKey))
c.Endpoint = cast.ToString(kvg.Get(EndpointKey))
c.Prefix = cast.ToString(kvg.Get(PrefixKey))
c.DoNotUseTLS = cast.ToBool(kvg.Get(DisableTLSKey))
c.DoNotVerifyTLS = cast.ToBool(kvg.Get(DisableTLSVerificationKey))
}
func (c *S3Config) s3CredsFromStore(kvg Getter) {
c.AccessKey = cast.ToString(kvg.Get(AccessKey))
c.SecretKey = cast.ToString(kvg.Get(SecretAccessKey))
c.SessionToken = cast.ToString(kvg.Get(SessionToken))
}
var _ Configurer = &S3Config{}
func (c *S3Config) ApplyConfigOverrides(
kvg Getter,
readConfigFromStore bool,
matchFromConfig bool,
overrides map[string]string,
) error {
if readConfigFromStore {
c.s3ConfigsFromStore(kvg)
if b, ok := overrides[Bucket]; ok {
overrides[Bucket] = common.NormalizeBucket(b)
}
if p, ok := overrides[Prefix]; ok {
overrides[Prefix] = common.NormalizePrefix(p)
}
if matchFromConfig {
providerType := cast.ToString(kvg.Get(StorageProviderTypeKey))
if providerType != ProviderS3.String() {
return clues.New("unsupported storage provider: " + providerType)
}
if err := mustMatchConfig(kvg, s3constToTomlKeyMap, s3Overrides(overrides)); err != nil {
return clues.Wrap(err, "verifying s3 configs in corso config file")
}
}
}
c.s3CredsFromStore(kvg)
aws := credentials.AWS{
AccessKey: str.First(
overrides[credentials.AWSAccessKeyID],
os.Getenv(credentials.AWSAccessKeyID),
c.AccessKey),
SecretKey: str.First(
overrides[credentials.AWSSecretAccessKey],
os.Getenv(credentials.AWSSecretAccessKey),
c.SecretKey),
SessionToken: str.First(
overrides[credentials.AWSSessionToken],
os.Getenv(credentials.AWSSessionToken),
c.SessionToken),
}
c.AWS = aws
c.Bucket = str.First(overrides[Bucket], c.Bucket)
c.Endpoint = str.First(overrides[Endpoint], c.Endpoint, "s3.amazonaws.com")
c.Prefix = str.First(overrides[Prefix], c.Prefix)
c.DoNotUseTLS = str.ParseBool(str.First(
overrides[DoNotUseTLS],
strconv.FormatBool(c.DoNotUseTLS),
"false"))
c.DoNotVerifyTLS = str.ParseBool(str.First(
overrides[DoNotVerifyTLS],
strconv.FormatBool(c.DoNotVerifyTLS),
"false"))
return c.validate()
}
var _ WriteConfigToStorer = &S3Config{}
func (c *S3Config) WriteConfigToStore(
kvs Setter,
) {
s3Config := c.normalize()
kvs.Set(StorageProviderTypeKey, ProviderS3.String())
kvs.Set(BucketNameKey, s3Config.Bucket)
kvs.Set(EndpointKey, s3Config.Endpoint)
kvs.Set(PrefixKey, s3Config.Prefix)
kvs.Set(DisableTLSKey, s3Config.DoNotUseTLS)
kvs.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS)
}

View File

@ -64,11 +64,12 @@ func (suite *S3CfgSuite) TestStorage_S3Config() {
t := suite.T() t := suite.T()
in := goodS3Config in := goodS3Config
s, err := NewStorage(ProviderS3, in) s, err := NewStorage(ProviderS3, &in)
assert.NoError(t, err, clues.ToCore(err)) assert.NoError(t, err, clues.ToCore(err))
out, err := s.S3Config() sc, err := s.StorageConfig()
assert.NoError(t, err, clues.ToCore(err)) assert.NoError(t, err, clues.ToCore(err))
out := sc.(*S3Config)
assert.Equal(t, in.Bucket, out.Bucket) assert.Equal(t, in.Bucket, out.Bucket)
assert.Equal(t, in.Endpoint, out.Endpoint) assert.Equal(t, in.Endpoint, out.Endpoint)
assert.Equal(t, in.Prefix, out.Prefix) assert.Equal(t, in.Prefix, out.Prefix)
@ -93,7 +94,7 @@ func (suite *S3CfgSuite) TestStorage_S3Config_invalidCases() {
} }
for _, test := range table { for _, test := range table {
suite.Run(test.name, func() { suite.Run(test.name, func() {
_, err := NewStorage(ProviderUnknown, test.cfg) _, err := NewStorage(ProviderUnknown, &test.cfg)
assert.Error(suite.T(), err) assert.Error(suite.T(), err)
}) })
} }
@ -114,10 +115,10 @@ func (suite *S3CfgSuite) TestStorage_S3Config_invalidCases() {
suite.Run(test.name, func() { suite.Run(test.name, func() {
t := suite.T() t := suite.T()
st, err := NewStorage(ProviderUnknown, goodS3Config) st, err := NewStorage(ProviderUnknown, &goodS3Config)
assert.NoError(t, err, clues.ToCore(err)) assert.NoError(t, err, clues.ToCore(err))
test.amend(st) test.amend(st)
_, err = st.S3Config() _, err = st.StorageConfig()
assert.Error(t, err) assert.Error(t, err)
}) })
} }
@ -187,7 +188,7 @@ func (suite *S3CfgSuite) TestStorage_S3Config_Normalize() {
Bucket: prefixedBkt, Bucket: prefixedBkt,
} }
result := st.Normalize() result := st.normalize()
assert.Equal(suite.T(), normalBkt, result.Bucket) assert.Equal(suite.T(), normalBkt, result.Bucket)
assert.NotEqual(suite.T(), st.Bucket, result.Bucket) assert.NotEqual(suite.T(), st.Bucket, result.Bucket)
} }

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/spf13/cast"
"github.com/alcionai/corso/src/internal/common" "github.com/alcionai/corso/src/internal/common"
) )
@ -92,3 +93,79 @@ func orEmptyString(v any) string {
return v.(string) return v.(string)
} }
func (s Storage) StorageConfig() (Configurer, error) {
switch s.Provider {
case ProviderS3:
return buildS3ConfigFromMap(s.Config)
}
return nil, clues.New("unsupported storage provider: " + s.Provider.String())
}
func NewStorageConfig(provider ProviderType) (Configurer, error) {
switch provider {
case ProviderS3:
return &S3Config{}, nil
}
return nil, clues.New("unsupported storage provider: " + provider.String())
}
type Getter interface {
Get(key string) any
}
type Setter interface {
Set(key string, value any)
}
// WriteConfigToStorer writes config key value pairs to provided store.
type WriteConfigToStorer interface {
WriteConfigToStore(
s Setter,
)
}
type Configurer interface {
common.StringConfigurer
// ApplyOverrides fetches config from file, processes overrides
// from sources like environment variables and flags, and updates the
// underlying configuration accordingly.
ApplyConfigOverrides(
g Getter,
readConfigFromStore bool,
matchFromConfig bool,
overrides map[string]string,
) error
WriteConfigToStorer
}
// mustMatchConfig compares the values of each key to their config file value in store.
// If any value differs from the store value, an error is returned.
// values in m that aren't stored in the config are ignored.
func mustMatchConfig(
g Getter,
tomlMap map[string]string,
m map[string]string,
) error {
for k, v := range m {
if len(v) == 0 {
continue // empty variables will get caught by configuration validators, if necessary
}
tomlK, ok := tomlMap[k]
if !ok {
continue // m may declare values which aren't stored in the config file
}
vv := cast.ToString(g.Get(tomlK))
if v != vv {
return clues.New("value of " + k + " (" + v + ") does not match corso configuration value (" + vv + ")")
}
}
return nil
}

View File

@ -38,7 +38,7 @@ func NewPrefixedS3Storage(t tester.TestT) storage.Storage {
st, err := storage.NewStorage( st, err := storage.NewStorage(
storage.ProviderS3, storage.ProviderS3,
storage.S3Config{ &storage.S3Config{
Bucket: cfg[tconfig.TestCfgBucket], Bucket: cfg[tconfig.TestCfgBucket],
Prefix: prefix, Prefix: prefix,
}, },