centralize observe channel listener handling (#2523)

## Description

To ensure that the observe channel handling
doesn't accidentally spawn unkillable routines,
adds a centralized channel listener func for
standard channel management in observe
progress bars.

## Does this PR need a docs update or release note?

- [x]  No 

## Type of change

- [x] 🐛 Bugfix

## Test Plan

- [x] 💪 Manual
- [x]  Unit test
This commit is contained in:
Keepers 2023-02-15 13:20:07 -07:00 committed by GitHub
parent e60ae2351f
commit 7187b969e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 214 additions and 150 deletions

View File

@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed ### Fixed
- Support for item.Attachment:Mail restore - Support for item.Attachment:Mail restore
- Errors from duplicate names in Exchange Calendars - Errors from duplicate names in Exchange Calendars
- Resolved an issue where progress bar displays could fail to exit, causing unbounded CPU consumption.
### Changed ### Changed
- When using Restore and Details on Exchange Calendars, the `--event-calendar` flag can now identify calendars by either a Display Name or a Microsoft 365 ID. - When using Restore and Details on Exchange Calendars, the `--event-calendar` flag can now identify calendars by either a Display Name or a Microsoft 365 ID.

View File

@ -22,6 +22,9 @@ const (
progressBarWidth = 32 progressBarWidth = 32
) )
// styling
const Bullet = "∙"
var ( var (
wg sync.WaitGroup wg sync.WaitGroup
// TODO: Revisit this being a global nd make it a parameter to the progress methods // TODO: Revisit this being a global nd make it a parameter to the progress methods
@ -168,16 +171,17 @@ func MessageWithCompletion(
ctx context.Context, ctx context.Context,
msg cleanable, msg cleanable,
) (chan<- struct{}, func()) { ) (chan<- struct{}, func()) {
clean := msg.clean() var (
message := msg.String() clean = msg.clean()
message = msg.String()
log = logger.Ctx(ctx)
ch = make(chan struct{}, 1)
)
log := logger.Ctx(ctx)
log.Info(clean) log.Info(clean)
completionCh := make(chan struct{}, 1)
if cfg.hidden() { if cfg.hidden() {
return completionCh, func() { log.Info("done - " + clean) } return ch, func() { log.Info("done - " + clean) }
} }
wg.Add(1) wg.Add(1)
@ -194,24 +198,24 @@ func MessageWithCompletion(
mpb.BarFillerOnComplete("done"), mpb.BarFillerOnComplete("done"),
) )
go func(ci <-chan struct{}) { go listen(
for { ctx,
select { ch,
case <-contxt.Done(): func() {
bar.SetTotal(-1, true) bar.SetTotal(-1, true)
case <-ci: bar.Abort(true)
},
func() {
// We don't care whether the channel was signalled or closed // We don't care whether the channel was signalled or closed
// Use either one as an indication that the bar is done // Use either one as an indication that the bar is done
bar.SetTotal(-1, true) bar.SetTotal(-1, true)
} })
}
}(completionCh)
wacb := waitAndCloseBar(bar, func() { wacb := waitAndCloseBar(bar, func() {
log.Info("done - " + clean) log.Info("done - " + clean)
}) })
return completionCh, wacb return ch, wacb
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -228,7 +232,9 @@ func ItemProgress(
iname cleanable, iname cleanable,
totalBytes int64, totalBytes int64,
) (io.ReadCloser, func()) { ) (io.ReadCloser, func()) {
log := logger.Ctx(ctx).With("item", iname.clean(), "size", humanize.Bytes(uint64(totalBytes))) log := logger.Ctx(ctx).With(
"item", iname.clean(),
"size", humanize.Bytes(uint64(totalBytes)))
log.Debug(header) log.Debug(header)
if cfg.hidden() || rc == nil || totalBytes == 0 { if cfg.hidden() || rc == nil || totalBytes == 0 {
@ -270,23 +276,17 @@ func ProgressWithCount(
message cleanable, message cleanable,
count int64, count int64,
) (chan<- struct{}, func()) { ) (chan<- struct{}, func()) {
log := logger.Ctx(ctx) var (
lmsg := fmt.Sprintf("%s %s - %d", header, message.clean(), count) log = logger.Ctx(ctx)
lmsg = fmt.Sprintf("%s %s - %d", header, message.clean(), count)
ch = make(chan struct{})
)
log.Info(lmsg) log.Info(lmsg)
progressCh := make(chan struct{})
if cfg.hidden() { if cfg.hidden() {
go func(ci <-chan struct{}) { go listen(ctx, ch, nop, nop)
for { return ch, func() { log.Info("done - " + lmsg) }
_, ok := <-ci
if !ok {
return
}
}
}(progressCh)
return progressCh, func() { log.Info("done - " + lmsg) }
} }
wg.Add(1) wg.Add(1)
@ -305,24 +305,11 @@ func ProgressWithCount(
bar := progress.New(count, mpb.NopStyle(), barOpts...) bar := progress.New(count, mpb.NopStyle(), barOpts...)
ch := make(chan struct{}) go listen(
go func(ci <-chan struct{}) { ctx,
for { ch,
select { func() { bar.Abort(true) },
case <-contxt.Done(): bar.Increment)
bar.Abort(true)
return
case _, ok := <-ci:
if !ok {
bar.Abort(true)
return
}
bar.Increment()
}
}
}(ch)
wacb := waitAndCloseBar(bar, func() { wacb := waitAndCloseBar(bar, func() {
log.Info("done - " + lmsg) log.Info("done - " + lmsg)
@ -371,33 +358,28 @@ func CollectionProgress(
category string, category string,
user, dirName cleanable, user, dirName cleanable,
) (chan<- struct{}, func()) { ) (chan<- struct{}, func()) {
log := logger.Ctx(ctx).With( var (
counted int
ch = make(chan struct{})
log = logger.Ctx(ctx).With(
"user", user.clean(), "user", user.clean(),
"category", category, "category", category,
"dir", dirName.clean()) "dir", dirName.clean())
message := "Collecting Directory" message = "Collecting Directory"
)
log.Info(message) log.Info(message)
if cfg.hidden() || len(user.String()) == 0 || len(dirName.String()) == 0 { incCount := func() {
ch := make(chan struct{})
counted := 0
go func(ci <-chan struct{}) {
for {
_, ok := <-ci
if !ok {
return
}
counted++ counted++
// Log every 1000 items that are processed // Log every 1000 items that are processed
if counted%1000 == 0 { if counted%1000 == 0 {
log.Infow("uploading", "count", counted) log.Infow("uploading", "count", counted)
} }
} }
}(ch)
if cfg.hidden() || len(user.String()) == 0 || len(dirName.String()) == 0 {
go listen(ctx, ch, nop, incCount)
return ch, func() { log.Infow("done - "+message, "count", counted) } return ch, func() { log.Infow("done - "+message, "count", counted) }
} }
@ -419,36 +401,16 @@ func CollectionProgress(
bar := progress.New( bar := progress.New(
-1, // -1 to indicate an unbounded count -1, // -1 to indicate an unbounded count
mpb.SpinnerStyle(spinFrames...), mpb.SpinnerStyle(spinFrames...),
barOpts..., barOpts...)
)
var counted int
ch := make(chan struct{})
go func(ci <-chan struct{}) {
for {
select {
case <-contxt.Done():
bar.SetTotal(-1, true)
return
case _, ok := <-ci:
if !ok {
bar.SetTotal(-1, true)
return
}
counted++
// Log every 1000 items that are processed
if counted%1000 == 0 {
log.Infow("uploading", "count", counted)
}
go listen(
ctx,
ch,
func() { bar.SetTotal(-1, true) },
func() {
incCount()
bar.Increment() bar.Increment()
} })
}
}(ch)
wacb := waitAndCloseBar(bar, func() { wacb := waitAndCloseBar(bar, func() {
log.Infow("done - "+message, "count", counted) log.Infow("done - "+message, "count", counted)
@ -469,7 +431,30 @@ func waitAndCloseBar(bar *mpb.Bar, log func()) func() {
// other funcs // other funcs
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
const Bullet = "∙" var nop = func() {}
// listen handles reading, and exiting, from a channel. It assumes the
// caller will run it inside a goroutine (ex: go listen(...)).
// On context timeout or channel close, the loop exits.
// onEnd() is called on both ctx.Done() and channel close. onInc is
// called on every channel read except when closing.
func listen(ctx context.Context, ch <-chan struct{}, onEnd, onInc func()) {
for {
select {
case <-ctx.Done():
onEnd()
return
case _, ok := <-ch:
if !ok {
onEnd()
return
}
onInc()
}
}
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// PII redaction // PII redaction

View File

@ -1,4 +1,4 @@
package observe_test package observe
import ( import (
"bytes" "bytes"
@ -14,22 +14,23 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/alcionai/corso/src/internal/observe"
"github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/internal/tester"
) )
type ObserveProgressUnitSuite struct { type ObserveProgressUnitSuite struct {
suite.Suite tester.Suite
} }
func TestObserveProgressUnitSuite(t *testing.T) { func TestObserveProgressUnitSuite(t *testing.T) {
suite.Run(t, new(ObserveProgressUnitSuite)) suite.Run(t, &ObserveProgressUnitSuite{
Suite: tester.NewUnitSuite(t),
})
} }
var ( var (
tst = observe.Safe("test") tst = Safe("test")
testcat = observe.Safe("testcat") testcat = Safe("testcat")
testertons = observe.Safe("testertons") testertons = Safe("testertons")
) )
func (suite *ObserveProgressUnitSuite) TestItemProgress() { func (suite *ObserveProgressUnitSuite) TestItemProgress() {
@ -39,17 +40,17 @@ func (suite *ObserveProgressUnitSuite) TestItemProgress() {
t := suite.T() t := suite.T()
recorder := strings.Builder{} recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil) SeedWriter(ctx, &recorder, nil)
defer func() { defer func() {
// don't cross-contaminate other tests. // don't cross-contaminate other tests.
observe.Complete() Complete()
//nolint:forbidigo //nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
from := make([]byte, 100) from := make([]byte, 100)
prog, closer := observe.ItemProgress( prog, closer := ItemProgress(
ctx, ctx,
io.NopCloser(bytes.NewReader(from)), io.NopCloser(bytes.NewReader(from)),
"folder", "folder",
@ -94,16 +95,16 @@ func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnCtxCancel
t := suite.T() t := suite.T()
recorder := strings.Builder{} recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil) SeedWriter(ctx, &recorder, nil)
defer func() { defer func() {
// don't cross-contaminate other tests. // don't cross-contaminate other tests.
observe.Complete() Complete()
//nolint:forbidigo //nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
progCh, closer := observe.CollectionProgress(ctx, "test", testcat, testertons) progCh, closer := CollectionProgress(ctx, "test", testcat, testertons)
require.NotNil(t, progCh) require.NotNil(t, progCh)
require.NotNil(t, closer) require.NotNil(t, closer)
@ -129,16 +130,16 @@ func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnChannelCl
t := suite.T() t := suite.T()
recorder := strings.Builder{} recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil) SeedWriter(ctx, &recorder, nil)
defer func() { defer func() {
// don't cross-contaminate other tests. // don't cross-contaminate other tests.
observe.Complete() Complete()
//nolint:forbidigo //nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
progCh, closer := observe.CollectionProgress(ctx, "test", testcat, testertons) progCh, closer := CollectionProgress(ctx, "test", testcat, testertons)
require.NotNil(t, progCh) require.NotNil(t, progCh)
require.NotNil(t, closer) require.NotNil(t, closer)
@ -160,18 +161,18 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgress() {
defer flush() defer flush()
recorder := strings.Builder{} recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil) SeedWriter(ctx, &recorder, nil)
defer func() { defer func() {
// don't cross-contaminate other tests. // don't cross-contaminate other tests.
//nolint:forbidigo //nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
message := "Test Message" message := "Test Message"
observe.Message(ctx, observe.Safe(message)) Message(ctx, Safe(message))
observe.Complete() Complete()
require.NotEmpty(suite.T(), recorder.String()) require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message) require.Contains(suite.T(), recorder.String(), message)
} }
@ -181,17 +182,17 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCompletion() {
defer flush() defer flush()
recorder := strings.Builder{} recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil) SeedWriter(ctx, &recorder, nil)
defer func() { defer func() {
// don't cross-contaminate other tests. // don't cross-contaminate other tests.
//nolint:forbidigo //nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
message := "Test Message" message := "Test Message"
ch, closer := observe.MessageWithCompletion(ctx, observe.Safe(message)) ch, closer := MessageWithCompletion(ctx, Safe(message))
// Trigger completion // Trigger completion
ch <- struct{}{} ch <- struct{}{}
@ -199,7 +200,7 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCompletion() {
// Run the closer - this should complete because the bar was compelted above // Run the closer - this should complete because the bar was compelted above
closer() closer()
observe.Complete() Complete()
require.NotEmpty(suite.T(), recorder.String()) require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message) require.Contains(suite.T(), recorder.String(), message)
@ -211,17 +212,17 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithChannelClosed() {
defer flush() defer flush()
recorder := strings.Builder{} recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil) SeedWriter(ctx, &recorder, nil)
defer func() { defer func() {
// don't cross-contaminate other tests. // don't cross-contaminate other tests.
//nolint:forbidigo //nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
message := "Test Message" message := "Test Message"
ch, closer := observe.MessageWithCompletion(ctx, observe.Safe(message)) ch, closer := MessageWithCompletion(ctx, Safe(message))
// Close channel without completing // Close channel without completing
close(ch) close(ch)
@ -229,7 +230,7 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithChannelClosed() {
// Run the closer - this should complete because the channel was closed above // Run the closer - this should complete because the channel was closed above
closer() closer()
observe.Complete() Complete()
require.NotEmpty(suite.T(), recorder.String()) require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message) require.Contains(suite.T(), recorder.String(), message)
@ -243,17 +244,17 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithContextCancelled()
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
recorder := strings.Builder{} recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil) SeedWriter(ctx, &recorder, nil)
defer func() { defer func() {
// don't cross-contaminate other tests. // don't cross-contaminate other tests.
//nolint:forbidigo //nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
message := "Test Message" message := "Test Message"
_, closer := observe.MessageWithCompletion(ctx, observe.Safe(message)) _, closer := MessageWithCompletion(ctx, Safe(message))
// cancel context // cancel context
cancel() cancel()
@ -261,7 +262,7 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithContextCancelled()
// Run the closer - this should complete because the context was closed above // Run the closer - this should complete because the context was closed above
closer() closer()
observe.Complete() Complete()
require.NotEmpty(suite.T(), recorder.String()) require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message) require.Contains(suite.T(), recorder.String(), message)
@ -272,19 +273,19 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCount() {
defer flush() defer flush()
recorder := strings.Builder{} recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil) SeedWriter(ctx, &recorder, nil)
defer func() { defer func() {
// don't cross-contaminate other tests. // don't cross-contaminate other tests.
//nolint:forbidigo //nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
header := "Header" header := "Header"
message := "Test Message" message := "Test Message"
count := 3 count := 3
ch, closer := observe.ProgressWithCount(ctx, header, observe.Safe(message), int64(count)) ch, closer := ProgressWithCount(ctx, header, Safe(message), int64(count))
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
ch <- struct{}{} ch <- struct{}{}
@ -293,40 +294,117 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCount() {
// Run the closer - this should complete because the context was closed above // Run the closer - this should complete because the context was closed above
closer() closer()
observe.Complete() Complete()
require.NotEmpty(suite.T(), recorder.String()) require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message) require.Contains(suite.T(), recorder.String(), message)
require.Contains(suite.T(), recorder.String(), fmt.Sprintf("%d/%d", count, count)) require.Contains(suite.T(), recorder.String(), fmt.Sprintf("%d/%d", count, count))
} }
func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCountChannelClosed() { func (suite *ObserveProgressUnitSuite) TestrogressWithCountChannelClosed() {
ctx, flush := tester.NewContext() ctx, flush := tester.NewContext()
defer flush() defer flush()
recorder := strings.Builder{} recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil) SeedWriter(ctx, &recorder, nil)
defer func() { defer func() {
// don't cross-contaminate other tests. // don't cross-contaminate other tests.
//nolint:forbidigo //nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil) SeedWriter(context.Background(), nil, nil)
}() }()
header := "Header" header := "Header"
message := "Test Message" message := "Test Message"
count := 3 count := 3
ch, closer := observe.ProgressWithCount(ctx, header, observe.Safe(message), int64(count)) ch, closer := ProgressWithCount(ctx, header, Safe(message), int64(count))
close(ch) close(ch)
// Run the closer - this should complete because the context was closed above // Run the closer - this should complete because the context was closed above
closer() closer()
observe.Complete() Complete()
require.NotEmpty(suite.T(), recorder.String()) require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message) require.Contains(suite.T(), recorder.String(), message)
require.Contains(suite.T(), recorder.String(), fmt.Sprintf("%d/%d", 0, count)) require.Contains(suite.T(), recorder.String(), fmt.Sprintf("%d/%d", 0, count))
} }
func (suite *ObserveProgressUnitSuite) TestListen() {
ctx, flush := tester.NewContext()
defer flush()
var (
t = suite.T()
ch = make(chan struct{})
end bool
onEnd = func() { end = true }
inc bool
onInc = func() { inc = true }
)
go func() {
time.Sleep(500 * time.Millisecond)
ch <- struct{}{}
time.Sleep(500 * time.Millisecond)
close(ch)
}()
// regular channel close
listen(ctx, ch, onEnd, onInc)
assert.True(t, end)
assert.True(t, inc)
}
func (suite *ObserveProgressUnitSuite) TestListen_close() {
ctx, flush := tester.NewContext()
defer flush()
var (
t = suite.T()
ch = make(chan struct{})
end bool
onEnd = func() { end = true }
inc bool
onInc = func() { inc = true }
)
go func() {
time.Sleep(500 * time.Millisecond)
close(ch)
}()
// regular channel close
listen(ctx, ch, onEnd, onInc)
assert.True(t, end)
assert.False(t, inc)
}
func (suite *ObserveProgressUnitSuite) TestListen_cancel() {
ctx, flush := tester.NewContext()
defer flush()
ctx, cancelFn := context.WithCancel(ctx)
var (
t = suite.T()
ch = make(chan struct{})
end bool
onEnd = func() { end = true }
inc bool
onInc = func() { inc = true }
)
go func() {
time.Sleep(500 * time.Millisecond)
cancelFn()
}()
// regular channel close
listen(ctx, ch, onEnd, onInc)
assert.True(t, end)
assert.False(t, inc)
}