Compare commits

...

3 Commits

Author SHA1 Message Date
Abhishek Pandey
bc1e9b3161 Fix rebase issues 2023-09-13 19:31:06 +05:30
Abhishek Pandey
9705c9d212 More changes before piecemealing 2023-09-13 19:27:08 +05:30
Abhishek Pandey
8060227bf9 Add new prototype for storage config refactor 2023-09-13 19:25:22 +05:30
23 changed files with 492 additions and 192 deletions

View File

@ -290,7 +290,12 @@ func genericDeleteCommand(
ctx := clues.Add(cmd.Context(), "delete_backup_id", bID) ctx := clues.Add(cmd.Context(), "delete_backup_id", bID)
r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, repo.S3Overrides(cmd)) storageProvider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, storageProvider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -316,7 +321,12 @@ func genericListCommand(
) error { ) error {
ctx := cmd.Context() ctx := cmd.Context()
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -168,7 +168,12 @@ func createExchangeCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.ExchangeService, repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.ExchangeService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -277,7 +282,12 @@ func detailsExchangeCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeExchangeOpts(cmd) opts := utils.MakeExchangeOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.ExchangeService, repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.ExchangeService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -154,7 +154,12 @@ func createGroupsCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.GroupsService, repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.GroupsService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -226,7 +231,12 @@ func detailsGroupsCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeGroupsOpts(cmd) opts := utils.MakeGroupsOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.GroupsService, repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.GroupsService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -140,9 +140,11 @@ func prepM365Test(
recorder = strings.Builder{} recorder = strings.Builder{}
) )
cfg, err := st.S3Config() c, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := c.(storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(), tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),

View File

