Parallelize restores within a collection for OneDrive (#3492)

This should massively speed up when restoring a collection with many
items. Will not impact much if we have a lot of collections with few
items each.

Numbers 🔢 :
- Restoring ~7000 files, mostly small, totaling 1.5GB
  - Sequential: ~70m
  - Parallel: ~50m
- Restoring 1200 50mb files
  - Sequential: 4h 45m
  - Parallel: <40m

---

#### Does this PR need a docs update or release note?

- [x]  Yes, it's included
- [ ] 🕐 Yes, but in a later PR
- [ ]  No

#### Type of change

<!--- Please check the type of change your PR introduces: --->
- [x] 🌻 Feature
- [ ] 🐛 Bugfix
- [ ] 🗺️ Documentation
- [ ] 🤖 Supportability/Tests
- [ ] 💻 CI/Deployment
- [ ] 🧹 Tech Debt/Cleanup

#### Issue(s)

<!-- Can reference multiple issues. Use one of the following "magic words" - "closes, fixes" to auto-close the Github issue. -->
* https://github.com/alcionai/corso/issues/3011
* closes https://github.com/alcionai/corso/issues/3536

#### Test Plan

<!-- How will this be tested prior to merging.-->
- [ ] 💪 Manual
- [x]  Unit test
- [x] 💚 E2E
This commit is contained in:
Abin Simon 2023-06-02 10:26:11 +05:30 committed by GitHub
parent cdf26b7988
commit 5c4d57b416
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 297 additions and 179 deletions

View File

@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Added ProtectedResourceName to the backup list json output. ProtectedResourceName holds either a UPN or a WebURL, depending on the resource type. - Added ProtectedResourceName to the backup list json output. ProtectedResourceName holds either a UPN or a WebURL, depending on the resource type.
- Rework base selection logic for incremental backups so it's more likely to find a valid base. - Rework base selection logic for incremental backups so it's more likely to find a valid base.
- Improve OneDrive restore performance by paralleling item restores
### Fixed ### Fixed
- Fix Exchange folder cache population error when parent folder isn't found. - Fix Exchange folder cache population error when parent folder isn't found.

View File

@ -3,6 +3,7 @@ package connector
import ( import (
"context" "context"
"strings" "strings"
"sync"
"github.com/alcionai/clues" "github.com/alcionai/clues"
@ -26,6 +27,13 @@ import (
"github.com/alcionai/corso/src/pkg/selectors" "github.com/alcionai/corso/src/pkg/selectors"
) )
const (
// copyBufferSize is used for chunked upload
// Microsoft recommends 5-10MB buffers
// https://docs.microsoft.com/en-us/graph/api/driveitem-createuploadsession?view=graph-rest-1.0#best-practices
copyBufferSize = 5 * 1024 * 1024
)
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Data Collections // Data Collections
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -256,13 +264,46 @@ func (gc *GraphConnector) ConsumeRestoreCollections(
return nil, clues.Wrap(err, "malformed azure credentials") return nil, clues.Wrap(err, "malformed azure credentials")
} }
// Buffer pool for uploads
pool := sync.Pool{
New: func() interface{} {
b := make([]byte, copyBufferSize)
return &b
},
}
switch sels.Service { switch sels.Service {
case selectors.ServiceExchange: case selectors.ServiceExchange:
status, err = exchange.RestoreCollections(ctx, creds, gc.Discovery, gc.Service, dest, dcs, deets, errs) status, err = exchange.RestoreCollections(ctx,
creds,
gc.Discovery,
gc.Service,
dest,
dcs,
deets,
errs)
case selectors.ServiceOneDrive: case selectors.ServiceOneDrive:
status, err = onedrive.RestoreCollections(ctx, creds, backupVersion, gc.Service, dest, opts, dcs, deets, errs) status, err = onedrive.RestoreCollections(ctx,
creds,
backupVersion,
gc.Service,
dest,
opts,
dcs,
deets,
&pool,
errs)
case selectors.ServiceSharePoint: case selectors.ServiceSharePoint:
status, err = sharepoint.RestoreCollections(ctx, backupVersion, creds, gc.Service, dest, opts, dcs, deets, errs) status, err = sharepoint.RestoreCollections(ctx,
backupVersion,
creds,
gc.Service,
dest,
opts,
dcs,
deets,
&pool,
errs)
default: default:
err = clues.Wrap(clues.New(sels.Service.String()), "service not supported") err = clues.Wrap(clues.New(sels.Service.String()), "service not supported")
} }

View File

@ -268,10 +268,9 @@ func createCollections(
return nil, clues.New("unsupported backup category type").WithClues(ctx) return nil, clues.New("unsupported backup category type").WithClues(ctx)
} }
foldersComplete, closer := observe.MessageWithCompletion( foldersComplete := observe.MessageWithCompletion(
ctx, ctx,
observe.Bulletf("%s", qp.Category)) observe.Bulletf("%s", qp.Category))
defer closer()
defer close(foldersComplete) defer close(foldersComplete)
rootFolder, cc := handler.NewContainerCache(user.ID()) rootFolder, cc := handler.NewContainerCache(user.ID())

