diff --git a/src/internal/connector/graph/middleware.go b/src/internal/connector/graph/middleware.go index bc9aabe2d..2e053f9d9 100644 --- a/src/internal/connector/graph/middleware.go +++ b/src/internal/connector/graph/middleware.go @@ -209,12 +209,16 @@ func (mw RetryMiddleware) Intercept( ctx := req.Context() resp, err := pipeline.Next(req, middlewareIndex) - if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) { - return resp, stackReq(ctx, req, resp, err) - } - if resp != nil && resp.StatusCode/100 != 4 && resp.StatusCode/100 != 5 { - return resp, err + retriable := IsErrTimeout(err) || IsErrConnectionReset(err) || + (resp != nil && (resp.StatusCode/100 == 4 || resp.StatusCode/100 == 5)) + + if !retriable { + if err != nil { + return resp, stackReq(ctx, req, resp, err) + } + + return resp, nil } exponentialBackOff := backoff.NewExponentialBackOff() @@ -304,7 +308,8 @@ func (mw RetryMiddleware) retryRequest( return nextResp, stackReq(ctx, req, nextResp, err) } - return mw.retryRequest(ctx, + return mw.retryRequest( + ctx, pipeline, middlewareIndex, req, diff --git a/src/internal/connector/graph/middleware_test.go b/src/internal/connector/graph/middleware_test.go index 15faf7a7a..f122cdd72 100644 --- a/src/internal/connector/graph/middleware_test.go +++ b/src/internal/connector/graph/middleware_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net/http" + "syscall" "testing" "time" @@ -37,12 +38,18 @@ func newMWReturns(code int, body []byte, err error) mwReturns { brc = io.NopCloser(bytes.NewBuffer(body)) } + resp := &http.Response{ + StatusCode: code, + Body: brc, + } + + if code == 0 { + resp = nil + } + return mwReturns{ - err: err, - resp: &http.Response{ - StatusCode: code, - Body: brc, - }, + err: err, + resp: resp, } } @@ -142,6 +149,7 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() { tests := []struct { name string status int + providedErr error expectRetryCount int mw testMW expectErr assert.ErrorAssertionFunc @@ -149,12 +157,14 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() { { name: "200, no retries", status: http.StatusOK, + providedErr: nil, expectRetryCount: 0, expectErr: assert.NoError, }, { name: "400, no retries", status: http.StatusBadRequest, + providedErr: nil, expectRetryCount: 0, expectErr: assert.Error, }, @@ -162,9 +172,47 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() { // don't test 504: gets intercepted by graph client for long waits. name: "502", status: http.StatusBadGateway, + providedErr: nil, expectRetryCount: defaultMaxRetries, expectErr: assert.Error, }, + { + name: "conn reset with 5xx", + status: http.StatusBadGateway, + providedErr: syscall.ECONNRESET, + expectRetryCount: defaultMaxRetries, + expectErr: assert.Error, + }, + { + name: "conn reset with 2xx", + status: http.StatusOK, + providedErr: syscall.ECONNRESET, + expectRetryCount: defaultMaxRetries, + expectErr: assert.Error, + }, + { + name: "conn reset with nil resp", + providedErr: syscall.ECONNRESET, + // Use 0 to denote nil http response + status: 0, + expectRetryCount: 3, + expectErr: assert.Error, + }, + { + // Unlikely but check if connection reset error takes precedence + name: "conn reset with 400 resp", + providedErr: syscall.ECONNRESET, + status: http.StatusBadRequest, + expectRetryCount: 3, + expectErr: assert.Error, + }, + { + name: "http timeout", + providedErr: http.ErrHandlerTimeout, + status: 0, + expectRetryCount: 3, + expectErr: assert.Error, + }, } for _, test := range tests { @@ -177,7 +225,7 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() { called := 0 mw := newTestMW( func(*http.Request) { called++ }, - newMWReturns(test.status, nil, nil)) + newMWReturns(test.status, nil, test.providedErr)) mw.repeatReturn0 = true adpt, err := mockAdapter(suite.creds, mw)