@ -149,7 +149,12 @@ func createOneDriveCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.OneDriveService, repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.OneDriveService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -235,7 +240,12 @@ func detailsOneDriveCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeOneDriveOpts(cmd) opts := utils.MakeOneDriveOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.OneDriveService, repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.OneDriveService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -159,7 +159,12 @@ func createSharePointCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.SharePointService, repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.SharePointService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -319,7 +324,12 @@ func detailsSharePointCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeSharePointOpts(cmd) opts := utils.MakeSharePointOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.SharePointService, repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.SharePointService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -70,9 +70,12 @@ func preRun(cc *cobra.Command, args []string) error {
} }
if !slices.Contains(avoidTheseDescription, cc.Short) { if !slices.Contains(avoidTheseDescription, cc.Short) {
overrides := repo.S3Overrides(cc) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cc)
if err != nil {
return err
}
cfg, err := config.GetConfigRepoDetails(ctx, true, false, overrides) cfg, err := config.GetConfigRepoDetails(ctx, provider, true, false, overrides)
if err != nil { if err != nil {
log.Error("Error while getting config info to run command: ", cc.Use) log.Error("Error while getting config info to run command: ", cc.Use)
return err return err

View File

@ -203,14 +203,14 @@ func Read(ctx context.Context) error {
// It does not check for conflicts or existing data. // It does not check for conflicts or existing data.
func WriteRepoConfig( func WriteRepoConfig(
ctx context.Context, ctx context.Context,
s3Config storage.S3Config, scfg storage.WriteConfigToStorer,
m365Config account.M365Config, m365Config account.M365Config,
repoOpts repository.Options, repoOpts repository.Options,
repoID string, repoID string,
) error { ) error {
return writeRepoConfigWithViper( return writeRepoConfigWithViper(
GetViper(ctx), GetViper(ctx),
s3Config, scfg,
m365Config, m365Config,
repoOpts, repoOpts,
repoID) repoID)
@ -220,20 +220,12 @@ func WriteRepoConfig(
// struct for testing. // struct for testing.
func writeRepoConfigWithViper( func writeRepoConfigWithViper(
vpr *viper.Viper, vpr *viper.Viper,
s3Config storage.S3Config, scfg storage.WriteConfigToStorer,
m365Config account.M365Config, m365Config account.M365Config,
repoOpts repository.Options, repoOpts repository.Options,
repoID string, repoID string,
) error { ) error {
s3Config = s3Config.Normalize() scfg.WriteConfigToStore(vpr)
// Rudimentary support for persisting repo config
// TODO: Handle conflicts, support other config types
vpr.Set(StorageProviderTypeKey, storage.ProviderS3.String())
vpr.Set(BucketNameKey, s3Config.Bucket)
vpr.Set(EndpointKey, s3Config.Endpoint)
vpr.Set(PrefixKey, s3Config.Prefix)
vpr.Set(DisableTLSKey, s3Config.DoNotUseTLS)
vpr.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS)
vpr.Set(RepoID, repoID) vpr.Set(RepoID, repoID)
// Need if-checks as Viper will write empty values otherwise. // Need if-checks as Viper will write empty values otherwise.
@ -263,6 +255,7 @@ func writeRepoConfigWithViper(
// data sources (config file, env vars, flag overrides) and the config file. // data sources (config file, env vars, flag overrides) and the config file.
func GetConfigRepoDetails( func GetConfigRepoDetails(
ctx context.Context, ctx context.Context,
provider storage.ProviderType,
readFromFile bool, readFromFile bool,
mustMatchFromConfig bool, mustMatchFromConfig bool,
overrides map[string]string, overrides map[string]string,
@ -270,7 +263,7 @@ func GetConfigRepoDetails(
RepoDetails, RepoDetails,
error, error,
) { ) {
config, err := getStorageAndAccountWithViper(GetViper(ctx), readFromFile, mustMatchFromConfig, overrides) config, err := getStorageAndAccountWithViper(GetViper(ctx), provider, readFromFile, mustMatchFromConfig, overrides)
return config, err return config, err
} }
@ -278,6 +271,7 @@ func GetConfigRepoDetails(
// struct for testing. // struct for testing.
func getStorageAndAccountWithViper( func getStorageAndAccountWithViper(
vpr *viper.Viper, vpr *viper.Viper,
provider storage.ProviderType,
readFromFile bool, readFromFile bool,
mustMatchFromConfig bool, mustMatchFromConfig bool,
overrides map[string]string, overrides map[string]string,
@ -312,7 +306,7 @@ func getStorageAndAccountWithViper(
return config, clues.Wrap(err, "retrieving account configuration details") return config, clues.Wrap(err, "retrieving account configuration details")
} }
config.Storage, err = configureStorage(vpr, readConfigFromViper, mustMatchFromConfig, overrides) config.Storage, err = configureStorage(vpr, provider, readConfigFromViper, mustMatchFromConfig, overrides)
if err != nil { if err != nil {
return config, clues.Wrap(err, "retrieving storage provider details") return config, clues.Wrap(err, "retrieving storage provider details")
} }
@ -378,3 +372,17 @@ func requireProps(props map[string]string) error {
return nil return nil
} }
// Storage provider is not a flag. It can only be sourced from config file.
// Only exceptions are the commands that create a new repo.
// This is needed to figure out which storage overrides to use.
func GetStorageProviderFromConfigFile(ctx context.Context) (storage.ProviderType, error) {
vpr := GetViper(ctx)
provider := vpr.GetString(StorageProviderTypeKey)
if provider != storage.ProviderS3.String() {
return storage.ProviderUnknown, clues.New("unsupported storage provider: " + provider)
}
return storage.StringToProviderType[provider], nil
}

View File

@ -107,14 +107,23 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() {
err = vpr.ReadInConfig() err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err)) require.NoError(t, err, "reading repo config", clues.ToCore(err))
s3Cfg, err := s3ConfigsFromViper(vpr) // Unset AWS env vars so that we can test reading creds from config file
os.Unsetenv(credentials.AWSAccessKeyID)
os.Unsetenv(credentials.AWSSecretAccessKey)
os.Unsetenv(credentials.AWSSessionToken)
sc, err := storage.S3Config{}.FetchConfigFromStore(vpr, true, true, nil)
require.NoError(t, err, clues.ToCore(err))
s3Cfg := sc.(storage.S3Config)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, b, s3Cfg.Bucket) assert.Equal(t, b, s3Cfg.Bucket)
assert.Equal(t, "test-prefix/", s3Cfg.Prefix) assert.Equal(t, "test-prefix/", s3Cfg.Prefix)
assert.Equal(t, disableTLS, strconv.FormatBool(s3Cfg.DoNotUseTLS)) assert.Equal(t, disableTLS, strconv.FormatBool(s3Cfg.DoNotUseTLS))
assert.Equal(t, disableTLSVerification, strconv.FormatBool(s3Cfg.DoNotVerifyTLS)) assert.Equal(t, disableTLSVerification, strconv.FormatBool(s3Cfg.DoNotVerifyTLS))
s3Cfg, err = s3CredsFromViper(vpr, s3Cfg) // s3Cfg, err = s3CredsFromViper(vpr, s3Cfg)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, accKey, s3Cfg.AWS.AccessKey) assert.Equal(t, accKey, s3Cfg.AWS.AccessKey)
assert.Equal(t, secret, s3Cfg.AWS.SecretKey) assert.Equal(t, secret, s3Cfg.AWS.SecretKey)
@ -160,7 +169,11 @@ func (suite *ConfigSuite) TestWriteReadConfig() {
err = vpr.ReadInConfig() err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err)) require.NoError(t, err, "reading repo config", clues.ToCore(err))
readS3Cfg, err := s3ConfigsFromViper(vpr) sc, err := storage.S3Config{}.FetchConfigFromStore(vpr, true, true, nil)
require.NoError(t, err, clues.ToCore(err))
readS3Cfg := sc.(storage.S3Config)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket)
assert.Equal(t, readS3Cfg.DoNotUseTLS, s3Cfg.DoNotUseTLS) assert.Equal(t, readS3Cfg.DoNotUseTLS, s3Cfg.DoNotUseTLS)
@ -326,12 +339,14 @@ func (suite *ConfigSuite) TestReadFromFlags() {
repoDetails, err := getStorageAndAccountWithViper( repoDetails, err := getStorageAndAccountWithViper(
vpr, vpr,
storage.ProviderS3,
true, true,
false, false,
overrides) overrides)
m365Config, _ := repoDetails.Account.M365Config() m365Config, _ := repoDetails.Account.M365Config()
s3Cfg, _ := repoDetails.Storage.S3Config() cfg, _ := repoDetails.Storage.StorageConfig()
s3Cfg, _ := cfg.(storage.S3Config)
commonConfig, _ := repoDetails.Storage.CommonConfig() commonConfig, _ := repoDetails.Storage.CommonConfig()
pass := commonConfig.Corso.CorsoPassphrase pass := commonConfig.Corso.CorsoPassphrase
@ -400,11 +415,13 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() {
err = vpr.ReadInConfig() err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err)) require.NoError(t, err, "reading repo config", clues.ToCore(err))
cfg, err := getStorageAndAccountWithViper(vpr, true, true, nil) cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, true, true, nil)
require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err))
readS3Cfg, err := cfg.Storage.S3Config() sc, err := cfg.Storage.StorageConfig()
require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err)) require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err))
readS3Cfg := sc.(storage.S3Config)
assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket)
assert.Equal(t, readS3Cfg.Endpoint, s3Cfg.Endpoint) assert.Equal(t, readS3Cfg.Endpoint, s3Cfg.Endpoint)
assert.Equal(t, readS3Cfg.Prefix, s3Cfg.Prefix) assert.Equal(t, readS3Cfg.Prefix, s3Cfg.Prefix)
@ -448,11 +465,13 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount_noFileOnlyOverride
StorageProviderTypeKey: storage.ProviderS3.String(), StorageProviderTypeKey: storage.ProviderS3.String(),
} }
cfg, err := getStorageAndAccountWithViper(vpr, false, true, overrides) cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, false, true, overrides)
require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err))
readS3Cfg, err := cfg.Storage.S3Config() sc, err := cfg.Storage.StorageConfig()
require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err)) require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err))
readS3Cfg := sc.(storage.S3Config)
assert.Equal(t, readS3Cfg.Bucket, bkt) assert.Equal(t, readS3Cfg.Bucket, bkt)
assert.Equal(t, cfg.RepoID, "") assert.Equal(t, cfg.RepoID, "")
assert.Equal(t, readS3Cfg.Endpoint, end) assert.Equal(t, readS3Cfg.Endpoint, end)