View File

@ -163,14 +163,11 @@ func (col *Collection) streamItems(ctx context.Context, errs *fault.Bus) {
}() }()
if len(col.added)+len(col.removed) > 0 { if len(col.added)+len(col.removed) > 0 {
var closer func() colProgress = observe.CollectionProgress(
colProgress, closer = observe.CollectionProgress(
ctx, ctx,
col.fullPath.Category().String(), col.fullPath.Category().String(),
col.LocationPath().Elements()) col.LocationPath().Elements())
go closer()
defer func() { defer func() {
close(colProgress) close(colProgress)
}() }()

View File

@ -145,11 +145,10 @@ func restoreCollection(
category = fullPath.Category() category = fullPath.Category()
) )
colProgress, closer := observe.CollectionProgress( colProgress := observe.CollectionProgress(
ctx, ctx,
category.String(), category.String(),
fullPath.Folder(false)) fullPath.Folder(false))
defer closer()
defer close(colProgress) defer close(colProgress)
for { for {

View File

@ -142,6 +142,8 @@ type limiterConsumptionKey string
const limiterConsumptionCtxKey limiterConsumptionKey = "corsoGraphRateLimiterConsumption" const limiterConsumptionCtxKey limiterConsumptionKey = "corsoGraphRateLimiterConsumption"
const ( const (
// https://learn.microsoft.com/en-us/sharepoint/dev/general-development
// /how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#application-throttling
defaultLC = 1 defaultLC = 1
driveDefaultLC = 2 driveDefaultLC = 2
// limit consumption rate for single-item GETs requests, // limit consumption rate for single-item GETs requests,

View File

@ -7,7 +7,15 @@ import (
"github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/path"
) )
const AttachmentChunkSize = 4 * 1024 * 1024 const (
AttachmentChunkSize = 4 * 1024 * 1024
// Upper limit on the number of concurrent uploads as we
// create buffer pools for each upload. This is not the actual
// number of uploads, but the max that can be specified. This is
// added as a safeguard in case we misconfigure the values.
maxConccurrentUploads = 20
)
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// item response AdditionalData // item response AdditionalData
@ -44,6 +52,8 @@ type parallelism struct {
collectionBuffer int collectionBuffer int
// sets the parallelism of item population within a collection. // sets the parallelism of item population within a collection.
item int item int
// sets the parallelism of concurrent uploads within a collection
itemUpload int
} }
func (p parallelism) CollectionBufferSize() int { func (p parallelism) CollectionBufferSize() int {
@ -88,6 +98,18 @@ func (p parallelism) Item() int {
return p.item return p.item
} }
func (p parallelism) ItemUpload() int {
if p.itemUpload == 0 {
return 1
}
if p.itemUpload > maxConccurrentUploads {
return maxConccurrentUploads
}
return p.itemUpload
}
// returns low <= v <= high // returns low <= v <= high
// if high < low, returns low <= v // if high < low, returns low <= v
func isWithin(low, high, v int) bool { func isWithin(low, high, v int) bool {
@ -102,6 +124,7 @@ var sp = map[path.ServiceType]parallelism{
path.OneDriveService: { path.OneDriveService: {
collectionBuffer: 5, collectionBuffer: 5,
item: 4, item: 4,
itemUpload: 7,
}, },
// sharepoint libraries are considered "onedrive" parallelism. // sharepoint libraries are considered "onedrive" parallelism.
// this only controls lists/pages. // this only controls lists/pages.

View File

@ -439,12 +439,11 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
queuedPath = "/" + oc.driveName + queuedPath queuedPath = "/" + oc.driveName + queuedPath
} }
folderProgress, colCloser := observe.ProgressWithCount( folderProgress := observe.ProgressWithCount(
ctx, ctx,
observe.ItemQueueMsg, observe.ItemQueueMsg,
path.NewElements(queuedPath), path.NewElements(queuedPath),
int64(len(oc.driveItems))) int64(len(oc.driveItems)))
defer colCloser()
defer close(folderProgress) defer close(folderProgress)
semaphoreCh := make(chan struct{}, graph.Parallelism(path.OneDriveService).Item()) semaphoreCh := make(chan struct{}, graph.Parallelism(path.OneDriveService).Item())
@ -535,13 +534,12 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
} }
// display/log the item download // display/log the item download
progReader, closer := observe.ItemProgress( progReader, _ := observe.ItemProgress(
ctx, ctx,
itemData, itemData,
observe.ItemBackupMsg, observe.ItemBackupMsg,
clues.Hide(itemName+dataSuffix), clues.Hide(itemName+dataSuffix),
itemSize) itemSize)
go closer()
return progReader, nil return progReader, nil
}) })
@ -554,13 +552,12 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
} }
metaReader := lazy.NewLazyReadCloser(func() (io.ReadCloser, error) { metaReader := lazy.NewLazyReadCloser(func() (io.ReadCloser, error) {
progReader, closer := observe.ItemProgress( progReader, _ := observe.ItemProgress(
ctx, ctx,
itemMeta, itemMeta,
observe.ItemBackupMsg, observe.ItemBackupMsg,
clues.Hide(itemName+metaSuffix), clues.Hide(itemName+metaSuffix),
int64(itemMetaSize)) int64(itemMetaSize))
go closer()
return progReader, nil return progReader, nil
}) })

