More changes before piecemealing

This commit is contained in:
Abhishek Pandey 2023-09-13 17:49:50 +05:30
parent 8060227bf9
commit 9705c9d212
23 changed files with 164 additions and 129 deletions

View File

@ -10,7 +10,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -291,16 +290,12 @@ func genericDeleteCommand(
ctx := clues.Add(cmd.Context(), "delete_backup_id", bID) ctx := clues.Add(cmd.Context(), "delete_backup_id", bID)
// Let it return both provider and overrides for now? storageProvider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
// That way we can stop config pkg from being included everywhere.
provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, provider, overrides) r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, storageProvider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -326,9 +321,7 @@ func genericListCommand(
) error { ) error {
ctx := cmd.Context() ctx := cmd.Context()
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -8,7 +8,6 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -169,8 +168,7 @@ func createExchangeCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -284,8 +282,7 @@ func detailsExchangeCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeExchangeOpts(cmd) opts := utils.MakeExchangeOpts(cmd)
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -10,7 +10,6 @@ import (
"github.com/spf13/pflag" "github.com/spf13/pflag"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -155,8 +154,7 @@ func createGroupsCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -233,8 +231,7 @@ func detailsGroupsCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeGroupsOpts(cmd) opts := utils.MakeGroupsOpts(cmd)
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -140,9 +140,11 @@ func prepM365Test(
recorder = strings.Builder{} recorder = strings.Builder{}
) )
cfg, err := st.S3Config() c, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := c.(storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(), tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),

View File

@ -8,7 +8,6 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -150,8 +149,7 @@ func createOneDriveCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -242,8 +240,7 @@ func detailsOneDriveCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeOneDriveOpts(cmd) opts := utils.MakeOneDriveOpts(cmd)
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -9,7 +9,6 @@ import (
"github.com/spf13/pflag" "github.com/spf13/pflag"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -160,9 +159,7 @@ func createSharePointCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -327,8 +324,7 @@ func detailsSharePointCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeSharePointOpts(cmd) opts := utils.MakeSharePointOpts(cmd)
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -70,8 +70,7 @@ func preRun(cc *cobra.Command, args []string) error {
} }
if !slices.Contains(avoidTheseDescription, cc.Short) { if !slices.Contains(avoidTheseDescription, cc.Short) {
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cc)
overrides, err := repo.GetStorageOverrides(ctx, cc, provider)
if err != nil { if err != nil {
return err return err
} }

View File

@ -255,7 +255,7 @@ func writeRepoConfigWithViper(
// data sources (config file, env vars, flag overrides) and the config file. // data sources (config file, env vars, flag overrides) and the config file.
func GetConfigRepoDetails( func GetConfigRepoDetails(
ctx context.Context, ctx context.Context,
provider string, provider storage.ProviderType,
readFromFile bool, readFromFile bool,
mustMatchFromConfig bool, mustMatchFromConfig bool,
overrides map[string]string, overrides map[string]string,
@ -271,7 +271,7 @@ func GetConfigRepoDetails(
// struct for testing. // struct for testing.
func getStorageAndAccountWithViper( func getStorageAndAccountWithViper(
vpr *viper.Viper, vpr *viper.Viper,
provider string, provider storage.ProviderType,
readFromFile bool, readFromFile bool,
mustMatchFromConfig bool, mustMatchFromConfig bool,
overrides map[string]string, overrides map[string]string,
@ -376,13 +376,13 @@ func requireProps(props map[string]string) error {
// Storage provider is not a flag. It can only be sourced from config file. // Storage provider is not a flag. It can only be sourced from config file.
// Only exceptions are the commands that create a new repo. // Only exceptions are the commands that create a new repo.
// This is needed to figure out which storage overrides to use. // This is needed to figure out which storage overrides to use.
func GetStorageProviderFromConfigFile(ctx context.Context) (string, error) { func GetStorageProviderFromConfigFile(ctx context.Context) (storage.ProviderType, error) {
vpr := GetViper(ctx) vpr := GetViper(ctx)
provider := vpr.GetString(StorageProviderTypeKey) provider := vpr.GetString(StorageProviderTypeKey)
if provider != storage.ProviderS3.String() { if provider != storage.ProviderS3.String() {
return storage.ProviderUnknown.String(), clues.New("unsupported storage provider: " + provider) return storage.ProviderUnknown, clues.New("unsupported storage provider: " + provider)
} }
return provider, nil return storage.StringToProviderType[provider], nil
} }

View File

@ -107,14 +107,23 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() {
err = vpr.ReadInConfig() err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err)) require.NoError(t, err, "reading repo config", clues.ToCore(err))
s3Cfg, err := s3ConfigsFromViper(vpr) // Unset AWS env vars so that we can test reading creds from config file
os.Unsetenv(credentials.AWSAccessKeyID)
os.Unsetenv(credentials.AWSSecretAccessKey)
os.Unsetenv(credentials.AWSSessionToken)
sc, err := storage.S3Config{}.FetchConfigFromStore(vpr, true, true, nil)
require.NoError(t, err, clues.ToCore(err))
s3Cfg := sc.(storage.S3Config)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, b, s3Cfg.Bucket) assert.Equal(t, b, s3Cfg.Bucket)
assert.Equal(t, "test-prefix/", s3Cfg.Prefix) assert.Equal(t, "test-prefix/", s3Cfg.Prefix)
assert.Equal(t, disableTLS, strconv.FormatBool(s3Cfg.DoNotUseTLS)) assert.Equal(t, disableTLS, strconv.FormatBool(s3Cfg.DoNotUseTLS))
assert.Equal(t, disableTLSVerification, strconv.FormatBool(s3Cfg.DoNotVerifyTLS)) assert.Equal(t, disableTLSVerification, strconv.FormatBool(s3Cfg.DoNotVerifyTLS))
s3Cfg, err = s3CredsFromViper(vpr, s3Cfg) // s3Cfg, err = s3CredsFromViper(vpr, s3Cfg)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, accKey, s3Cfg.AWS.AccessKey) assert.Equal(t, accKey, s3Cfg.AWS.AccessKey)
assert.Equal(t, secret, s3Cfg.AWS.SecretKey) assert.Equal(t, secret, s3Cfg.AWS.SecretKey)
@ -160,7 +169,11 @@ func (suite *ConfigSuite) TestWriteReadConfig() {
err = vpr.ReadInConfig() err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err)) require.NoError(t, err, "reading repo config", clues.ToCore(err))
readS3Cfg, err := s3ConfigsFromViper(vpr) sc, err := storage.S3Config{}.FetchConfigFromStore(vpr, true, true, nil)
require.NoError(t, err, clues.ToCore(err))
readS3Cfg := sc.(storage.S3Config)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket)
assert.Equal(t, readS3Cfg.DoNotUseTLS, s3Cfg.DoNotUseTLS) assert.Equal(t, readS3Cfg.DoNotUseTLS, s3Cfg.DoNotUseTLS)
@ -326,12 +339,14 @@ func (suite *ConfigSuite) TestReadFromFlags() {
repoDetails, err := getStorageAndAccountWithViper( repoDetails, err := getStorageAndAccountWithViper(
vpr, vpr,
storage.ProviderS3,
true, true,
false, false,
overrides) overrides)
m365Config, _ := repoDetails.Account.M365Config() m365Config, _ := repoDetails.Account.M365Config()
s3Cfg, _ := repoDetails.Storage.S3Config() cfg, _ := repoDetails.Storage.StorageConfig()
s3Cfg, _ := cfg.(storage.S3Config)
commonConfig, _ := repoDetails.Storage.CommonConfig() commonConfig, _ := repoDetails.Storage.CommonConfig()
pass := commonConfig.Corso.CorsoPassphrase pass := commonConfig.Corso.CorsoPassphrase
@ -400,11 +415,13 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() {
err = vpr.ReadInConfig() err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err)) require.NoError(t, err, "reading repo config", clues.ToCore(err))
cfg, err := getStorageAndAccountWithViper(vpr, true, true, nil) cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, true, true, nil)
require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err))
readS3Cfg, err := cfg.Storage.S3Config() sc, err := cfg.Storage.StorageConfig()
require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err)) require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err))
readS3Cfg := sc.(storage.S3Config)
assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket) assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket)
assert.Equal(t, readS3Cfg.Endpoint, s3Cfg.Endpoint) assert.Equal(t, readS3Cfg.Endpoint, s3Cfg.Endpoint)
assert.Equal(t, readS3Cfg.Prefix, s3Cfg.Prefix) assert.Equal(t, readS3Cfg.Prefix, s3Cfg.Prefix)
@ -448,11 +465,13 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount_noFileOnlyOverride
StorageProviderTypeKey: storage.ProviderS3.String(), StorageProviderTypeKey: storage.ProviderS3.String(),
} }
cfg, err := getStorageAndAccountWithViper(vpr, false, true, overrides) cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, false, true, overrides)
require.NoError(t, err, "getting storage and account from config", clues.ToCore(err)) require.NoError(t, err, "getting storage and account from config", clues.ToCore(err))
readS3Cfg, err := cfg.Storage.S3Config() sc, err := cfg.Storage.StorageConfig()
require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err)) require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err))
readS3Cfg := sc.(storage.S3Config)
assert.Equal(t, readS3Cfg.Bucket, bkt) assert.Equal(t, readS3Cfg.Bucket, bkt)
assert.Equal(t, cfg.RepoID, "") assert.Equal(t, cfg.RepoID, "")
assert.Equal(t, readS3Cfg.Endpoint, end) assert.Equal(t, readS3Cfg.Endpoint, end)

