Compare commits

...

3 Commits

Author SHA1 Message Date
Abhishek Pandey
bc1e9b3161 Fix rebase issues 2023-09-13 19:31:06 +05:30
Abhishek Pandey
9705c9d212 More changes before piecemealing 2023-09-13 19:27:08 +05:30
Abhishek Pandey
8060227bf9 Add new prototype for storage config refactor 2023-09-13 19:25:22 +05:30
23 changed files with 492 additions and 192 deletions

View File

@ -290,7 +290,12 @@ func genericDeleteCommand(
ctx := clues.Add(cmd.Context(), "delete_backup_id", bID)
r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, repo.S3Overrides(cmd))
storageProvider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, pst, storageProvider, overrides)
if err != nil {
return Only(ctx, err)
}
@ -316,7 +321,12 @@ func genericListCommand(
) error {
ctx := cmd.Context()
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, provider, overrides)
if err != nil {
return Only(ctx, err)
}

View File

@ -168,7 +168,12 @@ func createExchangeCmd(cmd *cobra.Command, args []string) error {
return err
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.ExchangeService, repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.ExchangeService, provider, overrides)
if err != nil {
return Only(ctx, err)
}
@ -277,7 +282,12 @@ func detailsExchangeCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
opts := utils.MakeExchangeOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.ExchangeService, repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.ExchangeService, provider, overrides)
if err != nil {
return Only(ctx, err)
}

View File

@ -154,7 +154,12 @@ func createGroupsCmd(cmd *cobra.Command, args []string) error {
return err
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.GroupsService, repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.GroupsService, provider, overrides)
if err != nil {
return Only(ctx, err)
}
@ -226,7 +231,12 @@ func detailsGroupsCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
opts := utils.MakeGroupsOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.GroupsService, repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.GroupsService, provider, overrides)
if err != nil {
return Only(ctx, err)
}

View File

@ -140,9 +140,11 @@ func prepM365Test(
recorder = strings.Builder{}
)
cfg, err := st.S3Config()
c, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err))
cfg := c.(storage.S3Config)
force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),

View File

@ -149,7 +149,12 @@ func createOneDriveCmd(cmd *cobra.Command, args []string) error {
return err
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.OneDriveService, repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.OneDriveService, provider, overrides)
if err != nil {
return Only(ctx, err)
}
@ -235,7 +240,12 @@ func detailsOneDriveCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
opts := utils.MakeOneDriveOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.OneDriveService, repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.OneDriveService, provider, overrides)
if err != nil {
return Only(ctx, err)
}

View File