View File

@ -283,8 +283,7 @@ func (c *Collections) Get(
driveTombstones[driveID] = struct{}{} driveTombstones[driveID] = struct{}{}
} }
driveComplete, closer := observe.MessageWithCompletion(ctx, observe.Bulletf("files")) driveComplete := observe.MessageWithCompletion(ctx, observe.Bulletf("files"))
defer closer()
defer close(driveComplete) defer close(driveComplete)
// Enumerate drives for the specified resourceOwner // Enumerate drives for the specified resourceOwner

View File

@ -346,27 +346,6 @@ func sharePointItemInfo(di models.DriveItemable, itemSize int64) *details.ShareP
} }
} }
// driveItemWriter is used to initialize and return an io.Writer to upload data for the specified item
// It does so by creating an upload session and using that URL to initialize an `itemWriter`
// TODO: @vkamra verify if var session is the desired input
func driveItemWriter(
ctx context.Context,
gs graph.Servicer,
driveID, itemID string,
itemSize int64,
) (io.Writer, error) {
ctx = clues.Add(ctx, "upload_item_id", itemID)
r, err := api.PostDriveItem(ctx, gs, driveID, itemID)
if err != nil {
return nil, clues.Stack(err)
}
iw := graph.NewLargeItemWriter(itemID, ptr.Val(r.GetUploadUrl()), itemSize)
return iw, nil
}
// constructWebURL helper function for recreating the webURL // constructWebURL helper function for recreating the webURL
// for the originating SharePoint site. Uses additional data map // for the originating SharePoint site. Uses additional data map
// from a models.DriveItemable that possesses a downloadURL within the map. // from a models.DriveItemable that possesses a downloadURL within the map.

View File

