Implement more ModelStore Get functions (#267)

* Implement ModelStore GetByType and Get

* Add tests for ModelStore Get functions

* Add stricter "type" checks for loaded models

Take modelType as parameter and check the model in question matches that
type. Adds a little extra layer of protection if models happen to have
the same struct layout.
This commit is contained in:
ashmrtn 2022-07-05 15:51:29 -07:00 committed by GitHub
parent 8725cacc22
commit ed4c71c093
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 226 additions and 42 deletions

View File

@ -14,9 +14,10 @@ import (
const stableIDKey = "stableID" const stableIDKey = "stableID"
var ( var (
errNoModelStoreID = errors.New("model has no ModelStoreID") errNoModelStoreID = errors.New("model has no ModelStoreID")
errNoStableID = errors.New("model has no StableID") errNoStableID = errors.New("model has no StableID")
errBadTagKey = errors.New("tag key overlaps with required key") errBadTagKey = errors.New("tag key overlaps with required key")
errModelTypeMismatch = errors.New("model type doesn't match request")
) )
type modelType int type modelType int
@ -139,22 +140,71 @@ func (ms *ModelStore) Put(
return errors.Wrap(err, "putting model") return errors.Wrap(err, "putting model")
} }
// GetIDsForType returns all IDs for models that match the given type and have func stripHiddenTags(tags map[string]string) {
// the given tags. Returned IDs can be used in subsequent calls to Get, Update, delete(tags, stableIDKey)
// or Delete. delete(tags, manifest.TypeLabelKey)
}
func baseModelFromMetadata(m *manifest.EntryMetadata) (*model.BaseModel, error) {
id, ok := m.Labels[stableIDKey]
if !ok {
return nil, errors.WithStack(errNoStableID)
}
res := &model.BaseModel{
ModelStoreID: m.ID,
StableID: model.ID(id),
Tags: m.Labels,
}
stripHiddenTags(res.Tags)
return res, nil
}
// GetIDsForType returns metadata for all models that match the given type and
// have the given tags. Returned IDs can be used in subsequent calls to Get,
// Update, or Delete.
func (ms *ModelStore) GetIDsForType( func (ms *ModelStore) GetIDsForType(
ctx context.Context, ctx context.Context,
t modelType, t modelType,
tags map[string]string, tags map[string]string,
) ([]model.ID, error) { ) ([]*model.BaseModel, error) {
return nil, nil if _, ok := tags[stableIDKey]; ok {
return nil, errors.WithStack(errBadTagKey)
}
tmpTags, err := tagsForModel(t, tags)
if err != nil {
return nil, errors.Wrap(err, "getting model metadata")
}
metadata, err := ms.wrapper.rep.FindManifests(ctx, tmpTags)
if err != nil {
return nil, errors.Wrap(err, "getting model metadata")
}
res := make([]*model.BaseModel, 0, len(metadata))
for _, m := range metadata {
bm, err := baseModelFromMetadata(m)
if err != nil {
return nil, errors.Wrap(err, "parsing model metadata")
}
res = append(res, bm)
}
return res, nil
} }
// getModelStoreID gets the ModelStoreID of the model with the given // getModelStoreID gets the ModelStoreID of the model with the given
// StableID. Returns github.com/kopia/kopia/repo/manifest.ErrNotFound if no // StableID. Returns github.com/kopia/kopia/repo/manifest.ErrNotFound if no
// model was found. Returns an error if the given StableID is empty or more than // model was found. Returns an error if the given StableID is empty or more than
// one model has the same StableID. // one model has the same StableID.
func (ms *ModelStore) getModelStoreID(ctx context.Context, id model.ID) (manifest.ID, error) { func (ms *ModelStore) getModelStoreID(
ctx context.Context,
t modelType,
id model.ID,
) (manifest.ID, error) {
if len(id) == 0 { if len(id) == 0 {
return "", errors.WithStack(errNoStableID) return "", errors.WithStack(errNoStableID)
} }
@ -171,20 +221,38 @@ func (ms *ModelStore) getModelStoreID(ctx context.Context, id model.ID) (manifes
if len(metadata) != 1 { if len(metadata) != 1 {
return "", errors.New("multiple models with same StableID") return "", errors.New("multiple models with same StableID")
} }
if metadata[0].Labels[manifest.TypeLabelKey] != t.String() {
return "", errors.WithStack(errModelTypeMismatch)
}
return metadata[0].ID, nil return metadata[0].ID, nil
} }
// Get deserializes the model with the given ID into data. // Get deserializes the model with the given StableID into data. Returns
func (ms *ModelStore) Get(ctx context.Context, id model.ID, data any) error { // github.com/kopia/kopia/repo/manifest.ErrNotFound if no model was found.
return nil // Returns and error if the persisted model has a different type than expected
// or if multiple models have the same StableID.
func (ms *ModelStore) Get(
ctx context.Context,
t modelType,
id model.ID,
data model.Model,
) error {
modelID, err := ms.getModelStoreID(ctx, t, id)
if err != nil {
return err
}
return ms.GetWithModelStoreID(ctx, t, modelID, data)
} }
// GetWithModelStoreID deserializes the model with the given ModelStoreID into // GetWithModelStoreID deserializes the model with the given ModelStoreID into
// data. Returns github.com/kopia/kopia/repo/manifest.ErrNotFound if no model // data. Returns github.com/kopia/kopia/repo/manifest.ErrNotFound if no model
// was found. // was found. Returns and error if the persisted model has a different type than
// expected.
func (ms *ModelStore) GetWithModelStoreID( func (ms *ModelStore) GetWithModelStoreID(
ctx context.Context, ctx context.Context,
t modelType,
id manifest.ID, id manifest.ID,
data model.Model, data model.Model,
) error { ) error {
@ -199,11 +267,13 @@ func (ms *ModelStore) GetWithModelStoreID(
return errors.Wrap(err, "getting model data") return errors.Wrap(err, "getting model data")
} }
if metadata.Labels[manifest.TypeLabelKey] != t.String() {
return errors.WithStack(errModelTypeMismatch)
}
base := data.Base() base := data.Base()
base.Tags = metadata.Labels base.Tags = metadata.Labels
// Hide the fact that StableID and modelType are just a tag from the user. stripHiddenTags(base.Tags)
delete(base.Tags, stableIDKey)
delete(base.Tags, manifest.TypeLabelKey)
base.ModelStoreID = id base.ModelStoreID = id
return nil return nil
} }
@ -218,7 +288,7 @@ func (ms *ModelStore) checkPrevModelVersion(
t modelType, t modelType,
b *model.BaseModel, b *model.BaseModel,
) error { ) error {
id, err := ms.getModelStoreID(ctx, b.StableID) id, err := ms.getModelStoreID(ctx, t, b.StableID)
if err != nil { if err != nil {
return err return err
} }
@ -297,9 +367,10 @@ func (ms *ModelStore) Update(
} }
// Delete deletes the model with the given StableID. Turns into a noop if id is // Delete deletes the model with the given StableID. Turns into a noop if id is
// not empty but the model does not exist. // not empty but the model does not exist. Returns an error if multiple models
func (ms *ModelStore) Delete(ctx context.Context, id model.ID) error { // have the same StableID.
latest, err := ms.getModelStoreID(ctx, id) func (ms *ModelStore) Delete(ctx context.Context, t modelType, id model.ID) error {
latest, err := ms.getModelStoreID(ctx, t, id)
if err != nil { if err != nil {
if errors.Is(err, manifest.ErrNotFound) { if errors.Is(err, manifest.ErrNotFound) {
return nil return nil

View File

@ -84,6 +84,9 @@ func (suite *ModelStoreIntegrationSuite) TestBadTagsErrors() {
assert.Error(t, m.Put(ctx, BackupOpModel, foo)) assert.Error(t, m.Put(ctx, BackupOpModel, foo))
assert.Error(t, m.Update(ctx, BackupOpModel, foo)) assert.Error(t, m.Update(ctx, BackupOpModel, foo))
_, err := m.GetIDsForType(ctx, BackupOpModel, test.tags)
assert.Error(t, err)
}) })
} }
} }
@ -91,6 +94,7 @@ func (suite *ModelStoreIntegrationSuite) TestBadTagsErrors() {
func (suite *ModelStoreIntegrationSuite) TestNoIDsErrors() { func (suite *ModelStoreIntegrationSuite) TestNoIDsErrors() {
ctx := context.Background() ctx := context.Background()
t := suite.T() t := suite.T()
theModelType := BackupOpModel
m := getModelStore(t, ctx) m := getModelStore(t, ctx)
defer func() { defer func() {
@ -105,15 +109,56 @@ func (suite *ModelStoreIntegrationSuite) TestNoIDsErrors() {
noModelStoreID.StableID = model.ID(uuid.NewString()) noModelStoreID.StableID = model.ID(uuid.NewString())
noModelStoreID.ModelStoreID = "" noModelStoreID.ModelStoreID = ""
assert.Error(t, m.Update(ctx, BackupOpModel, noStableID)) assert.Error(t, m.Update(ctx, theModelType, noStableID))
assert.Error(t, m.Update(ctx, BackupOpModel, noModelStoreID)) assert.Error(t, m.Update(ctx, theModelType, noModelStoreID))
assert.Error(t, m.GetWithModelStoreID(ctx, "", nil)) assert.Error(t, m.Get(ctx, theModelType, "", nil))
assert.Error(t, m.GetWithModelStoreID(ctx, theModelType, "", nil))
assert.Error(t, m.Delete(ctx, "")) assert.Error(t, m.Delete(ctx, theModelType, ""))
assert.Error(t, m.DeleteWithModelStoreID(ctx, "")) assert.Error(t, m.DeleteWithModelStoreID(ctx, ""))
} }
func (suite *ModelStoreIntegrationSuite) TestBadModelTypeErrors() {
ctx := context.Background()
t := suite.T()
m := getModelStore(t, ctx)
defer func() {
assert.NoError(t, m.wrapper.Close(ctx))
}()
foo := &fooModel{Bar: uuid.NewString()}
assert.Error(t, m.Put(ctx, UnknownModel, foo))
require.NoError(t, m.Put(ctx, BackupOpModel, foo))
_, err := m.GetIDsForType(ctx, UnknownModel, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "model type")
}
func (suite *ModelStoreIntegrationSuite) TestBadTypeErrors() {
ctx := context.Background()
t := suite.T()
m := getModelStore(t, ctx)
defer func() {
assert.NoError(t, m.wrapper.Close(ctx))
}()
foo := &fooModel{Bar: uuid.NewString()}
require.NoError(t, m.Put(ctx, BackupOpModel, foo))
returned := &fooModel{}
assert.Error(t, m.Get(ctx, RestoreOpModel, foo.StableID, returned))
assert.Error(t, m.GetWithModelStoreID(ctx, RestoreOpModel, foo.ModelStoreID, returned))
assert.Error(t, m.Delete(ctx, RestoreOpModel, foo.StableID))
}
func (suite *ModelStoreIntegrationSuite) TestPutGet() { func (suite *ModelStoreIntegrationSuite) TestPutGet() {
table := []struct { table := []struct {
t modelType t modelType
@ -167,7 +212,11 @@ func (suite *ModelStoreIntegrationSuite) TestPutGet() {
require.NotEmpty(t, foo.StableID) require.NotEmpty(t, foo.StableID)
returned := &fooModel{} returned := &fooModel{}
err = m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned) err = m.Get(ctx, test.t, foo.StableID, returned)
require.NoError(t, err)
assert.Equal(t, foo, returned)
err = m.GetWithModelStoreID(ctx, test.t, foo.ModelStoreID, returned)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, foo, returned) assert.Equal(t, foo, returned)
}) })
@ -177,6 +226,7 @@ func (suite *ModelStoreIntegrationSuite) TestPutGet() {
func (suite *ModelStoreIntegrationSuite) TestPutGet_WithTags() { func (suite *ModelStoreIntegrationSuite) TestPutGet_WithTags() {
ctx := context.Background() ctx := context.Background()
t := suite.T() t := suite.T()
theModelType := BackupOpModel
m := getModelStore(t, ctx) m := getModelStore(t, ctx)
defer func() { defer func() {
@ -188,13 +238,17 @@ func (suite *ModelStoreIntegrationSuite) TestPutGet_WithTags() {
"bar": "baz", "bar": "baz",
} }
require.NoError(t, m.Put(ctx, BackupOpModel, foo)) require.NoError(t, m.Put(ctx, theModelType, foo))
require.NotEmpty(t, foo.ModelStoreID) require.NotEmpty(t, foo.ModelStoreID)
require.NotEmpty(t, foo.StableID) require.NotEmpty(t, foo.StableID)
returned := &fooModel{} returned := &fooModel{}
err := m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned) err := m.Get(ctx, theModelType, foo.StableID, returned)
require.NoError(t, err)
assert.Equal(t, foo, returned)
err = m.GetWithModelStoreID(ctx, theModelType, foo.ModelStoreID, returned)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, foo, returned) assert.Equal(t, foo, returned)
} }
@ -208,7 +262,8 @@ func (suite *ModelStoreIntegrationSuite) TestGet_NotFoundErrors() {
assert.NoError(t, m.wrapper.Close(ctx)) assert.NoError(t, m.wrapper.Close(ctx))
}() }()
assert.ErrorIs(t, m.GetWithModelStoreID(ctx, "baz", nil), manifest.ErrNotFound) assert.ErrorIs(t, m.Get(ctx, BackupOpModel, "baz", nil), manifest.ErrNotFound)
assert.ErrorIs(t, m.GetWithModelStoreID(ctx, BackupOpModel, "baz", nil), manifest.ErrNotFound)
} }
func (suite *ModelStoreIntegrationSuite) TestPutUpdate() { func (suite *ModelStoreIntegrationSuite) TestPutUpdate() {
@ -236,6 +291,7 @@ func (suite *ModelStoreIntegrationSuite) TestPutUpdate() {
for _, test := range table { for _, test := range table {
suite.T().Run(test.name, func(t *testing.T) { suite.T().Run(test.name, func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
theModelType := BackupOpModel
m := getModelStore(t, ctx) m := getModelStore(t, ctx)
defer func() { defer func() {
@ -246,24 +302,24 @@ func (suite *ModelStoreIntegrationSuite) TestPutUpdate() {
// Avoid some silly test errors from comparing nil to empty map. // Avoid some silly test errors from comparing nil to empty map.
foo.Tags = map[string]string{} foo.Tags = map[string]string{}
require.NoError(t, m.Put(ctx, BackupOpModel, foo)) require.NoError(t, m.Put(ctx, theModelType, foo))
oldModelID := foo.ModelStoreID oldModelID := foo.ModelStoreID
oldStableID := foo.StableID oldStableID := foo.StableID
test.mutator(foo) test.mutator(foo)
require.NoError(t, m.Update(ctx, BackupOpModel, foo)) require.NoError(t, m.Update(ctx, theModelType, foo))
assert.Equal(t, oldStableID, foo.StableID) assert.Equal(t, oldStableID, foo.StableID)
returned := &fooModel{} returned := &fooModel{}
require.NoError(t, m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned)) require.NoError(
t, m.GetWithModelStoreID(ctx, theModelType, foo.ModelStoreID, returned))
assert.Equal(t, foo, returned) assert.Equal(t, foo, returned)
// TODO(ashmrtn): Uncomment once GetIDsForType is implemented. ids, err := m.GetIDsForType(ctx, theModelType, nil)
//ids, err := m.GetIDsForType(ctx, BackupOpModel, nil) require.NoError(t, err)
//require.NoError(t, err) assert.Len(t, ids, 1)
//assert.Len(t, ids, 1)
if oldModelID == foo.ModelStoreID { if oldModelID == foo.ModelStoreID {
// Unlikely, but we don't control ModelStoreID generation and can't // Unlikely, but we don't control ModelStoreID generation and can't
@ -271,12 +327,67 @@ func (suite *ModelStoreIntegrationSuite) TestPutUpdate() {
return return
} }
err := m.GetWithModelStoreID(ctx, oldModelID, nil) err = m.GetWithModelStoreID(ctx, theModelType, oldModelID, nil)
assert.ErrorIs(t, err, manifest.ErrNotFound) assert.ErrorIs(t, err, manifest.ErrNotFound)
}) })
} }
} }
func (suite *ModelStoreIntegrationSuite) TestPutGetOfType() {
table := []struct {
t modelType
check require.ErrorAssertionFunc
hasErr bool
}{
{
t: UnknownModel,
check: require.Error,
hasErr: true,
},
{
t: BackupOpModel,
check: require.NoError,
hasErr: false,
},
{
t: RestoreOpModel,
check: require.NoError,
hasErr: false,
},
{
t: RestorePointModel,
check: require.NoError,
hasErr: false,
},
}
ctx := context.Background()
t := suite.T()
m := getModelStore(t, ctx)
defer func() {
assert.NoError(t, m.wrapper.Close(ctx))
}()
for _, test := range table {
suite.T().Run(test.t.String(), func(t *testing.T) {
foo := &fooModel{Bar: uuid.NewString()}
err := m.Put(ctx, test.t, foo)
test.check(t, err)
if test.hasErr {
return
}
ids, err := m.GetIDsForType(ctx, test.t, nil)
require.NoError(t, err)
assert.Len(t, ids, 1)
})
}
}
func (suite *ModelStoreIntegrationSuite) TestPutUpdate_FailsNotMatchingPrev() { func (suite *ModelStoreIntegrationSuite) TestPutUpdate_FailsNotMatchingPrev() {
startModelType := BackupOpModel startModelType := BackupOpModel
@ -323,6 +434,7 @@ func (suite *ModelStoreIntegrationSuite) TestPutUpdate_FailsNotMatchingPrev() {
func (suite *ModelStoreIntegrationSuite) TestPutDelete() { func (suite *ModelStoreIntegrationSuite) TestPutDelete() {
ctx := context.Background() ctx := context.Background()
t := suite.T() t := suite.T()
theModelType := BackupOpModel
m := getModelStore(t, ctx) m := getModelStore(t, ctx)
defer func() { defer func() {
@ -331,12 +443,12 @@ func (suite *ModelStoreIntegrationSuite) TestPutDelete() {
foo := &fooModel{Bar: uuid.NewString()} foo := &fooModel{Bar: uuid.NewString()}
require.NoError(t, m.Put(ctx, BackupOpModel, foo)) require.NoError(t, m.Put(ctx, theModelType, foo))
require.NoError(t, m.Delete(ctx, foo.StableID)) require.NoError(t, m.Delete(ctx, theModelType, foo.StableID))
returned := &fooModel{} returned := &fooModel{}
err := m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned) err := m.GetWithModelStoreID(ctx, theModelType, foo.ModelStoreID, returned)
assert.ErrorIs(t, err, manifest.ErrNotFound) assert.ErrorIs(t, err, manifest.ErrNotFound)
} }
@ -349,7 +461,7 @@ func (suite *ModelStoreIntegrationSuite) TestPutDelete_BadIDsNoop() {
assert.NoError(t, m.wrapper.Close(ctx)) assert.NoError(t, m.wrapper.Close(ctx))
}() }()
assert.NoError(t, m.Delete(ctx, "foo")) assert.NoError(t, m.Delete(ctx, BackupOpModel, "foo"))
assert.NoError(t, m.DeleteWithModelStoreID(ctx, "foo")) assert.NoError(t, m.DeleteWithModelStoreID(ctx, "foo"))
} }
@ -427,10 +539,11 @@ func (suite *ModelStoreRegressionSuite) TestFailDuringWriteSessionHasNoVisibleEf
assert.ErrorIs(t, err, assert.AnError) assert.ErrorIs(t, err, assert.AnError)
err = m.GetWithModelStoreID(ctx, newID, nil) err = m.GetWithModelStoreID(ctx, theModelType, newID, nil)
assert.ErrorIs(t, err, manifest.ErrNotFound) assert.ErrorIs(t, err, manifest.ErrNotFound)
returned := &fooModel{} returned := &fooModel{}
require.NoError(t, m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned)) require.NoError(
t, m.GetWithModelStoreID(ctx, theModelType, foo.ModelStoreID, returned))
assert.Equal(t, foo, returned) assert.Equal(t, foo, returned)
} }