diff --git a/src/internal/kopia/model_store.go b/src/internal/kopia/model_store.go index b75eddf09..2851f555d 100644 --- a/src/internal/kopia/model_store.go +++ b/src/internal/kopia/model_store.go @@ -208,17 +208,92 @@ func (ms *ModelStore) GetWithModelStoreID( return nil } -// Update adds the new version of the model to the model store and deletes the -// version of the model with oldID if the old and new IDs do not match. The new -// ID of the model is returned. +// checkPrevModelVersion compares the modelType and ModelStoreID in this model +// to model(s) previously stored in ModelStore that have the same StableID. +// Returns an error if no models or more than one model has the same StableID or +// the modelType or ModelStoreID differ between the stored model and the given +// model. +func (ms *ModelStore) checkPrevModelVersion( + ctx context.Context, + t modelType, + b *model.BaseModel, +) error { + id, err := ms.getModelStoreID(ctx, b.StableID) + if err != nil { + return err + } + + // We actually got something back during our lookup. + meta, err := ms.wrapper.rep.GetManifest(ctx, id, nil) + if err != nil { + return errors.Wrap(err, "getting previous model version") + } + + if meta.ID != b.ModelStoreID { + return errors.New("updated model has different ModelStoreID") + } + if meta.Labels[manifest.TypeLabelKey] != t.String() { + return errors.New("updated model has different model type") + } + + return nil +} + +// Update adds the new version of the model with the given StableID to the model +// store and deletes the version of the model with old ModelStoreID if the old +// and new ModelStoreIDs do not match. Returns an error if another model has +// the same StableID but a different modelType or ModelStoreID or there is no +// previous version of the model. If an error occurs no visible changes will be +// made to the stored model. func (ms *ModelStore) Update( ctx context.Context, t modelType, - oldID model.ID, - tags map[string]string, - m any, -) (model.ID, error) { - return "", nil + m model.Model, +) error { + base := m.Base() + if len(base.ModelStoreID) == 0 { + return errors.WithStack(errNoModelStoreID) + } + + // TODO(ashmrtnz): Can remove if bottleneck. + if err := ms.checkPrevModelVersion(ctx, t, base); err != nil { + return err + } + + err := repo.WriteSession( + ctx, + ms.wrapper.rep, + repo.WriteSessionOptions{Purpose: "ModelStoreUpdate"}, + func(innerCtx context.Context, w repo.RepositoryWriter) (innerErr error) { + oldID := base.ModelStoreID + + defer func() { + if innerErr != nil { + // Restore the old ID if we failed. + base.ModelStoreID = oldID + } + }() + + if innerErr = putInner(innerCtx, w, t, m, false); innerErr != nil { + return innerErr + } + + // If we fail at this point no changes will be made to the manifest store + // in kopia, making it appear like nothing ever happened. At worst some + // orphaned content blobs may be uploaded, but they should be garbage + // collected the next time kopia maintenance is run. + if oldID != base.ModelStoreID { + innerErr = w.DeleteManifest(innerCtx, oldID) + } + + return innerErr + }, + ) + if err != nil { + return errors.Wrap(err, "updating model") + } + + return nil } // Delete deletes the model with the given StableID. Turns into a noop if id is diff --git a/src/internal/kopia/model_store_test.go b/src/internal/kopia/model_store_test.go index 7418051cc..9af3832ee 100644 --- a/src/internal/kopia/model_store_test.go +++ b/src/internal/kopia/model_store_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/google/uuid" + "github.com/kopia/kopia/repo" "github.com/kopia/kopia/repo/manifest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -82,6 +83,7 @@ func (suite *ModelStoreIntegrationSuite) TestBadTagsErrors() { foo.Tags = test.tags assert.Error(t, m.Put(ctx, BackupOpModel, foo)) + assert.Error(t, m.Update(ctx, BackupOpModel, foo)) }) } } @@ -103,6 +105,9 @@ func (suite *ModelStoreIntegrationSuite) TestNoIDsErrors() { noModelStoreID.StableID = model.ID(uuid.NewString()) noModelStoreID.ModelStoreID = "" + assert.Error(t, m.Update(ctx, BackupOpModel, noStableID)) + assert.Error(t, m.Update(ctx, BackupOpModel, noModelStoreID)) + assert.Error(t, m.GetWithModelStoreID(ctx, "", nil)) assert.Error(t, m.Delete(ctx, "")) @@ -206,6 +211,115 @@ func (suite *ModelStoreIntegrationSuite) TestGet_NotFoundErrors() { assert.ErrorIs(t, m.GetWithModelStoreID(ctx, "baz", nil), manifest.ErrNotFound) } +func (suite *ModelStoreIntegrationSuite) TestPutUpdate() { + table := []struct { + name string + mutator func(m *fooModel) + }{ + { + name: "NoTags", + mutator: func(m *fooModel) { + m.Bar = "baz" + }, + }, + { + name: "WithTags", + mutator: func(m *fooModel) { + m.Bar = "baz" + m.Tags = map[string]string{ + "a": "42", + } + }, + }, + } + + for _, test := range table { + suite.T().Run(test.name, func(t *testing.T) { + ctx := context.Background() + + m := getModelStore(t, ctx) + defer func() { + assert.NoError(t, m.wrapper.Close(ctx)) + }() + + foo := &fooModel{Bar: uuid.NewString()} + // Avoid some silly test errors from comparing nil to empty map. + foo.Tags = map[string]string{} + + require.NoError(t, m.Put(ctx, BackupOpModel, foo)) + + oldModelID := foo.ModelStoreID + oldStableID := foo.StableID + + test.mutator(foo) + + require.NoError(t, m.Update(ctx, BackupOpModel, foo)) + assert.Equal(t, oldStableID, foo.StableID) + + returned := &fooModel{} + require.NoError(t, m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned)) + assert.Equal(t, foo, returned) + + // TODO(ashmrtn): Uncomment once GetIDsForType is implemented. + //ids, err := m.GetIDsForType(ctx, BackupOpModel, nil) + //require.NoError(t, err) + //assert.Len(t, ids, 1) + + if oldModelID == foo.ModelStoreID { + // Unlikely, but we don't control ModelStoreID generation and can't + // guarantee this won't happen. + return + } + + err := m.GetWithModelStoreID(ctx, oldModelID, nil) + assert.ErrorIs(t, err, manifest.ErrNotFound) + }) + } +} + +func (suite *ModelStoreIntegrationSuite) TestPutUpdate_FailsNotMatchingPrev() { + startModelType := BackupOpModel + + table := []struct { + name string + t modelType + mutator func(m *fooModel) + }{ + { + name: "DifferentModelStoreID", + t: startModelType, + mutator: func(m *fooModel) { + m.ModelStoreID = manifest.ID("bar") + }, + }, + { + name: "DifferentModelType", + t: RestoreOpModel, + mutator: func(m *fooModel) { + }, + }, + } + + for _, test := range table { + suite.T().Run(test.name, func(t *testing.T) { + ctx := context.Background() + + m := getModelStore(t, ctx) + defer func() { + assert.NoError(t, m.wrapper.Close(ctx)) + }() + + foo := &fooModel{Bar: uuid.NewString()} + + require.NoError(t, m.Put(ctx, startModelType, foo)) + + test.mutator(foo) + + assert.Error(t, m.Update(ctx, test.t, foo)) + }) + } +} + func (suite *ModelStoreIntegrationSuite) TestPutDelete() { ctx := context.Background() t := suite.T() @@ -238,3 +352,85 @@ func (suite *ModelStoreIntegrationSuite) TestPutDelete_BadIDsNoop() { assert.NoError(t, m.Delete(ctx, "foo")) assert.NoError(t, m.DeleteWithModelStoreID(ctx, "foo")) } + +// --------------- +// regression tests that use kopia +// --------------- +type ModelStoreRegressionSuite struct { + suite.Suite +} + +func TestModelStoreRegressionSuite(t *testing.T) { + if err := ctesting.RunOnAny( + ctesting.CorsoCITests, + ctesting.CorsoModelStoreTests, + ); err != nil { + t.Skip() + } + + suite.Run(t, new(ModelStoreRegressionSuite)) +} + +func (suite *ModelStoreRegressionSuite) SetupSuite() { + _, err := ctesting.GetRequiredEnvVars(ctesting.AWSStorageCredEnvs...) + require.NoError(suite.T(), err) +} + +// TODO(ashmrtn): Make a mock of whatever controls the handle to kopia so we can +// ask it to fail on arbitrary function, thus allowing us to directly test +// Update. +// Tests that if we get an error or crash while in the middle of an Update no +// results will be visible to higher layers. +func (suite *ModelStoreRegressionSuite) TestFailDuringWriteSessionHasNoVisibleEffect() { + ctx := context.Background() + t := suite.T() + + m := getModelStore(t, ctx) + defer func() { + assert.NoError(t, m.wrapper.Close(ctx)) + }() + + foo := &fooModel{Bar: uuid.NewString()} + foo.StableID = model.ID(uuid.NewString()) + foo.ModelStoreID = manifest.ID(uuid.NewString()) + // Avoid some silly test errors from comparing nil to empty map. + foo.Tags = map[string]string{} + + theModelType := BackupOpModel + + require.NoError(t, m.Put(ctx, theModelType, foo)) + + newID := manifest.ID("") + err := repo.WriteSession( + ctx, + m.wrapper.rep, + repo.WriteSessionOptions{Purpose: "WriteSessionFailureTest"}, + func(innerCtx context.Context, w repo.RepositoryWriter) (innerErr error) { + base := foo.Base() + oldID := base.ModelStoreID + + defer func() { + if innerErr != nil { + // Restore the old ID if we failed. + base.ModelStoreID = oldID + } + }() + + innerErr = putInner(innerCtx, w, theModelType, foo, false) + require.NoError(t, innerErr) + + newID = foo.ModelStoreID + + return assert.AnError + }, + ) + + assert.ErrorIs(t, err, assert.AnError) + + err = m.GetWithModelStoreID(ctx, newID, nil) + assert.ErrorIs(t, err, manifest.ErrNotFound) + + returned := &fooModel{} + require.NoError(t, m.GetWithModelStoreID(ctx, foo.ModelStoreID, returned)) + assert.Equal(t, foo, returned) +}