Add new prototype for storage config refactor

This commit is contained in:
Abhishek Pandey 2023-09-13 00:58:43 +05:30
parent de062cd5de
commit 8060227bf9
17 changed files with 442 additions and 174 deletions

View File

@ -10,6 +10,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -290,7 +291,16 @@ func genericDeleteCommand(
ctx := clues.Add(cmd.Context(), "delete_backup_id", bID) ctx := clues.Add(cmd.Context(), "delete_backup_id", bID)
r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, repo.S3Overrides(cmd)) // Let it return both provider and overrides for now?
// That way we can stop config pkg from being included everywhere.
provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -316,7 +326,14 @@ func genericListCommand(
) error { ) error {
ctx := cmd.Context() ctx := cmd.Context()
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -168,7 +169,13 @@ func createExchangeCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.ExchangeService, repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.ExchangeService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -277,7 +284,13 @@ func detailsExchangeCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeExchangeOpts(cmd) opts := utils.MakeExchangeOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.ExchangeService, repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.ExchangeService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/spf13/pflag" "github.com/spf13/pflag"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -154,7 +155,13 @@ func createGroupsCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.GroupsService, repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.GroupsService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -226,7 +233,13 @@ func detailsGroupsCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeGroupsOpts(cmd) opts := utils.MakeGroupsOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.GroupsService, repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.GroupsService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -149,7 +150,13 @@ func createOneDriveCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.OneDriveService, repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.OneDriveService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -235,7 +242,13 @@ func detailsOneDriveCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeOneDriveOpts(cmd) opts := utils.MakeOneDriveOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.OneDriveService, repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.OneDriveService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/spf13/pflag" "github.com/spf13/pflag"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -159,7 +160,14 @@ func createSharePointCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.SharePointService, repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.SharePointService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -319,7 +327,13 @@ func detailsSharePointCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
opts := utils.MakeSharePointOpts(cmd) opts := utils.MakeSharePointOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.SharePointService, repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.SharePointService, provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -70,9 +70,13 @@ func preRun(cc *cobra.Command, args []string) error {
} }
if !slices.Contains(avoidTheseDescription, cc.Short) { if !slices.Contains(avoidTheseDescription, cc.Short) {
overrides := repo.S3Overrides(cc) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cc, provider)
if err != nil {
return err
}
cfg, err := config.GetConfigRepoDetails(ctx, true, false, overrides) cfg, err := config.GetConfigRepoDetails(ctx, provider, true, false, overrides)
if err != nil { if err != nil {
log.Error("Error while getting config info to run command: ", cc.Use) log.Error("Error while getting config info to run command: ", cc.Use)
return err return err

View File

@ -203,14 +203,14 @@ func Read(ctx context.Context) error {
// It does not check for conflicts or existing data. // It does not check for conflicts or existing data.
func WriteRepoConfig( func WriteRepoConfig(
ctx context.Context, ctx context.Context,
s3Config storage.S3Config, scfg storage.WriteConfigToStorer,
m365Config account.M365Config, m365Config account.M365Config,
repoOpts repository.Options, repoOpts repository.Options,
repoID string, repoID string,
) error { ) error {
return writeRepoConfigWithViper( return writeRepoConfigWithViper(
GetViper(ctx), GetViper(ctx),
s3Config, scfg,
m365Config, m365Config,
repoOpts, repoOpts,
repoID) repoID)
@ -220,20 +220,12 @@ func WriteRepoConfig(
// struct for testing. // struct for testing.
func writeRepoConfigWithViper( func writeRepoConfigWithViper(
vpr *viper.Viper, vpr *viper.Viper,
s3Config storage.S3Config, scfg storage.WriteConfigToStorer,
m365Config account.M365Config, m365Config account.M365Config,
repoOpts repository.Options, repoOpts repository.Options,
repoID string, repoID string,
) error { ) error {
s3Config = s3Config.Normalize() scfg.WriteConfigToStore(vpr)
// Rudimentary support for persisting repo config
// TODO: Handle conflicts, support other config types
vpr.Set(StorageProviderTypeKey, storage.ProviderS3.String())
vpr.Set(BucketNameKey, s3Config.Bucket)
vpr.Set(EndpointKey, s3Config.Endpoint)
vpr.Set(PrefixKey, s3Config.Prefix)
vpr.Set(DisableTLSKey, s3Config.DoNotUseTLS)
vpr.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS)
vpr.Set(RepoID, repoID) vpr.Set(RepoID, repoID)
// Need if-checks as Viper will write empty values otherwise. // Need if-checks as Viper will write empty values otherwise.
@ -263,6 +255,7 @@ func writeRepoConfigWithViper(
// data sources (config file, env vars, flag overrides) and the config file. // data sources (config file, env vars, flag overrides) and the config file.
func GetConfigRepoDetails( func GetConfigRepoDetails(
ctx context.Context, ctx context.Context,
provider string,
readFromFile bool, readFromFile bool,
mustMatchFromConfig bool, mustMatchFromConfig bool,
overrides map[string]string, overrides map[string]string,
@ -270,7 +263,7 @@ func GetConfigRepoDetails(
RepoDetails, RepoDetails,
error, error,
) { ) {
config, err := getStorageAndAccountWithViper(GetViper(ctx), readFromFile, mustMatchFromConfig, overrides) config, err := getStorageAndAccountWithViper(GetViper(ctx), provider, readFromFile, mustMatchFromConfig, overrides)
return config, err return config, err
} }
@ -278,6 +271,7 @@ func GetConfigRepoDetails(
// struct for testing. // struct for testing.
func getStorageAndAccountWithViper( func getStorageAndAccountWithViper(
vpr *viper.Viper, vpr *viper.Viper,
provider string,
readFromFile bool, readFromFile bool,
mustMatchFromConfig bool, mustMatchFromConfig bool,
overrides map[string]string, overrides map[string]string,
@ -312,7 +306,7 @@ func getStorageAndAccountWithViper(
return config, clues.Wrap(err, "retrieving account configuration details") return config, clues.Wrap(err, "retrieving account configuration details")
} }
config.Storage, err = configureStorage(vpr, readConfigFromViper, mustMatchFromConfig, overrides) config.Storage, err = configureStorage(vpr, provider, readConfigFromViper, mustMatchFromConfig, overrides)
if err != nil { if err != nil {
return config, clues.Wrap(err, "retrieving storage provider details") return config, clues.Wrap(err, "retrieving storage provider details")
} }
@ -378,3 +372,17 @@ func requireProps(props map[string]string) error {
return nil return nil
} }
// Storage provider is not a flag. It can only be sourced from config file.
// Only exceptions are the commands that create a new repo.
// This is needed to figure out which storage overrides to use.
func GetStorageProviderFromConfigFile(ctx context.Context) (string, error) {
vpr := GetViper(ctx)
provider := vpr.GetString(StorageProviderTypeKey)
if provider != storage.ProviderS3.String() {
return storage.ProviderUnknown.String(), clues.New("unsupported storage provider: " + provider)
}
return provider, nil
}

View File

@ -3,127 +3,39 @@ package config
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
"github.com/alcionai/corso/src/internal/common"
"github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/credentials"
"github.com/alcionai/corso/src/pkg/storage" "github.com/alcionai/corso/src/pkg/storage"
) )
// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values.
func s3ConfigsFromViper(vpr *viper.Viper) (storage.S3Config, error) {
var s3Config storage.S3Config
s3Config.Bucket = vpr.GetString(BucketNameKey)
s3Config.Endpoint = vpr.GetString(EndpointKey)
s3Config.Prefix = vpr.GetString(PrefixKey)
s3Config.DoNotUseTLS = vpr.GetBool(DisableTLSKey)
s3Config.DoNotVerifyTLS = vpr.GetBool(DisableTLSVerificationKey)
return s3Config, nil
}
// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values.
func s3CredsFromViper(vpr *viper.Viper, s3Config storage.S3Config) (storage.S3Config, error) {
s3Config.AccessKey = vpr.GetString(AccessKey)
s3Config.SecretKey = vpr.GetString(SecretAccessKey)
s3Config.SessionToken = vpr.GetString(SessionToken)
return s3Config, nil
}
func s3Overrides(in map[string]string) map[string]string {
return map[string]string{
storage.Bucket: in[storage.Bucket],
storage.Endpoint: in[storage.Endpoint],
storage.Prefix: in[storage.Prefix],
storage.DoNotUseTLS: in[storage.DoNotUseTLS],
storage.DoNotVerifyTLS: in[storage.DoNotVerifyTLS],
StorageProviderTypeKey: in[StorageProviderTypeKey],
}
}
// configureStorage builds a complete storage configuration from a mix of // configureStorage builds a complete storage configuration from a mix of
// viper properties and manual overrides. // viper properties and manual overrides.
func configureStorage( func configureStorage(
vpr *viper.Viper, vpr *viper.Viper,
provider string,
readConfigFromViper bool, readConfigFromViper bool,
matchFromConfig bool, matchFromConfig bool,
overrides map[string]string, overrides map[string]string,
) (storage.Storage, error) { ) (storage.Storage, error) {
var ( var (
s3Cfg storage.S3Config
store storage.Storage store storage.Storage
err error err error
) )
if readConfigFromViper { storageCfg, _ := storage.NewStorageConfig(provider)
if s3Cfg, err = s3ConfigsFromViper(vpr); err != nil {
return store, clues.Wrap(err, "reading s3 configs from corso config file")
}
if b, ok := overrides[storage.Bucket]; ok { err = storageCfg.FetchConfigFromStore(
overrides[storage.Bucket] = common.NormalizeBucket(b) vpr,
} readConfigFromViper,
matchFromConfig,
if p, ok := overrides[storage.Prefix]; ok { overrides)
overrides[storage.Prefix] = common.NormalizePrefix(p) if err != nil {
} return store, clues.Wrap(err, "fetching storage config from store")
if matchFromConfig {
providerType := vpr.GetString(StorageProviderTypeKey)
if providerType != storage.ProviderS3.String() {
return store, clues.New("unsupported storage provider: " + providerType)
}
if err := mustMatchConfig(vpr, s3Overrides(overrides)); err != nil {
return store, clues.Wrap(err, "verifying s3 configs in corso config file")
}
}
}
if s3Cfg, err = s3CredsFromViper(vpr, s3Cfg); err != nil {
return store, clues.Wrap(err, "reading s3 configs from corso config file")
}
s3Overrides(overrides)
aws := credentials.GetAWS(overrides)
if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 {
_, err = defaults.CredChain(defaults.Config().WithCredentialsChainVerboseErrors(true), defaults.Handlers()).Get()
if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) {
aws = credentials.AWS{
AccessKey: s3Cfg.AccessKey,
SecretKey: s3Cfg.SecretKey,
SessionToken: s3Cfg.SessionToken,
}
err = nil
}
if err != nil {
return store, clues.Wrap(err, "validating aws credentials")
}
}
s3Cfg = storage.S3Config{
AWS: aws,
Bucket: str.First(overrides[storage.Bucket], s3Cfg.Bucket),
Endpoint: str.First(overrides[storage.Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"),
Prefix: str.First(overrides[storage.Prefix], s3Cfg.Prefix),
DoNotUseTLS: str.ParseBool(str.First(
overrides[storage.DoNotUseTLS],
strconv.FormatBool(s3Cfg.DoNotUseTLS),
"false")),
DoNotVerifyTLS: str.ParseBool(str.First(
overrides[storage.DoNotVerifyTLS],
strconv.FormatBool(s3Cfg.DoNotVerifyTLS),
"false")),
} }
// compose the common config and credentials // compose the common config and credentials
@ -145,14 +57,14 @@ func configureStorage(
// ensure required properties are present // ensure required properties are present
if err := requireProps(map[string]string{ if err := requireProps(map[string]string{
storage.Bucket: s3Cfg.Bucket,
credentials.CorsoPassphrase: corso.CorsoPassphrase, credentials.CorsoPassphrase: corso.CorsoPassphrase,
}); err != nil { }); err != nil {
return storage.Storage{}, err return storage.Storage{}, err
} }
// build the storage // build the storage
store, err = storage.NewStorage(storage.ProviderS3, s3Cfg, cCfg) store, err = storage.NewStorage(
storage.StringToEnum(provider), storageCfg, cCfg)
if err != nil { if err != nil {
return store, clues.Wrap(err, "configuring repository storage") return store, clues.Wrap(err, "configuring repository storage")
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/alcionai/corso/src/cli/config"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
"github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/cli/utils"
@ -70,7 +71,13 @@ func runExport(
sel selectors.Selector, sel selectors.Selector,
backupID, serviceName string, backupID, serviceName string,
) error { ) error {
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -1,6 +1,7 @@
package repo package repo
import ( import (
"context"
"strings" "strings"
"github.com/alcionai/clues" "github.com/alcionai/clues"
@ -12,6 +13,7 @@ import (
"github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/cli/utils"
"github.com/alcionai/corso/src/pkg/control/repository" "github.com/alcionai/corso/src/pkg/control/repository"
"github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/path"
"github.com/alcionai/corso/src/pkg/storage"
) )
const ( const (
@ -121,7 +123,9 @@ func handleMaintenanceCmd(cmd *cobra.Command, args []string) error {
return err return err
} }
r, _, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.UnknownService, S3Overrides(cmd)) // Change this to override too?
r, _, err := utils.AccountConnectAndWriteRepoConfig(
ctx, path.UnknownService, storage.ProviderS3.String(), S3Overrides(cmd))
if err != nil { if err != nil {
return print.Only(ctx, err) return print.Only(ctx, err)
} }
@ -164,3 +168,18 @@ func getMaintenanceType(t string) (repository.MaintenanceType, error) {
return res, nil return res, nil
} }
func GetStorageOverrides(
ctx context.Context,
cmd *cobra.Command,
storageProvider string,
) (map[string]string, error) {
overrides := map[string]string{}
switch storageProvider {
case storage.ProviderS3.String():
overrides = S3Overrides(cmd)
}
return overrides, nil
}

View File

@ -2,7 +2,6 @@ package repo
import ( import (
"strconv" "strconv"
"strings"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -92,7 +91,7 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
// s3 values from flags // s3 values from flags
s3Override := S3Overrides(cmd) s3Override := S3Overrides(cmd)
cfg, err := config.GetConfigRepoDetails(ctx, true, false, s3Override) cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3.String(), true, false, s3Override)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -113,17 +112,19 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
cfg.Account.ID(), cfg.Account.ID(),
opt) opt)
s3Cfg, err := cfg.Storage.S3Config() //s3Cfg, err := cfg.Storage.S3Config() // why not let it return configurer?
if err != nil { storageCfg, err := cfg.Storage.GetStorageConfig()
return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) // if err != nil {
} // return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration"))
// }
if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { // BUG: This should be moved to validate()
invalidEndpointErr := "endpoint doesn't support specifying protocol. " + // if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") {
"pass --disable-tls flag to use http:// instead of default https://" // invalidEndpointErr := "endpoint doesn't support specifying protocol. " +
// "pass --disable-tls flag to use http:// instead of default https://"
return Only(ctx, clues.New(invalidEndpointErr)) // return Only(ctx, clues.New(invalidEndpointErr))
} // }
m365, err := cfg.Account.M365Config() m365, err := cfg.Account.M365Config()
if err != nil { if err != nil {
@ -146,9 +147,10 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
defer utils.CloseRepo(ctx, r) defer utils.CloseRepo(ctx, r)
Infof(ctx, "Initialized a S3 repository within bucket %s.", s3Cfg.Bucket) // Strong typecast?
Infof(ctx, "Initialized a S3 repository within bucket %s.", cfg.Storage.Config[storage.Bucket])
if err = config.WriteRepoConfig(ctx, s3Cfg, m365, opt.Repo, r.GetID()); err != nil { if err = config.WriteRepoConfig(ctx, storageCfg, m365, opt.Repo, r.GetID()); err != nil {
return Only(ctx, clues.Wrap(err, "Failed to write repository configuration")) return Only(ctx, clues.Wrap(err, "Failed to write repository configuration"))
} }
@ -178,7 +180,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
// s3 values from flags // s3 values from flags
s3Override := S3Overrides(cmd) s3Override := S3Overrides(cmd)
cfg, err := config.GetConfigRepoDetails(ctx, true, true, s3Override) cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3.String(), true, true, s3Override)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }
@ -188,7 +190,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
repoID = events.RepoIDNotFound repoID = events.RepoIDNotFound
} }
s3Cfg, err := cfg.Storage.S3Config() s3Cfg, err := cfg.Storage.GetStorageConfig()
if err != nil { if err != nil {
return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration"))
} }
@ -198,12 +200,13 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
return Only(ctx, clues.Wrap(err, "Failed to parse m365 account config")) return Only(ctx, clues.Wrap(err, "Failed to parse m365 account config"))
} }
if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { // Move these to validate()?
invalidEndpointErr := "endpoint doesn't support specifying protocol. " + // if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") {
"pass --disable-tls flag to use http:// instead of default https://" // invalidEndpointErr := "endpoint doesn't support specifying protocol. " +
// "pass --disable-tls flag to use http:// instead of default https://"
return Only(ctx, clues.New(invalidEndpointErr)) // return Only(ctx, clues.New(invalidEndpointErr))
} // }
opts := utils.ControlWithConfig(cfg) opts := utils.ControlWithConfig(cfg)
@ -219,7 +222,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
defer utils.CloseRepo(ctx, r) defer utils.CloseRepo(ctx, r)
Infof(ctx, "Connected to S3 bucket %s.", s3Cfg.Bucket) Infof(ctx, "Connected to S3 bucket %s.", cfg.Storage.Config[storage.Bucket])
if err = config.WriteRepoConfig(ctx, s3Cfg, m365, opts.Repo, r.GetID()); err != nil { if err = config.WriteRepoConfig(ctx, s3Cfg, m365, opts.Repo, r.GetID()); err != nil {
return Only(ctx, clues.Wrap(err, "Failed to write repository configuration")) return Only(ctx, clues.Wrap(err, "Failed to write repository configuration"))

View File

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags" "github.com/alcionai/corso/src/cli/flags"
. "github.com/alcionai/corso/src/cli/print" . "github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/repo"
@ -103,7 +104,13 @@ func runRestore(
sel selectors.Selector, sel selectors.Selector,
backupID, serviceName string, backupID, serviceName string,
) error { ) error {
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), repo.S3Overrides(cmd)) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
overrides, err := repo.GetStorageOverrides(ctx, cmd, provider)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), provider, overrides)
if err != nil { if err != nil {
return Only(ctx, err) return Only(ctx, err)
} }

View File

@ -24,9 +24,10 @@ var ErrNotYetImplemented = clues.New("not yet implemented")
func GetAccountAndConnect( func GetAccountAndConnect(
ctx context.Context, ctx context.Context,
pst path.ServiceType, pst path.ServiceType,
provider string,
overrides map[string]string, overrides map[string]string,
) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) { ) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) {
cfg, err := config.GetConfigRepoDetails(ctx, true, true, overrides) cfg, err := config.GetConfigRepoDetails(ctx, provider, true, true, overrides)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }
@ -55,17 +56,19 @@ func GetAccountAndConnect(
func AccountConnectAndWriteRepoConfig( func AccountConnectAndWriteRepoConfig(
ctx context.Context, ctx context.Context,
pst path.ServiceType, pst path.ServiceType,
provider string,
overrides map[string]string, overrides map[string]string,
) (repository.Repository, *account.Account, error) { ) (repository.Repository, *account.Account, error) {
r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, overrides) r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, provider, overrides)
if err != nil { if err != nil {
logger.CtxErr(ctx, err).Info("getting and connecting account") logger.CtxErr(ctx, err).Info("getting and connecting account")
return nil, nil, err return nil, nil, err
} }
s3Config, err := stg.S3Config() storageCfg, err := stg.GetStorageConfig()
if err != nil { if err != nil {
logger.CtxErr(ctx, err).Info("getting storage configuration") logger.CtxErr(ctx, err).Info("getting storage configuration")
return nil, nil, err return nil, nil, err
} }
@ -77,7 +80,7 @@ func AccountConnectAndWriteRepoConfig(
// repo config gets set during repo connect and init. // repo config gets set during repo connect and init.
// This call confirms we have the correct values. // This call confirms we have the correct values.
err = config.WriteRepoConfig(ctx, s3Config, m365Config, opts.Repo, r.GetID()) err = config.WriteRepoConfig(ctx, storageCfg, m365Config, opts.Repo, r.GetID())
if err != nil { if err != nil {
logger.CtxErr(ctx, err).Info("writing to repository configuration") logger.CtxErr(ctx, err).Info("writing to repository configuration")
return nil, nil, err return nil, nil, err

View File

@ -29,7 +29,8 @@ func deleteBackups(
) ([]string, error) { ) ([]string, error) {
ctx = clues.Add(ctx, "cutoff_days", deletionDays) ctx = clues.Add(ctx, "cutoff_days", deletionDays)
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, nil) provider, _ := config.GetStorageProviderFromConfigFile(ctx)
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, provider, nil)
if err != nil { if err != nil {
return nil, clues.Wrap(err, "connecting to account").WithClues(ctx) return nil, clues.Wrap(err, "connecting to account").WithClues(ctx)
} }

View File

@ -20,29 +20,32 @@ func s3BlobStorage(
repoOpts repository.Options, repoOpts repository.Options,
s storage.Storage, s storage.Storage,
) (blob.Storage, error) { ) (blob.Storage, error) {
cfg, err := s.S3Config() cfg, err := s.GetStorageConfig()
if err != nil { if err != nil {
return nil, clues.Stack(err).WithClues(ctx) return nil, clues.Stack(err).WithClues(ctx)
} }
// Cast to S3Config
s3Cfg := cfg.(storage.S3Config)
endpoint := defaultS3Endpoint endpoint := defaultS3Endpoint
if len(cfg.Endpoint) > 0 { if len(s3Cfg.Endpoint) > 0 {
endpoint = cfg.Endpoint endpoint = s3Cfg.Endpoint
} }
opts := s3.Options{ opts := s3.Options{
BucketName: cfg.Bucket, BucketName: s3Cfg.Bucket,
Endpoint: endpoint, Endpoint: endpoint,
Prefix: cfg.Prefix, Prefix: s3Cfg.Prefix,
DoNotUseTLS: cfg.DoNotUseTLS, DoNotUseTLS: s3Cfg.DoNotUseTLS,
DoNotVerifyTLS: cfg.DoNotVerifyTLS, DoNotVerifyTLS: s3Cfg.DoNotVerifyTLS,
Tags: s.SessionTags, Tags: s.SessionTags,
SessionName: s.SessionName, SessionName: s.SessionName,
RoleARN: s.Role, RoleARN: s.Role,
RoleDuration: s.SessionDuration, RoleDuration: s.SessionDuration,
AccessKeyID: cfg.AccessKey, AccessKeyID: s3Cfg.AccessKey,
SecretAccessKey: cfg.SecretKey, SecretAccessKey: s3Cfg.SecretKey,
SessionToken: cfg.SessionToken, SessionToken: s3Cfg.SessionToken,
TLSHandshakeTimeout: 60, TLSHandshakeTimeout: 60,
PointInTime: repoOpts.ViewTimestamp, PointInTime: repoOpts.ViewTimestamp,
} }

View File

@ -4,12 +4,29 @@ import (
"strconv" "strconv"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/spf13/cast"
"github.com/alcionai/corso/src/internal/common" "github.com/alcionai/corso/src/internal/common"
"github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/credentials"
) )
const (
// S3 config
StorageProviderTypeKey = "provider"
BucketNameKey = "bucket"
EndpointKey = "endpoint"
PrefixKey = "prefix"
DisableTLSKey = "disable_tls"
DisableTLSVerificationKey = "disable_tls_verification"
RepoID = "repo_id"
AccessKey = "aws_access_key_id"
SecretAccessKey = "aws_secret_access_key"
SessionToken = "aws_session_token"
)
type S3Config struct { type S3Config struct {
credentials.AWS credentials.AWS
Bucket string // required Bucket string // required
@ -50,6 +67,120 @@ func (c S3Config) Normalize() S3Config {
} }
} }
// No need to return error here. Viper returns empty values.
func s3ConfigsFromStore(kvs KVStorer) S3Config {
var s3Config S3Config
s3Config.Bucket = cast.ToString(kvs.Get(BucketNameKey))
s3Config.Endpoint = cast.ToString(kvs.Get(EndpointKey))
s3Config.Prefix = cast.ToString(kvs.Get(PrefixKey))
s3Config.DoNotUseTLS = cast.ToBool(kvs.Get(DisableTLSKey))
s3Config.DoNotVerifyTLS = cast.ToBool(kvs.Get(DisableTLSVerificationKey))
return s3Config
}
func s3CredsFromStore(
kvs KVStorer,
s3Config S3Config,
) S3Config {
s3Config.AccessKey = cast.ToString(kvs.Get(AccessKey))
s3Config.SecretKey = cast.ToString(kvs.Get(SecretAccessKey))
s3Config.SessionToken = cast.ToString(kvs.Get(SessionToken))
return s3Config
}
var _ StorageConfigurer = S3Config{}
func (c S3Config) FetchConfigFromStore(
kvs KVStorer,
readConfigFromStore bool,
matchFromConfig bool,
overrides map[string]string,
) error {
var (
s3Cfg S3Config
err error
)
if readConfigFromStore {
s3Cfg = s3ConfigsFromStore(kvs)
if b, ok := overrides[Bucket]; ok {
overrides[Bucket] = common.NormalizeBucket(b)
}
if p, ok := overrides[Prefix]; ok {
overrides[Prefix] = common.NormalizePrefix(p)
}
if matchFromConfig {
providerType := cast.ToString(kvs.Get(StorageProviderTypeKey))
if providerType != ProviderS3.String() {
return clues.New("unsupported storage provider: " + providerType)
}
// This is matching override values from config file.
if err := mustMatchConfig(kvs, s3Overrides(overrides)); err != nil {
return clues.Wrap(err, "verifying s3 configs in corso config file")
}
}
}
s3Cfg = s3CredsFromStore(kvs, s3Cfg)
aws := credentials.GetAWS(overrides)
if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 {
_, err = defaults.CredChain(
defaults.Config().WithCredentialsChainVerboseErrors(true),
defaults.Handlers()).Get()
if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) {
aws = credentials.AWS{
AccessKey: s3Cfg.AccessKey,
SecretKey: s3Cfg.SecretKey,
SessionToken: s3Cfg.SessionToken,
}
err = nil
}
if err != nil {
return clues.Wrap(err, "validating aws credentials")
}
}
s3Cfg = S3Config{
AWS: aws,
Bucket: str.First(overrides[Bucket], s3Cfg.Bucket),
Endpoint: str.First(overrides[Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"),
Prefix: str.First(overrides[Prefix], s3Cfg.Prefix),
DoNotUseTLS: str.ParseBool(str.First(
overrides[DoNotUseTLS],
strconv.FormatBool(s3Cfg.DoNotUseTLS),
"false")),
DoNotVerifyTLS: str.ParseBool(str.First(
overrides[DoNotVerifyTLS],
strconv.FormatBool(s3Cfg.DoNotVerifyTLS),
"false")),
}
return nil
}
var _ WriteConfigToStorer = S3Config{}
func (c S3Config) WriteConfigToStore(
kvs KVStoreSetter,
) {
s3Config := c.Normalize()
kvs.Set(StorageProviderTypeKey, ProviderS3.String())
kvs.Set(BucketNameKey, s3Config.Bucket)
kvs.Set(EndpointKey, s3Config.Endpoint)
kvs.Set(PrefixKey, s3Config.Prefix)
kvs.Set(DisableTLSKey, s3Config.DoNotUseTLS)
kvs.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS)
}
// StringConfig transforms a s3Config struct into a plain // StringConfig transforms a s3Config struct into a plain
// map[string]string. All values in the original struct which // map[string]string. All values in the original struct which
// serialize into the map are expected to be strings. // serialize into the map are expected to be strings.
@ -70,19 +201,19 @@ func (c S3Config) StringConfig() (map[string]string, error) {
} }
// S3Config retrieves the S3Config details from the Storage config. // S3Config retrieves the S3Config details from the Storage config.
func (s Storage) S3Config() (S3Config, error) { func MakeS3ConfigFromMap(config map[string]string) (S3Config, error) {
c := S3Config{} c := S3Config{}
if len(s.Config) > 0 { if len(config) > 0 {
c.AccessKey = orEmptyString(s.Config[keyS3AccessKey]) c.AccessKey = orEmptyString(config[keyS3AccessKey])
c.SecretKey = orEmptyString(s.Config[keyS3SecretKey]) c.SecretKey = orEmptyString(config[keyS3SecretKey])
c.SessionToken = orEmptyString(s.Config[keyS3SessionToken]) c.SessionToken = orEmptyString(config[keyS3SessionToken])
c.Bucket = orEmptyString(s.Config[keyS3Bucket]) c.Bucket = orEmptyString(config[keyS3Bucket])
c.Endpoint = orEmptyString(s.Config[keyS3Endpoint]) c.Endpoint = orEmptyString(config[keyS3Endpoint])
c.Prefix = orEmptyString(s.Config[keyS3Prefix]) c.Prefix = orEmptyString(config[keyS3Prefix])
c.DoNotUseTLS = str.ParseBool(s.Config[keyS3DoNotUseTLS]) c.DoNotUseTLS = str.ParseBool(config[keyS3DoNotUseTLS])
c.DoNotVerifyTLS = str.ParseBool(s.Config[keyS3DoNotVerifyTLS]) c.DoNotVerifyTLS = str.ParseBool(config[keyS3DoNotVerifyTLS])
} }
return c, c.validate() return c, c.validate()
@ -100,3 +231,44 @@ func (c S3Config) validate() error {
return nil return nil
} }
var constToTomlKeyMap = map[string]string{
Bucket: BucketNameKey,
Endpoint: EndpointKey,
Prefix: PrefixKey,
StorageProviderTypeKey: StorageProviderTypeKey,
}
// mustMatchConfig compares the values of each key to their config file value in store.
// If any value differs from the store value, an error is returned.
// values in m that aren't stored in the config are ignored.
func mustMatchConfig(kvs KVStorer, m map[string]string) error {
for k, v := range m {
if len(v) == 0 {
continue // empty variables will get caught by configuration validators, if necessary
}
tomlK, ok := constToTomlKeyMap[k]
if !ok {
continue // m may declare values which aren't stored in the config file
}
vv := cast.ToString(kvs.Get(tomlK))
if v != vv {
return clues.New("value of " + k + " (" + v + ") does not match corso configuration value (" + vv + ")")
}
}
return nil
}
func s3Overrides(in map[string]string) map[string]string {
return map[string]string{
Bucket: in[Bucket],
Endpoint: in[Endpoint],
Prefix: in[Prefix],
DoNotUseTLS: in[DoNotUseTLS],
DoNotVerifyTLS: in[DoNotVerifyTLS],
StorageProviderTypeKey: in[StorageProviderTypeKey],
}
}

View File

@ -17,6 +17,15 @@ const (
ProviderFilesystem ProviderType = 2 // Filesystem ProviderFilesystem ProviderType = 2 // Filesystem
) )
func StringToEnum(s string) StorageProvider {
switch s {
case ProviderS3.String():
return ProviderS3
}
return ProviderUnknown
}
// storage parsing errors // storage parsing errors
var ( var (
errMissingRequired = clues.New("missing required storage configuration") errMissingRequired = clues.New("missing required storage configuration")
@ -82,3 +91,53 @@ func orEmptyString(v any) string {
return v.(string) return v.(string)
} }
func (s Storage) GetStorageConfig() (StorageConfigurer, error) {
switch s.Provider {
case ProviderS3:
return MakeS3ConfigFromMap(s.Config)
}
return nil, clues.New("unsupported storage provider: " + s.Provider.String())
}
func NewStorageConfig(provider string) (StorageConfigurer, error) {
switch provider {
case ProviderS3.String():
return S3Config{}, nil
}
return nil, clues.New("unsupported storage provider: " + provider)
}
// Change it to just getter
type KVStorer interface {
Get(key string) any
Set(key string, value any)
}
type KVStoreSetter interface {
Set(key string, value any)
}
// Call it configurer if necessary.
type StorageConfigurer interface {
common.StringConfigurer
FetchConfigFromStorer
WriteConfigToStorer
}
type WriteConfigToStorer interface {
WriteConfigToStore(
kvs KVStoreSetter,
)
}
type FetchConfigFromStorer interface {
FetchConfigFromStore(
kv KVStorer,
readConfigFromStore bool,
matchFromConfig bool,
overrides map[string]string,
) error
}