View File

@ -3,127 +3,40 @@ package config
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
"github.com/alcionai/corso/src/internal/common"
"github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/credentials"
"github.com/alcionai/corso/src/pkg/storage" "github.com/alcionai/corso/src/pkg/storage"
) )
// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values.
func s3ConfigsFromViper(vpr *viper.Viper) (storage.S3Config, error) {
var s3Config storage.S3Config
s3Config.Bucket = vpr.GetString(BucketNameKey)
s3Config.Endpoint = vpr.GetString(EndpointKey)
s3Config.Prefix = vpr.GetString(PrefixKey)
s3Config.DoNotUseTLS = vpr.GetBool(DisableTLSKey)
s3Config.DoNotVerifyTLS = vpr.GetBool(DisableTLSVerificationKey)
return s3Config, nil
}
// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values.
func s3CredsFromViper(vpr *viper.Viper, s3Config storage.S3Config) (storage.S3Config, error) {
s3Config.AccessKey = vpr.GetString(AccessKey)
s3Config.SecretKey = vpr.GetString(SecretAccessKey)
s3Config.SessionToken = vpr.GetString(SessionToken)
return s3Config, nil
}
func s3Overrides(in map[string]string) map[string]string {
return map[string]string{
storage.Bucket: in[storage.Bucket],
storage.Endpoint: in[storage.Endpoint],
storage.Prefix: in[storage.Prefix],
storage.DoNotUseTLS: in[storage.DoNotUseTLS],
storage.DoNotVerifyTLS: in[storage.DoNotVerifyTLS],
StorageProviderTypeKey: in[StorageProviderTypeKey],
}
}
// configureStorage builds a complete storage configuration from a mix of // configureStorage builds a complete storage configuration from a mix of
// viper properties and manual overrides. // viper properties and manual overrides.
func configureStorage( func configureStorage(
vpr *viper.Viper, vpr *viper.Viper,
provider storage.ProviderType,
readConfigFromViper bool, readConfigFromViper bool,
matchFromConfig bool, matchFromConfig bool,
overrides map[string]string, overrides map[string]string,
) (storage.Storage, error) { ) (storage.Storage, error) {
var ( var (
s3Cfg storage.S3Config
store storage.Storage store storage.Storage
err error err error
) )
if readConfigFromViper { storageCfg, _ := storage.NewStorageConfig(provider)
if s3Cfg, err = s3ConfigsFromViper(vpr); err != nil {
return store, clues.Wrap(err, "reading s3 configs from corso config file")
}
if b, ok := overrides[storage.Bucket]; ok {
overrides[storage.Bucket] = common.NormalizeBucket(b)
}
if p, ok := overrides[storage.Prefix]; ok {
overrides[storage.Prefix] = common.NormalizePrefix(p)
}
if matchFromConfig {
providerType := vpr.GetString(StorageProviderTypeKey)
if providerType != storage.ProviderS3.String() {
return store, clues.New("unsupported storage provider: " + providerType)
}
if err := mustMatchConfig(vpr, s3Overrides(overrides)); err != nil {
return store, clues.Wrap(err, "verifying s3 configs in corso config file")
}
}
}
if s3Cfg, err = s3CredsFromViper(vpr, s3Cfg); err != nil {
return store, clues.Wrap(err, "reading s3 configs from corso config file")
}
s3Overrides(overrides)
aws := credentials.GetAWS(overrides)
if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 {
_, err = defaults.CredChain(defaults.Config().WithCredentialsChainVerboseErrors(true), defaults.Handlers()).Get()
if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) {
aws = credentials.AWS{
AccessKey: s3Cfg.AccessKey,
SecretKey: s3Cfg.SecretKey,
SessionToken: s3Cfg.SessionToken,
}
err = nil
}
// Rename this. It's not just fetch config from store.
storageCfg, err = storageCfg.FetchConfigFromStore(
vpr,
readConfigFromViper,
matchFromConfig,
overrides)
if err != nil { if err != nil {
return store, clues.Wrap(err, "validating aws credentials") return store, clues.Wrap(err, "fetching storage config from store")
}
}
s3Cfg = storage.S3Config{
AWS: aws,
Bucket: str.First(overrides[storage.Bucket], s3Cfg.Bucket),
Endpoint: str.First(overrides[storage.Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"),
Prefix: str.First(overrides[storage.Prefix], s3Cfg.Prefix),
DoNotUseTLS: str.ParseBool(str.First(
overrides[storage.DoNotUseTLS],
strconv.FormatBool(s3Cfg.DoNotUseTLS),
"false")),
DoNotVerifyTLS: str.ParseBool(str.First(
overrides[storage.DoNotVerifyTLS],
strconv.FormatBool(s3Cfg.DoNotVerifyTLS),
"false")),
} }
// compose the common config and credentials // compose the common config and credentials
@ -145,14 +58,13 @@ func configureStorage(
// ensure required properties are present // ensure required properties are present
if err := requireProps(map[string]string{ if err := requireProps(map[string]string{
storage.Bucket: s3Cfg.Bucket,
credentials.CorsoPassphrase: corso.CorsoPassphrase, credentials.CorsoPassphrase: corso.CorsoPassphrase,
}); err != nil { }); err != nil {
return storage.Storage{}, err return storage.Storage{}, err
} }
// build the storage // build the storage
store, err = storage.NewStorage(storage.ProviderS3, s3Cfg, cCfg) store, err = storage.NewStorage(provider, storageCfg, cCfg)
if err != nil { if err != nil {
return store, clues.Wrap(err, "configuring repository storage") return store, clues.Wrap(err, "configuring repository storage")
} }

View File

@ -70,7 +70,12 @@ func runExport(
sel selectors.Selector, sel selectors.Selector,
backupID, serviceName string, backupID, serviceName string,
) error { ) error {
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -1,17 +1,20 @@
package repo package repo
import ( import (
"context"
"strings" "strings"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
"github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/cli/utils"
"github.com/alcionai/corso/src/pkg/control/repository" "github.com/alcionai/corso/src/pkg/control/repository"
"github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/path"
"github.com/alcionai/corso/src/pkg/storage"
) )
const ( const (
@ -121,7 +124,9 @@ func handleMaintenanceCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
r, _, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.UnknownService, S3Overrides(cmd)) // Change this to override too?
r, _, err := utils.AccountConnectAndWriteRepoConfig(
ctx, path.UnknownService, storage.ProviderS3, S3Overrides(cmd))
if err != nil { if err != nil {
return print.Only(ctx, err) return print.Only(ctx, err)
} }
@ -164,3 +169,37 @@ func getMaintenanceType(t string) (repository.MaintenanceType, error) {
return res, nil return res, nil
} }
func GetStorageOverrides(
ctx context.Context,
cmd *cobra.Command,
storageProvider string,
) (map[string]string, error) {
overrides := map[string]string{}
switch storageProvider {
case storage.ProviderS3.String():
overrides = S3Overrides(cmd)
}
return overrides, nil
}
func GetStorageProviderAndOverrides(
ctx context.Context,
cmd *cobra.Command,
) (storage.ProviderType, map[string]string, error) {
provider, err := config.GetStorageProviderFromConfigFile(ctx)
if err != nil {
return provider, nil, clues.Stack(err)
}
overrides := map[string]string{}
switch provider {
case storage.ProviderS3:
overrides = S3Overrides(cmd)
}
return provider, overrides, nil
}

View File

@ -2,7 +2,6 @@ package repo
import ( import (
"strconv" "strconv"
"strings"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -92,7 +91,7 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
// s3 values from flags // s3 values from flags
s3Override := S3Overrides(cmd) s3Override := S3Overrides(cmd)
cfg, err := config.GetConfigRepoDetails(ctx, true, false, s3Override) cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, true, false, s3Override)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -113,17 +112,18 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
cfg.Account.ID(), cfg.Account.ID(),
opt) opt)
s3Cfg, err := cfg.Storage.S3Config() storageCfg, err := cfg.Storage.StorageConfig()
if err != nil { if err != nil {
return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration"))
} }
if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { // BUG: This should be moved to validate()
invalidEndpointErr := "endpoint doesn't support specifying protocol. " + // if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") {
"pass --disable-tls flag to use http:// instead of default https://" // invalidEndpointErr := "endpoint doesn't support specifying protocol. " +
// "pass --disable-tls flag to use http:// instead of default https://"
return Only(ctx, clues.New(invalidEndpointErr)) // return Only(ctx, clues.New(invalidEndpointErr))
} // }
m365, err := cfg.Account.M365Config() m365, err := cfg.Account.M365Config()
if err != nil { if err != nil {
@ -146,9 +146,10 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
defer utils.CloseRepo(ctx, r) defer utils.CloseRepo(ctx, r)
Infof(ctx, "Initialized a S3 repository within bucket %s.", s3Cfg.Bucket) // Strong typecast?
Infof(ctx, "Initialized a S3 repository within bucket %s.", cfg.Storage.Config[storage.Bucket])
if err = config.WriteRepoConfig(ctx, s3Cfg, m365, opt.Repo, r.GetID()); err != nil { if err = config.WriteRepoConfig(ctx, storageCfg, m365, opt.Repo, r.GetID()); err != nil {
return Only(ctx, clues.Wrap(err, "Failed to write repository configuration")) return Only(ctx, clues.Wrap(err, "Failed to write repository configuration"))
} }
@ -178,7 +179,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
// s3 values from flags // s3 values from flags
s3Override := S3Overrides(cmd) s3Override := S3Overrides(cmd)
cfg, err := config.GetConfigRepoDetails(ctx, true, true, s3Override) cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, true, true, s3Override)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -188,7 +189,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
repoID = events.RepoIDNotFound repoID = events.RepoIDNotFound
} }
s3Cfg, err := cfg.Storage.S3Config() s3Cfg, err := cfg.Storage.StorageConfig()
if err != nil { if err != nil {
return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration"))
} }
@ -198,12 +199,13 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
return Only(ctx, clues.Wrap(err, "Failed to parse m365 account config")) return Only(ctx, clues.Wrap(err, "Failed to parse m365 account config"))
} }
if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { // Move these to validate()?
invalidEndpointErr := "endpoint doesn't support specifying protocol. " + // if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") {
"pass --disable-tls flag to use http:// instead of default https://" // invalidEndpointErr := "endpoint doesn't support specifying protocol. " +
// "pass --disable-tls flag to use http:// instead of default https://"
return Only(ctx, clues.New(invalidEndpointErr)) // return Only(ctx, clues.New(invalidEndpointErr))
} // }
opts := utils.ControlWithConfig(cfg) opts := utils.ControlWithConfig(cfg)
@ -219,7 +221,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
defer utils.CloseRepo(ctx, r) defer utils.CloseRepo(ctx, r)
Infof(ctx, "Connected to S3 bucket %s.", s3Cfg.Bucket) Infof(ctx, "Connected to S3 bucket %s.", cfg.Storage.Config[storage.Bucket])
if err = config.WriteRepoConfig(ctx, s3Cfg, m365, opts.Repo, r.GetID()); err != nil { if err = config.WriteRepoConfig(ctx, s3Cfg, m365, opts.Repo, r.GetID()); err != nil {
return Only(ctx, clues.Wrap(err, "Failed to write repository configuration")) return Only(ctx, clues.Wrap(err, "Failed to write repository configuration"))

View File

@ -63,9 +63,11 @@ func (suite *S3E2ESuite) TestInitS3Cmd() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
if !test.hasConfigFile { if !test.hasConfigFile {
// Ideally we could use `/dev/null`, but you need a // Ideally we could use `/dev/null`, but you need a
@ -100,9 +102,11 @@ func (suite *S3E2ESuite) TestInitMultipleTimes() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr) ctx = config.SetViper(ctx, vpr)
@ -129,9 +133,10 @@ func (suite *S3E2ESuite) TestInitS3Cmd_missingBucket() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgBucket: "", tconfig.TestCfgBucket: "",
} }
@ -182,9 +187,11 @@ func (suite *S3E2ESuite) TestConnectS3Cmd() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(), tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),
@ -230,9 +237,10 @@ func (suite *S3E2ESuite) TestConnectS3Cmd_BadBucket() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr) ctx = config.SetViper(ctx, vpr)
@ -256,9 +264,11 @@ func (suite *S3E2ESuite) TestConnectS3Cmd_BadPrefix() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr) ctx = config.SetViper(ctx, vpr)

