diff --git a/src/internal/kopia/model_store.go b/src/internal/kopia/model_store.go index 2851f555d..dee614218 100644 --- a/src/internal/kopia/model_store.go +++ b/src/internal/kopia/model_store.go @@ -14,9 +14,10 @@ import ( const stableIDKey = "stableID" var ( - errNoModelStoreID = errors.New("model has no ModelStoreID") - errNoStableID = errors.New("model has no StableID") - errBadTagKey = errors.New("tag key overlaps with required key") + errNoModelStoreID = errors.New("model has no ModelStoreID") + errNoStableID = errors.New("model has no StableID") + errBadTagKey = errors.New("tag key overlaps with required key") + errModelTypeMismatch = errors.New("model type doesn't match request") ) type modelType int @@ -139,22 +140,71 @@ func (ms *ModelStore) Put( return errors.Wrap(err, "putting model") } -// GetIDsForType returns all IDs for models that match the given type and have -// the given tags. Returned IDs can be used in subsequent calls to Get, Update, -// or Delete. +func stripHiddenTags(tags map[string]string) { + delete(tags, stableIDKey) + delete(tags, manifest.TypeLabelKey) +} + +func baseModelFromMetadata(m *manifest.EntryMetadata) (*model.BaseModel, error) { + id, ok := m.Labels[stableIDKey] + if !ok { + return nil, errors.WithStack(errNoStableID) + } + + res := &model.BaseModel{ + ModelStoreID: m.ID, + StableID: model.ID(id), + Tags: m.Labels, + } + + stripHiddenTags(res.Tags) + return res, nil +} + +// GetIDsForType returns metadata for all models that match the given type and +// have the given tags. Returned IDs can be used in subsequent calls to Get, +// Update, or Delete. func (ms *ModelStore) GetIDsForType( ctx context.Context, t modelType, tags map[string]string, -) ([]model.ID, error) { - return nil, nil +) ([]*model.BaseModel, error) { + if _, ok := tags[stableIDKey]; ok { + return nil, errors.WithStack(errBadTagKey) + } + + tmpTags, err := tagsForModel(t, tags) + if err != nil { + return nil, errors.Wrap(err, "getting model metadata") + } + + metadata, err := ms.wrapper.rep.FindManifests(ctx, tmpTags) + if err != nil { + return nil, errors.Wrap(err, "getting model metadata") + } + + res := make([]*model.BaseModel, 0, len(metadata)) + for _, m := range metadata { + bm, err := baseModelFromMetadata(m) + if err != nil { + return nil, errors.Wrap(err, "parsing model metadata") + } + + res = append(res, bm) + } + + return res, nil } // getModelStoreID gets the ModelStoreID of the model with the given // StableID. Returns github.com/kopia/kopia/repo/manifest.ErrNotFound if no // model was found. Returns an error if the given StableID is empty or more than // one model has the same StableID. -func (ms *ModelStore) getModelStoreID(ctx context.Context, id model.ID) (manifest.ID, error) { +func (ms *ModelStore) getModelStoreID( + ctx context.Context, + t modelType, + id model.ID, +) (manifest.ID, error) { if len(id) == 0 { return "", errors.WithStack(errNoStableID) } @@ -171,20 +221,38 @@ func (ms *ModelStore) getModelStoreID(ctx context.Context, id model.ID) (manifes if len(metadata) != 1 { return "", errors.New("multiple models with same StableID") } + if metadata[0].Labels[manifest.TypeLabelKey] != t.String() { + return "", errors.WithStack(errModelTypeMismatch) + } return metadata[0].ID, nil } -// Get deserializes the model with the given ID into data. -func (ms *ModelStore) Get(ctx context.Context, id model.ID, data any) error { - return nil +// Get deserializes the model with the given StableID into data. Returns +// github.com/kopia/kopia/repo/manifest.ErrNotFound if no model was found. +// Returns and error if the persisted model has a different type than expected +// or if multiple models have the same StableID. +func (ms *ModelStore) Get( + ctx context.Context, + t modelType, + id model.ID, + data model.Model, +) error { + modelID, err := ms.getModelStoreID(ctx, t, id) + if err != nil { + return err + } + + return ms.GetWithModelStoreID(ctx, t, modelID, data) } // GetWithModelStoreID deserializes the model with the given ModelStoreID into // data. Returns github.com/kopia/kopia/repo/manifest.ErrNotFound if no model -// was found. +// was found. Returns and error if the persisted model has a different type than +// expected. func (ms *ModelStore) GetWithModelStoreID( ctx context.Context, + t modelType, id manifest.ID, data model.Model, ) error { @@ -199,11 +267,13 @@ func (ms *ModelStore) GetWithModelStoreID( return errors.Wrap(err, "getting model data") } + if metadata.Labels[manifest.TypeLabelKey] != t.String() { + return errors.WithStack(errModelTypeMismatch) + } + base := data.Base() base.Tags = metadata.Labels - // Hide the fact that StableID and modelType are just a tag from the user. - delete(base.Tags, stableIDKey) - delete(base.Tags, manifest.TypeLabelKey) + stripHiddenTags(base.Tags) base.ModelStoreID = id return nil } @@ -218,7 +288,7 @@ func (ms *ModelStore) checkPrevModelVersion( t modelType, b *model.BaseModel, ) error { - id, err := ms.getModelStoreID(ctx, b.StableID) + id, err := ms.getModelStoreID(ctx, t, b.StableID) if err != nil { return err } @@ -297,9 +367,10 @@ func (ms *ModelStore) Update( } // Delete deletes the model with the given StableID. Turns into a noop if id is -// not empty but the model does not exist. -func (ms *ModelStore) Delete(ctx context.Context, id model.ID) error { - latest, err := ms.getModelStoreID(ctx, id) +// not empty but the model does not exist. Returns an error if multiple models +// have the same StableID. +func (ms *ModelStore) Delete(ctx context.Context, t modelType, id model.ID) error { + latest, err := ms.getModelStoreID(ctx, t, id) if err != nil { if errors.Is(err, manifest.ErrNotFound) { return nil diff --git a/src/internal/kopia/model_store_test.go b/src/internal/kopia/model_store_test.go index 9af3832ee..bc64b7fe2 100644 --- a/src/internal/kopia/model_store_test.go +++ b/src/internal/kopia/model_store_test.go @@ -84,6 +84,9 @@ func (suite *ModelStoreIntegrationSuite) TestBadTagsErrors() { assert.Error(t, m.Put(ctx, BackupOpModel, foo)) assert.Error(t, m.Update(ctx, BackupOpModel, foo)) + + _, err := m.GetIDsForType(ctx, BackupOpModel, test.tags) + assert.Error(t, err) }) } } @@ -91,6 +94,7 @@ func (suite *ModelStoreIntegrationSuite) TestBadTagsErrors() { func (suite *ModelStoreIntegrationSuite) TestNoIDsErrors() { ctx := context.Background() t := suite.T() + theModelType := BackupOpModel m := getModelStore(t, ctx) defer func() { @@ -105,15 +109,56 @@ func (suite *ModelStoreIntegrationSuite) TestNoIDsErrors() { noModelStoreID.StableID = model.ID(uuid.NewString()) noModelStoreID.ModelStoreID = "" - assert.Error(t, m.Update(ctx, BackupOpModel, noStableID)) - assert.Error(t, m.Update(ctx, BackupOpModel, noModelStoreID)) + assert.Error(t, m.Update(ctx, theModelType, noStableID)) + assert.Error(t, m.Update(ctx, theModelType, noModelStoreID)) - assert.Error(t, m.GetWithModelStoreID(ctx, "", nil)) + assert.Error(t, m.Get(ctx, theModelType, "", nil)) + assert.Error(t, m.GetWithModelStoreID(ctx, theModelType, "", nil)) - assert.Error(t, m.Delete(ctx, "")) + assert.Error(t, m.Delete(ctx, theModelType, "")) assert.Error(t, m.DeleteWithModelStoreID(ctx, "")) } +func (suite *ModelStoreIntegrationSuite) TestBadModelTypeErrors() { + ctx := context.Background() + t := suite.T() + + m := getModelStore(t, ctx) + defer func() { + assert.NoError(t, m.wrapper.Close(ctx)) + }() + + foo := &fooModel{Bar: uuid.NewString()} + + assert.Error(t, m.Put(ctx, UnknownModel, foo)) + + require.NoError(t, m.Put(ctx, BackupOpModel, foo)) + + _, err := m.GetIDsForType(ctx, UnknownModel, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "model type") +} + +func (suite *ModelStoreIntegrationSuite) TestBadTypeErrors() { + ctx := context.Background() + t := suite.T() + + m := getModelStore(t, ctx) + defer func() { + assert.NoError(t, m.wrapper.Close(ctx)) + }() + + foo := &fooModel{Bar: uuid.NewString()} + + require.NoError(t, m.Put(ctx, BackupOpModel, foo)) + + returned := &fooModel{} + assert.Error(t, m.Get(ctx, RestoreOpModel, foo.StableID, returned)) + assert.Error(t, m.GetWithModelStoreID(ctx, RestoreOpModel, foo.ModelStoreID, returned)) + + assert.Error(t, m.Delete(ctx, RestoreOpModel, foo.StableID)) +} + func (suite *ModelStoreIntegrationSuite) TestPutGet() { table := []struct { t modelType @@ -167,7 +212,11 @@ func (suite *ModelStoreIntegrationSuite) TestPutGet() { require.NotEmpty(t, foo.StableID) returned := &fooModel{} - err = m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned) + err = m.Get(ctx, test.t, foo.StableID, returned) + require.NoError(t, err) + assert.Equal(t, foo, returned) + + err = m.GetWithModelStoreID(ctx, test.t, foo.ModelStoreID, returned) require.NoError(t, err) assert.Equal(t, foo, returned) }) @@ -177,6 +226,7 @@ func (suite *ModelStoreIntegrationSuite) TestPutGet() { func (suite *ModelStoreIntegrationSuite) TestPutGet_WithTags() { ctx := context.Background() t := suite.T() + theModelType := BackupOpModel m := getModelStore(t, ctx) defer func() { @@ -188,13 +238,17 @@ func (suite *ModelStoreIntegrationSuite) TestPutGet_WithTags() { "bar": "baz", } - require.NoError(t, m.Put(ctx, BackupOpModel, foo)) + require.NoError(t, m.Put(ctx, theModelType, foo)) require.NotEmpty(t, foo.ModelStoreID) require.NotEmpty(t, foo.StableID) returned := &fooModel{} - err := m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned) + err := m.Get(ctx, theModelType, foo.StableID, returned) + require.NoError(t, err) + assert.Equal(t, foo, returned) + + err = m.GetWithModelStoreID(ctx, theModelType, foo.ModelStoreID, returned) require.NoError(t, err) assert.Equal(t, foo, returned) } @@ -208,7 +262,8 @@ func (suite *ModelStoreIntegrationSuite) TestGet_NotFoundErrors() { assert.NoError(t, m.wrapper.Close(ctx)) }() - assert.ErrorIs(t, m.GetWithModelStoreID(ctx, "baz", nil), manifest.ErrNotFound) + assert.ErrorIs(t, m.Get(ctx, BackupOpModel, "baz", nil), manifest.ErrNotFound) + assert.ErrorIs(t, m.GetWithModelStoreID(ctx, BackupOpModel, "baz", nil), manifest.ErrNotFound) } func (suite *ModelStoreIntegrationSuite) TestPutUpdate() { @@ -236,6 +291,7 @@ func (suite *ModelStoreIntegrationSuite) TestPutUpdate() { for _, test := range table { suite.T().Run(test.name, func(t *testing.T) { ctx := context.Background() + theModelType := BackupOpModel m := getModelStore(t, ctx) defer func() { @@ -246,24 +302,24 @@ func (suite *ModelStoreIntegrationSuite) TestPutUpdate() { // Avoid some silly test errors from comparing nil to empty map. foo.Tags = map[string]string{} - require.NoError(t, m.Put(ctx, BackupOpModel, foo)) + require.NoError(t, m.Put(ctx, theModelType, foo)) oldModelID := foo.ModelStoreID oldStableID := foo.StableID test.mutator(foo) - require.NoError(t, m.Update(ctx, BackupOpModel, foo)) + require.NoError(t, m.Update(ctx, theModelType, foo)) assert.Equal(t, oldStableID, foo.StableID) returned := &fooModel{} - require.NoError(t, m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned)) + require.NoError( + t, m.GetWithModelStoreID(ctx, theModelType, foo.ModelStoreID, returned)) assert.Equal(t, foo, returned) - // TODO(ashmrtn): Uncomment once GetIDsForType is implemented. - //ids, err := m.GetIDsForType(ctx, BackupOpModel, nil) - //require.NoError(t, err) - //assert.Len(t, ids, 1) + ids, err := m.GetIDsForType(ctx, theModelType, nil) + require.NoError(t, err) + assert.Len(t, ids, 1) if oldModelID == foo.ModelStoreID { // Unlikely, but we don't control ModelStoreID generation and can't @@ -271,12 +327,67 @@ func (suite *ModelStoreIntegrationSuite) TestPutUpdate() { return } - err := m.GetWithModelStoreID(ctx, oldModelID, nil) + err = m.GetWithModelStoreID(ctx, theModelType, oldModelID, nil) assert.ErrorIs(t, err, manifest.ErrNotFound) }) } } +func (suite *ModelStoreIntegrationSuite) TestPutGetOfType() { + table := []struct { + t modelType + check require.ErrorAssertionFunc + hasErr bool + }{ + { + t: UnknownModel, + check: require.Error, + hasErr: true, + }, + { + t: BackupOpModel, + check: require.NoError, + hasErr: false, + }, + { + t: RestoreOpModel, + check: require.NoError, + hasErr: false, + }, + { + t: RestorePointModel, + check: require.NoError, + hasErr: false, + }, + } + + ctx := context.Background() + t := suite.T() + + m := getModelStore(t, ctx) + defer func() { + assert.NoError(t, m.wrapper.Close(ctx)) + }() + + for _, test := range table { + suite.T().Run(test.t.String(), func(t *testing.T) { + foo := &fooModel{Bar: uuid.NewString()} + + err := m.Put(ctx, test.t, foo) + test.check(t, err) + + if test.hasErr { + return + } + + ids, err := m.GetIDsForType(ctx, test.t, nil) + require.NoError(t, err) + + assert.Len(t, ids, 1) + }) + } +} + func (suite *ModelStoreIntegrationSuite) TestPutUpdate_FailsNotMatchingPrev() { startModelType := BackupOpModel @@ -323,6 +434,7 @@ func (suite *ModelStoreIntegrationSuite) TestPutUpdate_FailsNotMatchingPrev() { func (suite *ModelStoreIntegrationSuite) TestPutDelete() { ctx := context.Background() t := suite.T() + theModelType := BackupOpModel m := getModelStore(t, ctx) defer func() { @@ -331,12 +443,12 @@ func (suite *ModelStoreIntegrationSuite) TestPutDelete() { foo := &fooModel{Bar: uuid.NewString()} - require.NoError(t, m.Put(ctx, BackupOpModel, foo)) + require.NoError(t, m.Put(ctx, theModelType, foo)) - require.NoError(t, m.Delete(ctx, foo.StableID)) + require.NoError(t, m.Delete(ctx, theModelType, foo.StableID)) returned := &fooModel{} - err := m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned) + err := m.GetWithModelStoreID(ctx, theModelType, foo.ModelStoreID, returned) assert.ErrorIs(t, err, manifest.ErrNotFound) } @@ -349,7 +461,7 @@ func (suite *ModelStoreIntegrationSuite) TestPutDelete_BadIDsNoop() { assert.NoError(t, m.wrapper.Close(ctx)) }() - assert.NoError(t, m.Delete(ctx, "foo")) + assert.NoError(t, m.Delete(ctx, BackupOpModel, "foo")) assert.NoError(t, m.DeleteWithModelStoreID(ctx, "foo")) } @@ -427,10 +539,11 @@ func (suite *ModelStoreRegressionSuite) TestFailDuringWriteSessionHasNoVisibleEf assert.ErrorIs(t, err, assert.AnError) - err = m.GetWithModelStoreID(ctx, newID, nil) + err = m.GetWithModelStoreID(ctx, theModelType, newID, nil) assert.ErrorIs(t, err, manifest.ErrNotFound) returned := &fooModel{} - require.NoError(t, m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned)) + require.NoError( + t, m.GetWithModelStoreID(ctx, theModelType, foo.ModelStoreID, returned)) assert.Equal(t, foo, returned) }