2023-09-13 19:25:22 +05:30

275 lines
7.4 KiB
Go

package storage
import (
"strconv"
"github.com/alcionai/clues"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/spf13/cast"
"github.com/alcionai/corso/src/internal/common"
"github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/pkg/credentials"
)
const (
// S3 config
StorageProviderTypeKey = "provider"
BucketNameKey = "bucket"
EndpointKey = "endpoint"
PrefixKey = "prefix"
DisableTLSKey = "disable_tls"
DisableTLSVerificationKey = "disable_tls_verification"
RepoID = "repo_id"
AccessKey = "aws_access_key_id"
SecretAccessKey = "aws_secret_access_key"
SessionToken = "aws_session_token"
)
type S3Config struct {
credentials.AWS
Bucket string // required
Endpoint string
Prefix string
DoNotUseTLS bool
DoNotVerifyTLS bool
}
// config key consts
const (
keyS3AccessKey = "s3_access_key"
keyS3Bucket = "s3_bucket"
keyS3Endpoint = "s3_endpoint"
keyS3Prefix = "s3_prefix"
keyS3SecretKey = "s3_secret_key"
keyS3SessionToken = "s3_session_token"
keyS3DoNotUseTLS = "s3_donotusetls"
keyS3DoNotVerifyTLS = "s3_donotverifytls"
)
// config exported name consts
const (
Bucket = "bucket"
Endpoint = "endpoint"
Prefix = "prefix"
DoNotUseTLS = "donotusetls"
DoNotVerifyTLS = "donotverifytls"
)
func (c S3Config) Normalize() S3Config {
return S3Config{
Bucket: common.NormalizeBucket(c.Bucket),
Endpoint: c.Endpoint,
Prefix: common.NormalizePrefix(c.Prefix),
DoNotUseTLS: c.DoNotUseTLS,
DoNotVerifyTLS: c.DoNotVerifyTLS,
}
}
// No need to return error here. Viper returns empty values.
func s3ConfigsFromStore(kvs KVStorer) S3Config {
var s3Config S3Config
s3Config.Bucket = cast.ToString(kvs.Get(BucketNameKey))
s3Config.Endpoint = cast.ToString(kvs.Get(EndpointKey))
s3Config.Prefix = cast.ToString(kvs.Get(PrefixKey))
s3Config.DoNotUseTLS = cast.ToBool(kvs.Get(DisableTLSKey))
s3Config.DoNotVerifyTLS = cast.ToBool(kvs.Get(DisableTLSVerificationKey))
return s3Config
}
func s3CredsFromStore(
kvs KVStorer,
s3Config S3Config,
) S3Config {
s3Config.AccessKey = cast.ToString(kvs.Get(AccessKey))
s3Config.SecretKey = cast.ToString(kvs.Get(SecretAccessKey))
s3Config.SessionToken = cast.ToString(kvs.Get(SessionToken))
return s3Config
}
var _ StorageConfigurer = S3Config{}
func (c S3Config) FetchConfigFromStore(
kvs KVStorer,
readConfigFromStore bool,
matchFromConfig bool,
overrides map[string]string,
) error {
var (
s3Cfg S3Config
err error
)
if readConfigFromStore {
s3Cfg = s3ConfigsFromStore(kvs)
if b, ok := overrides[Bucket]; ok {
overrides[Bucket] = common.NormalizeBucket(b)
}
if p, ok := overrides[Prefix]; ok {
overrides[Prefix] = common.NormalizePrefix(p)
}
if matchFromConfig {
providerType := cast.ToString(kvs.Get(StorageProviderTypeKey))
if providerType != ProviderS3.String() {
return clues.New("unsupported storage provider: " + providerType)
}
// This is matching override values from config file.
if err := mustMatchConfig(kvs, s3Overrides(overrides)); err != nil {
return clues.Wrap(err, "verifying s3 configs in corso config file")
}
}
}
s3Cfg = s3CredsFromStore(kvs, s3Cfg)
aws := credentials.GetAWS(overrides)
if len(aws.AccessKey) <= 0 || len(aws.SecretKey) <= 0 {
_, err = defaults.CredChain(
defaults.Config().WithCredentialsChainVerboseErrors(true),
defaults.Handlers()).Get()
if err != nil && (len(s3Cfg.AccessKey) > 0 || len(s3Cfg.SecretKey) > 0) {
aws = credentials.AWS{
AccessKey: s3Cfg.AccessKey,
SecretKey: s3Cfg.SecretKey,
SessionToken: s3Cfg.SessionToken,
}
err = nil
}
if err != nil {
return clues.Wrap(err, "validating aws credentials")
}
}
s3Cfg = S3Config{
AWS: aws,
Bucket: str.First(overrides[Bucket], s3Cfg.Bucket),
Endpoint: str.First(overrides[Endpoint], s3Cfg.Endpoint, "s3.amazonaws.com"),
Prefix: str.First(overrides[Prefix], s3Cfg.Prefix),
DoNotUseTLS: str.ParseBool(str.First(
overrides[DoNotUseTLS],
strconv.FormatBool(s3Cfg.DoNotUseTLS),
"false")),
DoNotVerifyTLS: str.ParseBool(str.First(
overrides[DoNotVerifyTLS],
strconv.FormatBool(s3Cfg.DoNotVerifyTLS),
"false")),
}
return nil
}
var _ WriteConfigToStorer = S3Config{}
func (c S3Config) WriteConfigToStore(
kvs KVStoreSetter,
) {
s3Config := c.Normalize()
kvs.Set(StorageProviderTypeKey, ProviderS3.String())
kvs.Set(BucketNameKey, s3Config.Bucket)
kvs.Set(EndpointKey, s3Config.Endpoint)
kvs.Set(PrefixKey, s3Config.Prefix)
kvs.Set(DisableTLSKey, s3Config.DoNotUseTLS)
kvs.Set(DisableTLSVerificationKey, s3Config.DoNotVerifyTLS)
}
// StringConfig transforms a s3Config struct into a plain
// map[string]string. All values in the original struct which
// serialize into the map are expected to be strings.
func (c S3Config) StringConfig() (map[string]string, error) {
cn := c.Normalize()
cfg := map[string]string{
keyS3AccessKey: c.AccessKey,
keyS3Bucket: cn.Bucket,
keyS3Endpoint: cn.Endpoint,
keyS3Prefix: cn.Prefix,
keyS3SecretKey: c.SecretKey,
keyS3SessionToken: c.SessionToken,
keyS3DoNotUseTLS: strconv.FormatBool(cn.DoNotUseTLS),
keyS3DoNotVerifyTLS: strconv.FormatBool(cn.DoNotVerifyTLS),
}
return cfg, c.validate()
}
// S3Config retrieves the S3Config details from the Storage config.
func MakeS3ConfigFromMap(config map[string]string) (S3Config, error) {
c := S3Config{}
if len(config) > 0 {
c.AccessKey = orEmptyString(config[keyS3AccessKey])
c.SecretKey = orEmptyString(config[keyS3SecretKey])
c.SessionToken = orEmptyString(config[keyS3SessionToken])
c.Bucket = orEmptyString(config[keyS3Bucket])
c.Endpoint = orEmptyString(config[keyS3Endpoint])
c.Prefix = orEmptyString(config[keyS3Prefix])
c.DoNotUseTLS = str.ParseBool(config[keyS3DoNotUseTLS])
c.DoNotVerifyTLS = str.ParseBool(config[keyS3DoNotVerifyTLS])
}
return c, c.validate()
}
func (c S3Config) validate() error {
check := map[string]string{
Bucket: c.Bucket,
}
for k, v := range check {
if len(v) == 0 {
return clues.Stack(errMissingRequired, clues.New(k))
}
}
return nil
}
var constToTomlKeyMap = map[string]string{
Bucket: BucketNameKey,
Endpoint: EndpointKey,
Prefix: PrefixKey,
StorageProviderTypeKey: StorageProviderTypeKey,
}
// mustMatchConfig compares the values of each key to their config file value in store.
// If any value differs from the store value, an error is returned.
// values in m that aren't stored in the config are ignored.
func mustMatchConfig(kvs KVStorer, m map[string]string) error {
for k, v := range m {
if len(v) == 0 {
continue // empty variables will get caught by configuration validators, if necessary
}
tomlK, ok := constToTomlKeyMap[k]
if !ok {
continue // m may declare values which aren't stored in the config file
}
vv := cast.ToString(kvs.Get(tomlK))
if v != vv {
return clues.New("value of " + k + " (" + v + ") does not match corso configuration value (" + vv + ")")
}
}
return nil
}
func s3Overrides(in map[string]string) map[string]string {
return map[string]string{
Bucket: in[Bucket],
Endpoint: in[Endpoint],
Prefix: in[Prefix],
DoNotUseTLS: in[DoNotUseTLS],
DoNotVerifyTLS: in[DoNotVerifyTLS],
StorageProviderTypeKey: in[StorageProviderTypeKey],
}
}