diff --git a/src/internal/kopia/wrapper.go b/src/internal/kopia/wrapper.go index 5801917aa..7c9acabfe 100644 --- a/src/internal/kopia/wrapper.go +++ b/src/internal/kopia/wrapper.go @@ -2,9 +2,12 @@ package kopia import ( "context" + "encoding/binary" + "io" "runtime/trace" "sync" "sync/atomic" + "unsafe" "github.com/hashicorp/go-multierror" "github.com/kopia/kopia/fs" @@ -28,13 +31,93 @@ const ( // possibly corresponding to who is making the backup. corsoHost = "corso-host" corsoUser = "corso" + + serializationVersion uint32 = 1 ) var ( errNotConnected = errors.New("not connected to repo") errNoRestorePath = errors.New("no restore path given") + + versionSize = int(unsafe.Sizeof(serializationVersion)) ) +// backupStreamReader is a wrapper around the io.Reader that other Corso +// components return when backing up information. It injects a version number at +// the start of the data stream. Future versions of Corso may not need this if +// they use more complex serialization logic as serialization/version injection +// will be handled by other components. +type backupStreamReader struct { + io.ReadCloser + version uint32 + readBytes int +} + +func (rw *backupStreamReader) Read(p []byte) (n int, err error) { + if rw.readBytes < versionSize { + marshalled := make([]byte, versionSize) + + toCopy := len(marshalled) - rw.readBytes + if len(p) < toCopy { + toCopy = len(p) + } + + binary.BigEndian.PutUint32(marshalled, rw.version) + + copy(p, marshalled[rw.readBytes:rw.readBytes+toCopy]) + rw.readBytes += toCopy + + return toCopy, nil + } + + return rw.ReadCloser.Read(p) +} + +// restoreStreamReader is a wrapper around the io.Reader that kopia returns when +// reading data from an item. It examines and strips off the version number of +// the restored data. Future versions of Corso may not need this if they use +// more complex serialization logic as version checking/deserialization will be +// handled by other components. A reader that returns a version error is no +// longer valid and should not be used once the version error is returned. +type restoreStreamReader struct { + io.ReadCloser + expectedVersion uint32 + readVersion bool +} + +func (rw *restoreStreamReader) checkVersion() error { + versionBuf := make([]byte, versionSize) + + for newlyRead := 0; newlyRead < versionSize; { + n, err := rw.ReadCloser.Read(versionBuf[newlyRead:]) + if err != nil { + return errors.Wrap(err, "reading data format version") + } + + newlyRead += n + } + + version := binary.BigEndian.Uint32(versionBuf) + + if version != rw.expectedVersion { + return errors.Errorf("unexpected data format %v", version) + } + + return nil +} + +func (rw *restoreStreamReader) Read(p []byte) (n int, err error) { + if !rw.readVersion { + rw.readVersion = true + + if err := rw.checkVersion(); err != nil { + return 0, err + } + } + + return rw.ReadCloser.Read(p) +} + type BackupStats struct { SnapshotID string @@ -252,7 +335,13 @@ func getStreamItemFunc( d := &itemDetails{info: ei.Info(), repoPath: itemPath} progress.put(encodeAsPath(itemPath.PopFront().Elements()...), d) - entry := virtualfs.StreamingFileFromReader(encodeAsPath(e.UUID()), e.ToReader()) + entry := virtualfs.StreamingFileFromReader( + encodeAsPath(e.UUID()), + &backupStreamReader{ + version: serializationVersion, + ReadCloser: e.ToReader(), + }, + ) if err := cb(ctx, entry); err != nil { // Kopia's uploader swallows errors in most cases, so if we see // something here it's probably a big issue and we should return. @@ -544,9 +633,12 @@ func getItemStream( } return &kopiaDataStream{ - uuid: decodedName, - reader: r, - size: f.Size(), + uuid: decodedName, + reader: &restoreStreamReader{ + ReadCloser: r, + expectedVersion: serializationVersion, + }, + size: f.Size() - int64(versionSize), }, nil } diff --git a/src/internal/kopia/wrapper_test.go b/src/internal/kopia/wrapper_test.go index 7c960deb3..cd01f0333 100644 --- a/src/internal/kopia/wrapper_test.go +++ b/src/internal/kopia/wrapper_test.go @@ -3,10 +3,12 @@ package kopia import ( "bytes" "context" + "errors" "io" "io/ioutil" stdpath "path" "testing" + "unsafe" "github.com/google/uuid" "github.com/kopia/kopia/fs" @@ -121,6 +123,137 @@ func getDirEntriesForEntry( // --------------- // unit tests // --------------- +type limitedRangeReader struct { + readLen int + io.ReadCloser +} + +func (lrr *limitedRangeReader) Read(p []byte) (int, error) { + if len(p) == 0 { + // Not well specified behavior, defer to underlying reader. + return lrr.ReadCloser.Read(p) + } + + toRead := lrr.readLen + if len(p) < toRead { + toRead = len(p) + } + + return lrr.ReadCloser.Read(p[:toRead]) +} + +type VersionReadersUnitSuite struct { + suite.Suite +} + +func TestVersionReadersUnitSuite(t *testing.T) { + suite.Run(t, new(VersionReadersUnitSuite)) +} + +func (suite *VersionReadersUnitSuite) TestWriteAndRead() { + inputData := []byte("This is some data for the reader to test with") + table := []struct { + name string + readVersion uint32 + writeVersion uint32 + check assert.ErrorAssertionFunc + }{ + { + name: "SameVersionSucceeds", + readVersion: 42, + writeVersion: 42, + check: assert.NoError, + }, + { + name: "DifferentVersionsFail", + readVersion: 7, + writeVersion: 42, + check: assert.Error, + }, + } + + for _, test := range table { + suite.T().Run(test.name, func(t *testing.T) { + baseReader := bytes.NewReader(inputData) + + reversible := &restoreStreamReader{ + expectedVersion: test.readVersion, + ReadCloser: &backupStreamReader{ + version: test.writeVersion, + ReadCloser: io.NopCloser(baseReader), + }, + } + + defer reversible.Close() + + allData, err := io.ReadAll(reversible) + test.check(t, err) + + if err != nil { + return + } + + assert.Equal(t, inputData, allData) + }) + } +} + +func readAllInParts( + t *testing.T, + partLen int, + reader io.ReadCloser, +) ([]byte, int) { + res := []byte{} + read := 0 + tmp := make([]byte, partLen) + + for { + n, err := reader.Read(tmp) + if errors.Is(err, io.EOF) { + break + } + + require.NoError(t, err) + + read += n + res = append(res, tmp[:n]...) + } + + return res, read +} + +func (suite *VersionReadersUnitSuite) TestWriteHandlesShortReads() { + t := suite.T() + inputData := []byte("This is some data for the reader to test with") + version := uint32(42) + baseReader := bytes.NewReader(inputData) + versioner := &backupStreamReader{ + version: version, + ReadCloser: io.NopCloser(baseReader), + } + expectedToWrite := len(inputData) + int(unsafe.Sizeof(versioner.version)) + + // "Write" all the data. + versionedData, writtenLen := readAllInParts(t, 1, versioner) + assert.Equal(t, expectedToWrite, writtenLen) + + // Read all of the data back. + baseReader = bytes.NewReader(versionedData) + reader := &restoreStreamReader{ + expectedVersion: version, + // Be adversarial and only allow reads of length 1 from the byte reader. + ReadCloser: &limitedRangeReader{ + readLen: 1, + ReadCloser: io.NopCloser(baseReader), + }, + } + readData, readLen := readAllInParts(t, 1, reader) + // This reports the bytes read and returned to the user, excluding the version + // that is stripped off at the start. + assert.Equal(t, len(inputData), readLen) + assert.Equal(t, inputData, readData) +} + type CorsoProgressUnitSuite struct { suite.Suite targetFilePath path.Path