From cc3306e5e011df5b50798fc2191550d4fbae9b81 Mon Sep 17 00:00:00 2001 From: Keepers <104464746+ryanfkeepers@users.noreply.github.com> Date: Thu, 2 Jun 2022 15:33:19 -0600 Subject: [PATCH] validate required storage props (#85) (#108) * validate required storage props (#85) Centralizes validation of required storage config properties within the storage package. Requiremens are checked eagerly at configuration creation, and lazily at config retrieval. Additionally, updates /pkg/storage tests to use suites and assert funcs. * add validation failure tests to storage --- src/cli/repo/s3.go | 12 +++- src/internal/kopia/kopia.go | 19 +++--- src/internal/kopia/s3.go | 6 +- src/pkg/repository/repository_test.go | 60 ++++++++++------- src/pkg/storage/common.go | 21 ++++-- src/pkg/storage/common_test.go | 57 ++++++++++++++-- src/pkg/storage/s3.go | 34 +++++++--- src/pkg/storage/s3_test.go | 78 ++++++++++++++++++++-- src/pkg/storage/storage.go | 33 +++++++--- src/pkg/storage/storage_test.go | 93 ++++++++++++++++++++------- 10 files changed, 319 insertions(+), 94 deletions(-) diff --git a/src/cli/repo/s3.go b/src/cli/repo/s3.go index 625c48423..947d8788b 100644 --- a/src/cli/repo/s3.go +++ b/src/cli/repo/s3.go @@ -69,7 +69,11 @@ func initS3Cmd(cmd *cobra.Command, args []string) { ClientID: mv.clientID, ClientSecret: mv.clientSecret, } - s := storage.NewStorage(storage.ProviderS3, s3Cfg, commonCfg) + s, err := storage.NewStorage(storage.ProviderS3, s3Cfg, commonCfg) + if err != nil { + fmt.Printf("Failed to configure storage provider: %v", err) + os.Exit(1) + } if _, err := repository.Initialize(cmd.Context(), a, s); err != nil { fmt.Printf("Failed to initialize a new S3 repository: %v", err) @@ -111,7 +115,11 @@ func connectS3Cmd(cmd *cobra.Command, args []string) { ClientID: mv.clientID, ClientSecret: mv.clientSecret, } - s := storage.NewStorage(storage.ProviderS3, s3Cfg, commonCfg) + s, err := storage.NewStorage(storage.ProviderS3, s3Cfg, commonCfg) + if err != nil { + fmt.Printf("Failed to configure storage provider: %v", err) + os.Exit(1) + } if _, err := repository.Connect(cmd.Context(), a, s); err != nil { fmt.Printf("Failed to connect to the S3 repository: %v", err) diff --git a/src/internal/kopia/kopia.go b/src/internal/kopia/kopia.go index ea44ddc8c..a24e7470d 100644 --- a/src/internal/kopia/kopia.go +++ b/src/internal/kopia/kopia.go @@ -15,9 +15,8 @@ const ( ) var ( - errInit = errors.New("initializing repo") - errConnect = errors.New("connecting repo") - errRequriesPassword = errors.New("corso password required") + errInit = errors.New("initializing repo") + errConnect = errors.New("connecting repo") ) type kopiaWrapper struct { @@ -35,9 +34,9 @@ func (kw kopiaWrapper) Initialize(ctx context.Context) error { } defer bst.Close(ctx) - cfg := kw.storage.CommonConfig() - if len(cfg.CorsoPassword) == 0 { - return errRequriesPassword + cfg, err := kw.storage.CommonConfig() + if err != nil { + return err } // todo - issue #75: nil here should be a storage.NewRepoOptions() @@ -66,9 +65,9 @@ func (kw kopiaWrapper) Connect(ctx context.Context) error { } defer bst.Close(ctx) - cfg := kw.storage.CommonConfig() - if len(cfg.CorsoPassword) == 0 { - return errRequriesPassword + cfg, err := kw.storage.CommonConfig() + if err != nil { + return err } // todo - issue #75: nil here should be storage.ConnectOptions() @@ -87,7 +86,7 @@ func (kw kopiaWrapper) Connect(ctx context.Context) error { func blobStoreByProvider(ctx context.Context, s storage.Storage) (blob.Storage, error) { switch s.Provider { case storage.ProviderS3: - return s3BlobStorage(ctx, s.S3Config()) + return s3BlobStorage(ctx, s) default: return nil, errors.New("storage provider details are required") } diff --git a/src/internal/kopia/s3.go b/src/internal/kopia/s3.go index 2990cb27c..556bee52c 100644 --- a/src/internal/kopia/s3.go +++ b/src/internal/kopia/s3.go @@ -13,7 +13,11 @@ const ( defaultS3Endpoint = "s3.amazonaws.com" // matches kopia's default value ) -func s3BlobStorage(ctx context.Context, cfg storage.S3Config) (blob.Storage, error) { +func s3BlobStorage(ctx context.Context, s storage.Storage) (blob.Storage, error) { + cfg, err := s.S3Config() + if err != nil { + return nil, err + } endpoint := defaultS3Endpoint if len(cfg.Endpoint) > 0 { endpoint = cfg.Endpoint diff --git a/src/pkg/repository/repository_test.go b/src/pkg/repository/repository_test.go index 215928891..974d00998 100644 --- a/src/pkg/repository/repository_test.go +++ b/src/pkg/repository/repository_test.go @@ -29,21 +29,25 @@ func TestRepositorySuite(t *testing.T) { func (suite *RepositorySuite) TestInitialize() { table := []struct { name string - storage storage.Storage + storage func() (storage.Storage, error) account repository.Account errCheck assert.ErrorAssertionFunc }{ { storage.ProviderUnknown.String(), - storage.NewStorage(storage.ProviderUnknown), + func() (storage.Storage, error) { + return storage.NewStorage(storage.ProviderUnknown) + }, repository.Account{}, assert.Error, }, } for _, test := range table { suite.T().Run(test.name, func(t *testing.T) { - _, err := repository.Initialize(context.Background(), test.account, test.storage) - test.errCheck(suite.T(), err, "") + st, err := test.storage() + assert.NoError(t, err) + _, err = repository.Initialize(context.Background(), test.account, st) + test.errCheck(t, err, "") }) } } @@ -53,21 +57,25 @@ func (suite *RepositorySuite) TestInitialize() { func (suite *RepositorySuite) TestConnect() { table := []struct { name string - storage storage.Storage + storage func() (storage.Storage, error) account repository.Account errCheck assert.ErrorAssertionFunc }{ { storage.ProviderUnknown.String(), - storage.NewStorage(storage.ProviderUnknown), + func() (storage.Storage, error) { + return storage.NewStorage(storage.ProviderUnknown) + }, repository.Account{}, assert.Error, }, } for _, test := range table { suite.T().Run(test.name, func(t *testing.T) { - _, err := repository.Connect(context.Background(), test.account, test.storage) - test.errCheck(suite.T(), err) + st, err := test.storage() + assert.NoError(t, err) + _, err = repository.Connect(context.Background(), test.account, st) + test.errCheck(t, err) }) } } @@ -108,31 +116,35 @@ func (suite *RepositoryIntegrationSuite) TestInitialize() { table := []struct { prefix string account repository.Account - storage storage.Storage + storage func() (storage.Storage, error) errCheck assert.ErrorAssertionFunc }{ { prefix: "init-s3-" + timeOfTest, - storage: storage.NewStorage( - storage.ProviderS3, - storage.S3Config{ - AccessKey: os.Getenv(storage.AWS_ACCESS_KEY_ID), - Bucket: "test-corso-repo-init", - Prefix: "init-s3-" + timeOfTest, - SecretKey: os.Getenv(storage.AWS_SECRET_ACCESS_KEY), - SessionToken: os.Getenv(storage.AWS_SESSION_TOKEN), - }, - storage.CommonConfig{ - CorsoPassword: os.Getenv(storage.CORSO_PASSWORD), - }, - ), + storage: func() (storage.Storage, error) { + return storage.NewStorage( + storage.ProviderS3, + storage.S3Config{ + AccessKey: os.Getenv(storage.AWS_ACCESS_KEY_ID), + Bucket: "test-corso-repo-init", + Prefix: "init-s3-" + timeOfTest, + SecretKey: os.Getenv(storage.AWS_SECRET_ACCESS_KEY), + SessionToken: os.Getenv(storage.AWS_SESSION_TOKEN), + }, + storage.CommonConfig{ + CorsoPassword: os.Getenv(storage.CORSO_PASSWORD), + }, + ) + }, errCheck: assert.NoError, }, } for _, test := range table { suite.T().Run(test.prefix, func(t *testing.T) { - _, err := repository.Initialize(ctx, test.account, test.storage) - test.errCheck(suite.T(), err) + st, err := test.storage() + assert.NoError(t, err) + _, err = repository.Initialize(ctx, test.account, st) + test.errCheck(t, err) }) } } diff --git a/src/pkg/storage/common.go b/src/pkg/storage/common.go index 039a5a130..c89550147 100644 --- a/src/pkg/storage/common.go +++ b/src/pkg/storage/common.go @@ -1,7 +1,9 @@ package storage +import "github.com/pkg/errors" + type CommonConfig struct { - CorsoPassword string + CorsoPassword string // required } // envvar consts @@ -14,17 +16,26 @@ const ( keyCommonCorsoPassword = "common_corsoPassword" ) -func (c CommonConfig) Config() config { - return config{ +func (c CommonConfig) Config() (config, error) { + cfg := config{ keyCommonCorsoPassword: c.CorsoPassword, } + return cfg, c.validate() } // CommonConfig retrieves the CommonConfig details from the Storage config. -func (s Storage) CommonConfig() CommonConfig { +func (s Storage) CommonConfig() (CommonConfig, error) { c := CommonConfig{} if len(s.Config) > 0 { c.CorsoPassword = orEmptyString(s.Config[keyCommonCorsoPassword]) } - return c + return c, c.validate() +} + +// ensures all required properties are present +func (c CommonConfig) validate() error { + if len(c.CorsoPassword) == 0 { + return errors.Wrap(errMissingRequired, CORSO_PASSWORD) + } + return nil } diff --git a/src/pkg/storage/common_test.go b/src/pkg/storage/common_test.go index ee7b98310..c3f908576 100644 --- a/src/pkg/storage/common_test.go +++ b/src/pkg/storage/common_test.go @@ -17,9 +17,13 @@ func TestCommonCfgSuite(t *testing.T) { suite.Run(t, new(CommonCfgSuite)) } +var goodCommonConfig = storage.CommonConfig{"passwd"} + func (suite *CommonCfgSuite) TestCommonConfig_Config() { - cfg := storage.CommonConfig{"passwd"} - c := cfg.Config() + cfg := goodCommonConfig + c, err := cfg.Config() + assert.NoError(suite.T(), err) + table := []struct { key string expect string @@ -28,14 +32,57 @@ func (suite *CommonCfgSuite) TestCommonConfig_Config() { } for _, test := range table { suite.T().Run(test.key, func(t *testing.T) { - assert.Equal(t, c[test.key], test.expect) + assert.Equal(t, test.expect, c[test.key]) }) } } func (suite *CommonCfgSuite) TestStorage_CommonConfig() { - in := storage.CommonConfig{"passwd"} - out := storage.NewStorage(storage.ProviderUnknown, in).CommonConfig() t := suite.T() + + in := goodCommonConfig + s, err := storage.NewStorage(storage.ProviderUnknown, in) + assert.NoError(t, err) + out, err := s.CommonConfig() + assert.NoError(t, err) + assert.Equal(t, in.CorsoPassword, out.CorsoPassword) } + +func (suite *CommonCfgSuite) TestStorage_CommonConfig_InvalidCases() { + // missing required properties + table := []struct { + name string + cfg storage.CommonConfig + }{ + {"missing password", storage.CommonConfig{}}, + } + for _, test := range table { + suite.T().Run(test.name, func(t *testing.T) { + _, err := storage.NewStorage(storage.ProviderUnknown, test.cfg) + assert.Error(t, err) + }) + } + + // required property not populated in storage + table2 := []struct { + name string + amend func(storage.Storage) + }{ + { + "missing password", + func(s storage.Storage) { + s.Config["common_corsoPassword"] = "" + }, + }, + } + for _, test := range table2 { + suite.T().Run(test.name, func(t *testing.T) { + st, err := storage.NewStorage(storage.ProviderUnknown, goodCommonConfig) + assert.NoError(t, err) + test.amend(st) + _, err = st.CommonConfig() + assert.Error(t, err) + }) + } +} diff --git a/src/pkg/storage/s3.go b/src/pkg/storage/s3.go index 8dd71b1f6..a6b3553d3 100644 --- a/src/pkg/storage/s3.go +++ b/src/pkg/storage/s3.go @@ -1,12 +1,14 @@ package storage +import "github.com/pkg/errors" + type S3Config struct { - AccessKey string - Bucket string + AccessKey string // required + Bucket string // required Endpoint string Prefix string - SecretKey string - SessionToken string + SecretKey string // required + SessionToken string // required } // envvar consts @@ -26,8 +28,8 @@ const ( keyS3SessionToken = "s3_sessionToken" ) -func (c S3Config) Config() config { - return config{ +func (c S3Config) Config() (config, error) { + cfg := config{ keyS3AccessKey: c.AccessKey, keyS3Bucket: c.Bucket, keyS3Endpoint: c.Endpoint, @@ -35,10 +37,11 @@ func (c S3Config) Config() config { keyS3SecretKey: c.SecretKey, keyS3SessionToken: c.SessionToken, } + return cfg, c.validate() } // S3Config retrieves the S3Config details from the Storage config. -func (s Storage) S3Config() S3Config { +func (s Storage) S3Config() (S3Config, error) { c := S3Config{} if len(s.Config) > 0 { c.AccessKey = orEmptyString(s.Config[keyS3AccessKey]) @@ -48,5 +51,20 @@ func (s Storage) S3Config() S3Config { c.SecretKey = orEmptyString(s.Config[keyS3SecretKey]) c.SessionToken = orEmptyString(s.Config[keyS3SessionToken]) } - return c + return c, c.validate() +} + +func (c S3Config) validate() error { + check := map[string]string{ + AWS_ACCESS_KEY_ID: c.AccessKey, + AWS_SECRET_ACCESS_KEY: c.SecretKey, + AWS_SESSION_TOKEN: c.SessionToken, + "bucket": c.Bucket, + } + for k, v := range check { + if len(v) == 0 { + return errors.Wrap(errMissingRequired, k) + } + } + return nil } diff --git a/src/pkg/storage/s3_test.go b/src/pkg/storage/s3_test.go index 80ef358d2..724cc5657 100644 --- a/src/pkg/storage/s3_test.go +++ b/src/pkg/storage/s3_test.go @@ -17,9 +17,13 @@ func TestS3CfgSuite(t *testing.T) { suite.Run(t, new(S3CfgSuite)) } +var goodS3Config = storage.S3Config{"ak", "bkt", "end", "pre", "sk", "tkn"} + func (suite *S3CfgSuite) TestS3Config_Config() { - s3 := storage.S3Config{"ak", "bkt", "end", "pre", "sk", "tkn"} - c := s3.Config() + s3 := goodS3Config + c, err := s3.Config() + assert.NoError(suite.T(), err) + table := []struct { key string expect string @@ -32,14 +36,19 @@ func (suite *S3CfgSuite) TestS3Config_Config() { {"s3_sessionToken", s3.SessionToken}, } for _, test := range table { - assert.Equal(suite.T(), c[test.key], test.expect) + assert.Equal(suite.T(), test.expect, c[test.key]) } } func (suite *S3CfgSuite) TestStorage_S3Config() { - in := storage.S3Config{"ak", "bkt", "end", "pre", "sk", "tkn"} - out := storage.NewStorage(storage.ProviderS3, in).S3Config() t := suite.T() + + in := goodS3Config + s, err := storage.NewStorage(storage.ProviderS3, in) + assert.NoError(t, err) + out, err := s.S3Config() + assert.NoError(t, err) + assert.Equal(t, in.Bucket, out.Bucket) assert.Equal(t, in.AccessKey, out.AccessKey) assert.Equal(t, in.Endpoint, out.Endpoint) @@ -47,3 +56,62 @@ func (suite *S3CfgSuite) TestStorage_S3Config() { assert.Equal(t, in.SecretKey, out.SecretKey) assert.Equal(t, in.SessionToken, out.SessionToken) } + +func (suite *S3CfgSuite) TestStorage_S3Config_InvalidCases() { + // missing required properties + table := []struct { + name string + cfg storage.S3Config + }{ + {"missing access key", storage.S3Config{"", "bkt", "end", "pre", "sk", "tkn"}}, + {"missing bucket", storage.S3Config{"ak", "", "end", "pre", "sk", "tkn"}}, + {"missing secret key", storage.S3Config{"ak", "bkt", "end", "pre", "", "tkn"}}, + {"missing session token", storage.S3Config{"ak", "bkt", "end", "pre", "sk", ""}}, + } + for _, test := range table { + suite.T().Run(test.name, func(t *testing.T) { + _, err := storage.NewStorage(storage.ProviderUnknown, test.cfg) + assert.Error(t, err) + }) + } + + // required property not populated in storage + table2 := []struct { + name string + amend func(storage.Storage) + }{ + { + "missing access key", + func(s storage.Storage) { + s.Config["s3_accessKey"] = "" + }, + }, + { + "missing bucket", + func(s storage.Storage) { + s.Config["s3_bucket"] = "" + }, + }, + { + "missing secret key", + func(s storage.Storage) { + s.Config["s3_secretKey"] = "" + }, + }, + { + "missing session token", + func(s storage.Storage) { + s.Config["s3_sessionToken"] = "" + }, + }, + } + for _, test := range table2 { + suite.T().Run(test.name, func(t *testing.T) { + st, err := storage.NewStorage(storage.ProviderUnknown, goodS3Config) + assert.NoError(t, err) + test.amend(st) + _, err = st.CommonConfig() + assert.Error(t, err) + }) + } +} diff --git a/src/pkg/storage/storage.go b/src/pkg/storage/storage.go index f9cd1c955..12a86d05c 100644 --- a/src/pkg/storage/storage.go +++ b/src/pkg/storage/storage.go @@ -1,6 +1,9 @@ package storage -import "fmt" +import ( + "errors" + "fmt" +) type storageProvider int @@ -10,10 +13,15 @@ const ( ProviderS3 // S3 ) +// storage parsing errors +var ( + errMissingRequired = errors.New("missing required storage configuration") +) + type ( config map[string]any configurer interface { - Config() config + Config() (config, error) } ) @@ -25,21 +33,26 @@ type Storage struct { } // NewStorage aggregates all the supplied configurations into a single configuration. -func NewStorage(p storageProvider, cfgs ...configurer) Storage { +func NewStorage(p storageProvider, cfgs ...configurer) (Storage, error) { + cs, err := unionConfigs(cfgs...) return Storage{ Provider: p, - Config: unionConfigs(cfgs...), - } + Config: cs, + }, err } -func unionConfigs(cfgs ...configurer) config { - c := config{} +func unionConfigs(cfgs ...configurer) (config, error) { + union := config{} for _, cfg := range cfgs { - for k, v := range cfg.Config() { - c[k] = v + c, err := cfg.Config() + if err != nil { + return nil, err + } + for k, v := range c { + union[k] = v } } - return c + return union, nil } // Helper for parsing the values in a config object. diff --git a/src/pkg/storage/storage_test.go b/src/pkg/storage/storage_test.go index 93ca464c4..0960ecbea 100644 --- a/src/pkg/storage/storage_test.go +++ b/src/pkg/storage/storage_test.go @@ -2,51 +2,96 @@ package storage import ( "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) type testConfig struct { expect string + err error } -func (c testConfig) Config() config { - return config{"expect": c.expect} +func (c testConfig) Config() (config, error) { + return config{"expect": c.expect}, c.err } -func TestNewStorage(t *testing.T) { +type StorageSuite struct { + suite.Suite +} + +func TestStorageSuite(t *testing.T) { + suite.Run(t, new(StorageSuite)) +} + +func (suite *StorageSuite) TestNewStorage() { table := []struct { - p storageProvider - c testConfig + name string + p storageProvider + c testConfig + errCheck assert.ErrorAssertionFunc }{ - {ProviderUnknown, testConfig{"unknown"}}, - {ProviderS3, testConfig{"s3"}}, + {"unknown no error", ProviderUnknown, testConfig{"configVal", nil}, assert.NoError}, + {"s3 no error", ProviderS3, testConfig{"configVal", nil}, assert.NoError}, + {"unknown w/ error", ProviderUnknown, testConfig{"configVal", assert.AnError}, assert.Error}, + {"s3 w/ error", ProviderS3, testConfig{"configVal", assert.AnError}, assert.Error}, } for _, test := range table { - s := NewStorage(test.p, test.c) - if s.Provider != test.p { - t.Errorf("expected storage provider [%s], got [%s]", test.p, s.Provider) - } - if s.Config["expect"] != test.c.expect { - t.Errorf("expected storage config [%s], got [%s]", test.c.expect, s.Config["expect"]) - } + suite.T().Run(test.name, func(t *testing.T) { + s, err := NewStorage(test.p, test.c) + test.errCheck(t, err) + // remaining tests are dependent upon error-free state + if test.c.err != nil { + return + } + assert.Equalf(t, + test.p, + s.Provider, + "expected storage provider [%s], got [%s]", test.p, s.Provider) + assert.Equalf(t, + test.c.expect, + s.Config["expect"], + "expected storage config [%s], got [%s]", test.c.expect, s.Config["expect"]) + }) } } type fooConfig struct { foo string + err error } -func (c fooConfig) Config() config { - return config{"foo": c.foo} +func (c fooConfig) Config() (config, error) { + return config{"foo": c.foo}, c.err } -func TestUnionConfigs(t *testing.T) { - te := testConfig{"test"} - f := fooConfig{"foo"} - cs := unionConfigs(te, f) - if cs["expect"] != te.expect { - t.Errorf("expected unioned config to have value [%s] at key [expect], got [%s]", te.expect, cs["expect"]) +func (suite *StorageSuite) TestUnionConfigs() { + table := []struct { + name string + tc testConfig + fc fooConfig + errCheck assert.ErrorAssertionFunc + }{ + {"no error", testConfig{"test", nil}, fooConfig{"foo", nil}, assert.NoError}, + {"tc error", testConfig{"test", assert.AnError}, fooConfig{"foo", nil}, assert.Error}, + {"fc error", testConfig{"test", nil}, fooConfig{"foo", assert.AnError}, assert.Error}, } - if cs["foo"] != f.foo { - t.Errorf("expected unioned config to have value [%s] at key [foo], got [%s]", f.foo, cs["foo"]) + for _, test := range table { + suite.T().Run(test.name, func(t *testing.T) { + cs, err := unionConfigs(test.tc, test.fc) + test.errCheck(t, err) + // remaining tests depend on error-free state + if test.tc.err != nil || test.fc.err != nil { + return + } + assert.Equalf(t, + test.tc.expect, + cs["expect"], + "expected unioned config to have value [%s] at key [expect], got [%s]", test.tc.expect, cs["expect"]) + assert.Equalf(t, + test.fc.foo, + cs["foo"], + "expected unioned config to have value [%s] at key [foo], got [%s]", test.fc.foo, cs["foo"]) + }) } }