diff --git a/src/cli/backup/backup.go b/src/cli/backup/backup.go index 93fcce74f..d141d22ba 100644 --- a/src/cli/backup/backup.go +++ b/src/cli/backup/backup.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" "github.com/spf13/cobra" + "github.com/alcionai/corso/src/cli/config" "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/repo" @@ -290,7 +291,16 @@ func genericDeleteCommand( ctx := clues.Add(cmd.Context(), "delete_backup_id", bID) - r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, repo.S3Overrides(cmd)) + // Let it return both provider and overrides for now? + // That way we can stop config pkg from being included everywhere. + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, provider, overrides) if err != nil { return Only(ctx, err) } @@ -316,7 +326,14 @@ func genericListCommand( ) error { ctx := cmd.Context() - r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, provider, overrides) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/exchange.go b/src/cli/backup/exchange.go index 298569da6..5c650a547 100644 --- a/src/cli/backup/exchange.go +++ b/src/cli/backup/exchange.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" + "github.com/alcionai/corso/src/cli/config" "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/repo" @@ -168,7 +169,13 @@ func createExchangeCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.ExchangeService, repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.ExchangeService, provider, overrides) if err != nil { return Only(ctx, err) } @@ -277,7 +284,13 @@ func detailsExchangeCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeExchangeOpts(cmd) - r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.ExchangeService, repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.ExchangeService, provider, overrides) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/groups.go b/src/cli/backup/groups.go index 9f7f928c8..a93ac0c87 100644 --- a/src/cli/backup/groups.go +++ b/src/cli/backup/groups.go @@ -10,6 +10,7 @@ import ( "github.com/spf13/pflag" "golang.org/x/exp/slices" + "github.com/alcionai/corso/src/cli/config" "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/repo" @@ -154,7 +155,13 @@ func createGroupsCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.GroupsService, repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.GroupsService, provider, overrides) if err != nil { return Only(ctx, err) } @@ -226,7 +233,13 @@ func detailsGroupsCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeGroupsOpts(cmd) - r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.GroupsService, repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.GroupsService, provider, overrides) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/onedrive.go b/src/cli/backup/onedrive.go index 87a8a2236..cab9eb2e1 100644 --- a/src/cli/backup/onedrive.go +++ b/src/cli/backup/onedrive.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" + "github.com/alcionai/corso/src/cli/config" "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/repo" @@ -149,7 +150,13 @@ func createOneDriveCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.OneDriveService, repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.OneDriveService, provider, overrides) if err != nil { return Only(ctx, err) } @@ -235,7 +242,13 @@ func detailsOneDriveCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeOneDriveOpts(cmd) - r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.OneDriveService, repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.OneDriveService, provider, overrides) if err != nil { return Only(ctx, err) } diff --git a/src/cli/backup/sharepoint.go b/src/cli/backup/sharepoint.go index c80076512..fc0b1bc21 100644 --- a/src/cli/backup/sharepoint.go +++ b/src/cli/backup/sharepoint.go @@ -9,6 +9,7 @@ import ( "github.com/spf13/pflag" "golang.org/x/exp/slices" + "github.com/alcionai/corso/src/cli/config" "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/repo" @@ -159,7 +160,14 @@ func createSharePointCmd(cmd *cobra.Command, args []string) error { return err } - r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.SharePointService, repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.SharePointService, provider, overrides) if err != nil { return Only(ctx, err) } @@ -319,7 +327,13 @@ func detailsSharePointCmd(cmd *cobra.Command, args []string) error { ctx := cmd.Context() opts := utils.MakeSharePointOpts(cmd) - r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.SharePointService, repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.SharePointService, provider, overrides) if err != nil { return Only(ctx, err) } diff --git a/src/cli/cli.go b/src/cli/cli.go index 230433479..6a82ad75f 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -70,9 +70,13 @@ func preRun(cc *cobra.Command, args []string) error { } if !slices.Contains(avoidTheseDescription, cc.Short) { - overrides := repo.S3Overrides(cc) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + overrides, err := repo.GetStorageOverrides(ctx, cc, provider) + if err != nil { + return err + } - cfg, err := config.GetConfigRepoDetails(ctx, true, false, overrides) + cfg, err := config.GetConfigRepoDetails(ctx, provider, true, false, overrides) if err != nil { log.Error("Error while getting config info to run command: ", cc.Use) return err diff --git a/src/cli/config/config.go b/src/cli/config/config.go index 39a34eb9c..d8e731353 100644 --- a/src/cli/config/config.go +++ b/src/cli/config/config.go @@ -203,14 +203,14 @@ func Read(ctx context.Context) error { // It does not check for conflicts or existing data. func WriteRepoConfig( ctx context.Context, - s3Config storage.S3Config, + scfg storage.WriteConfigToStorer, m365Config account.M365Config, repoOpts repository.Options, repoID string, ) error { return writeRepoConfigWithViper( GetViper(ctx), - s3Config, + scfg, m365Config, repoOpts, repoID) @@ -220,20 +220,12 @@ func WriteRepoConfig( // struct for testing. func writeRepoConfigWithViper( vpr *viper.Viper, - s3Config storage.S3Config, + scfg storage.WriteConfigToStorer, m365Config account.M365Config, repoOpts repository.Options, repoID string, ) error { - s3Config = s3Config.Normalize() - // Rudimentary support for persisting repo config - // TODO: Handle conflicts, support other config types - vpr.Set(StorageProviderTypeKey, storage.ProviderS3.String()) - vpr.Set(BucketNameKey, s3Config.Bucket) - vpr.Set(EndpointKey, s3Config.Endpoint) - vpr.Set(PrefixKey, s3Config.Prefix) - vpr.Set(DisableTLSKey, s3Config.DoNotUseTLS) - vpr.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS) + scfg.WriteConfigToStore(vpr) vpr.Set(RepoID, repoID) // Need if-checks as Viper will write empty values otherwise. @@ -263,6 +255,7 @@ func writeRepoConfigWithViper( // data sources (config file, env vars, flag overrides) and the config file. func GetConfigRepoDetails( ctx context.Context, + provider string, readFromFile bool, mustMatchFromConfig bool, overrides map[string]string, @@ -270,7 +263,7 @@ func GetConfigRepoDetails( RepoDetails, error, ) { - config, err := getStorageAndAccountWithViper(GetViper(ctx), readFromFile, mustMatchFromConfig, overrides) + config, err := getStorageAndAccountWithViper(GetViper(ctx), provider, readFromFile, mustMatchFromConfig, overrides) return config, err } @@ -278,6 +271,7 @@ func GetConfigRepoDetails( // struct for testing. func getStorageAndAccountWithViper( vpr *viper.Viper, + provider string, readFromFile bool, mustMatchFromConfig bool, overrides map[string]string, @@ -312,7 +306,7 @@ func getStorageAndAccountWithViper( return config, clues.Wrap(err, "retrieving account configuration details") } - config.Storage, err = configureStorage(vpr, readConfigFromViper, mustMatchFromConfig, overrides) + config.Storage, err = configureStorage(vpr, provider, readConfigFromViper, mustMatchFromConfig, overrides) if err != nil { return config, clues.Wrap(err, "retrieving storage provider details") } @@ -378,3 +372,17 @@ func requireProps(props map[string]string) error { return nil } + +// Storage provider is not a flag. It can only be sourced from config file. +// Only exceptions are the commands that create a new repo. +// This is needed to figure out which storage overrides to use. +func GetStorageProviderFromConfigFile(ctx context.Context) (string, error) { + vpr := GetViper(ctx) + + provider := vpr.GetString(StorageProviderTypeKey) + if provider != storage.ProviderS3.String() { + return storage.ProviderUnknown.String(), clues.New("unsupported storage provider: " + provider) + } + + return provider, nil +} diff --git a/src/cli/config/storage.go b/src/cli/config/storage.go index 964c740fd..bad20d514 100644 --- a/src/cli/config/storage.go +++ b/src/cli/config/storage.go @@ -3,127 +3,39 @@ package config import ( "os" "path/filepath" - "strconv" "github.com/alcionai/clues" - "github.com/aws/aws-sdk-go/aws/defaults" "github.com/spf13/viper" "github.com/alcionai/corso/src/cli/flags" - "github.com/alcionai/corso/src/internal/common" "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/storage" ) -// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. -func s3ConfigsFromViper(vpr *viper.Viper) (storage.S3Config, error) { - var s3Config storage.S3Config - - s3Config.Bucket = vpr.GetString(BucketNameKey) - s3Config.Endpoint = vpr.GetString(EndpointKey) - s3Config.Prefix = vpr.GetString(PrefixKey) - s3Config.DoNotUseTLS = vpr.GetBool(DisableTLSKey) - s3Config.DoNotVerifyTLS = vpr.GetBool(DisableTLSVerificationKey) - - return s3Config, nil -} - -// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values. -func s3CredsFromViper(vpr *viper.Viper, s3Config storage.S3Config) (storage.S3Config, error) { - s3Config.AccessKey = vpr.GetString(AccessKey) - s3Config.SecretKey = vpr.GetString(SecretAccessKey) - s3Config.SessionToken = vpr.GetString(SessionToken) - - return s3Config, nil -} - -func s3Overrides(in map[string]string) map[string]string { - return map[string]string{ - storage.Bucket: in[storage.Bucket], - storage.Endpoint: in[storage.Endpoint], - storage.Prefix: in[storage.Prefix], - storage.DoNotUseTLS: in[storage.DoNotUseTLS], - storage.DoNotVerifyTLS: in[storage.DoNotVerifyTLS], - StorageProviderTypeKey: in[StorageProviderTypeKey], - } -} - // configureStorage builds a complete storage configuration from a mix of // viper properties and manual overrides. func configureStorage( vpr *viper.Viper, + provider string, readConfigFromViper bool, matchFromConfig bool, overrides map[string]string, ) (storage.Storage, error) { var ( - s3Cfg storage.S3Config store storage.Storage err error ) - if readConfigFromViper { - if s3Cfg, err = s3ConfigsFromViper(vpr); err != nil { - return store, clues.Wrap(err, "reading s3 configs from corso config file") - } + storageCfg, _ := storage.NewStorageConfig(provider) - if b, ok := overrides[storage.Bucket]; ok { - overrides[storage.Bucket] = common.NormalizeBucket(b) - } - - if p, ok := overrides[storage.Prefix]; ok { - overrides[storage.Prefix] = common.NormalizePrefix(p) - } - - if matchFromConfig { - providerType := vpr.GetString(StorageProviderTypeKey) - if providerType != storage.ProviderS3.String() { - return store, clues.New("unsupported storage provider: " + providerType) - } - - if err := mustMatchConfig(vpr, s3Overrides(overrides)); err != nil { - return store, clues.Wrap(err, "verifying s3 configs in corso config file") - } - } - } - - if s3Cfg, err = s3CredsFromViper(vpr, s3Cfg); err != nil { - return store, clues.Wrap(err, "reading s3 configs from corso config file") - } - - s3Overrides(overrides) - aws := credentials.GetAWS(overrides) - - if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 { - _, err = defaults.CredChain(defaults.Config().WithCredentialsChainVerboseErrors(true), defaults.Handlers()).Get() - if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) { - aws = credentials.AWS{ - AccessKey: s3Cfg.AccessKey, - SecretKey: s3Cfg.SecretKey, - SessionToken: s3Cfg.SessionToken, - } - err = nil - } - - if err != nil { - return store, clues.Wrap(err, "validating aws credentials") - } - } - - s3Cfg = storage.S3Config{ - AWS: aws, - Bucket: str.First(overrides[storage.Bucket], s3Cfg.Bucket), - Endpoint: str.First(overrides[storage.Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"), - Prefix: str.First(overrides[storage.Prefix], s3Cfg.Prefix), - DoNotUseTLS: str.ParseBool(str.First( - overrides[storage.DoNotUseTLS], - strconv.FormatBool(s3Cfg.DoNotUseTLS), - "false")), - DoNotVerifyTLS: str.ParseBool(str.First( - overrides[storage.DoNotVerifyTLS], - strconv.FormatBool(s3Cfg.DoNotVerifyTLS), - "false")), + err = storageCfg.FetchConfigFromStore( + vpr, + readConfigFromViper, + matchFromConfig, + overrides) + if err != nil { + return store, clues.Wrap(err, "fetching storage config from store") } // compose the common config and credentials @@ -145,14 +57,14 @@ func configureStorage( // ensure required properties are present if err := requireProps(map[string]string{ - storage.Bucket: s3Cfg.Bucket, credentials.CorsoPassphrase: corso.CorsoPassphrase, }); err != nil { return storage.Storage{}, err } // build the storage - store, err = storage.NewStorage(storage.ProviderS3, s3Cfg, cCfg) + store, err = storage.NewStorage( + storage.StringToEnum(provider), storageCfg, cCfg) if err != nil { return store, clues.Wrap(err, "configuring repository storage") } diff --git a/src/cli/export/export.go b/src/cli/export/export.go index 89a2111fc..571d4f938 100644 --- a/src/cli/export/export.go +++ b/src/cli/export/export.go @@ -8,6 +8,7 @@ import ( "github.com/alcionai/clues" "github.com/spf13/cobra" + "github.com/alcionai/corso/src/cli/config" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/repo" "github.com/alcionai/corso/src/cli/utils" @@ -70,7 +71,13 @@ func runExport( sel selectors.Selector, backupID, serviceName string, ) error { - r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), provider, overrides) if err != nil { return Only(ctx, err) } diff --git a/src/cli/repo/repo.go b/src/cli/repo/repo.go index 378cce16a..fd83f7f97 100644 --- a/src/cli/repo/repo.go +++ b/src/cli/repo/repo.go @@ -1,6 +1,7 @@ package repo import ( + "context" "strings" "github.com/alcionai/clues" @@ -12,6 +13,7 @@ import ( "github.com/alcionai/corso/src/cli/utils" "github.com/alcionai/corso/src/pkg/control/repository" "github.com/alcionai/corso/src/pkg/path" + "github.com/alcionai/corso/src/pkg/storage" ) const ( @@ -121,7 +123,9 @@ func handleMaintenanceCmd(cmd *cobra.Command, args []string) error { return err } - r, _, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.UnknownService, S3Overrides(cmd)) + // Change this to override too? + r, _, err := utils.AccountConnectAndWriteRepoConfig( + ctx, path.UnknownService, storage.ProviderS3.String(), S3Overrides(cmd)) if err != nil { return print.Only(ctx, err) } @@ -164,3 +168,18 @@ func getMaintenanceType(t string) (repository.MaintenanceType, error) { return res, nil } + +func GetStorageOverrides( + ctx context.Context, + cmd *cobra.Command, + storageProvider string, +) (map[string]string, error) { + overrides := map[string]string{} + + switch storageProvider { + case storage.ProviderS3.String(): + overrides = S3Overrides(cmd) + } + + return overrides, nil +} diff --git a/src/cli/repo/s3.go b/src/cli/repo/s3.go index af18cb65e..e2a03d158 100644 --- a/src/cli/repo/s3.go +++ b/src/cli/repo/s3.go @@ -2,7 +2,6 @@ package repo import ( "strconv" - "strings" "github.com/alcionai/clues" "github.com/pkg/errors" @@ -92,7 +91,7 @@ func initS3Cmd(cmd *cobra.Command, args []string) error { // s3 values from flags s3Override := S3Overrides(cmd) - cfg, err := config.GetConfigRepoDetails(ctx, true, false, s3Override) + cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3.String(), true, false, s3Override) if err != nil { return Only(ctx, err) } @@ -113,17 +112,19 @@ func initS3Cmd(cmd *cobra.Command, args []string) error { cfg.Account.ID(), opt) - s3Cfg, err := cfg.Storage.S3Config() - if err != nil { - return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) - } + //s3Cfg, err := cfg.Storage.S3Config() // why not let it return configurer? + storageCfg, err := cfg.Storage.GetStorageConfig() + // if err != nil { + // return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) + // } - if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { - invalidEndpointErr := "endpoint doesn't support specifying protocol. " + - "pass --disable-tls flag to use http:// instead of default https://" + // BUG: This should be moved to validate() + // if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { + // invalidEndpointErr := "endpoint doesn't support specifying protocol. " + + // "pass --disable-tls flag to use http:// instead of default https://" - return Only(ctx, clues.New(invalidEndpointErr)) - } + // return Only(ctx, clues.New(invalidEndpointErr)) + // } m365, err := cfg.Account.M365Config() if err != nil { @@ -146,9 +147,10 @@ func initS3Cmd(cmd *cobra.Command, args []string) error { defer utils.CloseRepo(ctx, r) - Infof(ctx, "Initialized a S3 repository within bucket %s.", s3Cfg.Bucket) + // Strong typecast? + Infof(ctx, "Initialized a S3 repository within bucket %s.", cfg.Storage.Config[storage.Bucket]) - if err = config.WriteRepoConfig(ctx, s3Cfg, m365, opt.Repo, r.GetID()); err != nil { + if err = config.WriteRepoConfig(ctx, storageCfg, m365, opt.Repo, r.GetID()); err != nil { return Only(ctx, clues.Wrap(err, "Failed to write repository configuration")) } @@ -178,7 +180,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { // s3 values from flags s3Override := S3Overrides(cmd) - cfg, err := config.GetConfigRepoDetails(ctx, true, true, s3Override) + cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3.String(), true, true, s3Override) if err != nil { return Only(ctx, err) } @@ -188,7 +190,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { repoID = events.RepoIDNotFound } - s3Cfg, err := cfg.Storage.S3Config() + s3Cfg, err := cfg.Storage.GetStorageConfig() if err != nil { return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration")) } @@ -198,12 +200,13 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { return Only(ctx, clues.Wrap(err, "Failed to parse m365 account config")) } - if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { - invalidEndpointErr := "endpoint doesn't support specifying protocol. " + - "pass --disable-tls flag to use http:// instead of default https://" + // Move these to validate()? + // if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") { + // invalidEndpointErr := "endpoint doesn't support specifying protocol. " + + // "pass --disable-tls flag to use http:// instead of default https://" - return Only(ctx, clues.New(invalidEndpointErr)) - } + // return Only(ctx, clues.New(invalidEndpointErr)) + // } opts := utils.ControlWithConfig(cfg) @@ -219,7 +222,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error { defer utils.CloseRepo(ctx, r) - Infof(ctx, "Connected to S3 bucket %s.", s3Cfg.Bucket) + Infof(ctx, "Connected to S3 bucket %s.", cfg.Storage.Config[storage.Bucket]) if err = config.WriteRepoConfig(ctx, s3Cfg, m365, opts.Repo, r.GetID()); err != nil { return Only(ctx, clues.Wrap(err, "Failed to write repository configuration")) diff --git a/src/cli/restore/restore.go b/src/cli/restore/restore.go index 3f62ab0ae..dc4e3d1b0 100644 --- a/src/cli/restore/restore.go +++ b/src/cli/restore/restore.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "github.com/spf13/cobra" + "github.com/alcionai/corso/src/cli/config" "github.com/alcionai/corso/src/cli/flags" . "github.com/alcionai/corso/src/cli/print" "github.com/alcionai/corso/src/cli/repo" @@ -103,7 +104,13 @@ func runRestore( sel selectors.Selector, backupID, serviceName string, ) error { - r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), repo.S3Overrides(cmd)) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + overrides, err := repo.GetStorageOverrides(ctx, cmd, provider) + if err != nil { + return Only(ctx, err) + } + + r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), provider, overrides) if err != nil { return Only(ctx, err) } diff --git a/src/cli/utils/utils.go b/src/cli/utils/utils.go index 5a639474a..0ac27a261 100644 --- a/src/cli/utils/utils.go +++ b/src/cli/utils/utils.go @@ -24,9 +24,10 @@ var ErrNotYetImplemented = clues.New("not yet implemented") func GetAccountAndConnect( ctx context.Context, pst path.ServiceType, + provider string, overrides map[string]string, ) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) { - cfg, err := config.GetConfigRepoDetails(ctx, true, true, overrides) + cfg, err := config.GetConfigRepoDetails(ctx, provider, true, true, overrides) if err != nil { return nil, nil, nil, nil, err } @@ -55,17 +56,19 @@ func GetAccountAndConnect( func AccountConnectAndWriteRepoConfig( ctx context.Context, pst path.ServiceType, + provider string, overrides map[string]string, ) (repository.Repository, *account.Account, error) { - r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, overrides) + r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, provider, overrides) if err != nil { logger.CtxErr(ctx, err).Info("getting and connecting account") return nil, nil, err } - s3Config, err := stg.S3Config() + storageCfg, err := stg.GetStorageConfig() if err != nil { logger.CtxErr(ctx, err).Info("getting storage configuration") + return nil, nil, err } @@ -77,7 +80,7 @@ func AccountConnectAndWriteRepoConfig( // repo config gets set during repo connect and init. // This call confirms we have the correct values. - err = config.WriteRepoConfig(ctx, s3Config, m365Config, opts.Repo, r.GetID()) + err = config.WriteRepoConfig(ctx, storageCfg, m365Config, opts.Repo, r.GetID()) if err != nil { logger.CtxErr(ctx, err).Info("writing to repository configuration") return nil, nil, err diff --git a/src/cmd/longevity_test/longevity.go b/src/cmd/longevity_test/longevity.go index b3d6f865d..6f36057cc 100644 --- a/src/cmd/longevity_test/longevity.go +++ b/src/cmd/longevity_test/longevity.go @@ -29,7 +29,8 @@ func deleteBackups( ) ([]string, error) { ctx = clues.Add(ctx, "cutoff_days", deletionDays) - r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, nil) + provider, _ := config.GetStorageProviderFromConfigFile(ctx) + r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, provider, nil) if err != nil { return nil, clues.Wrap(err, "connecting to account").WithClues(ctx) } diff --git a/src/internal/kopia/s3.go b/src/internal/kopia/s3.go index adad4330e..a4f7524e2 100644 --- a/src/internal/kopia/s3.go +++ b/src/internal/kopia/s3.go @@ -20,29 +20,32 @@ func s3BlobStorage( repoOpts repository.Options, s storage.Storage, ) (blob.Storage, error) { - cfg, err := s.S3Config() + cfg, err := s.GetStorageConfig() if err != nil { return nil, clues.Stack(err).WithClues(ctx) } + // Cast to S3Config + s3Cfg := cfg.(storage.S3Config) + endpoint := defaultS3Endpoint - if len(cfg.Endpoint) > 0 { - endpoint = cfg.Endpoint + if len(s3Cfg.Endpoint) > 0 { + endpoint = s3Cfg.Endpoint } opts := s3.Options{ - BucketName: cfg.Bucket, + BucketName: s3Cfg.Bucket, Endpoint: endpoint, - Prefix: cfg.Prefix, - DoNotUseTLS: cfg.DoNotUseTLS, - DoNotVerifyTLS: cfg.DoNotVerifyTLS, + Prefix: s3Cfg.Prefix, + DoNotUseTLS: s3Cfg.DoNotUseTLS, + DoNotVerifyTLS: s3Cfg.DoNotVerifyTLS, Tags: s.SessionTags, SessionName: s.SessionName, RoleARN: s.Role, RoleDuration: s.SessionDuration, - AccessKeyID: cfg.AccessKey, - SecretAccessKey: cfg.SecretKey, - SessionToken: cfg.SessionToken, + AccessKeyID: s3Cfg.AccessKey, + SecretAccessKey: s3Cfg.SecretKey, + SessionToken: s3Cfg.SessionToken, TLSHandshakeTimeout: 60, PointInTime: repoOpts.ViewTimestamp, } diff --git a/src/pkg/storage/s3.go b/src/pkg/storage/s3.go index a332326e8..33aaf4592 100644 --- a/src/pkg/storage/s3.go +++ b/src/pkg/storage/s3.go @@ -4,12 +4,29 @@ import ( "strconv" "github.com/alcionai/clues" + "github.com/aws/aws-sdk-go/aws/defaults" + "github.com/spf13/cast" "github.com/alcionai/corso/src/internal/common" "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/pkg/credentials" ) +const ( + // S3 config + StorageProviderTypeKey = "provider" + BucketNameKey = "bucket" + EndpointKey = "endpoint" + PrefixKey = "prefix" + DisableTLSKey = "disable_tls" + DisableTLSVerificationKey = "disable_tls_verification" + RepoID = "repo_id" + + AccessKey = "aws_access_key_id" + SecretAccessKey = "aws_secret_access_key" + SessionToken = "aws_session_token" +) + type S3Config struct { credentials.AWS Bucket string // required @@ -50,6 +67,120 @@ func (c S3Config) Normalize() S3Config { } } +// No need to return error here. Viper returns empty values. +func s3ConfigsFromStore(kvs KVStorer) S3Config { + var s3Config S3Config + + s3Config.Bucket = cast.ToString(kvs.Get(BucketNameKey)) + s3Config.Endpoint = cast.ToString(kvs.Get(EndpointKey)) + s3Config.Prefix = cast.ToString(kvs.Get(PrefixKey)) + s3Config.DoNotUseTLS = cast.ToBool(kvs.Get(DisableTLSKey)) + s3Config.DoNotVerifyTLS = cast.ToBool(kvs.Get(DisableTLSVerificationKey)) + + return s3Config +} + +func s3CredsFromStore( + kvs KVStorer, + s3Config S3Config, +) S3Config { + s3Config.AccessKey = cast.ToString(kvs.Get(AccessKey)) + s3Config.SecretKey = cast.ToString(kvs.Get(SecretAccessKey)) + s3Config.SessionToken = cast.ToString(kvs.Get(SessionToken)) + + return s3Config +} + +var _ StorageConfigurer = S3Config{} + +func (c S3Config) FetchConfigFromStore( + kvs KVStorer, + readConfigFromStore bool, + matchFromConfig bool, + overrides map[string]string, +) error { + var ( + s3Cfg S3Config + err error + ) + + if readConfigFromStore { + s3Cfg = s3ConfigsFromStore(kvs) + if b, ok := overrides[Bucket]; ok { + overrides[Bucket] = common.NormalizeBucket(b) + } + + if p, ok := overrides[Prefix]; ok { + overrides[Prefix] = common.NormalizePrefix(p) + } + + if matchFromConfig { + providerType := cast.ToString(kvs.Get(StorageProviderTypeKey)) + if providerType != ProviderS3.String() { + return clues.New("unsupported storage provider: " + providerType) + } + + // This is matching override values from config file. + if err := mustMatchConfig(kvs, s3Overrides(overrides)); err != nil { + return clues.Wrap(err, "verifying s3 configs in corso config file") + } + } + } + + s3Cfg = s3CredsFromStore(kvs, s3Cfg) + aws := credentials.GetAWS(overrides) + + if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 { + _, err = defaults.CredChain( + defaults.Config().WithCredentialsChainVerboseErrors(true), + defaults.Handlers()).Get() + if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) { + aws = credentials.AWS{ + AccessKey: s3Cfg.AccessKey, + SecretKey: s3Cfg.SecretKey, + SessionToken: s3Cfg.SessionToken, + } + err = nil + } + + if err != nil { + return clues.Wrap(err, "validating aws credentials") + } + } + + s3Cfg = S3Config{ + AWS: aws, + Bucket: str.First(overrides[Bucket], s3Cfg.Bucket), + Endpoint: str.First(overrides[Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"), + Prefix: str.First(overrides[Prefix], s3Cfg.Prefix), + DoNotUseTLS: str.ParseBool(str.First( + overrides[DoNotUseTLS], + strconv.FormatBool(s3Cfg.DoNotUseTLS), + "false")), + DoNotVerifyTLS: str.ParseBool(str.First( + overrides[DoNotVerifyTLS], + strconv.FormatBool(s3Cfg.DoNotVerifyTLS), + "false")), + } + + return nil +} + +var _ WriteConfigToStorer = S3Config{} + +func (c S3Config) WriteConfigToStore( + kvs KVStoreSetter, +) { + s3Config := c.Normalize() + + kvs.Set(StorageProviderTypeKey, ProviderS3.String()) + kvs.Set(BucketNameKey, s3Config.Bucket) + kvs.Set(EndpointKey, s3Config.Endpoint) + kvs.Set(PrefixKey, s3Config.Prefix) + kvs.Set(DisableTLSKey, s3Config.DoNotUseTLS) + kvs.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS) +} + // StringConfig transforms a s3Config struct into a plain // map[string]string. All values in the original struct which // serialize into the map are expected to be strings. @@ -70,19 +201,19 @@ func (c S3Config) StringConfig() (map[string]string, error) { } // S3Config retrieves the S3Config details from the Storage config. -func (s Storage) S3Config() (S3Config, error) { +func MakeS3ConfigFromMap(config map[string]string) (S3Config, error) { c := S3Config{} - if len(s.Config) > 0 { - c.AccessKey = orEmptyString(s.Config[keyS3AccessKey]) - c.SecretKey = orEmptyString(s.Config[keyS3SecretKey]) - c.SessionToken = orEmptyString(s.Config[keyS3SessionToken]) + if len(config) > 0 { + c.AccessKey = orEmptyString(config[keyS3AccessKey]) + c.SecretKey = orEmptyString(config[keyS3SecretKey]) + c.SessionToken = orEmptyString(config[keyS3SessionToken]) - c.Bucket = orEmptyString(s.Config[keyS3Bucket]) - c.Endpoint = orEmptyString(s.Config[keyS3Endpoint]) - c.Prefix = orEmptyString(s.Config[keyS3Prefix]) - c.DoNotUseTLS = str.ParseBool(s.Config[keyS3DoNotUseTLS]) - c.DoNotVerifyTLS = str.ParseBool(s.Config[keyS3DoNotVerifyTLS]) + c.Bucket = orEmptyString(config[keyS3Bucket]) + c.Endpoint = orEmptyString(config[keyS3Endpoint]) + c.Prefix = orEmptyString(config[keyS3Prefix]) + c.DoNotUseTLS = str.ParseBool(config[keyS3DoNotUseTLS]) + c.DoNotVerifyTLS = str.ParseBool(config[keyS3DoNotVerifyTLS]) } return c, c.validate() @@ -100,3 +231,44 @@ func (c S3Config) validate() error { return nil } + +var constToTomlKeyMap = map[string]string{ + Bucket: BucketNameKey, + Endpoint: EndpointKey, + Prefix: PrefixKey, + StorageProviderTypeKey: StorageProviderTypeKey, +} + +// mustMatchConfig compares the values of each key to their config file value in store. +// If any value differs from the store value, an error is returned. +// values in m that aren't stored in the config are ignored. +func mustMatchConfig(kvs KVStorer, m map[string]string) error { + for k, v := range m { + if len(v) == 0 { + continue // empty variables will get caught by configuration validators, if necessary + } + + tomlK, ok := constToTomlKeyMap[k] + if !ok { + continue // m may declare values which aren't stored in the config file + } + + vv := cast.ToString(kvs.Get(tomlK)) + if v != vv { + return clues.New("value of " + k + " (" + v + ") does not match corso configuration value (" + vv + ")") + } + } + + return nil +} + +func s3Overrides(in map[string]string) map[string]string { + return map[string]string{ + Bucket: in[Bucket], + Endpoint: in[Endpoint], + Prefix: in[Prefix], + DoNotUseTLS: in[DoNotUseTLS], + DoNotVerifyTLS: in[DoNotVerifyTLS], + StorageProviderTypeKey: in[StorageProviderTypeKey], + } +} diff --git a/src/pkg/storage/storage.go b/src/pkg/storage/storage.go index d1a1067a6..c55169be2 100644 --- a/src/pkg/storage/storage.go +++ b/src/pkg/storage/storage.go @@ -17,6 +17,15 @@ const ( ProviderFilesystem ProviderType = 2 // Filesystem ) +func StringToEnum(s string) StorageProvider { + switch s { + case ProviderS3.String(): + return ProviderS3 + } + + return ProviderUnknown +} + // storage parsing errors var ( errMissingRequired = clues.New("missing required storage configuration") @@ -82,3 +91,53 @@ func orEmptyString(v any) string { return v.(string) } + +func (s Storage) GetStorageConfig() (StorageConfigurer, error) { + switch s.Provider { + case ProviderS3: + return MakeS3ConfigFromMap(s.Config) + } + + return nil, clues.New("unsupported storage provider: " + s.Provider.String()) +} + +func NewStorageConfig(provider string) (StorageConfigurer, error) { + switch provider { + case ProviderS3.String(): + return S3Config{}, nil + } + + return nil, clues.New("unsupported storage provider: " + provider) +} + +// Change it to just getter +type KVStorer interface { + Get(key string) any + Set(key string, value any) +} + +type KVStoreSetter interface { + Set(key string, value any) +} + +// Call it configurer if necessary. +type StorageConfigurer interface { + common.StringConfigurer + FetchConfigFromStorer + WriteConfigToStorer +} + +type WriteConfigToStorer interface { + WriteConfigToStore( + kvs KVStoreSetter, + ) +} + +type FetchConfigFromStorer interface { + FetchConfigFromStore( + kv KVStorer, + readConfigFromStore bool, + matchFromConfig bool, + overrides map[string]string, + ) error +}