Use io.MultiReader to inject kopia file version (#1767)

## Description

Instead of rolling our own logic for injecting a version, use a
MultiReader to concatenate the streams. Handling Close() is now more
complex though.

## Type of change

<!--- Please check the type of change your PR introduces: --->
- [ ] 🌻 Feature
- [ ] 🐛 Bugfix
- [ ] 🗺️ Documentation
- [ ] 🤖 Test
- [ ] 💻 CI/Deployment
- [x] 🐹 Trivial/Minor

## Issue(s)

* closes #1766 

## Test Plan

<!-- How will this be tested prior to merging.-->
- [ ] 💪 Manual
- [x]  Unit test
- [ ] 💚 E2E
This commit is contained in:
ashmrtn 2022-12-12 13:46:38 -08:00 committed by GitHub
parent dd96a87611
commit 893bc978ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 35 deletions

View File

@ -1,9 +1,11 @@
package kopia package kopia
import ( import (
"bytes"
"context" "context"
"encoding/binary" "encoding/binary"
"io" "io"
"os"
"runtime/trace" "runtime/trace"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -25,35 +27,47 @@ import (
var versionSize = int(unsafe.Sizeof(serializationVersion)) var versionSize = int(unsafe.Sizeof(serializationVersion))
func newBackupStreamReader(version uint32, reader io.ReadCloser) *backupStreamReader {
buf := make([]byte, versionSize)
binary.BigEndian.PutUint32(buf, version)
bufReader := io.NopCloser(bytes.NewReader(buf))
return &backupStreamReader{
readers: []io.ReadCloser{bufReader, reader},
combined: io.NopCloser(io.MultiReader(bufReader, reader)),
}
}
// backupStreamReader is a wrapper around the io.Reader that other Corso // backupStreamReader is a wrapper around the io.Reader that other Corso
// components return when backing up information. It injects a version number at // components return when backing up information. It injects a version number at
// the start of the data stream. Future versions of Corso may not need this if // the start of the data stream. Future versions of Corso may not need this if
// they use more complex serialization logic as serialization/version injection // they use more complex serialization logic as serialization/version injection
// will be handled by other components. // will be handled by other components.
type backupStreamReader struct { type backupStreamReader struct {
io.ReadCloser readers []io.ReadCloser
version uint32 combined io.ReadCloser
readBytes int
} }
func (rw *backupStreamReader) Read(p []byte) (n int, err error) { func (rw *backupStreamReader) Read(p []byte) (n int, err error) {
if rw.readBytes < versionSize { if rw.combined == nil {
marshalled := make([]byte, versionSize) return 0, os.ErrClosed
toCopy := len(marshalled) - rw.readBytes
if len(p) < toCopy {
toCopy = len(p)
} }
binary.BigEndian.PutUint32(marshalled, rw.version) return rw.combined.Read(p)
}
copy(p, marshalled[rw.readBytes:rw.readBytes+toCopy]) func (rw *backupStreamReader) Close() error {
rw.readBytes += toCopy if rw.combined == nil {
return nil
return toCopy, nil
} }
return rw.ReadCloser.Read(p) rw.combined = nil
for _, r := range rw.readers {
r.Close()
}
return nil
} }
// restoreStreamReader is a wrapper around the io.Reader that kopia returns when // restoreStreamReader is a wrapper around the io.Reader that kopia returns when
@ -169,11 +183,11 @@ func (cp *corsoProgress) FinishedFile(relativePath string, err error) {
} }
// Kopia interface function used as a callback when kopia finishes hashing a file. // Kopia interface function used as a callback when kopia finishes hashing a file.
func (cp *corsoProgress) FinishedHashingFile(fname string, bytes int64) { func (cp *corsoProgress) FinishedHashingFile(fname string, bs int64) {
// Pass the call through as well so we don't break expected functionality. // Pass the call through as well so we don't break expected functionality.
defer cp.UploadProgress.FinishedHashingFile(fname, bytes) defer cp.UploadProgress.FinishedHashingFile(fname, bs)
atomic.AddInt64(&cp.totalBytes, bytes) atomic.AddInt64(&cp.totalBytes, bs)
} }
func (cp *corsoProgress) put(k string, v *itemDetails) { func (cp *corsoProgress) put(k string, v *itemDetails) {
@ -275,10 +289,7 @@ func collectionEntries(
entry := virtualfs.StreamingFileWithModTimeFromReader( entry := virtualfs.StreamingFileWithModTimeFromReader(
encodeAsPath(e.UUID()), encodeAsPath(e.UUID()),
modTime, modTime,
&backupStreamReader{ newBackupStreamReader(serializationVersion, e.ToReader()),
version: serializationVersion,
ReadCloser: e.ToReader(),
},
) )
if err := cb(ctx, entry); err != nil { if err := cb(ctx, entry); err != nil {
// Kopia's uploader swallows errors in most cases, so if we see // Kopia's uploader swallows errors in most cases, so if we see

View File

@ -6,7 +6,6 @@ import (
"io" "io"
stdpath "path" stdpath "path"
"testing" "testing"
"unsafe"
"github.com/kopia/kopia/fs" "github.com/kopia/kopia/fs"
"github.com/kopia/kopia/snapshot/snapshotfs" "github.com/kopia/kopia/snapshot/snapshotfs"
@ -116,10 +115,10 @@ func (suite *VersionReadersUnitSuite) TestWriteAndRead() {
reversible := &restoreStreamReader{ reversible := &restoreStreamReader{
expectedVersion: test.readVersion, expectedVersion: test.readVersion,
ReadCloser: &backupStreamReader{ ReadCloser: newBackupStreamReader(
version: test.writeVersion, test.writeVersion,
ReadCloser: io.NopCloser(baseReader), io.NopCloser(baseReader),
}, ),
} }
defer reversible.Close() defer reversible.Close()
@ -165,11 +164,8 @@ func (suite *VersionReadersUnitSuite) TestWriteHandlesShortReads() {
inputData := []byte("This is some data for the reader to test with") inputData := []byte("This is some data for the reader to test with")
version := uint32(42) version := uint32(42)
baseReader := bytes.NewReader(inputData) baseReader := bytes.NewReader(inputData)
versioner := &backupStreamReader{ versioner := newBackupStreamReader(version, io.NopCloser(baseReader))
version: version, expectedToWrite := len(inputData) + int(versionSize)
ReadCloser: io.NopCloser(baseReader),
}
expectedToWrite := len(inputData) + int(unsafe.Sizeof(versioner.version))
// "Write" all the data. // "Write" all the data.
versionedData, writtenLen := readAllInParts(t, 1, versioner) versionedData, writtenLen := readAllInParts(t, 1, versioner)