diff --git a/src/cli/config/config.go b/src/cli/config/config.go index c7aef8c91..2ad548cd2 100644 --- a/src/cli/config/config.go +++ b/src/cli/config/config.go @@ -184,6 +184,7 @@ func WriteRepoConfig(ctx context.Context, s3Config storage.S3Config, m365Config // writeRepoConfigWithViper implements WriteRepoConfig, but takes in a viper // struct for testing. func writeRepoConfigWithViper(vpr *viper.Viper, s3Config storage.S3Config, m365Config account.M365Config) error { + s3Config = s3Config.Normalize() // Rudimentary support for persisting repo config // TODO: Handle conflicts, support other config types vpr.Set(StorageProviderTypeKey, storage.ProviderS3.String()) diff --git a/src/pkg/storage/s3.go b/src/pkg/storage/s3.go index d63637334..367d3ce93 100644 --- a/src/pkg/storage/s3.go +++ b/src/pkg/storage/s3.go @@ -1,6 +1,8 @@ package storage import ( + "strings" + "github.com/pkg/errors" ) @@ -24,14 +26,23 @@ const ( Prefix = "prefix" ) +func (c S3Config) Normalize() S3Config { + return S3Config{ + Bucket: strings.TrimPrefix(c.Bucket, "s3://"), + Endpoint: c.Endpoint, + Prefix: c.Prefix, + } +} + // StringConfig transforms a s3Config struct into a plain // map[string]string. All values in the original struct which // serialize into the map are expected to be strings. func (c S3Config) StringConfig() (map[string]string, error) { + cn := c.Normalize() cfg := map[string]string{ - keyS3Bucket: c.Bucket, - keyS3Endpoint: c.Endpoint, - keyS3Prefix: c.Prefix, + keyS3Bucket: cn.Bucket, + keyS3Endpoint: cn.Endpoint, + keyS3Prefix: cn.Prefix, } return cfg, c.validate() diff --git a/src/pkg/storage/s3_test.go b/src/pkg/storage/s3_test.go index ebbb97efd..b74dadcd0 100644 --- a/src/pkg/storage/s3_test.go +++ b/src/pkg/storage/s3_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/alcionai/corso/src/pkg/storage" @@ -17,11 +18,19 @@ func TestS3CfgSuite(t *testing.T) { suite.Run(t, new(S3CfgSuite)) } -var goodS3Config = storage.S3Config{ - Bucket: "bkt", - Endpoint: "end", - Prefix: "pre", -} +var ( + goodS3Config = storage.S3Config{ + Bucket: "bkt", + Endpoint: "end", + Prefix: "pre", + } + + goodS3Map = map[string]string{ + "s3_bucket": "bkt", + "s3_endpoint": "end", + "s3_prefix": "pre", + } +) func (suite *S3CfgSuite) TestS3Config_Config() { s3 := goodS3Config @@ -55,7 +64,7 @@ func (suite *S3CfgSuite) TestStorage_S3Config() { assert.Equal(t, in.Prefix, out.Prefix) } -func makeTestS3Cfg(ak, bkt, end, pre, sk, tkn string) storage.S3Config { +func makeTestS3Cfg(bkt, end, pre string) storage.S3Config { return storage.S3Config{ Bucket: bkt, Endpoint: end, @@ -63,13 +72,13 @@ func makeTestS3Cfg(ak, bkt, end, pre, sk, tkn string) storage.S3Config { } } -func (suite *S3CfgSuite) TestStorage_S3Config_InvalidCases() { +func (suite *S3CfgSuite) TestStorage_S3Config_invalidCases() { // missing required properties table := []struct { name string cfg storage.S3Config }{ - {"missing bucket", makeTestS3Cfg("ak", "", "end", "pre", "sk", "tkn")}, + {"missing bucket", makeTestS3Cfg("", "end", "pre")}, } for _, test := range table { suite.T().Run(test.name, func(t *testing.T) { @@ -100,3 +109,44 @@ func (suite *S3CfgSuite) TestStorage_S3Config_InvalidCases() { }) } } + +func (suite *S3CfgSuite) TestStorage_S3Config_StringConfig() { + table := []struct { + name string + input storage.S3Config + expect map[string]string + }{ + { + name: "standard", + input: goodS3Config, + expect: goodS3Map, + }, + { + name: "normalized bucket name", + input: makeTestS3Cfg("s3://"+goodS3Config.Bucket, goodS3Config.Endpoint, goodS3Config.Prefix), + expect: goodS3Map, + }, + } + for _, test := range table { + suite.T().Run(test.name, func(t *testing.T) { + result, err := test.input.StringConfig() + require.NoError(t, err) + assert.Equal(t, test.expect, result) + }) + } +} + +func (suite *S3CfgSuite) TestStorage_S3Config_Normalize() { + const ( + prefixedBkt = "s3://bkt" + normalBkt = "bkt" + ) + + st := storage.S3Config{ + Bucket: prefixedBkt, + } + + result := st.Normalize() + assert.Equal(suite.T(), normalBkt, result.Bucket) + assert.NotEqual(suite.T(), st.Bucket, result.Bucket) +}