@ -189,9 +189,13 @@ func (suite *ItemIntegrationSuite) TestItemWriter() {
// Initialize a 100KB mockDataProvider // Initialize a 100KB mockDataProvider
td, writeSize := mockDataReader(int64(100 * 1024)) td, writeSize := mockDataReader(int64(100 * 1024))
w, err := driveItemWriter(ctx, srv, test.driveID, ptr.Val(newItem.GetId()), writeSize) itemID := ptr.Val(newItem.GetId())
r, err := api.PostDriveItem(ctx, srv, test.driveID, itemID)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
w := graph.NewLargeItemWriter(itemID, ptr.Val(r.GetUploadUrl()), writeSize)
// Using a 32 KB buffer for the copy allows us to validate the // Using a 32 KB buffer for the copy allows us to validate the
// multi-part upload. `io.CopyBuffer` will only write 32 KB at // multi-part upload. `io.CopyBuffer` will only write 32 KB at
// a time // a time

View File

@ -3,10 +3,13 @@ package onedrive
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"runtime/trace" "runtime/trace"
"sort" "sort"
"strings" "strings"
"sync"
"sync/atomic"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -28,10 +31,10 @@ import (
"github.com/alcionai/corso/src/pkg/services/m365/api" "github.com/alcionai/corso/src/pkg/services/m365/api"
) )
// copyBufferSize is used for chunked upload const (
// Microsoft recommends 5-10MB buffers // Maximum number of retries for upload failures
// https://docs.microsoft.com/en-us/graph/api/driveitem-createuploadsession?view=graph-rest-1.0#best-practices maxUploadRetries = 3
const copyBufferSize = 5 * 1024 * 1024 )
type restoreCaches struct { type restoreCaches struct {
Folders *folderCache Folders *folderCache
@ -59,6 +62,7 @@ func RestoreCollections(
opts control.Options, opts control.Options,
dcs []data.RestoreCollection, dcs []data.RestoreCollection,
deets *details.Builder, deets *details.Builder,
pool *sync.Pool,
errs *fault.Bus, errs *fault.Bus,
) (*support.ConnectorOperationStatus, error) { ) (*support.ConnectorOperationStatus, error) {
var ( var (
@ -104,6 +108,7 @@ func RestoreCollections(
dest.ContainerName, dest.ContainerName,
deets, deets,
opts.RestorePermissions, opts.RestorePermissions,
pool,
errs) errs)
if err != nil { if err != nil {
el.AddRecoverable(err) el.AddRecoverable(err)
@ -142,13 +147,18 @@ func RestoreCollection(
restoreContainerName string, restoreContainerName string,
deets *details.Builder, deets *details.Builder,
restorePerms bool, restorePerms bool,
pool *sync.Pool,
errs *fault.Bus, errs *fault.Bus,
) (support.CollectionMetrics, error) { ) (support.CollectionMetrics, error) {
var ( var (
metrics = support.CollectionMetrics{} metrics = support.CollectionMetrics{}
copyBuffer = make([]byte, copyBufferSize)
directory = dc.FullPath() directory = dc.FullPath()
el = errs.Local() el = errs.Local()
metricsObjects int64
metricsBytes int64
metricsSuccess int64
wg sync.WaitGroup
complete bool
) )
ctx, end := diagnostics.Span(ctx, "gc:drive:restoreCollection", diagnostics.Label("path", directory)) ctx, end := diagnostics.Span(ctx, "gc:drive:restoreCollection", diagnostics.Label("path", directory))
@ -212,8 +222,30 @@ func RestoreCollection(
caches.ParentDirToMeta[dc.FullPath().String()] = colMeta caches.ParentDirToMeta[dc.FullPath().String()] = colMeta
items := dc.Items(ctx, errs) items := dc.Items(ctx, errs)
semaphoreCh := make(chan struct{}, graph.Parallelism(path.OneDriveService).ItemUpload())
defer close(semaphoreCh)
deetsLock := sync.Mutex{}
updateDeets := func(
ctx context.Context,
repoRef path.Path,
locationRef *path.Builder,
updated bool,
info details.ItemInfo,
) {
deetsLock.Lock()
defer deetsLock.Unlock()
err = deets.Add(repoRef, locationRef, updated, info)
if err != nil {
// Not critical enough to need to stop restore operation.
logger.CtxErr(ctx, err).Infow("adding restored item to details")
}
}
for { for {
if el.Failure() != nil { if el.Failure() != nil || complete {
break break
} }
@ -223,15 +255,29 @@ func RestoreCollection(
case itemData, ok := <-items: case itemData, ok := <-items:
if !ok { if !ok {
return metrics, nil // We've processed all items in this collection, exit the loop
complete = true
break
} }
wg.Add(1)
semaphoreCh <- struct{}{}
go func(ctx context.Context, itemData data.Stream) {
defer wg.Done()
defer func() { <-semaphoreCh }()
copyBufferPtr := pool.Get().(*[]byte)
defer pool.Put(copyBufferPtr)
copyBuffer := *copyBufferPtr
ictx := clues.Add(ctx, "restore_item_id", itemData.UUID()) ictx := clues.Add(ctx, "restore_item_id", itemData.UUID())
itemPath, err := dc.FullPath().AppendItem(itemData.UUID()) itemPath, err := dc.FullPath().AppendItem(itemData.UUID())
if err != nil { if err != nil {
el.AddRecoverable(clues.Wrap(err, "appending item to full path").WithClues(ictx)) el.AddRecoverable(clues.Wrap(err, "appending item to full path").WithClues(ictx))
continue return
} }
itemInfo, skipped, err := restoreItem( itemInfo, skipped, err := restoreItem(
@ -251,33 +297,33 @@ func RestoreCollection(
// skipped items don't get counted, but they can error // skipped items don't get counted, but they can error
if !skipped { if !skipped {
metrics.Objects++ atomic.AddInt64(&metricsObjects, 1)
metrics.Bytes += int64(len(copyBuffer)) atomic.AddInt64(&metricsBytes, int64(len(copyBuffer)))
} }
if err != nil { if err != nil {
el.AddRecoverable(clues.Wrap(err, "restoring item")) el.AddRecoverable(clues.Wrap(err, "restoring item"))
continue return
} }
if skipped { if skipped {
logger.Ctx(ictx).With("item_path", itemPath).Debug("did not restore item") logger.Ctx(ictx).With("item_path", itemPath).Debug("did not restore item")
continue return
} }
err = deets.Add( // TODO: implement locationRef
itemPath, updateDeets(ictx, itemPath, &path.Builder{}, true, itemInfo)
&path.Builder{}, // TODO: implement locationRef
true, atomic.AddInt64(&metricsSuccess, 1)
itemInfo) }(ctx, itemData)
if err != nil { }
// Not critical enough to need to stop restore operation.
logger.CtxErr(ictx, err).Infow("adding restored item to details")
} }
metrics.Successes++ wg.Wait()
}
} metrics.Objects = int(metricsObjects)
metrics.Bytes = metricsBytes
metrics.Successes = int(metricsSuccess)
return metrics, el.Failure() return metrics, el.Failure()
} }
@ -308,6 +354,7 @@ func restoreItem(
source, source,
service, service,
drivePath, drivePath,
dc,
restoreFolderID, restoreFolderID,
copyBuffer, copyBuffer,
itemData) itemData)
@ -399,6 +446,7 @@ func restoreV0File(
source driveSource, source driveSource,
service graph.Servicer, service graph.Servicer,
drivePath *path.DrivePath, drivePath *path.DrivePath,
fetcher fileFetcher,
restoreFolderID string, restoreFolderID string,
copyBuffer []byte, copyBuffer []byte,
itemData data.Stream, itemData data.Stream,
@ -406,6 +454,7 @@ func restoreV0File(
_, itemInfo, err := restoreData( _, itemInfo, err := restoreData(
ctx, ctx,
service, service,
fetcher,
itemData.UUID(), itemData.UUID(),
itemData, itemData,
drivePath.DriveID, drivePath.DriveID,
@ -442,6 +491,7 @@ func restoreV1File(
itemID, itemInfo, err := restoreData( itemID, itemInfo, err := restoreData(
ctx, ctx,
service, service,
fetcher,
trimmedName, trimmedName,
itemData, itemData,
drivePath.DriveID, drivePath.DriveID,
@ -525,6 +575,7 @@ func restoreV6File(
itemID, itemInfo, err := restoreData( itemID, itemInfo, err := restoreData(
ctx, ctx,
service, service,
fetcher,
meta.FileName, meta.FileName,
itemData, itemData,
drivePath.DriveID, drivePath.DriveID,
@ -673,6 +724,7 @@ func createRestoreFolders(
func restoreData( func restoreData(
ctx context.Context, ctx context.Context,
service graph.Servicer, service graph.Servicer,
fetcher fileFetcher,
name string, name string,
itemData data.Stream, itemData data.Stream,
driveID, parentFolderID string, driveID, parentFolderID string,
@ -696,26 +748,65 @@ func restoreData(
return "", details.ItemInfo{}, err return "", details.ItemInfo{}, err
} }
// Get a drive item writer itemID := ptr.Val(newItem.GetId())
w, err := driveItemWriter(ctx, service, driveID, ptr.Val(newItem.GetId()), ss.Size()) ctx = clues.Add(ctx, "upload_item_id", itemID)
r, err := api.PostDriveItem(ctx, service, driveID, itemID)
if err != nil { if err != nil {
return "", details.ItemInfo{}, err return "", details.ItemInfo{}, clues.Wrap(err, "get upload session")
} }
var written int64
// This is just to retry file upload, the uploadSession creation is
// not retried here We need extra logic to retry file upload as we
// have to pull the file again from kopia If we fail a file upload,
// we restart from scratch and try to upload again. Graph does not
// show "register" any partial file uploads and so if we fail an
// upload the file size will be 0.
for i := 0; i <= maxUploadRetries; i++ {
// Initialize and return an io.Writer to upload data for the
// specified item It does so by creating an upload session and
// using that URL to initialize an `itemWriter`
// TODO: @vkamra verify if var session is the desired input
w := graph.NewLargeItemWriter(itemID, ptr.Val(r.GetUploadUrl()), ss.Size())
pname := name
iReader := itemData.ToReader() iReader := itemData.ToReader()
progReader, closer := observe.ItemProgress(
if i > 0 {
pname = fmt.Sprintf("%s (retry %d)", name, i)
// If it is not the first try, we have to pull the file
// again from kopia. Ideally we could just seek the stream
// but we don't have a Seeker available here.
itemData, err := fetcher.Fetch(ctx, itemData.UUID())
if err != nil {
return "", details.ItemInfo{}, clues.Wrap(err, "get data file")
}
iReader = itemData.ToReader()
}
progReader, abort := observe.ItemProgress(
ctx, ctx,
iReader, iReader,
observe.ItemRestoreMsg, observe.ItemRestoreMsg,
clues.Hide(name), clues.Hide(pname),
ss.Size()) ss.Size())
go closer()
// Upload the stream data // Upload the stream data
written, err := io.CopyBuffer(w, progReader, copyBuffer) written, err = io.CopyBuffer(w, progReader, copyBuffer)
if err == nil {
break
}
// clear out the bar if err
abort()
}
if err != nil { if err != nil {
return "", details.ItemInfo{}, graph.Wrap(ctx, err, "writing item bytes") return "", details.ItemInfo{}, clues.Wrap(err, "uploading file")
} }
dii := details.ItemInfo{} dii := details.ItemInfo{}

View File

@ -183,11 +183,10 @@ func (sc *Collection) runPopulate(ctx context.Context, errs *fault.Bus) (support
) )
// TODO: Insert correct ID for CollectionProgress // TODO: Insert correct ID for CollectionProgress
colProgress, closer := observe.CollectionProgress( colProgress := observe.CollectionProgress(
ctx, ctx,
sc.fullPath.Category().String(), sc.fullPath.Category().String(),
sc.fullPath.Folders()) sc.fullPath.Folders())
go closer()
defer func() { defer func() {
close(colProgress) close(colProgress)

View File

@ -61,10 +61,9 @@ func DataCollections(
break break
} }
foldersComplete, closer := observe.MessageWithCompletion( foldersComplete := observe.MessageWithCompletion(
ctx, ctx,
observe.Bulletf("%s", scope.Category().PathType())) observe.Bulletf("%s", scope.Category().PathType()))
defer closer()
defer close(foldersComplete) defer close(foldersComplete)
var spcs []data.BackupCollection var spcs []data.BackupCollection

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"runtime/trace" "runtime/trace"
"sync"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/microsoftgraph/msgraph-sdk-go/models" "github.com/microsoftgraph/msgraph-sdk-go/models"
@ -48,6 +49,7 @@ func RestoreCollections(
opts control.Options, opts control.Options,
dcs []data.RestoreCollection, dcs []data.RestoreCollection,
deets *details.Builder, deets *details.Builder,
pool *sync.Pool,
errs *fault.Bus, errs *fault.Bus,
) (*support.ConnectorOperationStatus, error) { ) (*support.ConnectorOperationStatus, error) {
var ( var (
@ -90,6 +92,7 @@ func RestoreCollections(
dest.ContainerName, dest.ContainerName,
deets, deets,
opts.RestorePermissions, opts.RestorePermissions,
pool,
errs) errs)
case path.ListsCategory: case path.ListsCategory:

View File

@ -180,7 +180,7 @@ func Message(ctx context.Context, msgs ...any) {
func MessageWithCompletion( func MessageWithCompletion(
ctx context.Context, ctx context.Context,
msg any, msg any,
) (chan<- struct{}, func()) { ) chan<- struct{} {
var ( var (
plain = plainString(msg) plain = plainString(msg)
loggable = fmt.Sprintf("%v", msg) loggable = fmt.Sprintf("%v", msg)
@ -191,7 +191,8 @@ func MessageWithCompletion(
log.Info(loggable) log.Info(loggable)
if cfg.hidden() { if cfg.hidden() {
return ch, func() { log.Info("done - " + loggable) } defer log.Info("done - " + loggable)
return ch
} }
wg.Add(1) wg.Add(1)
@ -219,11 +220,11 @@ func MessageWithCompletion(
bar.SetTotal(-1, true) bar.SetTotal(-1, true)
}) })
wacb := waitAndCloseBar(bar, func() { go waitAndCloseBar(bar, func() {
log.Info("done - " + loggable) log.Info("done - " + loggable)
}) })()
return ch, wacb return ch
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -247,7 +248,8 @@ func ItemProgress(
log.Debug(header) log.Debug(header)
if cfg.hidden() || rc == nil || totalBytes == 0 { if cfg.hidden() || rc == nil || totalBytes == 0 {
return rc, func() { log.Debug("done - " + header) } defer log.Debug("done - " + header)
return rc, func() {}
} }
wg.Add(1) wg.Add(1)
@ -266,12 +268,17 @@ func ItemProgress(
bar := progress.New(totalBytes, mpb.NopStyle(), barOpts...) bar := progress.New(totalBytes, mpb.NopStyle(), barOpts...)
wacb := waitAndCloseBar(bar, func() { go waitAndCloseBar(bar, func() {
// might be overly chatty, we can remove if needed. // might be overly chatty, we can remove if needed.
log.Debug("done - " + header) log.Debug("done - " + header)
}) })()
return bar.ProxyReader(rc), wacb abort := func() {
bar.SetTotal(-1, true)
bar.Abort(true)
}
return bar.ProxyReader(rc), abort
} }
// ProgressWithCount tracks the display of a bar that tracks the completion // ProgressWithCount tracks the display of a bar that tracks the completion
@ -283,7 +290,7 @@ func ProgressWithCount(
header string, header string,
msg any, msg any,
count int64, count int64,
) (chan<- struct{}, func()) { ) chan<- struct{} {
var ( var (
plain = plainString(msg) plain = plainString(msg)
loggable = fmt.Sprintf("%s %v - %d", header, msg, count) loggable = fmt.Sprintf("%s %v - %d", header, msg, count)
@ -295,7 +302,10 @@ func ProgressWithCount(
if cfg.hidden() { if cfg.hidden() {
go listen(ctx, ch, nop, nop) go listen(ctx, ch, nop, nop)
return ch, func() { log.Info("done - " + loggable) }
defer log.Info("done - " + loggable)
return ch
} }
wg.Add(1) wg.Add(1)
@ -319,11 +329,11 @@ func ProgressWithCount(
func() { bar.Abort(true) }, func() { bar.Abort(true) },
bar.Increment) bar.Increment)
wacb := waitAndCloseBar(bar, func() { go waitAndCloseBar(bar, func() {
log.Info("done - " + loggable) log.Info("done - " + loggable)
}) })()
return ch, wacb return ch
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -365,7 +375,7 @@ func CollectionProgress(
ctx context.Context, ctx context.Context,
category string, category string,
dirName any, dirName any,
) (chan<- struct{}, func()) { ) chan<- struct{} {
var ( var (
counted int counted int
plain = plainString(dirName) plain = plainString(dirName)
@ -388,7 +398,10 @@ func CollectionProgress(
if cfg.hidden() || len(plain) == 0 { if cfg.hidden() || len(plain) == 0 {
go listen(ctx, ch, nop, incCount) go listen(ctx, ch, nop, incCount)
return ch, func() { log.Infow("done - "+message, "count", counted) }
defer log.Infow("done - "+message, "count", counted)
return ch
} }
wg.Add(1) wg.Add(1)
@ -420,19 +433,22 @@ func CollectionProgress(
bar.Increment() bar.Increment()
}) })
wacb := waitAndCloseBar(bar, func() { go waitAndCloseBar(bar, func() {
log.Infow("done - "+message, "count", counted) log.Infow("done - "+message, "count", counted)
}) })()
return ch, wacb return ch
} }
func waitAndCloseBar(bar *mpb.Bar, log func()) func() { func waitAndCloseBar(bar *mpb.Bar, log func()) func() {
return func() { return func() {
bar.Wait() bar.Wait()
wg.Done() wg.Done()
if !bar.Aborted() {
log() log()
} }
}
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------

View File

@ -51,16 +51,14 @@ func (suite *ObserveProgressUnitSuite) TestItemProgress() {
}() }()
from := make([]byte, 100) from := make([]byte, 100)
prog, closer := ItemProgress( prog, abort := ItemProgress(
ctx, ctx,
io.NopCloser(bytes.NewReader(from)), io.NopCloser(bytes.NewReader(from)),
"folder", "folder",
tst, tst,
100) 100)
require.NotNil(t, prog) require.NotNil(t, prog)
require.NotNil(t, closer) require.NotNil(t, abort)
defer closer()
var i int var i int
@ -105,9 +103,8 @@ func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnCtxCancel
SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
progCh, closer := CollectionProgress(ctx, testcat, testertons) progCh := CollectionProgress(ctx, testcat, testertons)
require.NotNil(t, progCh) require.NotNil(t, progCh)
require.NotNil(t, closer)
defer close(progCh) defer close(progCh)
@ -119,9 +116,6 @@ func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnCtxCancel
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
cancel() cancel()
}() }()
// blocks, but should resolve due to the ctx cancel
closer()
} }
func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnChannelClose() { func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnChannelClose() {
@ -140,9 +134,8 @@ func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnChannelCl
SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
progCh, closer := CollectionProgress(ctx, testcat, testertons) progCh := CollectionProgress(ctx, testcat, testertons)
require.NotNil(t, progCh) require.NotNil(t, progCh)
require.NotNil(t, closer)
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
progCh <- struct{}{} progCh <- struct{}{}
@ -152,9 +145,6 @@ func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnChannelCl
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
close(progCh) close(progCh)
}() }()
// blocks, but should resolve due to the cancel
closer()
} }
func (suite *ObserveProgressUnitSuite) TestObserveProgress() { func (suite *ObserveProgressUnitSuite) TestObserveProgress() {
@ -197,14 +187,11 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCompletion() {
message := "Test Message" message := "Test Message"
ch, closer := MessageWithCompletion(ctx, message) ch := MessageWithCompletion(ctx, message)
// Trigger completion // Trigger completion
ch <- struct{}{} ch <- struct{}{}
// Run the closer - this should complete because the bar was compelted above
closer()
Complete() Complete()
require.NotEmpty(t, recorder.String()) require.NotEmpty(t, recorder.String())
@ -229,14 +216,11 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithChannelClosed() {
message := "Test Message" message := "Test Message"
ch, closer := MessageWithCompletion(ctx, message) ch := MessageWithCompletion(ctx, message)
// Close channel without completing // Close channel without completing
close(ch) close(ch)
// Run the closer - this should complete because the channel was closed above
closer()
Complete() Complete()
require.NotEmpty(t, recorder.String()) require.NotEmpty(t, recorder.String())
@ -263,14 +247,11 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithContextCancelled()
message := "Test Message" message := "Test Message"
_, closer := MessageWithCompletion(ctx, message) _ = MessageWithCompletion(ctx, message)
// cancel context // cancel context
cancel() cancel()
// Run the closer - this should complete because the context was closed above
closer()
Complete() Complete()
require.NotEmpty(t, recorder.String()) require.NotEmpty(t, recorder.String())
@ -296,15 +277,12 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCount() {
message := "Test Message" message := "Test Message"
count := 3 count := 3
ch, closer := ProgressWithCount(ctx, header, message, int64(count)) ch := ProgressWithCount(ctx, header, message, int64(count))
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
ch <- struct{}{} ch <- struct{}{}
} }
// Run the closer - this should complete because the context was closed above
closer()
Complete() Complete()
require.NotEmpty(t, recorder.String()) require.NotEmpty(t, recorder.String())
@ -331,13 +309,10 @@ func (suite *ObserveProgressUnitSuite) TestrogressWithCountChannelClosed() {
message := "Test Message" message := "Test Message"
count := 3 count := 3
ch, closer := ProgressWithCount(ctx, header, message, int64(count)) ch := ProgressWithCount(ctx, header, message, int64(count))
close(ch) close(ch)
// Run the closer - this should complete because the context was closed above
closer()
Complete() Complete()
require.NotEmpty(t, recorder.String()) require.NotEmpty(t, recorder.String())

View File

@ -407,11 +407,10 @@ func produceBackupDataCollections(
ctrlOpts control.Options, ctrlOpts control.Options,
errs *fault.Bus, errs *fault.Bus,
) ([]data.BackupCollection, prefixmatcher.StringSetReader, error) { ) ([]data.BackupCollection, prefixmatcher.StringSetReader, error) {
complete, closer := observe.MessageWithCompletion(ctx, "Discovering items to backup") complete := observe.MessageWithCompletion(ctx, "Discovering items to backup")
defer func() { defer func() {
complete <- struct{}{} complete <- struct{}{}
close(complete) close(complete)
closer()
}() }()
return bp.ProduceBackupCollections( return bp.ProduceBackupCollections(
@ -490,11 +489,10 @@ func consumeBackupCollections(
isIncremental bool, isIncremental bool,
errs *fault.Bus, errs *fault.Bus,
) (*kopia.BackupStats, *details.Builder, kopia.DetailsMergeInfoer, error) { ) (*kopia.BackupStats, *details.Builder, kopia.DetailsMergeInfoer, error) {
complete, closer := observe.MessageWithCompletion(ctx, "Backing up data") complete := observe.MessageWithCompletion(ctx, "Backing up data")
defer func() { defer func() {
complete <- struct{}{} complete <- struct{}{}
close(complete) close(complete)
closer()
}() }()
tags := map[string]string{ tags := map[string]string{

View File

@ -236,8 +236,7 @@ func (op *RestoreOperation) do(
observe.Message(ctx, fmt.Sprintf("Discovered %d items in backup %s to restore", len(paths), op.BackupID)) observe.Message(ctx, fmt.Sprintf("Discovered %d items in backup %s to restore", len(paths), op.BackupID))
kopiaComplete, closer := observe.MessageWithCompletion(ctx, "Enumerating items in repository") kopiaComplete := observe.MessageWithCompletion(ctx, "Enumerating items in repository")
defer closer()
defer close(kopiaComplete) defer close(kopiaComplete)
dcs, err := op.kopia.ProduceRestoreCollections(ctx, bup.SnapshotID, paths, opStats.bytesRead, op.Errors) dcs, err := op.kopia.ProduceRestoreCollections(ctx, bup.SnapshotID, paths, opStats.bytesRead, op.Errors)
@ -322,11 +321,10 @@ func consumeRestoreCollections(
dcs []data.RestoreCollection, dcs []data.RestoreCollection,
errs *fault.Bus, errs *fault.Bus,
) (*details.Details, error) { ) (*details.Details, error) {
complete, closer := observe.MessageWithCompletion(ctx, "Restoring data") complete := observe.MessageWithCompletion(ctx, "Restoring data")
defer func() { defer func() {
complete <- struct{}{} complete <- struct{}{}
close(complete) close(complete)
closer()
}() }()
deets, err := rc.ConsumeRestoreCollections( deets, err := rc.ConsumeRestoreCollections(

View File

@ -203,8 +203,7 @@ func Connect(
// their output getting clobbered (#1720) // their output getting clobbered (#1720)
defer observe.Complete() defer observe.Complete()
complete, closer := observe.MessageWithCompletion(ctx, "Connecting to repository") complete := observe.MessageWithCompletion(ctx, "Connecting to repository")
defer closer()
defer close(complete) defer close(complete)
kopiaRef := kopia.NewConn(s) kopiaRef := kopia.NewConn(s)
@ -630,11 +629,10 @@ func connectToM365(
sel selectors.Selector, sel selectors.Selector,
acct account.Account, acct account.Account,
) (*connector.GraphConnector, error) { ) (*connector.GraphConnector, error) {
complete, closer := observe.MessageWithCompletion(ctx, "Connecting to M365") complete := observe.MessageWithCompletion(ctx, "Connecting to M365")
defer func() { defer func() {
complete <- struct{}{} complete <- struct{}{}
close(complete) close(complete)
closer()
}() }()
// retrieve data from the producer // retrieve data from the producer