View File

@ -66,9 +66,11 @@ func (suite *RestoreExchangeE2ESuite) SetupSuite() {
suite.acct = tconfig.NewM365Account(t) suite.acct = tconfig.NewM365Account(t)
suite.st = storeTD.NewPrefixedS3Storage(t) suite.st = storeTD.NewPrefixedS3Storage(t)
cfg, err := suite.st.S3Config() sc, err := suite.st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(), tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),

View File

@ -103,7 +103,12 @@ func runRestore(
sel selectors.Selector, sel selectors.Selector,
backupID, serviceName string, backupID, serviceName string,
) error { ) error {
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), repo.S3Overrides(cmd)) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -24,9 +24,10 @@ var ErrNotYetImplemented = clues.New("not yet implemented")
func GetAccountAndConnect( func GetAccountAndConnect(
ctx context.Context, ctx context.Context,
pst path.ServiceType, pst path.ServiceType,
provider storage.ProviderType,
overrides map[string]string, overrides map[string]string,
) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) { ) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) {
cfg, err := config.GetConfigRepoDetails(ctx, true, true, overrides) cfg, err := config.GetConfigRepoDetails(ctx, provider, true, true, overrides)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }
@ -55,17 +56,19 @@ func GetAccountAndConnect(
func AccountConnectAndWriteRepoConfig( func AccountConnectAndWriteRepoConfig(
ctx context.Context, ctx context.Context,
pst path.ServiceType, pst path.ServiceType,
provider storage.ProviderType,
overrides map[string]string, overrides map[string]string,
) (repository.Repository, *account.Account, error) { ) (repository.Repository, *account.Account, error) {
r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, overrides) r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, provider, overrides)
if err != nil { if err != nil {
logger.CtxErr(ctx, err).Info("getting and connecting account") logger.CtxErr(ctx, err).Info("getting and connecting account")
return nil, nil, err return nil, nil, err
} }
s3Config, err := stg.S3Config() storageCfg, err := stg.StorageConfig()
if err != nil { if err != nil {
logger.CtxErr(ctx, err).Info("getting storage configuration") logger.CtxErr(ctx, err).Info("getting storage configuration")
return nil, nil, err return nil, nil, err
} }
@ -77,7 +80,7 @@ func AccountConnectAndWriteRepoConfig(
// repo config gets set during repo connect and init. // repo config gets set during repo connect and init.
// This call confirms we have the correct values. // This call confirms we have the correct values.
err = config.WriteRepoConfig(ctx, s3Config, m365Config, opts.Repo, r.GetID()) err = config.WriteRepoConfig(ctx, storageCfg, m365Config, opts.Repo, r.GetID())
if err != nil { if err != nil {
logger.CtxErr(ctx, err).Info("writing to repository configuration") logger.CtxErr(ctx, err).Info("writing to repository configuration")
return nil, nil, err return nil, nil, err

View File

@ -16,6 +16,7 @@ import (
"github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/logger"
"github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/path"
"github.com/alcionai/corso/src/pkg/repository" "github.com/alcionai/corso/src/pkg/repository"
"github.com/alcionai/corso/src/pkg/storage"
"github.com/alcionai/corso/src/pkg/store" "github.com/alcionai/corso/src/pkg/store"
) )
@ -29,7 +30,12 @@ func deleteBackups(
) ([]string, error) { ) ([]string, error) {
ctx = clues.Add(ctx, "cutoff_days", deletionDays) ctx = clues.Add(ctx, "cutoff_days", deletionDays)
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, nil) provider, err := config.GetStorageProviderFromConfigFile(ctx)
if err != nil {
return nil, err
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, provider, nil)
if err != nil { if err != nil {
return nil, clues.Wrap(err, "connecting to account").WithClues(ctx) return nil, clues.Wrap(err, "connecting to account").WithClues(ctx)
} }
@ -87,7 +93,7 @@ func pitrListBackups(
// TODO(ashmrtn): This may be moved into CLI layer at some point when we add // TODO(ashmrtn): This may be moved into CLI layer at some point when we add
// flags for opening a repo at a point in time. // flags for opening a repo at a point in time.
cfg, err := config.GetConfigRepoDetails(ctx, true, true, nil) cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, true, true, nil)
if err != nil { if err != nil {
return clues.Wrap(err, "getting config info") return clues.Wrap(err, "getting config info")
} }

View File

@ -187,16 +187,18 @@ func handleCheckerCommand(cmd *cobra.Command, args []string, f flags) error {
storage.Prefix: f.bucketPrefix, storage.Prefix: f.bucketPrefix,
} }
repoDetails, err := config.GetConfigRepoDetails(ctx, false, false, overrides) repoDetails, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, false, false, overrides)
if err != nil { if err != nil {
return clues.Wrap(err, "getting storage config") return clues.Wrap(err, "getting storage config")
} }
cfg, err := repoDetails.Storage.S3Config() c, err := repoDetails.Storage.StorageConfig()
if err != nil { if err != nil {
return clues.Wrap(err, "getting S3 config") return clues.Wrap(err, "getting S3 config")
} }
cfg := c.(storage.S3Config)
endpoint := defaultS3Endpoint endpoint := defaultS3Endpoint
if len(cfg.Endpoint) > 0 { if len(cfg.Endpoint) > 0 {
endpoint = cfg.Endpoint endpoint = cfg.Endpoint

View File

@ -8,6 +8,7 @@ import (
"github.com/kopia/kopia/repo/blob/s3" "github.com/kopia/kopia/repo/blob/s3"
"github.com/alcionai/corso/src/pkg/control/repository" "github.com/alcionai/corso/src/pkg/control/repository"
"github.com/alcionai/corso/src/pkg/logger"
"github.com/alcionai/corso/src/pkg/storage" "github.com/alcionai/corso/src/pkg/storage"
) )
@ -20,29 +21,33 @@ func s3BlobStorage(
repoOpts repository.Options, repoOpts repository.Options,
s storage.Storage, s storage.Storage,
) (blob.Storage, error) { ) (blob.Storage, error) {
cfg, err := s.S3Config() cfg, err := s.StorageConfig()
if err != nil { if err != nil {
return nil, clues.Stack(err).WithClues(ctx) return nil, clues.Stack(err).WithClues(ctx)
} }
// Cast to S3Config
s3Cfg := cfg.(storage.S3Config)
endpoint := defaultS3Endpoint endpoint := defaultS3Endpoint
if len(cfg.Endpoint) > 0 { if len(s3Cfg.Endpoint) > 0 {
endpoint = cfg.Endpoint endpoint = s3Cfg.Endpoint
} }
logger.Ctx(ctx).Infow("aws creds", "key", s3Cfg.AccessKey)
opts := s3.Options{ opts := s3.Options{
BucketName: cfg.Bucket, BucketName: s3Cfg.Bucket,
Endpoint: endpoint, Endpoint: endpoint,
Prefix: cfg.Prefix, Prefix: s3Cfg.Prefix,
DoNotUseTLS: cfg.DoNotUseTLS, DoNotUseTLS: s3Cfg.DoNotUseTLS,
DoNotVerifyTLS: cfg.DoNotVerifyTLS, DoNotVerifyTLS: s3Cfg.DoNotVerifyTLS,
Tags: s.SessionTags, Tags: s.SessionTags,
SessionName: s.SessionName, SessionName: s.SessionName,
RoleARN: s.Role, RoleARN: s.Role,
RoleDuration: s.SessionDuration, RoleDuration: s.SessionDuration,
AccessKeyID: cfg.AccessKey, AccessKeyID: s3Cfg.AccessKey,
SecretAccessKey: cfg.SecretKey, SecretAccessKey: s3Cfg.SecretKey,
SessionToken: cfg.SessionToken, SessionToken: s3Cfg.SessionToken,
TLSHandshakeTimeout: 60, TLSHandshakeTimeout: 60,
PointInTime: repoOpts.ViewTimestamp, PointInTime: repoOpts.ViewTimestamp,
} }

View File

@ -4,12 +4,29 @@ import (
"strconv" "strconv"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/spf13/cast"
"github.com/alcionai/corso/src/internal/common" "github.com/alcionai/corso/src/internal/common"
"github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/credentials"
) )
const (
// S3 config
StorageProviderTypeKey = "provider"
BucketNameKey = "bucket"
EndpointKey = "endpoint"
PrefixKey = "prefix"
DisableTLSKey = "disable_tls"
DisableTLSVerificationKey = "disable_tls_verification"
RepoID = "repo_id"
AccessKey = "aws_access_key_id"
SecretAccessKey = "aws_secret_access_key"
SessionToken = "aws_session_token"
)
type S3Config struct { type S3Config struct {
credentials.AWS credentials.AWS
Bucket string // required Bucket string // required
@ -50,6 +67,121 @@ func (c S3Config) Normalize() S3Config {
} }
} }
// No need to return error here. Viper returns empty values.
func s3ConfigsFromStore(kvg KVStoreGetter) S3Config {
var s3Config S3Config
s3Config.Bucket = cast.ToString(kvg.Get(BucketNameKey))
s3Config.Endpoint = cast.ToString(kvg.Get(EndpointKey))
s3Config.Prefix = cast.ToString(kvg.Get(PrefixKey))
s3Config.DoNotUseTLS = cast.ToBool(kvg.Get(DisableTLSKey))
s3Config.DoNotVerifyTLS = cast.ToBool(kvg.Get(DisableTLSVerificationKey))
return s3Config
}
func s3CredsFromStore(
kvg KVStoreGetter,
s3Config S3Config,
) S3Config {
s3Config.AccessKey = cast.ToString(kvg.Get(AccessKey))
s3Config.SecretKey = cast.ToString(kvg.Get(SecretAccessKey))
s3Config.SessionToken = cast.ToString(kvg.Get(SessionToken))
return s3Config
}
var _ Configurer = S3Config{}
func (c S3Config) FetchConfigFromStore(
kvg KVStoreGetter,
readConfigFromStore bool,
matchFromConfig bool,
overrides map[string]string,
) (Configurer, error) {
var (
s3Cfg S3Config
err error
)
if readConfigFromStore {
s3Cfg = s3ConfigsFromStore(kvg)
if b, ok := overrides[Bucket]; ok {
overrides[Bucket] = common.NormalizeBucket(b)
}
if p, ok := overrides[Prefix]; ok {
overrides[Prefix] = common.NormalizePrefix(p)
}
if matchFromConfig {
providerType := cast.ToString(kvg.Get(StorageProviderTypeKey))
if providerType != ProviderS3.String() {
return S3Config{}, clues.New("unsupported storage provider: " + providerType)
}
// This is matching override values from config file.
if err := mustMatchConfig(kvg, s3Overrides(overrides)); err != nil {
return S3Config{}, clues.Wrap(err, "verifying s3 configs in corso config file")
}
}
}
s3Cfg = s3CredsFromStore(kvg, s3Cfg)
aws := credentials.GetAWS(overrides)
if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 {
_, err = defaults.CredChain(
defaults.Config().WithCredentialsChainVerboseErrors(true),
defaults.Handlers()).Get()
if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) {
aws = credentials.AWS{
AccessKey: s3Cfg.AccessKey,
SecretKey: s3Cfg.SecretKey,
SessionToken: s3Cfg.SessionToken,
}
err = nil
}
if err != nil {
return S3Config{}, clues.Wrap(err, "validating aws credentials")
}
}
s3Cfg = S3Config{
AWS: aws,
Bucket: str.First(overrides[Bucket], s3Cfg.Bucket),
Endpoint: str.First(overrides[Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"),
Prefix: str.First(overrides[Prefix], s3Cfg.Prefix),
DoNotUseTLS: str.ParseBool(str.First(
overrides[DoNotUseTLS],
strconv.FormatBool(s3Cfg.DoNotUseTLS),
"false")),
DoNotVerifyTLS: str.ParseBool(str.First(
overrides[DoNotVerifyTLS],
strconv.FormatBool(s3Cfg.DoNotVerifyTLS),
"false")),
}
return s3Cfg, s3Cfg.validate()
}
var _ WriteConfigToStorer = S3Config{}
func (c S3Config) WriteConfigToStore(
kvs KVStoreSetter,
) {
s3Config := c.Normalize()
kvs.Set(StorageProviderTypeKey, ProviderS3.String())
kvs.Set(BucketNameKey, s3Config.Bucket)
kvs.Set(EndpointKey, s3Config.Endpoint)
kvs.Set(PrefixKey, s3Config.Prefix)
kvs.Set(DisableTLSKey, s3Config.DoNotUseTLS)
kvs.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS)
}
// StringConfig transforms a s3Config struct into a plain // StringConfig transforms a s3Config struct into a plain
// map[string]string. All values in the original struct which // map[string]string. All values in the original struct which
// serialize into the map are expected to be strings. // serialize into the map are expected to be strings.
@ -70,19 +202,19 @@ func (c S3Config) StringConfig() (map[string]string, error) {
} }
// S3Config retrieves the S3Config details from the Storage config. // S3Config retrieves the S3Config details from the Storage config.
func (s Storage) S3Config() (S3Config, error) { func MakeS3ConfigFromMap(config map[string]string) (S3Config, error) {
c := S3Config{} c := S3Config{}
if len(s.Config) > 0 { if len(config) > 0 {
c.AccessKey = orEmptyString(s.Config[keyS3AccessKey]) c.AccessKey = orEmptyString(config[keyS3AccessKey])
c.SecretKey = orEmptyString(s.Config[keyS3SecretKey]) c.SecretKey = orEmptyString(config[keyS3SecretKey])
c.SessionToken = orEmptyString(s.Config[keyS3SessionToken]) c.SessionToken = orEmptyString(config[keyS3SessionToken])
c.Bucket = orEmptyString(s.Config[keyS3Bucket]) c.Bucket = orEmptyString(config[keyS3Bucket])
c.Endpoint = orEmptyString(s.Config[keyS3Endpoint]) c.Endpoint = orEmptyString(config[keyS3Endpoint])
c.Prefix = orEmptyString(s.Config[keyS3Prefix]) c.Prefix = orEmptyString(config[keyS3Prefix])
c.DoNotUseTLS = str.ParseBool(s.Config[keyS3DoNotUseTLS]) c.DoNotUseTLS = str.ParseBool(config[keyS3DoNotUseTLS])
c.DoNotVerifyTLS = str.ParseBool(s.Config[keyS3DoNotVerifyTLS]) c.DoNotVerifyTLS = str.ParseBool(config[keyS3DoNotVerifyTLS])
} }
return c, c.validate() return c, c.validate()
@ -100,3 +232,44 @@ func (c S3Config) validate() error {
return nil return nil
} }
var constToTomlKeyMap = map[string]string{
Bucket: BucketNameKey,
Endpoint: EndpointKey,
Prefix: PrefixKey,
StorageProviderTypeKey: StorageProviderTypeKey,
}
// mustMatchConfig compares the values of each key to their config file value in store.
// If any value differs from the store value, an error is returned.
// values in m that aren't stored in the config are ignored.
func mustMatchConfig(kvg KVStoreGetter, m map[string]string) error {
for k, v := range m {
if len(v) == 0 {
continue // empty variables will get caught by configuration validators, if necessary
}
tomlK, ok := constToTomlKeyMap[k]
if !ok {
continue // m may declare values which aren't stored in the config file
}
vv := cast.ToString(kvg.Get(tomlK))
if v != vv {
return clues.New("value of " + k + " (" + v + ") does not match corso configuration value (" + vv + ")")
}
}
return nil
}
func s3Overrides(in map[string]string) map[string]string {
return map[string]string{
Bucket: in[Bucket],
Endpoint: in[Endpoint],
Prefix: in[Prefix],
DoNotUseTLS: in[DoNotUseTLS],
DoNotVerifyTLS: in[DoNotVerifyTLS],
StorageProviderTypeKey: in[StorageProviderTypeKey],
}
}

View File

@ -66,9 +66,11 @@ func (suite *S3CfgSuite) TestStorage_S3Config() {
in := goodS3Config in := goodS3Config
s, err := NewStorage(ProviderS3, in) s, err := NewStorage(ProviderS3, in)
assert.NoError(t, err, clues.ToCore(err)) assert.NoError(t, err, clues.ToCore(err))
out, err := s.S3Config() sc, err := s.StorageConfig()
assert.NoError(t, err, clues.ToCore(err)) assert.NoError(t, err, clues.ToCore(err))
out := sc.(S3Config)
assert.Equal(t, in.Bucket, out.Bucket) assert.Equal(t, in.Bucket, out.Bucket)
assert.Equal(t, in.Endpoint, out.Endpoint) assert.Equal(t, in.Endpoint, out.Endpoint)
assert.Equal(t, in.Prefix, out.Prefix) assert.Equal(t, in.Prefix, out.Prefix)
@ -117,7 +119,7 @@ func (suite *S3CfgSuite) TestStorage_S3Config_invalidCases() {
st, err := NewStorage(ProviderUnknown, goodS3Config) st, err := NewStorage(ProviderUnknown, goodS3Config)
assert.NoError(t, err, clues.ToCore(err)) assert.NoError(t, err, clues.ToCore(err))
test.amend(st) test.amend(st)
_, err = st.S3Config() _, err = st.StorageConfig()
assert.Error(t, err) assert.Error(t, err)
}) })
} }