View File

@ -17,7 +17,7 @@ import (
// viper properties and manual overrides. // viper properties and manual overrides.
func configureStorage( func configureStorage(
vpr *viper.Viper, vpr *viper.Viper,
provider string, provider storage.ProviderType,
readConfigFromViper bool, readConfigFromViper bool,
matchFromConfig bool, matchFromConfig bool,
overrides map[string]string, overrides map[string]string,
@ -29,7 +29,8 @@ func configureStorage(
storageCfg, _ := storage.NewStorageConfig(provider) storageCfg, _ := storage.NewStorageConfig(provider)
err = storageCfg.FetchConfigFromStore( // Rename this. It's not just fetch config from store.
storageCfg, err = storageCfg.FetchConfigFromStore(
vpr, vpr,
readConfigFromViper, readConfigFromViper,
matchFromConfig, matchFromConfig,
@ -63,8 +64,7 @@ func configureStorage(
} }
// build the storage // build the storage
store, err = storage.NewStorage( store, err = storage.NewStorage(provider, storageCfg, cCfg)
storage.StringToEnum(provider), storageCfg, cCfg)
if err != nil { if err != nil {
return store, clues.Wrap(err, "configuring repository storage") return store, clues.Wrap(err, "configuring repository storage")
} }

View File

@ -8,7 +8,6 @@ import (
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/alcionai/corso/src/cli/config"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
"github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/cli/utils"
@ -71,8 +70,7 @@ func runExport(
sel selectors.Selector, sel selectors.Selector,
backupID, serviceName string, backupID, serviceName string,
) error { ) error {
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
"github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/cli/utils"
@ -125,7 +126,7 @@ func handleMaintenanceCmd(cmd *cobra.Command, args []string) error {
// Change this to override too? // Change this to override too?
r, _, err := utils.AccountConnectAndWriteRepoConfig( r, _, err := utils.AccountConnectAndWriteRepoConfig(
ctx, path.UnknownService, storage.ProviderS3.String(), S3Overrides(cmd)) ctx, path.UnknownService, storage.ProviderS3, S3Overrides(cmd))
if err != nil { if err != nil {
return print.Only(ctx, err) return print.Only(ctx, err)
} }
@ -183,3 +184,22 @@ func GetStorageOverrides(
return overrides, nil return overrides, nil
} }
func GetStorageProviderAndOverrides(
ctx context.Context,
cmd *cobra.Command,
) (storage.ProviderType, map[string]string, error) {
provider, err := config.GetStorageProviderFromConfigFile(ctx)
if err != nil {
return provider, nil, clues.Stack(err)
}
overrides := map[string]string{}
switch provider {
case storage.ProviderS3:
overrides = S3Overrides(cmd)
}
return provider, overrides, nil
}

View File

@ -91,7 +91,7 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
// s3 values from flags // s3 values from flags
s3Override := S3Overrides(cmd) s3Override := S3Overrides(cmd)
cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3.String(), true, false, s3Override) cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, true, false, s3Override)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -112,11 +112,10 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
cfg.Account.ID(), cfg.Account.ID(),
opt) opt)
//s3Cfg, err := cfg.Storage.S3Config() // why not let it return configurer? storageCfg, err := cfg.Storage.StorageConfig()
storageCfg, err := cfg.Storage.GetStorageConfig() if err != nil {
// if err != nil { return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration"))
// return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) }
// }
// BUG: This should be moved to validate() // BUG: This should be moved to validate()
// if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { // if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") {
@ -180,7 +179,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
// s3 values from flags // s3 values from flags
s3Override := S3Overrides(cmd) s3Override := S3Overrides(cmd)
cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3.String(), true, true, s3Override) cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, true, true, s3Override)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -190,7 +189,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
repoID = events.RepoIDNotFound repoID = events.RepoIDNotFound
} }
s3Cfg, err := cfg.Storage.GetStorageConfig() s3Cfg, err := cfg.Storage.StorageConfig()
if err != nil { if err != nil {
return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration"))
} }

