diff --git a/src/internal/connector/graph_connector.go b/src/internal/connector/graph_connector.go index 42d4ece0f..989459dcb 100644 --- a/src/internal/connector/graph_connector.go +++ b/src/internal/connector/graph_connector.go @@ -240,6 +240,7 @@ func (gc *GraphConnector) RestoreMessages(ctx context.Context, dcs []data.Collec attempts, successes int errs error ) + gc.incrementAwaitingMessages() for _, dc := range dcs { // must be user.GetId(), PrimaryName no longer works 6-15-2022 @@ -302,7 +303,10 @@ func (gc *GraphConnector) RestoreMessages(ctx context.Context, dcs []data.Collec } status := support.CreateStatus(ctx, support.Restore, attempts, successes, len(pathCounter), errs) - gc.SetStatus(*status) + // set the channel asynchronously so that this func doesn't block. + go func(cos *support.ConnectorOperationStatus) { + gc.statusCh <- cos + }(status) logger.Ctx(ctx).Debug(gc.PrintableStatus()) return errs } @@ -356,19 +360,13 @@ func (gc *GraphConnector) serializeMessages(ctx context.Context, user string) (m return collections, err } -// SetStatus helper function -func (gc *GraphConnector) SetStatus(cos support.ConnectorOperationStatus) { - gc.status = &cos -} - // AwaitStatus updates status field based on item within statusChannel. func (gc *GraphConnector) AwaitStatus() *support.ConnectorOperationStatus { if gc.awaitingMessages > 0 { gc.status = <-gc.statusCh atomic.AddInt32(&gc.awaitingMessages, -1) - return gc.status } - return nil + return gc.status } // Status returns the current status of the graphConnector operaion. diff --git a/src/internal/connector/graph_connector_disconnected_test.go b/src/internal/connector/graph_connector_disconnected_test.go index 9268ebf1e..44956460f 100644 --- a/src/internal/connector/graph_connector_disconnected_test.go +++ b/src/internal/connector/graph_connector_disconnected_test.go @@ -93,14 +93,20 @@ func (suite *DisconnectedGraphConnectorSuite) TestInterfaceAlignment() { } func (suite *DisconnectedGraphConnectorSuite) TestGraphConnector_Status() { - gc := GraphConnector{} + gc := GraphConnector{ + statusCh: make(chan *support.ConnectorOperationStatus), + } suite.Equal(len(gc.PrintableStatus()), 0) - status := support.CreateStatus( - context.Background(), - support.Restore, - 12, 9, 8, - support.WrapAndAppend("tres", errors.New("three"), support.WrapAndAppend("arc376", errors.New("one"), errors.New("two")))) - gc.SetStatus(*status) + gc.incrementAwaitingMessages() + go func() { + status := support.CreateStatus( + context.Background(), + support.Restore, + 12, 9, 8, + support.WrapAndAppend("tres", errors.New("three"), support.WrapAndAppend("arc376", errors.New("one"), errors.New("two")))) + gc.statusCh <- status + }() + gc.AwaitStatus() suite.Greater(len(gc.PrintableStatus()), 0) suite.Greater(gc.Status().ObjectCount, 0) } diff --git a/src/internal/operations/restore.go b/src/internal/operations/restore.go index b8b6f5b0f..e75c36646 100644 --- a/src/internal/operations/restore.go +++ b/src/internal/operations/restore.go @@ -127,7 +127,7 @@ func (op *RestoreOperation) Run(ctx context.Context) error { stats.writeErr = errors.Wrap(err, "restoring service data") return stats.writeErr } - stats.gc = gc.Status() + stats.gc = gc.AwaitStatus() op.Status = Successful return nil @@ -139,7 +139,8 @@ func (op *RestoreOperation) persistResults( stats *restoreStats, ) { op.Status = Successful - if stats.readErr != nil || stats.writeErr != nil { + if (stats.readErr != nil || stats.writeErr != nil) && + (stats.gc == nil || stats.gc.Successful == 0) { op.Status = Failed } op.Results.ReadErrors = stats.readErr @@ -148,7 +149,7 @@ func (op *RestoreOperation) persistResults( op.Results.ItemsRead = len(stats.cs) // TODO: file count, not collection count if stats.gc != nil { - op.Results.ItemsWritten = stats.gc.ObjectCount + op.Results.ItemsWritten = stats.gc.Successful } op.Results.StartedAt = started diff --git a/src/internal/operations/restore_test.go b/src/internal/operations/restore_test.go index 6fe6af4f5..611808798 100644 --- a/src/internal/operations/restore_test.go +++ b/src/internal/operations/restore_test.go @@ -59,13 +59,13 @@ func (suite *RestoreOpSuite) TestRestoreOperation_PersistResults() { op.persistResults(now, &stats) - assert.Equal(t, op.Status, Failed) - assert.Equal(t, op.Results.ItemsRead, len(stats.cs)) - assert.Equal(t, op.Results.ReadErrors, stats.readErr) - assert.Equal(t, op.Results.ItemsWritten, stats.gc.ObjectCount) - assert.Equal(t, op.Results.WriteErrors, stats.writeErr) - assert.Equal(t, op.Results.StartedAt, now) - assert.Less(t, now, op.Results.CompletedAt) + assert.Equal(t, op.Status, Failed, "status") + assert.Equal(t, op.Results.ItemsRead, len(stats.cs), "items read") + assert.Equal(t, op.Results.ReadErrors, stats.readErr, "read errors") + assert.Equal(t, op.Results.ItemsWritten, stats.gc.Successful, "items written") + assert.Equal(t, op.Results.WriteErrors, stats.writeErr, "write errors") + assert.Equal(t, op.Results.StartedAt, now, "started at") + assert.Less(t, now, op.Results.CompletedAt, "completed at") } // ---------------------------------------------------------------------------