diff --git a/src/internal/kopia/upload.go b/src/internal/kopia/upload.go index 2215ed9c1..e22ad4031 100644 --- a/src/internal/kopia/upload.go +++ b/src/internal/kopia/upload.go @@ -1,9 +1,11 @@ package kopia import ( + "bytes" "context" "encoding/binary" "io" + "os" "runtime/trace" "sync" "sync/atomic" @@ -25,35 +27,47 @@ import ( var versionSize = int(unsafe.Sizeof(serializationVersion)) +func newBackupStreamReader(version uint32, reader io.ReadCloser) *backupStreamReader { + buf := make([]byte, versionSize) + binary.BigEndian.PutUint32(buf, version) + bufReader := io.NopCloser(bytes.NewReader(buf)) + + return &backupStreamReader{ + readers: []io.ReadCloser{bufReader, reader}, + combined: io.NopCloser(io.MultiReader(bufReader, reader)), + } +} + // backupStreamReader is a wrapper around the io.Reader that other Corso // components return when backing up information. It injects a version number at // the start of the data stream. Future versions of Corso may not need this if // they use more complex serialization logic as serialization/version injection // will be handled by other components. type backupStreamReader struct { - io.ReadCloser - version uint32 - readBytes int + readers []io.ReadCloser + combined io.ReadCloser } func (rw *backupStreamReader) Read(p []byte) (n int, err error) { - if rw.readBytes < versionSize { - marshalled := make([]byte, versionSize) - - toCopy := len(marshalled) - rw.readBytes - if len(p) < toCopy { - toCopy = len(p) - } - - binary.BigEndian.PutUint32(marshalled, rw.version) - - copy(p, marshalled[rw.readBytes:rw.readBytes+toCopy]) - rw.readBytes += toCopy - - return toCopy, nil + if rw.combined == nil { + return 0, os.ErrClosed } - return rw.ReadCloser.Read(p) + return rw.combined.Read(p) +} + +func (rw *backupStreamReader) Close() error { + if rw.combined == nil { + return nil + } + + rw.combined = nil + + for _, r := range rw.readers { + r.Close() + } + + return nil } // restoreStreamReader is a wrapper around the io.Reader that kopia returns when @@ -169,11 +183,11 @@ func (cp *corsoProgress) FinishedFile(relativePath string, err error) { } // Kopia interface function used as a callback when kopia finishes hashing a file. -func (cp *corsoProgress) FinishedHashingFile(fname string, bytes int64) { +func (cp *corsoProgress) FinishedHashingFile(fname string, bs int64) { // Pass the call through as well so we don't break expected functionality. - defer cp.UploadProgress.FinishedHashingFile(fname, bytes) + defer cp.UploadProgress.FinishedHashingFile(fname, bs) - atomic.AddInt64(&cp.totalBytes, bytes) + atomic.AddInt64(&cp.totalBytes, bs) } func (cp *corsoProgress) put(k string, v *itemDetails) { @@ -275,10 +289,7 @@ func collectionEntries( entry := virtualfs.StreamingFileWithModTimeFromReader( encodeAsPath(e.UUID()), modTime, - &backupStreamReader{ - version: serializationVersion, - ReadCloser: e.ToReader(), - }, + newBackupStreamReader(serializationVersion, e.ToReader()), ) if err := cb(ctx, entry); err != nil { // Kopia's uploader swallows errors in most cases, so if we see diff --git a/src/internal/kopia/upload_test.go b/src/internal/kopia/upload_test.go index dbb239734..0b2d07fcc 100644 --- a/src/internal/kopia/upload_test.go +++ b/src/internal/kopia/upload_test.go @@ -6,7 +6,6 @@ import ( "io" stdpath "path" "testing" - "unsafe" "github.com/kopia/kopia/fs" "github.com/kopia/kopia/snapshot/snapshotfs" @@ -116,10 +115,10 @@ func (suite *VersionReadersUnitSuite) TestWriteAndRead() { reversible := &restoreStreamReader{ expectedVersion: test.readVersion, - ReadCloser: &backupStreamReader{ - version: test.writeVersion, - ReadCloser: io.NopCloser(baseReader), - }, + ReadCloser: newBackupStreamReader( + test.writeVersion, + io.NopCloser(baseReader), + ), } defer reversible.Close() @@ -165,11 +164,8 @@ func (suite *VersionReadersUnitSuite) TestWriteHandlesShortReads() { inputData := []byte("This is some data for the reader to test with") version := uint32(42) baseReader := bytes.NewReader(inputData) - versioner := &backupStreamReader{ - version: version, - ReadCloser: io.NopCloser(baseReader), - } - expectedToWrite := len(inputData) + int(unsafe.Sizeof(versioner.version)) + versioner := newBackupStreamReader(version, io.NopCloser(baseReader)) + expectedToWrite := len(inputData) + int(versionSize) // "Write" all the data. versionedData, writtenLen := readAllInParts(t, 1, versioner)