View File

@ -17,6 +17,11 @@ const (
ProviderFilesystem ProviderType = 2 // Filesystem ProviderFilesystem ProviderType = 2 // Filesystem
) )
var StringToProviderType = map[string]ProviderType{
ProviderUnknown.String(): ProviderUnknown,
ProviderS3.String(): ProviderS3,
}
// storage parsing errors // storage parsing errors
var ( var (
errMissingRequired = clues.New("missing required storage configuration") errMissingRequired = clues.New("missing required storage configuration")
@ -82,3 +87,50 @@ func orEmptyString(v any) string {
return v.(string) return v.(string)
} }
func (s Storage) StorageConfig() (Configurer, error) {
switch s.Provider {
case ProviderS3:
return MakeS3ConfigFromMap(s.Config)
}
return nil, clues.New("unsupported storage provider: " + s.Provider.String())
}
func NewStorageConfig(provider ProviderType) (Configurer, error) {
switch provider {
case ProviderS3:
return S3Config{}, nil
}
return nil, clues.New("unsupported storage provider: " + provider.String())
}
type KVStoreGetter interface {
Get(key string) any
}
type KVStoreSetter interface {
Set(key string, value any)
}
type WriteConfigToStorer interface {
WriteConfigToStore(
kvs KVStoreSetter,
)
}
type FetchConfigFromStorer interface {
FetchConfigFromStore(
kvg KVStoreGetter,
readConfigFromStore bool,
matchFromConfig bool,
overrides map[string]string,
) (Configurer, error)
}
type Configurer interface {
common.StringConfigurer
FetchConfigFromStorer
WriteConfigToStorer
}