View File

@ -63,9 +63,11 @@ func (suite *S3E2ESuite) TestInitS3Cmd() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
if !test.hasConfigFile { if !test.hasConfigFile {
// Ideally we could use `/dev/null`, but you need a // Ideally we could use `/dev/null`, but you need a
@ -100,9 +102,11 @@ func (suite *S3E2ESuite) TestInitMultipleTimes() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr) ctx = config.SetViper(ctx, vpr)
@ -129,9 +133,10 @@ func (suite *S3E2ESuite) TestInitS3Cmd_missingBucket() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgBucket: "", tconfig.TestCfgBucket: "",
} }
@ -182,9 +187,11 @@ func (suite *S3E2ESuite) TestConnectS3Cmd() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(), tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),
@ -230,9 +237,10 @@ func (suite *S3E2ESuite) TestConnectS3Cmd_BadBucket() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr) ctx = config.SetViper(ctx, vpr)
@ -256,9 +264,11 @@ func (suite *S3E2ESuite) TestConnectS3Cmd_BadPrefix() {
defer flush() defer flush()
st := storeTD.NewPrefixedS3Storage(t) st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config() sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil) vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr) ctx = config.SetViper(ctx, vpr)

View File

@ -66,9 +66,11 @@ func (suite *RestoreExchangeE2ESuite) SetupSuite() {
suite.acct = tconfig.NewM365Account(t) suite.acct = tconfig.NewM365Account(t)
suite.st = storeTD.NewPrefixedS3Storage(t) suite.st = storeTD.NewPrefixedS3Storage(t)
cfg, err := suite.st.S3Config() sc, err := suite.st.StorageConfig()
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
force := map[string]string{ force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(), tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(), tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),

View File

@ -8,7 +8,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -104,8 +103,7 @@ func runRestore(
sel selectors.Selector, sel selectors.Selector,
backupID, serviceName string, backupID, serviceName string,
) error { ) error {
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -24,7 +24,7 @@ var ErrNotYetImplemented = clues.New("not yet implemented")
func GetAccountAndConnect( func GetAccountAndConnect(
ctx context.Context, ctx context.Context,
pst path.ServiceType, pst path.ServiceType,
provider string, provider storage.ProviderType,
overrides map[string]string, overrides map[string]string,
) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) { ) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) {
cfg, err := config.GetConfigRepoDetails(ctx, provider, true, true, overrides) cfg, err := config.GetConfigRepoDetails(ctx, provider, true, true, overrides)
@ -56,7 +56,7 @@ func GetAccountAndConnect(
func AccountConnectAndWriteRepoConfig( func AccountConnectAndWriteRepoConfig(
ctx context.Context, ctx context.Context,
pst path.ServiceType, pst path.ServiceType,
provider string, provider storage.ProviderType,
overrides map[string]string, overrides map[string]string,
) (repository.Repository, *account.Account, error) { ) (repository.Repository, *account.Account, error) {
r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, provider, overrides) r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, provider, overrides)
@ -65,7 +65,7 @@ func AccountConnectAndWriteRepoConfig(
return nil, nil, err return nil, nil, err
} }
storageCfg, err := stg.GetStorageConfig() storageCfg, err := stg.StorageConfig()
if err != nil { if err != nil {
logger.CtxErr(ctx, err).Info("getting storage configuration") logger.CtxErr(ctx, err).Info("getting storage configuration")

View File

@ -16,6 +16,7 @@ import (
"github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/logger"
"github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/path"
"github.com/alcionai/corso/src/pkg/repository" "github.com/alcionai/corso/src/pkg/repository"
"github.com/alcionai/corso/src/pkg/storage"
"github.com/alcionai/corso/src/pkg/store" "github.com/alcionai/corso/src/pkg/store"
) )
@ -29,7 +30,11 @@ func deleteBackups(
) ([]string, error) { ) ([]string, error) {
ctx = clues.Add(ctx, "cutoff_days", deletionDays) ctx = clues.Add(ctx, "cutoff_days", deletionDays)
provider, _ := config.GetStorageProviderFromConfigFile(ctx) provider, err := config.GetStorageProviderFromConfigFile(ctx)
if err != nil {
return nil, err
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, provider, nil) r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, provider, nil)
if err != nil { if err != nil {
return nil, clues.Wrap(err, "connecting to account").WithClues(ctx) return nil, clues.Wrap(err, "connecting to account").WithClues(ctx)
@ -88,7 +93,7 @@ func pitrListBackups(
// TODO(ashmrtn): This may be moved into CLI layer at some point when we add // TODO(ashmrtn): This may be moved into CLI layer at some point when we add
// flags for opening a repo at a point in time. // flags for opening a repo at a point in time.
cfg, err := config.GetConfigRepoDetails(ctx, true, true, nil) cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, true, true, nil)
if err != nil { if err != nil {
return clues.Wrap(err, "getting config info") return clues.Wrap(err, "getting config info")
} }

View File

@ -187,16 +187,18 @@ func handleCheckerCommand(cmd *cobra.Command, args []string, f flags) error {
storage.Prefix: f.bucketPrefix, storage.Prefix: f.bucketPrefix,
} }
repoDetails, err := config.GetConfigRepoDetails(ctx, false, false, overrides) repoDetails, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, false, false, overrides)
if err != nil { if err != nil {
return clues.Wrap(err, "getting storage config") return clues.Wrap(err, "getting storage config")
} }
cfg, err := repoDetails.Storage.S3Config() c, err := repoDetails.Storage.StorageConfig()
if err != nil { if err != nil {
return clues.Wrap(err, "getting S3 config") return clues.Wrap(err, "getting S3 config")
} }
cfg := c.(storage.S3Config)
endpoint := defaultS3Endpoint endpoint := defaultS3Endpoint
if len(cfg.Endpoint) > 0 { if len(cfg.Endpoint) > 0 {
endpoint = cfg.Endpoint endpoint = cfg.Endpoint

View File

@ -8,6 +8,7 @@ import (
"github.com/kopia/kopia/repo/blob/s3" "github.com/kopia/kopia/repo/blob/s3"
"github.com/alcionai/corso/src/pkg/control/repository" "github.com/alcionai/corso/src/pkg/control/repository"
"github.com/alcionai/corso/src/pkg/logger"
"github.com/alcionai/corso/src/pkg/storage" "github.com/alcionai/corso/src/pkg/storage"
) )
@ -20,7 +21,7 @@ func s3BlobStorage(
repoOpts repository.Options, repoOpts repository.Options,
s storage.Storage, s storage.Storage,
) (blob.Storage, error) { ) (blob.Storage, error) {
cfg, err := s.GetStorageConfig() cfg, err := s.StorageConfig()
if err != nil { if err != nil {
return nil, clues.Stack(err).WithClues(ctx) return nil, clues.Stack(err).WithClues(ctx)
} }
@ -33,6 +34,7 @@ func s3BlobStorage(
endpoint = s3Cfg.Endpoint endpoint = s3Cfg.Endpoint
} }
logger.Ctx(ctx).Infow("aws creds", "key", s3Cfg.AccessKey)
opts := s3.Options{ opts := s3.Options{
BucketName: s3Cfg.Bucket, BucketName: s3Cfg.Bucket,
Endpoint: endpoint, Endpoint: endpoint,

View File

@ -68,44 +68,45 @@ func (c S3Config) Normalize() S3Config {
} }
// No need to return error here. Viper returns empty values. // No need to return error here. Viper returns empty values.
func s3ConfigsFromStore(kvs KVStorer) S3Config { func s3ConfigsFromStore(kvg KVStoreGetter) S3Config {
var s3Config S3Config var s3Config S3Config
s3Config.Bucket = cast.ToString(kvs.Get(BucketNameKey)) s3Config.Bucket = cast.ToString(kvg.Get(BucketNameKey))
s3Config.Endpoint = cast.ToString(kvs.Get(EndpointKey)) s3Config.Endpoint = cast.ToString(kvg.Get(EndpointKey))
s3Config.Prefix = cast.ToString(kvs.Get(PrefixKey)) s3Config.Prefix = cast.ToString(kvg.Get(PrefixKey))
s3Config.DoNotUseTLS = cast.ToBool(kvs.Get(DisableTLSKey)) s3Config.DoNotUseTLS = cast.ToBool(kvg.Get(DisableTLSKey))
s3Config.DoNotVerifyTLS = cast.ToBool(kvs.Get(DisableTLSVerificationKey)) s3Config.DoNotVerifyTLS = cast.ToBool(kvg.Get(DisableTLSVerificationKey))
return s3Config return s3Config
} }
func s3CredsFromStore( func s3CredsFromStore(
kvs KVStorer, kvg KVStoreGetter,
s3Config S3Config, s3Config S3Config,
) S3Config { ) S3Config {
s3Config.AccessKey = cast.ToString(kvs.Get(AccessKey)) s3Config.AccessKey = cast.ToString(kvg.Get(AccessKey))
s3Config.SecretKey = cast.ToString(kvs.Get(SecretAccessKey)) s3Config.SecretKey = cast.ToString(kvg.Get(SecretAccessKey))
s3Config.SessionToken = cast.ToString(kvs.Get(SessionToken)) s3Config.SessionToken = cast.ToString(kvg.Get(SessionToken))
return s3Config return s3Config
} }
var _ StorageConfigurer = S3Config{} var _ Configurer = S3Config{}
func (c S3Config) FetchConfigFromStore( func (c S3Config) FetchConfigFromStore(
kvs KVStorer, kvg KVStoreGetter,
readConfigFromStore bool, readConfigFromStore bool,
matchFromConfig bool, matchFromConfig bool,
overrides map[string]string, overrides map[string]string,
) error { ) (Configurer, error) {
var ( var (
s3Cfg S3Config s3Cfg S3Config
err error err error
) )
if readConfigFromStore { if readConfigFromStore {
s3Cfg = s3ConfigsFromStore(kvs) s3Cfg = s3ConfigsFromStore(kvg)
if b, ok := overrides[Bucket]; ok { if b, ok := overrides[Bucket]; ok {
overrides[Bucket] = common.NormalizeBucket(b) overrides[Bucket] = common.NormalizeBucket(b)
} }
@ -115,19 +116,19 @@ func (c S3Config) FetchConfigFromStore(
} }
if matchFromConfig { if matchFromConfig {
providerType := cast.ToString(kvs.Get(StorageProviderTypeKey)) providerType := cast.ToString(kvg.Get(StorageProviderTypeKey))
if providerType != ProviderS3.String() { if providerType != ProviderS3.String() {
return clues.New("unsupported storage provider: " + providerType) return S3Config{}, clues.New("unsupported storage provider: " + providerType)
} }
// This is matching override values from config file. // This is matching override values from config file.
if err := mustMatchConfig(kvs, s3Overrides(overrides)); err != nil { if err := mustMatchConfig(kvg, s3Overrides(overrides)); err != nil {
return clues.Wrap(err, "verifying s3 configs in corso config file") return S3Config{}, clues.Wrap(err, "verifying s3 configs in corso config file")
} }
} }
} }
s3Cfg = s3CredsFromStore(kvs, s3Cfg) s3Cfg = s3CredsFromStore(kvg, s3Cfg)
aws := credentials.GetAWS(overrides) aws := credentials.GetAWS(overrides)
if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 { if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 {
@ -144,7 +145,7 @@ func (c S3Config) FetchConfigFromStore(
} }
if err != nil { if err != nil {
return clues.Wrap(err, "validating aws credentials") return S3Config{}, clues.Wrap(err, "validating aws credentials")
} }
} }
@ -163,7 +164,7 @@ func (c S3Config) FetchConfigFromStore(
"false")), "false")),
} }
return nil return s3Cfg, s3Cfg.validate()
} }
var _ WriteConfigToStorer = S3Config{} var _ WriteConfigToStorer = S3Config{}
@ -242,7 +243,7 @@ var constToTomlKeyMap = map[string]string{
// mustMatchConfig compares the values of each key to their config file value in store. // mustMatchConfig compares the values of each key to their config file value in store.
// If any value differs from the store value, an error is returned. // If any value differs from the store value, an error is returned.
// values in m that aren't stored in the config are ignored. // values in m that aren't stored in the config are ignored.
func mustMatchConfig(kvs KVStorer, m map[string]string) error { func mustMatchConfig(kvg KVStoreGetter, m map[string]string) error {
for k, v := range m { for k, v := range m {
if len(v) == 0 { if len(v) == 0 {
continue // empty variables will get caught by configuration validators, if necessary continue // empty variables will get caught by configuration validators, if necessary
@ -253,7 +254,7 @@ func mustMatchConfig(kvs KVStorer, m map[string]string) error {
continue // m may declare values which aren't stored in the config file continue // m may declare values which aren't stored in the config file
} }
vv := cast.ToString(kvs.Get(tomlK)) vv := cast.ToString(kvg.Get(tomlK))
if v != vv { if v != vv {
return clues.New("value of " + k + " (" + v + ") does not match corso configuration value (" + vv + ")") return clues.New("value of " + k + " (" + v + ") does not match corso configuration value (" + vv + ")")
} }

View File

@ -66,9 +66,11 @@ func (suite *S3CfgSuite) TestStorage_S3Config() {
in := goodS3Config in := goodS3Config
s, err := NewStorage(ProviderS3, in) s, err := NewStorage(ProviderS3, in)
assert.NoError(t, err, clues.ToCore(err)) assert.NoError(t, err, clues.ToCore(err))
out, err := s.S3Config() sc, err := s.StorageConfig()
assert.NoError(t, err, clues.ToCore(err)) assert.NoError(t, err, clues.ToCore(err))
out := sc.(S3Config)
assert.Equal(t, in.Bucket, out.Bucket) assert.Equal(t, in.Bucket, out.Bucket)
assert.Equal(t, in.Endpoint, out.Endpoint) assert.Equal(t, in.Endpoint, out.Endpoint)
assert.Equal(t, in.Prefix, out.Prefix) assert.Equal(t, in.Prefix, out.Prefix)
@ -117,7 +119,7 @@ func (suite *S3CfgSuite) TestStorage_S3Config_invalidCases() {
st, err := NewStorage(ProviderUnknown, goodS3Config) st, err := NewStorage(ProviderUnknown, goodS3Config)
assert.NoError(t, err, clues.ToCore(err)) assert.NoError(t, err, clues.ToCore(err))
test.amend(st) test.amend(st)
_, err = st.S3Config() _, err = st.StorageConfig()
assert.Error(t, err) assert.Error(t, err)
}) })
} }

View File

@ -17,13 +17,9 @@ const (
ProviderFilesystem ProviderType = 2 // Filesystem ProviderFilesystem ProviderType = 2 // Filesystem
) )
func StringToEnum(s string) StorageProvider { var StringToProviderType = map[string]ProviderType{
switch s { ProviderUnknown.String(): ProviderUnknown,
case ProviderS3.String(): ProviderS3.String(): ProviderS3,
return ProviderS3
}
return ProviderUnknown
} }
// storage parsing errors // storage parsing errors
@ -34,6 +30,7 @@ var (
// Storage defines a storage provider, along with any configuration // Storage defines a storage provider, along with any configuration
// required to set up or communicate with that provider. // required to set up or communicate with that provider.
type Storage struct { type Storage struct {
Provider ProviderType
Provider ProviderType Provider ProviderType
Config map[string]string Config map[string]string
// TODO: These are AWS S3 specific -> move these out // TODO: These are AWS S3 specific -> move these out
@ -44,6 +41,7 @@ type Storage struct {
} }
// NewStorage aggregates all the supplied configurations into a single configuration. // NewStorage aggregates all the supplied configurations into a single configuration.
func NewStorage(p ProviderType, cfgs ...common.StringConfigurer) (Storage, error) {
func NewStorage(p ProviderType, cfgs ...common.StringConfigurer) (Storage, error) { func NewStorage(p ProviderType, cfgs ...common.StringConfigurer) (Storage, error) {
cs, err := common.UnionStringConfigs(cfgs...) cs, err := common.UnionStringConfigs(cfgs...)
@ -56,6 +54,7 @@ func NewStorage(p ProviderType, cfgs ...common.StringConfigurer) (Storage, error
// NewStorageUsingRole supports specifying an AWS IAM role the storage provider // NewStorageUsingRole supports specifying an AWS IAM role the storage provider
// should assume. // should assume.
func NewStorageUsingRole( func NewStorageUsingRole(
p ProviderType,
p ProviderType, p ProviderType,
roleARN string, roleARN string,
sessionName string, sessionName string,
@ -92,7 +91,7 @@ func orEmptyString(v any) string {
return v.(string) return v.(string)
} }
func (s Storage) GetStorageConfig() (StorageConfigurer, error) { func (s Storage) StorageConfig() (Configurer, error) {
switch s.Provider { switch s.Provider {
case ProviderS3: case ProviderS3:
return MakeS3ConfigFromMap(s.Config) return MakeS3ConfigFromMap(s.Config)
@ -101,32 +100,23 @@ func (s Storage) GetStorageConfig() (StorageConfigurer, error) {
return nil, clues.New("unsupported storage provider: " + s.Provider.String()) return nil, clues.New("unsupported storage provider: " + s.Provider.String())
} }
func NewStorageConfig(provider string) (StorageConfigurer, error) { func NewStorageConfig(provider ProviderType) (Configurer, error) {
switch provider { switch provider {
case ProviderS3.String(): case ProviderS3:
return S3Config{}, nil return S3Config{}, nil
} }
return nil, clues.New("unsupported storage provider: " + provider) return nil, clues.New("unsupported storage provider: " + provider.String())
} }
// Change it to just getter type KVStoreGetter interface {
type KVStorer interface {
Get(key string) any Get(key string) any
Set(key string, value any)
} }
type KVStoreSetter interface { type KVStoreSetter interface {
Set(key string, value any) Set(key string, value any)
} }
// Call it configurer if necessary.
type StorageConfigurer interface {
common.StringConfigurer
FetchConfigFromStorer
WriteConfigToStorer
}
type WriteConfigToStorer interface { type WriteConfigToStorer interface {
WriteConfigToStore( WriteConfigToStore(
kvs KVStoreSetter, kvs KVStoreSetter,
@ -135,9 +125,15 @@ type WriteConfigToStorer interface {
type FetchConfigFromStorer interface { type FetchConfigFromStorer interface {
FetchConfigFromStore( FetchConfigFromStore(
kv KVStorer, kvg KVStoreGetter,
readConfigFromStore bool, readConfigFromStore bool,
matchFromConfig bool, matchFromConfig bool,
overrides map[string]string, overrides map[string]string,
) error ) (Configurer, error)
}
type Configurer interface {
common.StringConfigurer
FetchConfigFromStorer
WriteConfigToStorer
} }