diff --git a/src/internal/common/readers/retry_handler.go b/src/internal/common/readers/retry_handler.go new file mode 100644 index 000000000..ea6ece185 --- /dev/null +++ b/src/internal/common/readers/retry_handler.go @@ -0,0 +1,226 @@ +package readers + +import ( + "context" + "errors" + "fmt" + "io" + "syscall" + "time" + + "github.com/alcionai/clues" + + "github.com/alcionai/corso/src/pkg/logger" +) + +var _ io.ReadCloser = &resetRetryHandler{} + +const ( + minSleepTime = 3 + numMaxRetries = 3 + rangeHeaderKey = "Range" + // One-sided range like this is defined as starting at the given byte and + // extending to the end of the item. + rangeHeaderOneSidedValueTmpl = "bytes=%d-" +) + +// Could make this per wrapper instance if we need additional flexibility +// between callers. +var retryErrs = []error{ + syscall.ECONNRESET, +} + +type Getter interface { + // SupportsRange returns true if this Getter supports adding Range headers to + // the Get call. Otherwise returns false. + SupportsRange() bool + // Get attempts to get another reader for the data this reader is returning. + // headers denotes any additional headers that should be added to the request, + // like a Range header. + // + // Don't allow passing a URL to Get so that we can hide the fact that some + // components may need to dynamically refresh the fetch URL (i.e. OneDrive) + // from this wrapper. + // + // Get should encapsulate all error handling and status code checking required + // for the component. This function is called both during NewResetRetryHandler + // and Read so it's possible to discover errors with the item prior to + // informing other components about it if desired. + Get(ctx context.Context, headers map[string]string) (io.ReadCloser, error) +} + +// NewResetRetryHandler returns an io.ReadCloser with the reader initialized to +// the result of getter. The reader is eagerly initialized during this call so +// if callers of this function want to delay initialization they should wrap +// this reader in a lazy initializer. +// +// Selected errors that the reader hits during Read calls (e.x. +// syscall.ECONNRESET) will be automatically retried by the returned reader. +func NewResetRetryHandler( + ctx context.Context, + getter Getter, +) (*resetRetryHandler, error) { + rrh := &resetRetryHandler{ + ctx: ctx, + getter: getter, + } + + // Retry logic encapsulated in reconnect so no need for it here. + _, err := rrh.reconnect(numMaxRetries) + + return rrh, clues.Wrap(err, "initializing reader").OrNil() +} + +//nolint:unused +type resetRetryHandler struct { + ctx context.Context + getter Getter + innerReader io.ReadCloser + offset int64 +} + +func isRetriable(err error) bool { + if err == nil { + return false + } + + for _, e := range retryErrs { + if errors.Is(err, e) { + return true + } + } + + return false +} + +func (rrh *resetRetryHandler) Read(p []byte) (int, error) { + if rrh.innerReader == nil { + return 0, clues.New("not initialized") + } + + var ( + // Use separate error variable just to make other assignments in the loop a + // bit cleaner. + finalErr error + read int + numRetries int + ) + + // Still need to check retry count in loop header so we don't go through one + // last time after failing to reconnect due to exhausting retries. + for numRetries < numMaxRetries { + n, err := rrh.innerReader.Read(p[read:]) + rrh.offset = rrh.offset + int64(n) + read = read + n + + // Catch short reads with no error and errors we don't know how to retry. + if !isRetriable(err) { + // Not everything knows how to handle a wrapped version of EOF (including + // io.ReadAll) so return the error itself here. + if errors.Is(err, io.EOF) { + // Log info about the error, but only if it's not directly an EOF. + // Otherwise this can be rather chatty and annoying to filter out. + if err != io.EOF { + logger.CtxErr(rrh.ctx, err).Debug("dropping wrapped io.EOF") + } + + return read, io.EOF + } + + return read, clues.Stack(err).WithClues(rrh.ctx).OrNil() + } + + logger.Ctx(rrh.ctx).Infow( + "restarting reader", + "supports_range", rrh.getter.SupportsRange(), + "restart_at_offset", rrh.offset, + "retries_remaining", numMaxRetries-numRetries, + "retriable_error", err) + + attempts, err := rrh.reconnect(numMaxRetries - numRetries) + numRetries = numRetries + attempts + finalErr = err + } + + // We couln't read anything through all the retries but never had an error + // getting another reader. Report this as an error so we don't get stuck in an + // infinite loop. + if read == 0 && finalErr == nil && numRetries >= numMaxRetries { + finalErr = clues.Wrap(io.ErrNoProgress, "unable to read data") + } + + return read, clues.Stack(finalErr).OrNil() +} + +// reconnect attempts to get another instance of the underlying reader and set +// the reader to pickup where the previous reader left off. +// +// Since this function can be called by functions that also implement retries on +// read errors pass an int in to denote how many times to attempt to reconnect. +// This avoids mulplicative retries when called from other functions. +func (rrh *resetRetryHandler) reconnect(maxRetries int) (int, error) { + var ( + attempts int + skip = rrh.offset + headers = map[string]string{} + // This is annoying but we want the equivalent of a do-while loop. + err = retryErrs[0] + ) + + if rrh.getter.SupportsRange() { + headers[rangeHeaderKey] = fmt.Sprintf( + rangeHeaderOneSidedValueTmpl, + rrh.offset) + skip = 0 + } + + ctx := clues.Add( + rrh.ctx, + "supports_range", rrh.getter.SupportsRange(), + "restart_at_offset", rrh.offset) + + for attempts < maxRetries && isRetriable(err) { + // Attempts will be 0 the first time through so it won't sleep then. + time.Sleep(time.Duration(attempts*minSleepTime) * time.Second) + + attempts++ + + var r io.ReadCloser + + r, err = rrh.getter.Get(ctx, headers) + if err != nil { + err = clues.Wrap(err, "retrying connection"). + WithClues(ctx). + With("attempt_num", attempts) + + continue + } + + if rrh.innerReader != nil { + rrh.innerReader.Close() + } + + rrh.innerReader = r + + // If we can't request a specific range of content then read as many bytes + // as we've already processed into the equivalent of /dev/null so that the + // next read will get content we haven't seen before. + if skip > 0 { + _, err = io.CopyN(io.Discard, rrh.innerReader, skip) + if err != nil { + err = clues.Wrap(err, "seeking to correct offset"). + WithClues(ctx). + With("attempt_num", attempts) + } + } + } + + return attempts, err +} + +func (rrh *resetRetryHandler) Close() error { + err := rrh.innerReader.Close() + rrh.innerReader = nil + + return clues.Stack(err).OrNil() +} diff --git a/src/internal/common/readers/retry_handler_test.go b/src/internal/common/readers/retry_handler_test.go new file mode 100644 index 000000000..f5842e6fa --- /dev/null +++ b/src/internal/common/readers/retry_handler_test.go @@ -0,0 +1,496 @@ +package readers_test + +import ( + "bytes" + "context" + "io" + "syscall" + "testing" + + "github.com/alcionai/clues" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/common/readers" + "github.com/alcionai/corso/src/internal/tester" +) + +type readResp struct { + read int + // sticky denotes whether the error should continue to be returned until reset + // is called. + sticky bool + err error +} + +type mockReader struct { + r io.Reader + data []byte + // Associate return values for Read with calls. Allows partial reads as well. + // If a value for a particular read call is not in the map that means + // completing the request completely with no errors (i.e. all bytes requested + // are returned or as many as possible and EOF). + resps map[int]readResp + callCount int + stickyErr error +} + +func (mr *mockReader) Read(p []byte) (int, error) { + defer func() { + mr.callCount++ + }() + + if mr.r == nil { + mr.reset(0) + } + + if mr.stickyErr != nil { + return 0, clues.Wrap(mr.stickyErr, "sticky error") + } + + resp, ok := mr.resps[mr.callCount] + if !ok { + n, err := mr.r.Read(p) + return n, clues.Stack(err).OrNil() + } + + n, err := mr.r.Read(p[:resp.read]) + + if resp.err != nil { + if resp.sticky { + mr.stickyErr = resp.err + } + + return n, clues.Stack(resp.err) + } + + return n, clues.Stack(err).OrNil() +} + +func (mr *mockReader) reset(n int) { + mr.r = bytes.NewBuffer(mr.data[n:]) + mr.stickyErr = nil +} + +type getterResp struct { + offset int + err error +} + +type mockGetter struct { + t *testing.T + supportsRange bool + reader *mockReader + resps map[int]getterResp + expectHeaders map[int]map[string]string + callCount int +} + +func (mg *mockGetter) SupportsRange() bool { + return mg.supportsRange +} + +func (mg *mockGetter) Get( + ctx context.Context, + headers map[string]string, +) (io.ReadCloser, error) { + defer func() { + mg.callCount++ + }() + + expectHeaders := mg.expectHeaders[mg.callCount] + if expectHeaders == nil { + expectHeaders = map[string]string{} + } + + assert.Equal(mg.t, expectHeaders, headers) + + resp := mg.resps[mg.callCount] + + if resp.offset >= 0 { + mg.reader.reset(resp.offset) + } + + return io.NopCloser(mg.reader), clues.Stack(resp.err).OrNil() +} + +type ResetRetryHandlerUnitSuite struct { + tester.Suite +} + +func TestResetRetryHandlerUnitSuite(t *testing.T) { + suite.Run(t, &ResetRetryHandlerUnitSuite{Suite: tester.NewUnitSuite(t)}) +} + +func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() { + data := []byte("abcdefghijklmnopqrstuvwxyz") + // Pick a smaller read size so we can see how things will act if we have a + // "chunked" set of data. + readSize := 4 + + table := []struct { + name string + supportsRange bool + // 0th entry is the return data when trying to initialize the wrapper. + getterResps map[int]getterResp + // 0th entry is the return data when trying to initialize the wrapper. + getterExpectHeaders map[int]map[string]string + readerResps map[int]readResp + expectData []byte + expectErr error + }{ + { + name: "OnlyFirstGetErrors NoRangeSupport", + getterResps: map[int]getterResp{ + 0: { + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "OnlyFirstReadErrors RangeSupport", + supportsRange: true, + getterExpectHeaders: map[int]map[string]string{ + 0: {"Range": "bytes=0-"}, + 1: {"Range": "bytes=0-"}, + }, + getterResps: map[int]getterResp{ + 0: { + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "ErrorInMiddle NoRangeSupport", + readerResps: map[int]readResp{ + 3: { + read: 0, + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "ErrorInMiddle RangeSupport", + supportsRange: true, + getterResps: map[int]getterResp{ + 1: {offset: 12}, + }, + getterExpectHeaders: map[int]map[string]string{ + 0: {"Range": "bytes=0-"}, + 1: {"Range": "bytes=12-"}, + }, + readerResps: map[int]readResp{ + 3: { + read: 0, + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "MultipleErrorsInMiddle NoRangeSupport", + readerResps: map[int]readResp{ + 3: { + read: 0, + err: syscall.ECONNRESET, + }, + 7: { + read: 0, + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "MultipleErrorsInMiddle RangeSupport", + supportsRange: true, + getterResps: map[int]getterResp{ + 1: {offset: 12}, + 2: {offset: 20}, + }, + getterExpectHeaders: map[int]map[string]string{ + 0: {"Range": "bytes=0-"}, + 1: {"Range": "bytes=12-"}, + 2: {"Range": "bytes=20-"}, + }, + readerResps: map[int]readResp{ + 3: { + read: 0, + err: syscall.ECONNRESET, + }, + 6: { + read: 0, + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "ShortReadWithError NoRangeSupport", + readerResps: map[int]readResp{ + 3: { + read: readSize / 2, + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "ShortReadWithError RangeSupport", + supportsRange: true, + getterResps: map[int]getterResp{ + 1: {offset: 14}, + }, + getterExpectHeaders: map[int]map[string]string{ + 0: {"Range": "bytes=0-"}, + 1: {"Range": "bytes=14-"}, + }, + readerResps: map[int]readResp{ + 3: { + read: readSize / 2, + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "ErrorAtEndOfRead NoRangeSupport", + readerResps: map[int]readResp{ + 3: { + read: readSize, + sticky: true, + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "ErrorAtEndOfRead RangeSupport", + supportsRange: true, + getterResps: map[int]getterResp{ + 1: {offset: 16}, + }, + getterExpectHeaders: map[int]map[string]string{ + 0: {"Range": "bytes=0-"}, + 1: {"Range": "bytes=16-"}, + }, + readerResps: map[int]readResp{ + 3: { + read: readSize, + sticky: true, + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "UnexpectedError NoRangeSupport", + readerResps: map[int]readResp{ + 3: { + read: 0, + err: assert.AnError, + }, + }, + expectData: data[:12], + expectErr: assert.AnError, + }, + { + name: "UnexpectedError RangeSupport", + supportsRange: true, + getterResps: map[int]getterResp{ + 1: {offset: 12}, + }, + getterExpectHeaders: map[int]map[string]string{ + 0: {"Range": "bytes=0-"}, + 1: {"Range": "bytes=12-"}, + }, + readerResps: map[int]readResp{ + 3: { + read: 0, + err: assert.AnError, + }, + }, + expectData: data[:12], + expectErr: assert.AnError, + }, + { + name: "ErrorWhileSeeking NoRangeSupport", + readerResps: map[int]readResp{ + 3: { + read: 0, + err: syscall.ECONNRESET, + }, + 4: { + read: 0, + err: syscall.ECONNRESET, + }, + }, + expectData: data, + }, + { + name: "ShortReadNoError NoRangeSupport", + readerResps: map[int]readResp{ + 3: { + read: readSize / 2, + }, + }, + expectData: data, + }, + { + name: "ShortReadNoError RangeSupport", + supportsRange: true, + getterResps: map[int]getterResp{ + 1: {offset: 14}, + }, + getterExpectHeaders: map[int]map[string]string{ + 0: {"Range": "bytes=0-"}, + 1: {"Range": "bytes=14-"}, + }, + readerResps: map[int]readResp{ + 3: { + read: readSize / 2, + }, + }, + expectData: data, + }, + { + name: "TooManyRetriesDuringRead NoRangeSupport", + // Fail the final reconnect attempt so we run out of retries. Otherwise we + // exit with a short read and successful reconnect. + getterResps: map[int]getterResp{ + 3: {err: syscall.ECONNRESET}, + }, + // Even numbered read requests are seeks to the proper offset. + readerResps: map[int]readResp{ + 3: { + read: 0, + err: syscall.ECONNRESET, + }, + 5: { + read: 1, + err: syscall.ECONNRESET, + }, + 7: { + read: 1, + err: syscall.ECONNRESET, + }, + }, + expectData: data[:14], + expectErr: syscall.ECONNRESET, + }, + { + name: "TooManyRetriesDuringRead RangeSupport", + supportsRange: true, + getterResps: map[int]getterResp{ + 1: {offset: 12}, + 2: {offset: 12}, + 3: {err: syscall.ECONNRESET}, + }, + getterExpectHeaders: map[int]map[string]string{ + 0: {"Range": "bytes=0-"}, + 1: {"Range": "bytes=12-"}, + 2: {"Range": "bytes=13-"}, + 3: {"Range": "bytes=14-"}, + }, + readerResps: map[int]readResp{ + 3: { + read: 0, + err: syscall.ECONNRESET, + }, + 4: { + read: 1, + err: syscall.ECONNRESET, + }, + 5: { + read: 1, + err: syscall.ECONNRESET, + }, + }, + expectData: data[:14], + expectErr: syscall.ECONNRESET, + }, + { + name: "TooManyRetriesDuringRead AlwaysReturnError RangeSupport", + supportsRange: true, + getterResps: map[int]getterResp{ + 1: {offset: -1}, + 2: {offset: -1}, + 3: {offset: -1}, + 4: {offset: -1}, + 5: {offset: -1}, + }, + getterExpectHeaders: map[int]map[string]string{ + 0: {"Range": "bytes=0-"}, + 1: {"Range": "bytes=0-"}, + 2: {"Range": "bytes=0-"}, + 3: {"Range": "bytes=0-"}, + 4: {"Range": "bytes=0-"}, + 5: {"Range": "bytes=0-"}, + }, + readerResps: map[int]readResp{ + 0: { + sticky: true, + err: syscall.ECONNRESET, + }, + }, + expectData: []byte{}, + expectErr: io.ErrNoProgress, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + reader := &mockReader{ + data: data, + resps: test.readerResps, + } + + getter := &mockGetter{ + t: t, + supportsRange: test.supportsRange, + reader: reader, + resps: test.getterResps, + expectHeaders: test.getterExpectHeaders, + } + + var ( + err error + n int + offset int + resData = make([]byte, len(data)) + ) + + rrh, err := readers.NewResetRetryHandler(ctx, getter) + require.NoError(t, err, "making reader wrapper: %v", clues.ToCore(err)) + + for err == nil && offset < len(data) { + end := offset + readSize + if end > len(data) { + end = len(data) + } + + n, err = rrh.Read(resData[offset:end]) + + offset = offset + n + } + + assert.Equal(t, test.expectData, data[:offset]) + + if test.expectErr == nil { + assert.NoError(t, err, clues.ToCore(err)) + return + } + + assert.ErrorIs(t, err, test.expectErr, clues.ToCore(err)) + }) + } +}