centralize observe channel listener handling (#2523)

## Description

To ensure that the observe channel handling
doesn't accidentally spawn unkillable routines,
adds a centralized channel listener func for
standard channel management in observe
progress bars.

## Does this PR need a docs update or release note?

- [x]  No 

## Type of change

- [x] 🐛 Bugfix

## Test Plan

- [x] 💪 Manual
- [x]  Unit test
This commit is contained in:
Keepers 2023-02-15 13:20:07 -07:00 committed by GitHub
parent e60ae2351f
commit 7187b969e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 214 additions and 150 deletions

View File

@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Support for item.Attachment:Mail restore
- Errors from duplicate names in Exchange Calendars
- Resolved an issue where progress bar displays could fail to exit, causing unbounded CPU consumption.
### Changed
- When using Restore and Details on Exchange Calendars, the `--event-calendar` flag can now identify calendars by either a Display Name or a Microsoft 365 ID.

View File

@ -22,6 +22,9 @@ const (
progressBarWidth = 32
)
// styling
const Bullet = "∙"
var (
wg sync.WaitGroup
// TODO: Revisit this being a global nd make it a parameter to the progress methods
@ -168,16 +171,17 @@ func MessageWithCompletion(
ctx context.Context,
msg cleanable,
) (chan<- struct{}, func()) {
clean := msg.clean()
message := msg.String()
var (
clean = msg.clean()
message = msg.String()
log = logger.Ctx(ctx)
ch = make(chan struct{}, 1)
)
log := logger.Ctx(ctx)
log.Info(clean)
completionCh := make(chan struct{}, 1)
if cfg.hidden() {
return completionCh, func() { log.Info("done - " + clean) }
return ch, func() { log.Info("done - " + clean) }
}
wg.Add(1)
@ -194,24 +198,24 @@ func MessageWithCompletion(
mpb.BarFillerOnComplete("done"),
)
go func(ci <-chan struct{}) {
for {
select {
case <-contxt.Done():
bar.SetTotal(-1, true)
case <-ci:
// We don't care whether the channel was signalled or closed
// Use either one as an indication that the bar is done
bar.SetTotal(-1, true)
}
}
}(completionCh)
go listen(
ctx,
ch,
func() {
bar.SetTotal(-1, true)
bar.Abort(true)
},
func() {
// We don't care whether the channel was signalled or closed
// Use either one as an indication that the bar is done
bar.SetTotal(-1, true)
})
wacb := waitAndCloseBar(bar, func() {
log.Info("done - " + clean)
})
return completionCh, wacb
return ch, wacb
}
// ---------------------------------------------------------------------------
@ -228,7 +232,9 @@ func ItemProgress(
iname cleanable,
totalBytes int64,
) (io.ReadCloser, func()) {
log := logger.Ctx(ctx).With("item", iname.clean(), "size", humanize.Bytes(uint64(totalBytes)))
log := logger.Ctx(ctx).With(
"item", iname.clean(),
"size", humanize.Bytes(uint64(totalBytes)))
log.Debug(header)
if cfg.hidden() || rc == nil || totalBytes == 0 {
@ -270,23 +276,17 @@ func ProgressWithCount(
message cleanable,
count int64,
) (chan<- struct{}, func()) {
log := logger.Ctx(ctx)
lmsg := fmt.Sprintf("%s %s - %d", header, message.clean(), count)
var (
log = logger.Ctx(ctx)
lmsg = fmt.Sprintf("%s %s - %d", header, message.clean(), count)
ch = make(chan struct{})
)
log.Info(lmsg)
progressCh := make(chan struct{})
if cfg.hidden() {
go func(ci <-chan struct{}) {
for {
_, ok := <-ci
if !ok {
return
}
}
}(progressCh)
return progressCh, func() { log.Info("done - " + lmsg) }
go listen(ctx, ch, nop, nop)
return ch, func() { log.Info("done - " + lmsg) }
}
wg.Add(1)
@ -305,24 +305,11 @@ func ProgressWithCount(
bar := progress.New(count, mpb.NopStyle(), barOpts...)
ch := make(chan struct{})
go func(ci <-chan struct{}) {
for {
select {
case <-contxt.Done():
bar.Abort(true)
return
case _, ok := <-ci:
if !ok {
bar.Abort(true)
return
}
bar.Increment()
}
}
}(ch)
go listen(
ctx,
ch,
func() { bar.Abort(true) },
bar.Increment)
wacb := waitAndCloseBar(bar, func() {
log.Info("done - " + lmsg)
@ -371,33 +358,28 @@ func CollectionProgress(
category string,
user, dirName cleanable,
) (chan<- struct{}, func()) {
log := logger.Ctx(ctx).With(
"user", user.clean(),
"category", category,
"dir", dirName.clean())
message := "Collecting Directory"
var (
counted int
ch = make(chan struct{})
log = logger.Ctx(ctx).With(
"user", user.clean(),
"category", category,
"dir", dirName.clean())
message = "Collecting Directory"
)
log.Info(message)
incCount := func() {
counted++
// Log every 1000 items that are processed
if counted%1000 == 0 {
log.Infow("uploading", "count", counted)
}
}
if cfg.hidden() || len(user.String()) == 0 || len(dirName.String()) == 0 {
ch := make(chan struct{})
counted := 0
go func(ci <-chan struct{}) {
for {
_, ok := <-ci
if !ok {
return
}
counted++
// Log every 1000 items that are processed
if counted%1000 == 0 {
log.Infow("uploading", "count", counted)
}
}
}(ch)
go listen(ctx, ch, nop, incCount)
return ch, func() { log.Infow("done - "+message, "count", counted) }
}
@ -419,36 +401,16 @@ func CollectionProgress(
bar := progress.New(
-1, // -1 to indicate an unbounded count
mpb.SpinnerStyle(spinFrames...),
barOpts...,
)
barOpts...)
var counted int
ch := make(chan struct{})
go func(ci <-chan struct{}) {
for {
select {
case <-contxt.Done():
bar.SetTotal(-1, true)
return
case _, ok := <-ci:
if !ok {
bar.SetTotal(-1, true)
return
}
counted++
// Log every 1000 items that are processed
if counted%1000 == 0 {
log.Infow("uploading", "count", counted)
}
bar.Increment()
}
}
}(ch)
go listen(
ctx,
ch,
func() { bar.SetTotal(-1, true) },
func() {
incCount()
bar.Increment()
})
wacb := waitAndCloseBar(bar, func() {
log.Infow("done - "+message, "count", counted)
@ -469,7 +431,30 @@ func waitAndCloseBar(bar *mpb.Bar, log func()) func() {
// other funcs
// ---------------------------------------------------------------------------
const Bullet = "∙"
var nop = func() {}
// listen handles reading, and exiting, from a channel. It assumes the
// caller will run it inside a goroutine (ex: go listen(...)).
// On context timeout or channel close, the loop exits.
// onEnd() is called on both ctx.Done() and channel close. onInc is
// called on every channel read except when closing.
func listen(ctx context.Context, ch <-chan struct{}, onEnd, onInc func()) {
for {
select {
case <-ctx.Done():
onEnd()
return
case _, ok := <-ch:
if !ok {
onEnd()
return
}
onInc()
}
}
}
// ---------------------------------------------------------------------------
// PII redaction

View File

@ -1,4 +1,4 @@
package observe_test
package observe
import (
"bytes"
@ -14,22 +14,23 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/alcionai/corso/src/internal/observe"
"github.com/alcionai/corso/src/internal/tester"
)
type ObserveProgressUnitSuite struct {
suite.Suite
tester.Suite
}
func TestObserveProgressUnitSuite(t *testing.T) {
suite.Run(t, new(ObserveProgressUnitSuite))
suite.Run(t, &ObserveProgressUnitSuite{
Suite: tester.NewUnitSuite(t),
})
}
var (
tst = observe.Safe("test")
testcat = observe.Safe("testcat")
testertons = observe.Safe("testertons")
tst = Safe("test")
testcat = Safe("testcat")
testertons = Safe("testertons")
)
func (suite *ObserveProgressUnitSuite) TestItemProgress() {
@ -39,17 +40,17 @@ func (suite *ObserveProgressUnitSuite) TestItemProgress() {
t := suite.T()
recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil)
SeedWriter(ctx, &recorder, nil)
defer func() {
// don't cross-contaminate other tests.
observe.Complete()
Complete()
//nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil)
SeedWriter(context.Background(), nil, nil)
}()
from := make([]byte, 100)
prog, closer := observe.ItemProgress(
prog, closer := ItemProgress(
ctx,
io.NopCloser(bytes.NewReader(from)),
"folder",
@ -94,16 +95,16 @@ func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnCtxCancel
t := suite.T()
recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil)
SeedWriter(ctx, &recorder, nil)
defer func() {
// don't cross-contaminate other tests.
observe.Complete()
Complete()
//nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil)
SeedWriter(context.Background(), nil, nil)
}()
progCh, closer := observe.CollectionProgress(ctx, "test", testcat, testertons)
progCh, closer := CollectionProgress(ctx, "test", testcat, testertons)
require.NotNil(t, progCh)
require.NotNil(t, closer)
@ -129,16 +130,16 @@ func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnChannelCl
t := suite.T()
recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil)
SeedWriter(ctx, &recorder, nil)
defer func() {
// don't cross-contaminate other tests.
observe.Complete()
Complete()
//nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil)
SeedWriter(context.Background(), nil, nil)
}()
progCh, closer := observe.CollectionProgress(ctx, "test", testcat, testertons)
progCh, closer := CollectionProgress(ctx, "test", testcat, testertons)
require.NotNil(t, progCh)
require.NotNil(t, closer)
@ -160,18 +161,18 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgress() {
defer flush()
recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil)
SeedWriter(ctx, &recorder, nil)
defer func() {
// don't cross-contaminate other tests.
//nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil)
SeedWriter(context.Background(), nil, nil)
}()
message := "Test Message"
observe.Message(ctx, observe.Safe(message))
observe.Complete()
Message(ctx, Safe(message))
Complete()
require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message)
}
@ -181,17 +182,17 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCompletion() {
defer flush()
recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil)
SeedWriter(ctx, &recorder, nil)
defer func() {
// don't cross-contaminate other tests.
//nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil)
SeedWriter(context.Background(), nil, nil)
}()
message := "Test Message"
ch, closer := observe.MessageWithCompletion(ctx, observe.Safe(message))
ch, closer := MessageWithCompletion(ctx, Safe(message))
// Trigger completion
ch <- struct{}{}
@ -199,7 +200,7 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCompletion() {
// Run the closer - this should complete because the bar was compelted above
closer()
observe.Complete()
Complete()
require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message)
@ -211,17 +212,17 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithChannelClosed() {
defer flush()
recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil)
SeedWriter(ctx, &recorder, nil)
defer func() {
// don't cross-contaminate other tests.
//nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil)
SeedWriter(context.Background(), nil, nil)
}()
message := "Test Message"
ch, closer := observe.MessageWithCompletion(ctx, observe.Safe(message))
ch, closer := MessageWithCompletion(ctx, Safe(message))
// Close channel without completing
close(ch)
@ -229,7 +230,7 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithChannelClosed() {
// Run the closer - this should complete because the channel was closed above
closer()
observe.Complete()
Complete()
require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message)
@ -243,17 +244,17 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithContextCancelled()
ctx, cancel := context.WithCancel(ctx)
recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil)
SeedWriter(ctx, &recorder, nil)
defer func() {
// don't cross-contaminate other tests.
//nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil)
SeedWriter(context.Background(), nil, nil)
}()
message := "Test Message"
_, closer := observe.MessageWithCompletion(ctx, observe.Safe(message))
_, closer := MessageWithCompletion(ctx, Safe(message))
// cancel context
cancel()
@ -261,7 +262,7 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithContextCancelled()
// Run the closer - this should complete because the context was closed above
closer()
observe.Complete()
Complete()
require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message)
@ -272,19 +273,19 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCount() {
defer flush()
recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil)
SeedWriter(ctx, &recorder, nil)
defer func() {
// don't cross-contaminate other tests.
//nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil)
SeedWriter(context.Background(), nil, nil)
}()
header := "Header"
message := "Test Message"
count := 3
ch, closer := observe.ProgressWithCount(ctx, header, observe.Safe(message), int64(count))
ch, closer := ProgressWithCount(ctx, header, Safe(message), int64(count))
for i := 0; i < count; i++ {
ch <- struct{}{}
@ -293,40 +294,117 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCount() {
// Run the closer - this should complete because the context was closed above
closer()
observe.Complete()
Complete()
require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message)
require.Contains(suite.T(), recorder.String(), fmt.Sprintf("%d/%d", count, count))
}
func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCountChannelClosed() {
func (suite *ObserveProgressUnitSuite) TestrogressWithCountChannelClosed() {
ctx, flush := tester.NewContext()
defer flush()
recorder := strings.Builder{}
observe.SeedWriter(ctx, &recorder, nil)
SeedWriter(ctx, &recorder, nil)
defer func() {
// don't cross-contaminate other tests.
//nolint:forbidigo
observe.SeedWriter(context.Background(), nil, nil)
SeedWriter(context.Background(), nil, nil)
}()
header := "Header"
message := "Test Message"
count := 3
ch, closer := observe.ProgressWithCount(ctx, header, observe.Safe(message), int64(count))
ch, closer := ProgressWithCount(ctx, header, Safe(message), int64(count))
close(ch)
// Run the closer - this should complete because the context was closed above
closer()
observe.Complete()
Complete()
require.NotEmpty(suite.T(), recorder.String())
require.Contains(suite.T(), recorder.String(), message)
require.Contains(suite.T(), recorder.String(), fmt.Sprintf("%d/%d", 0, count))
}
func (suite *ObserveProgressUnitSuite) TestListen() {
ctx, flush := tester.NewContext()
defer flush()
var (
t = suite.T()
ch = make(chan struct{})
end bool
onEnd = func() { end = true }
inc bool
onInc = func() { inc = true }
)
go func() {
time.Sleep(500 * time.Millisecond)
ch <- struct{}{}
time.Sleep(500 * time.Millisecond)
close(ch)
}()
// regular channel close
listen(ctx, ch, onEnd, onInc)
assert.True(t, end)
assert.True(t, inc)
}
func (suite *ObserveProgressUnitSuite) TestListen_close() {
ctx, flush := tester.NewContext()
defer flush()
var (
t = suite.T()
ch = make(chan struct{})
end bool
onEnd = func() { end = true }
inc bool
onInc = func() { inc = true }
)
go func() {
time.Sleep(500 * time.Millisecond)
close(ch)
}()
// regular channel close
listen(ctx, ch, onEnd, onInc)
assert.True(t, end)
assert.False(t, inc)
}
func (suite *ObserveProgressUnitSuite) TestListen_cancel() {
ctx, flush := tester.NewContext()
defer flush()
ctx, cancelFn := context.WithCancel(ctx)
var (
t = suite.T()
ch = make(chan struct{})
end bool
onEnd = func() { end = true }
inc bool
onInc = func() { inc = true }
)
go func() {
time.Sleep(500 * time.Millisecond)
cancelFn()
}()
// regular channel close
listen(ctx, ch, onEnd, onInc)
assert.True(t, end)
assert.False(t, inc)
}