diff --git a/src/internal/observe/observe.go b/src/internal/observe/observe.go index b91562efb..fc50cfd29 100644 --- a/src/internal/observe/observe.go +++ b/src/internal/observe/observe.go @@ -14,15 +14,16 @@ import ( ) const ( - noProgressBarsFN = "no-progress-bars" - progressBarWidth = 32 + hideProgressBarsFN = "hide-progress" + retainProgressBarsFN = "retain-progress" + progressBarWidth = 32 ) var ( wg sync.WaitGroup // TODO: Revisit this being a global nd make it a parameter to the progress methods // so that each bar can be initialized with different contexts if needed. - con context.Context + contxt context.Context writer io.Writer progress *mpb.Progress cfg *config @@ -34,37 +35,49 @@ func init() { makeSpinFrames(progressBarWidth) } -// adds the persistent boolean flag --no-progress-bars to the provided command. +// adds the persistent boolean flag --hide-progress to the provided command. // This is a hack for help displays. Due to seeding the context, we also // need to parse the configuration before we execute the command. func AddProgressBarFlags(parent *cobra.Command) { fs := parent.PersistentFlags() - fs.Bool(noProgressBarsFN, false, "turn off the progress bar displays") + fs.Bool(hideProgressBarsFN, false, "turn off the progress bar displays") + fs.Bool(retainProgressBarsFN, false, "retain the progress bar displays after completion") } // Due to races between the lazy evaluation of flags in cobra and the need to init observer // behavior in a ctx, these options get pre-processed manually here using pflags. The canonical // AddProgressBarFlag() ensures the flags are displayed as part of the help/usage output. -func PreloadFlags() bool { +func PreloadFlags() *config { fs := pflag.NewFlagSet("seed-observer", pflag.ContinueOnError) fs.ParseErrorsWhitelist.UnknownFlags = true - fs.Bool(noProgressBarsFN, false, "turn off the progress bar displays") + fs.Bool(hideProgressBarsFN, false, "turn off the progress bar displays") + fs.Bool(retainProgressBarsFN, false, "retain the progress bar displays after completion") // prevents overriding the corso/cobra help processor fs.BoolP("help", "h", false, "") - // parse the os args list to find the log level flag + // parse the os args list to find the observer display flags if err := fs.Parse(os.Args[1:]); err != nil { - return false + return nil } // retrieve the user's preferred display // automatically defaults to "info" - shouldHide, err := fs.GetBool(noProgressBarsFN) + shouldHide, err := fs.GetBool(hideProgressBarsFN) if err != nil { - return false + return nil } - return shouldHide + // retrieve the user's preferred display + // automatically defaults to "info" + shouldAlwaysShow, err := fs.GetBool(retainProgressBarsFN) + if err != nil { + return nil + } + + return &config{ + doNotDisplay: shouldHide, + keepBarsAfterComplete: shouldAlwaysShow, + } } // --------------------------------------------------------------------------- @@ -73,7 +86,8 @@ func PreloadFlags() bool { // config handles observer configuration type config struct { - doNotDisplay bool + doNotDisplay bool + keepBarsAfterComplete bool } func (c config) hidden() bool { @@ -82,20 +96,20 @@ func (c config) hidden() bool { // SeedWriter adds default writer to the observe package. // Uses a noop writer until seeded. -func SeedWriter(ctx context.Context, w io.Writer, hide bool) { +func SeedWriter(ctx context.Context, w io.Writer, c *config) { writer = w - con = ctx + contxt = ctx - if con == nil { - con = context.Background() + if contxt == nil { + contxt = context.Background() } - cfg = &config{ - doNotDisplay: hide, + if c != nil { + cfg = c } progress = mpb.NewWithContext( - con, + contxt, mpb.WithWidth(progressBarWidth), mpb.WithWaitGroup(&wg), mpb.WithOutput(writer), @@ -109,7 +123,7 @@ func Complete() { progress.Wait() } - SeedWriter(con, writer, cfg.doNotDisplay) + SeedWriter(contxt, writer, cfg) } const ( @@ -168,7 +182,7 @@ func MessageWithCompletion(message string) (chan<- struct{}, func()) { go func(ci <-chan struct{}) { for { select { - case <-con.Done(): + case <-contxt.Done(): bar.SetTotal(-1, true) case <-ci: // We don't care whether the channel was signalled or closed @@ -195,17 +209,20 @@ func ItemProgress(rc io.ReadCloser, header, iname string, totalBytes int64) (io. wg.Add(1) - bar := progress.New( - totalBytes, - mpb.NopStyle(), - mpb.BarRemoveOnComplete(), + barOpts := []mpb.BarOption{ mpb.PrependDecorators( decor.Name(header, decor.WCSyncSpaceR), decor.Name(iname, decor.WCSyncSpaceR), decor.CountersKibiByte(" %.1f/%.1f ", decor.WC{W: 8}), decor.NewPercentage("%d ", decor.WC{W: 4}), ), - ) + } + + if !cfg.keepBarsAfterComplete { + barOpts = append(barOpts, mpb.BarRemoveOnComplete()) + } + + bar := progress.New(totalBytes, mpb.NopStyle(), barOpts...) return bar.ProxyReader(rc), waitAndCloseBar(bar) } @@ -232,23 +249,25 @@ func ProgressWithCount(header, message string, count int64) (chan<- struct{}, fu wg.Add(1) - bar := progress.New( - count, - mpb.NopStyle(), - mpb.BarRemoveOnComplete(), + barOpts := []mpb.BarOption{ mpb.PrependDecorators( decor.Name(header, decor.WCSyncSpaceR), decor.Counters(0, " %d/%d "), decor.Name(message), ), - ) + } + + if !cfg.keepBarsAfterComplete { + barOpts = append(barOpts, mpb.BarRemoveOnComplete()) + } + + bar := progress.New(count, mpb.NopStyle(), barOpts...) ch := make(chan struct{}) - go func(ci <-chan struct{}) { for { select { - case <-con.Done(): + case <-contxt.Done(): bar.Abort(true) return @@ -319,25 +338,29 @@ func CollectionProgress(user, category, dirName string) (chan<- struct{}, func() wg.Add(1) - bar := progress.New( - -1, // -1 to indicate an unbounded count - mpb.SpinnerStyle(spinFrames...), - mpb.BarRemoveOnComplete(), - mpb.PrependDecorators( - decor.Name(category), - ), + barOpts := []mpb.BarOption{ + mpb.PrependDecorators(decor.Name(category)), mpb.AppendDecorators( decor.CurrentNoUnit("%d - ", decor.WCSyncSpace), decor.Name(fmt.Sprintf("%s - %s", user, dirName)), ), + } + + if !cfg.keepBarsAfterComplete { + barOpts = append(barOpts, mpb.BarRemoveOnComplete()) + } + + bar := progress.New( + -1, // -1 to indicate an unbounded count + mpb.SpinnerStyle(spinFrames...), + barOpts..., ) ch := make(chan struct{}) - go func(ci <-chan struct{}) { for { select { - case <-con.Done(): + case <-contxt.Done(): bar.SetTotal(-1, true) return diff --git a/src/internal/observe/observe_test.go b/src/internal/observe/observe_test.go index d4f273685..96809a235 100644 --- a/src/internal/observe/observe_test.go +++ b/src/internal/observe/observe_test.go @@ -33,13 +33,13 @@ func (suite *ObserveProgressUnitSuite) TestItemProgress() { t := suite.T() recorder := strings.Builder{} - observe.SeedWriter(ctx, &recorder, false) + observe.SeedWriter(ctx, &recorder, nil) defer func() { // don't cross-contaminate other tests. observe.Complete() //nolint:forbidigo - observe.SeedWriter(context.Background(), nil, false) + observe.SeedWriter(context.Background(), nil, nil) }() from := make([]byte, 100) @@ -87,13 +87,13 @@ func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnCtxCancel t := suite.T() recorder := strings.Builder{} - observe.SeedWriter(ctx, &recorder, false) + observe.SeedWriter(ctx, &recorder, nil) defer func() { // don't cross-contaminate other tests. observe.Complete() //nolint:forbidigo - observe.SeedWriter(context.Background(), nil, false) + observe.SeedWriter(context.Background(), nil, nil) }() progCh, closer := observe.CollectionProgress("test", "testcat", "testertons") @@ -122,13 +122,13 @@ func (suite *ObserveProgressUnitSuite) TestCollectionProgress_unblockOnChannelCl t := suite.T() recorder := strings.Builder{} - observe.SeedWriter(ctx, &recorder, false) + observe.SeedWriter(ctx, &recorder, nil) defer func() { // don't cross-contaminate other tests. observe.Complete() //nolint:forbidigo - observe.SeedWriter(context.Background(), nil, false) + observe.SeedWriter(context.Background(), nil, nil) }() progCh, closer := observe.CollectionProgress("test", "testcat", "testertons") @@ -153,12 +153,12 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgress() { defer flush() recorder := strings.Builder{} - observe.SeedWriter(ctx, &recorder, false) + observe.SeedWriter(ctx, &recorder, nil) defer func() { // don't cross-contaminate other tests. //nolint:forbidigo - observe.SeedWriter(context.Background(), nil, false) + observe.SeedWriter(context.Background(), nil, nil) }() message := "Test Message" @@ -174,12 +174,12 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCompletion() { defer flush() recorder := strings.Builder{} - observe.SeedWriter(ctx, &recorder, false) + observe.SeedWriter(ctx, &recorder, nil) defer func() { // don't cross-contaminate other tests. //nolint:forbidigo - observe.SeedWriter(context.Background(), nil, false) + observe.SeedWriter(context.Background(), nil, nil) }() message := "Test Message" @@ -204,12 +204,12 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithChannelClosed() { defer flush() recorder := strings.Builder{} - observe.SeedWriter(ctx, &recorder, false) + observe.SeedWriter(ctx, &recorder, nil) defer func() { // don't cross-contaminate other tests. //nolint:forbidigo - observe.SeedWriter(context.Background(), nil, false) + observe.SeedWriter(context.Background(), nil, nil) }() message := "Test Message" @@ -236,12 +236,12 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithContextCancelled() ctx, cancel := context.WithCancel(ctx) recorder := strings.Builder{} - observe.SeedWriter(ctx, &recorder, false) + observe.SeedWriter(ctx, &recorder, nil) defer func() { // don't cross-contaminate other tests. //nolint:forbidigo - observe.SeedWriter(context.Background(), nil, false) + observe.SeedWriter(context.Background(), nil, nil) }() message := "Test Message" @@ -265,12 +265,12 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCount() { defer flush() recorder := strings.Builder{} - observe.SeedWriter(ctx, &recorder, false) + observe.SeedWriter(ctx, &recorder, nil) defer func() { // don't cross-contaminate other tests. //nolint:forbidigo - observe.SeedWriter(context.Background(), nil, false) + observe.SeedWriter(context.Background(), nil, nil) }() header := "Header" @@ -298,12 +298,12 @@ func (suite *ObserveProgressUnitSuite) TestObserveProgressWithCountChannelClosed defer flush() recorder := strings.Builder{} - observe.SeedWriter(ctx, &recorder, false) + observe.SeedWriter(ctx, &recorder, nil) defer func() { // don't cross-contaminate other tests. //nolint:forbidigo - observe.SeedWriter(context.Background(), nil, false) + observe.SeedWriter(context.Background(), nil, nil) }() header := "Header"