@ -159,7 +159,12 @@ func createSharePointCmd(cmd *cobra.Command, args []string) error {
return err
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.SharePointService, repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, acct, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.SharePointService, provider, overrides)
if err != nil {
return Only(ctx, err)
}
@ -319,7 +324,12 @@ func detailsSharePointCmd(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
opts := utils.MakeSharePointOpts(cmd)
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.SharePointService, repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, ctrlOpts, err := utils.GetAccountAndConnect(ctx, path.SharePointService, provider, overrides)
if err != nil {
return Only(ctx, err)
}

View File

@ -70,9 +70,12 @@ func preRun(cc *cobra.Command, args []string) error {
}
if !slices.Contains(avoidTheseDescription, cc.Short) {
overrides := repo.S3Overrides(cc)
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cc)
if err != nil {
return err
}
cfg, err := config.GetConfigRepoDetails(ctx, true, false, overrides)
cfg, err := config.GetConfigRepoDetails(ctx, provider, true, false, overrides)
if err != nil {
log.Error("Error while getting config info to run command: ", cc.Use)
return err

View File

@ -203,14 +203,14 @@ func Read(ctx context.Context) error {
// It does not check for conflicts or existing data.
func WriteRepoConfig(
ctx context.Context,
s3Config storage.S3Config,
scfg storage.WriteConfigToStorer,
m365Config account.M365Config,
repoOpts repository.Options,
repoID string,
) error {
return writeRepoConfigWithViper(
GetViper(ctx),
s3Config,
scfg,
m365Config,
repoOpts,
repoID)
@ -220,20 +220,12 @@ func WriteRepoConfig(
// struct for testing.
func writeRepoConfigWithViper(
vpr *viper.Viper,
s3Config storage.S3Config,
scfg storage.WriteConfigToStorer,
m365Config account.M365Config,
repoOpts repository.Options,
repoID string,
) error {
s3Config = s3Config.Normalize()
// Rudimentary support for persisting repo config
// TODO: Handle conflicts, support other config types
vpr.Set(StorageProviderTypeKey, storage.ProviderS3.String())
vpr.Set(BucketNameKey, s3Config.Bucket)
vpr.Set(EndpointKey, s3Config.Endpoint)
vpr.Set(PrefixKey, s3Config.Prefix)
vpr.Set(DisableTLSKey, s3Config.DoNotUseTLS)
vpr.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS)
scfg.WriteConfigToStore(vpr)
vpr.Set(RepoID, repoID)
// Need if-checks as Viper will write empty values otherwise.
@ -263,6 +255,7 @@ func writeRepoConfigWithViper(
// data sources (config file, env vars, flag overrides) and the config file.
func GetConfigRepoDetails(
ctx context.Context,
provider storage.ProviderType,
readFromFile bool,
mustMatchFromConfig bool,
overrides map[string]string,
@ -270,7 +263,7 @@ func GetConfigRepoDetails(
RepoDetails,
error,
) {
config, err := getStorageAndAccountWithViper(GetViper(ctx), readFromFile, mustMatchFromConfig, overrides)
config, err := getStorageAndAccountWithViper(GetViper(ctx), provider, readFromFile, mustMatchFromConfig, overrides)
return config, err
}
@ -278,6 +271,7 @@ func GetConfigRepoDetails(
// struct for testing.
func getStorageAndAccountWithViper(
vpr *viper.Viper,
provider storage.ProviderType,
readFromFile bool,
mustMatchFromConfig bool,
overrides map[string]string,
@ -312,7 +306,7 @@ func getStorageAndAccountWithViper(
return config, clues.Wrap(err, "retrieving account configuration details")
}
config.Storage, err = configureStorage(vpr, readConfigFromViper, mustMatchFromConfig, overrides)
config.Storage, err = configureStorage(vpr, provider, readConfigFromViper, mustMatchFromConfig, overrides)
if err != nil {
return config, clues.Wrap(err, "retrieving storage provider details")
}
@ -378,3 +372,17 @@ func requireProps(props map[string]string) error {
return nil
}
// Storage provider is not a flag. It can only be sourced from config file.
// Only exceptions are the commands that create a new repo.
// This is needed to figure out which storage overrides to use.
func GetStorageProviderFromConfigFile(ctx context.Context) (storage.ProviderType, error) {
vpr := GetViper(ctx)
provider := vpr.GetString(StorageProviderTypeKey)
if provider != storage.ProviderS3.String() {
return storage.ProviderUnknown, clues.New("unsupported storage provider: " + provider)
}
return storage.StringToProviderType[provider], nil
}

View File

@ -107,14 +107,23 @@ func (suite *ConfigSuite) TestReadRepoConfigBasic() {
err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err))
s3Cfg, err := s3ConfigsFromViper(vpr)
// Unset AWS env vars so that we can test reading creds from config file
os.Unsetenv(credentials.AWSAccessKeyID)
os.Unsetenv(credentials.AWSSecretAccessKey)
os.Unsetenv(credentials.AWSSessionToken)
sc, err := storage.S3Config{}.FetchConfigFromStore(vpr, true, true, nil)
require.NoError(t, err, clues.ToCore(err))
s3Cfg := sc.(storage.S3Config)
require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, b, s3Cfg.Bucket)
assert.Equal(t, "test-prefix/", s3Cfg.Prefix)
assert.Equal(t, disableTLS, strconv.FormatBool(s3Cfg.DoNotUseTLS))
assert.Equal(t, disableTLSVerification, strconv.FormatBool(s3Cfg.DoNotVerifyTLS))
s3Cfg, err = s3CredsFromViper(vpr, s3Cfg)
// s3Cfg, err = s3CredsFromViper(vpr, s3Cfg)
require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, accKey, s3Cfg.AWS.AccessKey)
assert.Equal(t, secret, s3Cfg.AWS.SecretKey)
@ -160,7 +169,11 @@ func (suite *ConfigSuite) TestWriteReadConfig() {
err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err))
readS3Cfg, err := s3ConfigsFromViper(vpr)
sc, err := storage.S3Config{}.FetchConfigFromStore(vpr, true, true, nil)
require.NoError(t, err, clues.ToCore(err))
readS3Cfg := sc.(storage.S3Config)
require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket)
assert.Equal(t, readS3Cfg.DoNotUseTLS, s3Cfg.DoNotUseTLS)
@ -326,12 +339,14 @@ func (suite *ConfigSuite) TestReadFromFlags() {
repoDetails, err := getStorageAndAccountWithViper(
vpr,
storage.ProviderS3,
true,
false,
overrides)
m365Config, _ := repoDetails.Account.M365Config()
s3Cfg, _ := repoDetails.Storage.S3Config()
cfg, _ := repoDetails.Storage.StorageConfig()
s3Cfg, _ := cfg.(storage.S3Config)
commonConfig, _ := repoDetails.Storage.CommonConfig()
pass := commonConfig.Corso.CorsoPassphrase
@ -400,11 +415,13 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount() {
err = vpr.ReadInConfig()
require.NoError(t, err, "reading repo config", clues.ToCore(err))
cfg, err := getStorageAndAccountWithViper(vpr, true, true, nil)
cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, true, true, nil)
require.NoError(t, err, "getting storage and account from config", clues.ToCore(err))
readS3Cfg, err := cfg.Storage.S3Config()
sc, err := cfg.Storage.StorageConfig()
require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err))
readS3Cfg := sc.(storage.S3Config)
assert.Equal(t, readS3Cfg.Bucket, s3Cfg.Bucket)
assert.Equal(t, readS3Cfg.Endpoint, s3Cfg.Endpoint)
assert.Equal(t, readS3Cfg.Prefix, s3Cfg.Prefix)
@ -448,11 +465,13 @@ func (suite *ConfigIntegrationSuite) TestGetStorageAndAccount_noFileOnlyOverride
StorageProviderTypeKey: storage.ProviderS3.String(),
}
cfg, err := getStorageAndAccountWithViper(vpr, false, true, overrides)
cfg, err := getStorageAndAccountWithViper(vpr, storage.ProviderS3, false, true, overrides)
require.NoError(t, err, "getting storage and account from config", clues.ToCore(err))
readS3Cfg, err := cfg.Storage.S3Config()
sc, err := cfg.Storage.StorageConfig()
require.NoError(t, err, "reading s3 config from storage", clues.ToCore(err))
readS3Cfg := sc.(storage.S3Config)
assert.Equal(t, readS3Cfg.Bucket, bkt)
assert.Equal(t, cfg.RepoID, "")
assert.Equal(t, readS3Cfg.Endpoint, end)

View File

@ -3,127 +3,40 @@ package config
import (
"os"
"path/filepath"
"strconv"
"github.com/alcionai/clues"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/spf13/viper"
"github.com/alcionai/corso/src/cli/flags"
"github.com/alcionai/corso/src/internal/common"
"github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/pkg/credentials"
"github.com/alcionai/corso/src/pkg/storage"
)
// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values.
func s3ConfigsFromViper(vpr *viper.Viper) (storage.S3Config, error) {
var s3Config storage.S3Config
s3Config.Bucket = vpr.GetString(BucketNameKey)
s3Config.Endpoint = vpr.GetString(EndpointKey)
s3Config.Prefix = vpr.GetString(PrefixKey)
s3Config.DoNotUseTLS = vpr.GetBool(DisableTLSKey)
s3Config.DoNotVerifyTLS = vpr.GetBool(DisableTLSVerificationKey)
return s3Config, nil
}
// prerequisite: readRepoConfig must have been run prior to this to populate the global viper values.
func s3CredsFromViper(vpr *viper.Viper, s3Config storage.S3Config) (storage.S3Config, error) {
s3Config.AccessKey = vpr.GetString(AccessKey)
s3Config.SecretKey = vpr.GetString(SecretAccessKey)
s3Config.SessionToken = vpr.GetString(SessionToken)
return s3Config, nil
}
func s3Overrides(in map[string]string) map[string]string {
return map[string]string{
storage.Bucket: in[storage.Bucket],
storage.Endpoint: in[storage.Endpoint],
storage.Prefix: in[storage.Prefix],
storage.DoNotUseTLS: in[storage.DoNotUseTLS],
storage.DoNotVerifyTLS: in[storage.DoNotVerifyTLS],
StorageProviderTypeKey: in[StorageProviderTypeKey],
}
}
// configureStorage builds a complete storage configuration from a mix of
// viper properties and manual overrides.
func configureStorage(
vpr *viper.Viper,
provider storage.ProviderType,
readConfigFromViper bool,
matchFromConfig bool,
overrides map[string]string,
) (storage.Storage, error) {
var (
s3Cfg storage.S3Config
store storage.Storage
err error
)
if readConfigFromViper {
if s3Cfg, err = s3ConfigsFromViper(vpr); err != nil {
return store, clues.Wrap(err, "reading s3 configs from corso config file")
}
storageCfg, _ := storage.NewStorageConfig(provider)
if b, ok := overrides[storage.Bucket]; ok {
overrides[storage.Bucket] = common.NormalizeBucket(b)
}
if p, ok := overrides[storage.Prefix]; ok {
overrides[storage.Prefix] = common.NormalizePrefix(p)
}
if matchFromConfig {
providerType := vpr.GetString(StorageProviderTypeKey)
if providerType != storage.ProviderS3.String() {
return store, clues.New("unsupported storage provider: " + providerType)
}
if err := mustMatchConfig(vpr, s3Overrides(overrides)); err != nil {
return store, clues.Wrap(err, "verifying s3 configs in corso config file")
}
}
}
if s3Cfg, err = s3CredsFromViper(vpr, s3Cfg); err != nil {
return store, clues.Wrap(err, "reading s3 configs from corso config file")
}
s3Overrides(overrides)
aws := credentials.GetAWS(overrides)
if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 {
_, err = defaults.CredChain(defaults.Config().WithCredentialsChainVerboseErrors(true), defaults.Handlers()).Get()
if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) {
aws = credentials.AWS{
AccessKey: s3Cfg.AccessKey,
SecretKey: s3Cfg.SecretKey,
SessionToken: s3Cfg.SessionToken,
}
err = nil
}
if err != nil {
return store, clues.Wrap(err, "validating aws credentials")
}
}
s3Cfg = storage.S3Config{
AWS: aws,
Bucket: str.First(overrides[storage.Bucket], s3Cfg.Bucket),
Endpoint: str.First(overrides[storage.Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"),
Prefix: str.First(overrides[storage.Prefix], s3Cfg.Prefix),
DoNotUseTLS: str.ParseBool(str.First(
overrides[storage.DoNotUseTLS],
strconv.FormatBool(s3Cfg.DoNotUseTLS),
"false")),
DoNotVerifyTLS: str.ParseBool(str.First(
overrides[storage.DoNotVerifyTLS],
strconv.FormatBool(s3Cfg.DoNotVerifyTLS),
"false")),
// Rename this. It's not just fetch config from store.
storageCfg, err = storageCfg.FetchConfigFromStore(
vpr,
readConfigFromViper,
matchFromConfig,
overrides)
if err != nil {
return store, clues.Wrap(err, "fetching storage config from store")
}
// compose the common config and credentials
@ -145,14 +58,13 @@ func configureStorage(
// ensure required properties are present
if err := requireProps(map[string]string{
storage.Bucket: s3Cfg.Bucket,
credentials.CorsoPassphrase: corso.CorsoPassphrase,
}); err != nil {
return storage.Storage{}, err
}
// build the storage
store, err = storage.NewStorage(storage.ProviderS3, s3Cfg, cCfg)
store, err = storage.NewStorage(provider, storageCfg, cCfg)
if err != nil {
return store, clues.Wrap(err, "configuring repository storage")
}

View File

@ -70,7 +70,12 @@ func runExport(
sel selectors.Selector,
backupID, serviceName string,
) error {
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), provider, overrides)
if err != nil {
return Only(ctx, err)
}

View File

@ -1,17 +1,20 @@
package repo
import (
"context"
"strings"
"github.com/alcionai/clues"
"github.com/spf13/cobra"
"golang.org/x/exp/maps"
"github.com/alcionai/corso/src/cli/config"
"github.com/alcionai/corso/src/cli/flags"
"github.com/alcionai/corso/src/cli/print"
"github.com/alcionai/corso/src/cli/utils"
"github.com/alcionai/corso/src/pkg/control/repository"
"github.com/alcionai/corso/src/pkg/path"
"github.com/alcionai/corso/src/pkg/storage"
)
const (
@ -121,7 +124,9 @@ func handleMaintenanceCmd(cmd *cobra.Command, args []string) error {
return err
}
r, _, err := utils.AccountConnectAndWriteRepoConfig(ctx, path.UnknownService, S3Overrides(cmd))
// Change this to override too?
r, _, err := utils.AccountConnectAndWriteRepoConfig(
ctx, path.UnknownService, storage.ProviderS3, S3Overrides(cmd))
if err != nil {
return print.Only(ctx, err)
}
@ -164,3 +169,37 @@ func getMaintenanceType(t string) (repository.MaintenanceType, error) {
return res, nil
}
func GetStorageOverrides(
ctx context.Context,
cmd *cobra.Command,
storageProvider string,
) (map[string]string, error) {
overrides := map[string]string{}
switch storageProvider {
case storage.ProviderS3.String():
overrides = S3Overrides(cmd)
}
return overrides, nil
}
func GetStorageProviderAndOverrides(
ctx context.Context,
cmd *cobra.Command,
) (storage.ProviderType, map[string]string, error) {
provider, err := config.GetStorageProviderFromConfigFile(ctx)
if err != nil {
return provider, nil, clues.Stack(err)
}
overrides := map[string]string{}
switch provider {
case storage.ProviderS3:
overrides = S3Overrides(cmd)
}
return provider, overrides, nil
}

View File

@ -2,7 +2,6 @@ package repo
import (
"strconv"
"strings"
"github.com/alcionai/clues"
"github.com/pkg/errors"
@ -92,7 +91,7 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
// s3 values from flags
s3Override := S3Overrides(cmd)
cfg, err := config.GetConfigRepoDetails(ctx, true, false, s3Override)
cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, true, false, s3Override)
if err != nil {
return Only(ctx, err)
}
@ -113,17 +112,18 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
cfg.Account.ID(),
opt)
s3Cfg, err := cfg.Storage.S3Config()
storageCfg, err := cfg.Storage.StorageConfig()
if err != nil {
return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration"))
}
if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") {
invalidEndpointErr := "endpoint doesn't support specifying protocol. " +
"pass --disable-tls flag to use http:// instead of default https://"
// BUG: This should be moved to validate()
// if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") {
// invalidEndpointErr := "endpoint doesn't support specifying protocol. " +
// "pass --disable-tls flag to use http:// instead of default https://"
return Only(ctx, clues.New(invalidEndpointErr))
}
// return Only(ctx, clues.New(invalidEndpointErr))
// }
m365, err := cfg.Account.M365Config()
if err != nil {
@ -146,9 +146,10 @@ func initS3Cmd(cmd *cobra.Command, args []string) error {
defer utils.CloseRepo(ctx, r)
Infof(ctx, "Initialized a S3 repository within bucket %s.", s3Cfg.Bucket)
// Strong typecast?
Infof(ctx, "Initialized a S3 repository within bucket %s.", cfg.Storage.Config[storage.Bucket])
if err = config.WriteRepoConfig(ctx, s3Cfg, m365, opt.Repo, r.GetID()); err != nil {
if err = config.WriteRepoConfig(ctx, storageCfg, m365, opt.Repo, r.GetID()); err != nil {
return Only(ctx, clues.Wrap(err, "Failed to write repository configuration"))
}
@ -178,7 +179,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
// s3 values from flags
s3Override := S3Overrides(cmd)
cfg, err := config.GetConfigRepoDetails(ctx, true, true, s3Override)
cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, true, true, s3Override)
if err != nil {
return Only(ctx, err)
}
@ -188,7 +189,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
repoID = events.RepoIDNotFound
}
s3Cfg, err := cfg.Storage.S3Config()
s3Cfg, err := cfg.Storage.StorageConfig()
if err != nil {
return Only(ctx, clues.Wrap(err, "Retrieving s3 configuration"))
}
@ -198,12 +199,13 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
return Only(ctx, clues.Wrap(err, "Failed to parse m365 account config"))
}
if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") {
invalidEndpointErr := "endpoint doesn't support specifying protocol. " +
"pass --disable-tls flag to use http:// instead of default https://"
// Move these to validate()?
// if strings.HasPrefix(s3Cfg.Endpoint, "http://") || strings.HasPrefix(s3Cfg.Endpoint, "https://") {
// invalidEndpointErr := "endpoint doesn't support specifying protocol. " +
// "pass --disable-tls flag to use http:// instead of default https://"
return Only(ctx, clues.New(invalidEndpointErr))
}
// return Only(ctx, clues.New(invalidEndpointErr))
// }
opts := utils.ControlWithConfig(cfg)
@ -219,7 +221,7 @@ func connectS3Cmd(cmd *cobra.Command, args []string) error {
defer utils.CloseRepo(ctx, r)
Infof(ctx, "Connected to S3 bucket %s.", s3Cfg.Bucket)
Infof(ctx, "Connected to S3 bucket %s.", cfg.Storage.Config[storage.Bucket])
if err = config.WriteRepoConfig(ctx, s3Cfg, m365, opts.Repo, r.GetID()); err != nil {
return Only(ctx, clues.Wrap(err, "Failed to write repository configuration"))

View File

@ -63,9 +63,11 @@ func (suite *S3E2ESuite) TestInitS3Cmd() {
defer flush()
st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config()
sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
if !test.hasConfigFile {
// Ideally we could use `/dev/null`, but you need a
@ -100,9 +102,11 @@ func (suite *S3E2ESuite) TestInitMultipleTimes() {
defer flush()
st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config()
sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr)
@ -129,9 +133,10 @@ func (suite *S3E2ESuite) TestInitS3Cmd_missingBucket() {
defer flush()
st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config()
sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
force := map[string]string{
tconfig.TestCfgBucket: "",
}
@ -182,9 +187,11 @@ func (suite *S3E2ESuite) TestConnectS3Cmd() {
defer flush()
st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config()
sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),
@ -230,9 +237,10 @@ func (suite *S3E2ESuite) TestConnectS3Cmd_BadBucket() {
defer flush()
st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config()
sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr)
@ -256,9 +264,11 @@ func (suite *S3E2ESuite) TestConnectS3Cmd_BadPrefix() {
defer flush()
st := storeTD.NewPrefixedS3Storage(t)
cfg, err := st.S3Config()
sc, err := st.StorageConfig()
require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
vpr, configFP := tconfig.MakeTempTestConfigClone(t, nil)
ctx = config.SetViper(ctx, vpr)

View File

@ -66,9 +66,11 @@ func (suite *RestoreExchangeE2ESuite) SetupSuite() {
suite.acct = tconfig.NewM365Account(t)
suite.st = storeTD.NewPrefixedS3Storage(t)
cfg, err := suite.st.S3Config()
sc, err := suite.st.StorageConfig()
require.NoError(t, err, clues.ToCore(err))
cfg := sc.(storage.S3Config)
force := map[string]string{
tconfig.TestCfgAccountProvider: account.ProviderM365.String(),
tconfig.TestCfgStorageProvider: storage.ProviderS3.String(),

View File

@ -103,7 +103,12 @@ func runRestore(
sel selectors.Selector,
backupID, serviceName string,
) error {
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), repo.S3Overrides(cmd))
provider, overrides, err := repo.GetStorageProviderAndOverrides(ctx, cmd)
if err != nil {
return Only(ctx, err)
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, sel.PathService(), provider, overrides)
if err != nil {
return Only(ctx, err)
}

View File

@ -24,9 +24,10 @@ var ErrNotYetImplemented = clues.New("not yet implemented")
func GetAccountAndConnect(
ctx context.Context,
pst path.ServiceType,
provider storage.ProviderType,
overrides map[string]string,
) (repository.Repository, *storage.Storage, *account.Account, *control.Options, error) {
cfg, err := config.GetConfigRepoDetails(ctx, true, true, overrides)
cfg, err := config.GetConfigRepoDetails(ctx, provider, true, true, overrides)
if err != nil {
return nil, nil, nil, nil, err
}
@ -55,17 +56,19 @@ func GetAccountAndConnect(
func AccountConnectAndWriteRepoConfig(
ctx context.Context,
pst path.ServiceType,
provider storage.ProviderType,
overrides map[string]string,
) (repository.Repository, *account.Account, error) {
r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, overrides)
r, stg, acc, opts, err := GetAccountAndConnect(ctx, pst, provider, overrides)
if err != nil {
logger.CtxErr(ctx, err).Info("getting and connecting account")
return nil, nil, err
}
s3Config, err := stg.S3Config()
storageCfg, err := stg.StorageConfig()
if err != nil {
logger.CtxErr(ctx, err).Info("getting storage configuration")
return nil, nil, err
}
@ -77,7 +80,7 @@ func AccountConnectAndWriteRepoConfig(
// repo config gets set during repo connect and init.
// This call confirms we have the correct values.
err = config.WriteRepoConfig(ctx, s3Config, m365Config, opts.Repo, r.GetID())
err = config.WriteRepoConfig(ctx, storageCfg, m365Config, opts.Repo, r.GetID())
if err != nil {
logger.CtxErr(ctx, err).Info("writing to repository configuration")
return nil, nil, err

View File

@ -16,6 +16,7 @@ import (
"github.com/alcionai/corso/src/pkg/logger"
"github.com/alcionai/corso/src/pkg/path"
"github.com/alcionai/corso/src/pkg/repository"
"github.com/alcionai/corso/src/pkg/storage"
"github.com/alcionai/corso/src/pkg/store"
)
@ -29,7 +30,12 @@ func deleteBackups(
) ([]string, error) {
ctx = clues.Add(ctx, "cutoff_days", deletionDays)
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, nil)
provider, err := config.GetStorageProviderFromConfigFile(ctx)
if err != nil {
return nil, err
}
r, _, _, _, err := utils.GetAccountAndConnect(ctx, service, provider, nil)
if err != nil {
return nil, clues.Wrap(err, "connecting to account").WithClues(ctx)
}
@ -87,7 +93,7 @@ func pitrListBackups(
// TODO(ashmrtn): This may be moved into CLI layer at some point when we add
// flags for opening a repo at a point in time.
cfg, err := config.GetConfigRepoDetails(ctx, true, true, nil)
cfg, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, true, true, nil)
if err != nil {
return clues.Wrap(err, "getting config info")
}

View File

@ -187,16 +187,18 @@ func handleCheckerCommand(cmd *cobra.Command, args []string, f flags) error {
storage.Prefix: f.bucketPrefix,
}
repoDetails, err := config.GetConfigRepoDetails(ctx, false, false, overrides)
repoDetails, err := config.GetConfigRepoDetails(ctx, storage.ProviderS3, false, false, overrides)
if err != nil {
return clues.Wrap(err, "getting storage config")
}
cfg, err := repoDetails.Storage.S3Config()
c, err := repoDetails.Storage.StorageConfig()
if err != nil {
return clues.Wrap(err, "getting S3 config")
}
cfg := c.(storage.S3Config)
endpoint := defaultS3Endpoint
if len(cfg.Endpoint) > 0 {
endpoint = cfg.Endpoint

View File

@ -8,6 +8,7 @@ import (
"github.com/kopia/kopia/repo/blob/s3"
"github.com/alcionai/corso/src/pkg/control/repository"
"github.com/alcionai/corso/src/pkg/logger"
"github.com/alcionai/corso/src/pkg/storage"
)
@ -20,29 +21,33 @@ func s3BlobStorage(
repoOpts repository.Options,
s storage.Storage,
) (blob.Storage, error) {
cfg, err := s.S3Config()
cfg, err := s.StorageConfig()
if err != nil {
return nil, clues.Stack(err).WithClues(ctx)
}
// Cast to S3Config
s3Cfg := cfg.(storage.S3Config)
endpoint := defaultS3Endpoint
if len(cfg.Endpoint) > 0 {
endpoint = cfg.Endpoint
if len(s3Cfg.Endpoint) > 0 {
endpoint = s3Cfg.Endpoint
}
logger.Ctx(ctx).Infow("aws creds", "key", s3Cfg.AccessKey)
opts := s3.Options{
BucketName: cfg.Bucket,
BucketName: s3Cfg.Bucket,
Endpoint: endpoint,
Prefix: cfg.Prefix,
DoNotUseTLS: cfg.DoNotUseTLS,
DoNotVerifyTLS: cfg.DoNotVerifyTLS,
Prefix: s3Cfg.Prefix,
DoNotUseTLS: s3Cfg.DoNotUseTLS,
DoNotVerifyTLS: s3Cfg.DoNotVerifyTLS,
Tags: s.SessionTags,
SessionName: s.SessionName,
RoleARN: s.Role,
RoleDuration: s.SessionDuration,
AccessKeyID: cfg.AccessKey,
SecretAccessKey: cfg.SecretKey,
SessionToken: cfg.SessionToken,
AccessKeyID: s3Cfg.AccessKey,
SecretAccessKey: s3Cfg.SecretKey,
SessionToken: s3Cfg.SessionToken,
TLSHandshakeTimeout: 60,
PointInTime: repoOpts.ViewTimestamp,
}

View File

@ -4,12 +4,29 @@ import (
"strconv"
"github.com/alcionai/clues"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/spf13/cast"
"github.com/alcionai/corso/src/internal/common"
"github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/pkg/credentials"
)
const (
// S3 config
StorageProviderTypeKey = "provider"
BucketNameKey = "bucket"
EndpointKey = "endpoint"
PrefixKey = "prefix"
DisableTLSKey = "disable_tls"
DisableTLSVerificationKey = "disable_tls_verification"
RepoID = "repo_id"
AccessKey = "aws_access_key_id"
SecretAccessKey = "aws_secret_access_key"
SessionToken = "aws_session_token"
)
type S3Config struct {
credentials.AWS
Bucket string // required
@ -50,6 +67,121 @@ func (c S3Config) Normalize() S3Config {
}
}
// No need to return error here. Viper returns empty values.
func s3ConfigsFromStore(kvg KVStoreGetter) S3Config {
var s3Config S3Config
s3Config.Bucket = cast.ToString(kvg.Get(BucketNameKey))
s3Config.Endpoint = cast.ToString(kvg.Get(EndpointKey))
s3Config.Prefix = cast.ToString(kvg.Get(PrefixKey))
s3Config.DoNotUseTLS = cast.ToBool(kvg.Get(DisableTLSKey))
s3Config.DoNotVerifyTLS = cast.ToBool(kvg.Get(DisableTLSVerificationKey))
return s3Config
}
func s3CredsFromStore(
kvg KVStoreGetter,
s3Config S3Config,
) S3Config {
s3Config.AccessKey = cast.ToString(kvg.Get(AccessKey))
s3Config.SecretKey = cast.ToString(kvg.Get(SecretAccessKey))
s3Config.SessionToken = cast.ToString(kvg.Get(SessionToken))
return s3Config
}
var _ Configurer = S3Config{}
func (c S3Config) FetchConfigFromStore(
kvg KVStoreGetter,
readConfigFromStore bool,
matchFromConfig bool,
overrides map[string]string,
) (Configurer, error) {
var (
s3Cfg S3Config
err error
)
if readConfigFromStore {
s3Cfg = s3ConfigsFromStore(kvg)
if b, ok := overrides[Bucket]; ok {
overrides[Bucket] = common.NormalizeBucket(b)
}
if p, ok := overrides[Prefix]; ok {
overrides[Prefix] = common.NormalizePrefix(p)
}
if matchFromConfig {
providerType := cast.ToString(kvg.Get(StorageProviderTypeKey))
if providerType != ProviderS3.String() {
return S3Config{}, clues.New("unsupported storage provider: " + providerType)
}
// This is matching override values from config file.
if err := mustMatchConfig(kvg, s3Overrides(overrides)); err != nil {
return S3Config{}, clues.Wrap(err, "verifying s3 configs in corso config file")
}
}
}
s3Cfg = s3CredsFromStore(kvg, s3Cfg)
aws := credentials.GetAWS(overrides)
if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 {
_, err = defaults.CredChain(
defaults.Config().WithCredentialsChainVerboseErrors(true),
defaults.Handlers()).Get()
if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) {
aws = credentials.AWS{
AccessKey: s3Cfg.AccessKey,
SecretKey: s3Cfg.SecretKey,
SessionToken: s3Cfg.SessionToken,
}
err = nil
}
if err != nil {
return S3Config{}, clues.Wrap(err, "validating aws credentials")
}
}
s3Cfg = S3Config{
AWS: aws,
Bucket: str.First(overrides[Bucket], s3Cfg.Bucket),
Endpoint: str.First(overrides[Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"),
Prefix: str.First(overrides[Prefix], s3Cfg.Prefix),
DoNotUseTLS: str.ParseBool(str.First(
overrides[DoNotUseTLS],
strconv.FormatBool(s3Cfg.DoNotUseTLS),
"false")),
DoNotVerifyTLS: str.ParseBool(str.First(
overrides[DoNotVerifyTLS],
strconv.FormatBool(s3Cfg.DoNotVerifyTLS),
"false")),
}
return s3Cfg, s3Cfg.validate()
}
var _ WriteConfigToStorer = S3Config{}
func (c S3Config) WriteConfigToStore(
kvs KVStoreSetter,
) {
s3Config := c.Normalize()
kvs.Set(StorageProviderTypeKey, ProviderS3.String())
kvs.Set(BucketNameKey, s3Config.Bucket)
kvs.Set(EndpointKey, s3Config.Endpoint)
kvs.Set(PrefixKey, s3Config.Prefix)
kvs.Set(DisableTLSKey, s3Config.DoNotUseTLS)
kvs.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS)
}
// StringConfig transforms a s3Config struct into a plain
// map[string]string. All values in the original struct which
// serialize into the map are expected to be strings.
@ -70,19 +202,19 @@ func (c S3Config) StringConfig() (map[string]string, error) {
}
// S3Config retrieves the S3Config details from the Storage config.
func (s Storage) S3Config() (S3Config, error) {
func MakeS3ConfigFromMap(config map[string]string) (S3Config, error) {
c := S3Config{}
if len(s.Config) > 0 {
c.AccessKey = orEmptyString(s.Config[keyS3AccessKey])
c.SecretKey = orEmptyString(s.Config[keyS3SecretKey])
c.SessionToken = orEmptyString(s.Config[keyS3SessionToken])
if len(config) > 0 {
c.AccessKey = orEmptyString(config[keyS3AccessKey])
c.SecretKey = orEmptyString(config[keyS3SecretKey])
c.SessionToken = orEmptyString(config[keyS3SessionToken])
c.Bucket = orEmptyString(s.Config[keyS3Bucket])
c.Endpoint = orEmptyString(s.Config[keyS3Endpoint])
c.Prefix = orEmptyString(s.Config[keyS3Prefix])
c.DoNotUseTLS = str.ParseBool(s.Config[keyS3DoNotUseTLS])
c.DoNotVerifyTLS = str.ParseBool(s.Config[keyS3DoNotVerifyTLS])
c.Bucket = orEmptyString(config[keyS3Bucket])
c.Endpoint = orEmptyString(config[keyS3Endpoint])
c.Prefix = orEmptyString(config[keyS3Prefix])
c.DoNotUseTLS = str.ParseBool(config[keyS3DoNotUseTLS])
c.DoNotVerifyTLS = str.ParseBool(config[keyS3DoNotVerifyTLS])
}
return c, c.validate()
@ -100,3 +232,44 @@ func (c S3Config) validate() error {
return nil
}
var constToTomlKeyMap = map[string]string{
Bucket: BucketNameKey,
Endpoint: EndpointKey,
Prefix: PrefixKey,
StorageProviderTypeKey: StorageProviderTypeKey,
}
// mustMatchConfig compares the values of each key to their config file value in store.
// If any value differs from the store value, an error is returned.
// values in m that aren't stored in the config are ignored.
func mustMatchConfig(kvg KVStoreGetter, m map[string]string) error {
for k, v := range m {
if len(v) == 0 {
continue // empty variables will get caught by configuration validators, if necessary
}
tomlK, ok := constToTomlKeyMap[k]
if !ok {
continue // m may declare values which aren't stored in the config file
}
vv := cast.ToString(kvg.Get(tomlK))
if v != vv {
return clues.New("value of " + k + " (" + v + ") does not match corso configuration value (" + vv + ")")
}
}
return nil
}
func s3Overrides(in map[string]string) map[string]string {
return map[string]string{
Bucket: in[Bucket],
Endpoint: in[Endpoint],
Prefix: in[Prefix],
DoNotUseTLS: in[DoNotUseTLS],
DoNotVerifyTLS: in[DoNotVerifyTLS],
StorageProviderTypeKey: in[StorageProviderTypeKey],
}
}

View File

@ -66,9 +66,11 @@ func (suite *S3CfgSuite) TestStorage_S3Config() {
in := goodS3Config
s, err := NewStorage(ProviderS3, in)
assert.NoError(t, err, clues.ToCore(err))
out, err := s.S3Config()
sc, err := s.StorageConfig()
assert.NoError(t, err, clues.ToCore(err))
out := sc.(S3Config)
assert.Equal(t, in.Bucket, out.Bucket)
assert.Equal(t, in.Endpoint, out.Endpoint)
assert.Equal(t, in.Prefix, out.Prefix)
@ -117,7 +119,7 @@ func (suite *S3CfgSuite) TestStorage_S3Config_invalidCases() {
st, err := NewStorage(ProviderUnknown, goodS3Config)
assert.NoError(t, err, clues.ToCore(err))
test.amend(st)
_, err = st.S3Config()
_, err = st.StorageConfig()
assert.Error(t, err)
})
}

View File

@ -17,6 +17,11 @@ const (
ProviderFilesystem ProviderType = 2 // Filesystem
)
var StringToProviderType = map[string]ProviderType{
ProviderUnknown.String(): ProviderUnknown,
ProviderS3.String(): ProviderS3,
}
// storage parsing errors
var (
errMissingRequired = clues.New("missing required storage configuration")
@ -82,3 +87,50 @@ func orEmptyString(v any) string {
return v.(string)
}
func (s Storage) StorageConfig() (Configurer, error) {
switch s.Provider {
case ProviderS3:
return MakeS3ConfigFromMap(s.Config)
}
return nil, clues.New("unsupported storage provider: " + s.Provider.String())
}
func NewStorageConfig(provider ProviderType) (Configurer, error) {
switch provider {
case ProviderS3:
return S3Config{}, nil
}
return nil, clues.New("unsupported storage provider: " + provider.String())
}
type KVStoreGetter interface {
Get(key string) any
}
type KVStoreSetter interface {
Set(key string, value any)
}
type WriteConfigToStorer interface {
WriteConfigToStore(
kvs KVStoreSetter,
)
}
type FetchConfigFromStorer interface {
FetchConfigFromStore(
kvg KVStoreGetter,
readConfigFromStore bool,
matchFromConfig bool,
overrides map[string]string,
) (Configurer, error)
}
type Configurer interface {
common.StringConfigurer
FetchConfigFromStorer
WriteConfigToStorer
}