diff --git a/src/internal/connector/exchange/exchange_data_collection.go b/src/internal/connector/exchange/exchange_data_collection.go index c65db87f4..53fda6717 100644 --- a/src/internal/connector/exchange/exchange_data_collection.go +++ b/src/internal/connector/exchange/exchange_data_collection.go @@ -18,6 +18,7 @@ import ( "github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/support" "github.com/alcionai/corso/src/internal/data" + "github.com/alcionai/corso/src/internal/observe" "github.com/alcionai/corso/src/pkg/backup/details" "github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/path" @@ -116,14 +117,19 @@ func (col *Collection) populateByOptionIdentifier( errs error success int totalBytes int64 + + user = col.user + objectWriter = kw.NewJsonSerializationWriter() ) + colProgress, closer := observe.CollectionProgress(user, col.fullPath.Category().String(), col.fullPath.Folder()) + go closer() + defer func() { + close(colProgress) col.finishPopulation(ctx, success, totalBytes, errs) }() - user := col.user - objectWriter := kw.NewJsonSerializationWriter() // get QueryBasedonIdentifier // verify that it is the correct type in called function // serializationFunction @@ -159,6 +165,7 @@ func (col *Collection) populateByOptionIdentifier( success++ totalBytes += int64(byteCount) + colProgress <- struct{}{} } } diff --git a/src/internal/connector/exchange/service_restore.go b/src/internal/connector/exchange/service_restore.go index 07e1c7480..c6b8f7b7e 100644 --- a/src/internal/connector/exchange/service_restore.go +++ b/src/internal/connector/exchange/service_restore.go @@ -14,6 +14,7 @@ import ( "github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/support" "github.com/alcionai/corso/src/internal/data" + "github.com/alcionai/corso/src/internal/observe" "github.com/alcionai/corso/src/pkg/backup/details" "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/logger" @@ -322,6 +323,10 @@ func restoreCollection( user = directory.ResourceOwner() ) + colProgress, closer := observe.CollectionProgress(user, category.String(), directory.Folder()) + defer closer() + defer close(colProgress) + for { select { case <-ctx.Done(): @@ -372,6 +377,8 @@ func restoreCollection( details.ItemInfo{ Exchange: info, }) + + colProgress <- struct{}{} } } } diff --git a/src/internal/observe/observe.go b/src/internal/observe/observe.go index 55c504b10..58754b439 100644 --- a/src/internal/observe/observe.go +++ b/src/internal/observe/observe.go @@ -2,6 +2,7 @@ package observe import ( "context" + "fmt" "io" "sync" @@ -9,6 +10,8 @@ import ( "github.com/vbauerster/mpb/v8/decor" ) +const progressBarWidth = 32 + var ( wg sync.WaitGroup con context.Context @@ -16,6 +19,10 @@ var ( progress *mpb.Progress ) +func init() { + makeSpinFrames(progressBarWidth) +} + // SeedWriter adds default writer to the observe package. // Uses a noop writer until seeded. func SeedWriter(ctx context.Context, w io.Writer) { @@ -28,7 +35,7 @@ func SeedWriter(ctx context.Context, w io.Writer) { progress = mpb.NewWithContext( con, - mpb.WithWidth(32), + mpb.WithWidth(progressBarWidth), mpb.WithWaitGroup(&wg), mpb.WithOutput(writer), ) @@ -67,10 +74,98 @@ func ItemProgress(rc io.ReadCloser, iname string, totalBytes int64) (io.ReadClos ), ) - return bar.ProxyReader(rc), waitAndCloseBar(iname, bar) + return bar.ProxyReader(rc), waitAndCloseBar(bar) } -func waitAndCloseBar(n string, bar *mpb.Bar) func() { +var spinFrames []string + +// The bar width is set to a static 32 characters. The default spinner is only +// one char wide, which puts a lot of white space between it and the useful text. +// This builds a custom spinner animation to fill up that whitespace for a cleaner +// display. +func makeSpinFrames(barWidth int) { + s, l := rune('∙'), rune('●') + + line := []rune{} + for i := 0; i < barWidth; i++ { + line = append(line, s) + } + + sl := make([]string, 0, barWidth+1) + sl = append(sl, string(line)) + + for i := 1; i < barWidth; i++ { + l2 := make([]rune, len(line)) + copy(l2, line) + l2[i] = l + + sl = append(sl, string(l2)) + } + + spinFrames = sl +} + +// ItemProgress tracks the display a spinner that idles while the collection +// incrementing the count of items handled. Each write to the provided channel +// counts as a single increment. The caller is expected to close the channel. +func CollectionProgress(user, category, dirName string) (chan<- struct{}, func()) { + if writer == nil || len(user) == 0 || len(dirName) == 0 { + ch := make(chan struct{}) + + go func(ci <-chan struct{}) { + for { + _, ok := <-ci + if !ok { + return + } + } + }(ch) + + return ch, func() {} + } + + wg.Add(1) + + bar := progress.New( + -1, // -1 to indicate an unbounded count + mpb.SpinnerStyle(spinFrames...), + mpb.BarFillerOnComplete(""), + mpb.BarRemoveOnComplete(), + mpb.PrependDecorators( + decor.OnComplete(decor.Name(category), ""), + ), + mpb.AppendDecorators( + decor.OnComplete(decor.CurrentNoUnit("%d - ", decor.WCSyncSpace), ""), + decor.OnComplete( + decor.Name(fmt.Sprintf("%s - %s", user, dirName)), + ""), + ), + ) + + ch := make(chan struct{}) + + go func(ci <-chan struct{}) { + for { + select { + case <-con.Done(): + bar.SetTotal(-1, true) + return + + case _, ok := <-ci: + if !ok { + bar.SetTotal(-1, true) + return + } + + bar.Increment() + } + } + }(ch) + + return ch, waitAndCloseBar(bar) +} + +func waitAndCloseBar(bar *mpb.Bar) func() { return func() { bar.Wait() wg.Done() diff --git a/src/internal/observe/observe_test.go b/src/internal/observe/observe_test.go index 84183f9cd..16b559e4a 100644 --- a/src/internal/observe/observe_test.go +++ b/src/internal/observe/observe_test.go @@ -7,6 +7,7 @@ import ( "io" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -24,7 +25,7 @@ func TestObserveProgressUnitSuite(t *testing.T) { suite.Run(t, new(ObserveProgressUnitSuite)) } -func (suite *ObserveProgressUnitSuite) TestDoesThings() { +func (suite *ObserveProgressUnitSuite) TestItemProgress() { ctx, flush := tester.NewContext() defer flush() @@ -74,3 +75,73 @@ func (suite *ObserveProgressUnitSuite) TestDoesThings() { // assert.Contains(t, recorded, "75%") assert.Equal(t, 4, i) } + +func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnCtxCancel() { + ctx, flush := tester.NewContext() + defer flush() + + ctx, cancel := context.WithCancel(ctx) + + t := suite.T() + + recorder := strings.Builder{} + observe.SeedWriter(ctx, &recorder) + + defer func() { + // don't cross-contaminate other tests. + observe.Complete() + //nolint:forbidigo + observe.SeedWriter(context.Background(), nil) + }() + + progCh, closer := observe.CollectionProgress("test", "testcat", "testertons") + require.NotNil(t, progCh) + require.NotNil(t, closer) + + defer close(progCh) + + for i := 0; i < 50; i++ { + progCh <- struct{}{} + } + + go func() { + time.Sleep(1 * time.Second) + cancel() + }() + + // blocks, but should resolve due to the ctx cancel + closer() +} + +func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnChannelClose() { + ctx, flush := tester.NewContext() + defer flush() + + t := suite.T() + + recorder := strings.Builder{} + observe.SeedWriter(ctx, &recorder) + + defer func() { + // don't cross-contaminate other tests. + observe.Complete() + //nolint:forbidigo + observe.SeedWriter(context.Background(), nil) + }() + + progCh, closer := observe.CollectionProgress("test", "testcat", "testertons") + require.NotNil(t, progCh) + require.NotNil(t, closer) + + for i := 0; i < 50; i++ { + progCh <- struct{}{} + } + + go func() { + time.Sleep(1 * time.Second) + close(progCh) + }() + + // blocks, but should resolve due to the cancel + closer() +}