diff --git a/src/internal/connector/exchange/api/mail_test.go b/src/internal/connector/exchange/api/mail_test.go index ad134041e..2ce0cd537 100644 --- a/src/internal/connector/exchange/api/mail_test.go +++ b/src/internal/connector/exchange/api/mail_test.go @@ -157,7 +157,7 @@ func (suite *MailAPIUnitSuite) TestMailInfo() { } } -type MailAPIE2ESuite struct { +type MailAPIIntgSuite struct { tester.Suite credentials account.M365Config ac api.Client @@ -165,9 +165,9 @@ type MailAPIE2ESuite struct { } // We do end up mocking the actual request, but creating the rest -// similar to E2E suite -func TestMailAPIE2ESuite(t *testing.T) { - suite.Run(t, &MailAPIE2ESuite{ +// similar to full integration tests. +func TestMailAPIIntgSuite(t *testing.T) { + suite.Run(t, &MailAPIIntgSuite{ Suite: tester.NewIntegrationSuite( t, [][]string{tester.M365AcctCredEnvs}, @@ -175,7 +175,7 @@ func TestMailAPIE2ESuite(t *testing.T) { }) } -func (suite *MailAPIE2ESuite) SetupSuite() { +func (suite *MailAPIIntgSuite) SetupSuite() { t := suite.T() a := tester.NewM365Account(t) @@ -205,7 +205,7 @@ func getJSONObject(t *testing.T, thing serialization.Parsable) map[string]interf return out } -func (suite *MailAPIE2ESuite) TestHugeAttachmentListDownload() { +func (suite *MailAPIIntgSuite) TestHugeAttachmentListDownload() { mid := "fake-message-id" aid := "fake-attachment-id" diff --git a/src/internal/connector/graph/errors.go b/src/internal/connector/graph/errors.go index 9f83a1c50..527465621 100644 --- a/src/internal/connector/graph/errors.go +++ b/src/internal/connector/graph/errors.go @@ -43,6 +43,12 @@ const ( syncStateNotFound errorCode = "SyncStateNotFound" ) +type errorMessage string + +const ( + IOErrDuringRead errorMessage = "IO error during request payload read" +) + const ( mysiteURLNotFound = "unable to retrieve user's mysite url" mysiteNotFound = "user's mysite not found" @@ -241,6 +247,26 @@ func Stack(ctx context.Context, e error) *clues.Err { return setLabels(clues.Stack(e).WithClues(ctx).With(data...), innerMsg) } +// stackReq is a helper function that extracts ODataError metadata from +// the error, plus http req/resp data. If the error is not an ODataError +// type, returns the error with only the req/resp values. +func stackReq( + ctx context.Context, + req *http.Request, + resp *http.Response, + e error, +) *clues.Err { + if e == nil { + return nil + } + + se := Stack(ctx, e). + WithMap(reqData(req)). + WithMap(respData(resp)) + + return se +} + // Checks for the following conditions and labels the error accordingly: // * mysiteNotFound | mysiteURLNotFound // * malware @@ -290,6 +316,34 @@ func errData(err odataerrors.ODataErrorable) (string, []any, string) { return mainMsg, data, strings.ToLower(msgConcat) } +func reqData(req *http.Request) map[string]any { + if req == nil { + return nil + } + + r := map[string]any{} + r["req_method"] = req.Method + r["req_len"] = req.ContentLength + + if req.URL != nil { + r["req_url"] = LoggableURL(req.URL.String()) + } + + return r +} + +func respData(resp *http.Response) map[string]any { + if resp == nil { + return nil + } + + r := map[string]any{} + r["resp_status"] = resp.Status + r["resp_len"] = resp.ContentLength + + return r +} + func appendIf(a []any, k string, v *string) []any { if v == nil { return a diff --git a/src/internal/connector/graph/http_wrapper.go b/src/internal/connector/graph/http_wrapper.go index 1410fb194..bc469c5f2 100644 --- a/src/internal/connector/graph/http_wrapper.go +++ b/src/internal/connector/graph/http_wrapper.go @@ -141,7 +141,7 @@ func defaultTransport() http.RoundTripper { func internalMiddleware(cc *clientConfig) []khttp.Middleware { return []khttp.Middleware{ - &RetryHandler{ + &RetryMiddleware{ MaxRetries: cc.maxRetries, Delay: cc.minDelay, }, diff --git a/src/internal/connector/graph/middleware.go b/src/internal/connector/graph/middleware.go index 57825c38f..4bd914dbf 100644 --- a/src/internal/connector/graph/middleware.go +++ b/src/internal/connector/graph/middleware.go @@ -13,6 +13,7 @@ import ( "github.com/alcionai/clues" backoff "github.com/cenkalti/backoff/v4" khttp "github.com/microsoft/kiota-http-go" + "golang.org/x/exp/slices" "golang.org/x/time/rate" "github.com/alcionai/corso/src/internal/common/pii" @@ -98,7 +99,7 @@ func LoggableURL(url string) pii.SafeURL { } } -func (handler *LoggingMiddleware) Intercept( +func (mw *LoggingMiddleware) Intercept( pipeline khttp.Pipeline, middlewareIndex int, req *http.Request, @@ -173,15 +174,49 @@ func getRespDump(ctx context.Context, resp *http.Response, getBody bool) string // Retry & Backoff // --------------------------------------------------------------------------- -// RetryHandler handles transient HTTP responses and retries the request given the retry options -type RetryHandler struct { +// RetryMiddleware handles transient HTTP responses and retries the request given the retry options +type RetryMiddleware struct { // The maximum number of times a request can be retried MaxRetries int // The delay in seconds between retries Delay time.Duration } -func (middleware RetryHandler) retryRequest( +// Intercept implements the interface and evaluates whether to retry a failed request. +func (mw RetryMiddleware) Intercept( + pipeline khttp.Pipeline, + middlewareIndex int, + req *http.Request, +) (*http.Response, error) { + ctx := req.Context() + + resp, err := pipeline.Next(req, middlewareIndex) + if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) { + return resp, stackReq(ctx, req, resp, err) + } + + exponentialBackOff := backoff.NewExponentialBackOff() + exponentialBackOff.InitialInterval = mw.Delay + exponentialBackOff.Reset() + + resp, err = mw.retryRequest( + ctx, + pipeline, + middlewareIndex, + req, + resp, + 0, + 0, + exponentialBackOff, + err) + if err != nil { + return nil, stackReq(ctx, req, resp, err) + } + + return resp, nil +} + +func (mw RetryMiddleware) retryRequest( ctx context.Context, pipeline khttp.Pipeline, middlewareIndex int, @@ -190,14 +225,23 @@ func (middleware RetryHandler) retryRequest( executionCount int, cumulativeDelay time.Duration, exponentialBackoff *backoff.ExponentialBackOff, - respErr error, + priorErr error, ) (*http.Response, error) { - if (respErr != nil || middleware.isRetriableErrorCode(req, resp.StatusCode)) && - middleware.isRetriableRequest(req) && - executionCount < middleware.MaxRetries { + ctx = clues.Add( + ctx, + "retry_count", executionCount, + "prev_resp_status", resp.Status) + + // only retry under certain conditions: + // 1, there was an error. 2, the resp and/or status code match retriable conditions. + // 3, the request is retriable. + // 4, we haven't hit our max retries already. + if (priorErr != nil || mw.isRetriableRespCode(ctx, resp, resp.StatusCode)) && + mw.isRetriableRequest(req) && + executionCount < mw.MaxRetries { executionCount++ - delay := middleware.getRetryDelay(req, resp, exponentialBackoff) + delay := mw.getRetryDelay(req, resp, exponentialBackoff) cumulativeDelay += delay @@ -209,19 +253,17 @@ func (middleware RetryHandler) retryRequest( case <-ctx.Done(): // Don't retry if the context is marked as done, it will just error out // when we attempt to send the retry anyway. - return resp, ctx.Err() + return resp, clues.Stack(ctx.Err()).WithClues(ctx) - // Will exit switch-block so the remainder of the code doesn't need to be - // indented. case <-timer.C: } response, err := pipeline.Next(req, middlewareIndex) if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) { - return response, Stack(ctx, err).With("retry_count", executionCount) + return response, stackReq(ctx, req, response, err) } - return middleware.retryRequest(ctx, + return mw.retryRequest(ctx, pipeline, middlewareIndex, req, @@ -232,18 +274,33 @@ func (middleware RetryHandler) retryRequest( err) } - if respErr != nil { - return nil, Stack(ctx, respErr).With("retry_count", executionCount) + if priorErr != nil { + return nil, stackReq(ctx, req, nil, priorErr) } return resp, nil } -func (middleware RetryHandler) isRetriableErrorCode(req *http.Request, code int) bool { - return code == http.StatusInternalServerError || code == http.StatusServiceUnavailable +var retryableRespCodes = []int{ + http.StatusInternalServerError, + http.StatusServiceUnavailable, + http.StatusBadGateway, + http.StatusGatewayTimeout, } -func (middleware RetryHandler) isRetriableRequest(req *http.Request) bool { +func (mw RetryMiddleware) isRetriableRespCode(ctx context.Context, resp *http.Response, code int) bool { + if slices.Contains(retryableRespCodes, code) { + return true + } + + // not a status code, but the message itself might indicate a connectivity issue that + // can be retried independent of the status code. + return strings.Contains( + strings.ToLower(getRespDump(ctx, resp, true)), + strings.ToLower(string(IOErrDuringRead))) +} + +func (mw RetryMiddleware) isRetriableRequest(req *http.Request) bool { isBodiedMethod := req.Method == "POST" || req.Method == "PUT" || req.Method == "PATCH" if isBodiedMethod && req.Body != nil { return req.ContentLength != -1 @@ -252,7 +309,7 @@ func (middleware RetryHandler) isRetriableRequest(req *http.Request) bool { return true } -func (middleware RetryHandler) getRetryDelay( +func (mw RetryMiddleware) getRetryDelay( req *http.Request, resp *http.Response, exponentialBackoff *backoff.ExponentialBackOff, @@ -272,40 +329,6 @@ func (middleware RetryHandler) getRetryDelay( return exponentialBackoff.NextBackOff() } -// Intercept implements the interface and evaluates whether to retry a failed request. -func (middleware RetryHandler) Intercept( - pipeline khttp.Pipeline, - middlewareIndex int, - req *http.Request, -) (*http.Response, error) { - ctx := req.Context() - - response, err := pipeline.Next(req, middlewareIndex) - if err != nil && !IsErrTimeout(err) { - return response, Stack(ctx, err) - } - - exponentialBackOff := backoff.NewExponentialBackOff() - exponentialBackOff.InitialInterval = middleware.Delay - exponentialBackOff.Reset() - - response, err = middleware.retryRequest( - ctx, - pipeline, - middlewareIndex, - req, - response, - 0, - 0, - exponentialBackOff, - err) - if err != nil { - return nil, Stack(ctx, err) - } - - return response, nil -} - // We're trying to keep calls below the 10k-per-10-minute threshold. // 15 tokens every second nets 900 per minute. That's 9000 every 10 minutes, // which is a bit below the mark. @@ -341,7 +364,7 @@ func QueueRequest(ctx context.Context) { // request limits. type ThrottleControlMiddleware struct{} -func (handler *ThrottleControlMiddleware) Intercept( +func (mw *ThrottleControlMiddleware) Intercept( pipeline khttp.Pipeline, middlewareIndex int, req *http.Request, @@ -353,7 +376,7 @@ func (handler *ThrottleControlMiddleware) Intercept( // MetricsMiddleware aggregates per-request metrics on the events bus type MetricsMiddleware struct{} -func (handler *MetricsMiddleware) Intercept( +func (mw *MetricsMiddleware) Intercept( pipeline khttp.Pipeline, middlewareIndex int, req *http.Request, diff --git a/src/internal/connector/graph/middleware_test.go b/src/internal/connector/graph/middleware_test.go new file mode 100644 index 000000000..3a8ec7656 --- /dev/null +++ b/src/internal/connector/graph/middleware_test.go @@ -0,0 +1,152 @@ +package graph + +import ( + "net/http" + "testing" + "time" + + "github.com/alcionai/clues" + khttp "github.com/microsoft/kiota-http-go" + msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go" + msgraphgocore "github.com/microsoftgraph/msgraph-sdk-go-core" + "github.com/microsoftgraph/msgraph-sdk-go/users" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/pkg/account" +) + +func newBodylessTestMW(onIntercept func(), code int, err error) testMW { + return testMW{ + err: err, + onIntercept: onIntercept, + resp: &http.Response{StatusCode: code}, + } +} + +type testMW struct { + err error + onIntercept func() + resp *http.Response +} + +func (mw testMW) Intercept( + pipeline khttp.Pipeline, + middlewareIndex int, + req *http.Request, +) (*http.Response, error) { + mw.onIntercept() + return mw.resp, mw.err +} + +// can't use graph/mock.CreateAdapter() due to circular references. +func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.GraphRequestAdapter, error) { + auth, err := GetAuth( + creds.AzureTenantID, + creds.AzureClientID, + creds.AzureClientSecret) + if err != nil { + return nil, err + } + + var ( + clientOptions = msgraphsdkgo.GetDefaultClientOptions() + cc = populateConfig(MinimumBackoff(10 * time.Millisecond)) + middlewares = append(kiotaMiddlewares(&clientOptions, cc), mw) + httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...) + ) + + httpClient.Timeout = 5 * time.Second + + cc.apply(httpClient) + + return msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient( + auth, + nil, nil, + httpClient) +} + +type RetryMWIntgSuite struct { + tester.Suite + creds account.M365Config +} + +// We do end up mocking the actual request, but creating the rest +// similar to E2E suite +func TestRetryMWIntgSuite(t *testing.T) { + suite.Run(t, &RetryMWIntgSuite{ + Suite: tester.NewIntegrationSuite( + t, + [][]string{tester.M365AcctCredEnvs}), + }) +} + +func (suite *RetryMWIntgSuite) SetupSuite() { + var ( + a = tester.NewM365Account(suite.T()) + err error + ) + + suite.creds, err = a.M365Config() + require.NoError(suite.T(), err, clues.ToCore(err)) +} + +func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() { + var ( + uri = "https://graph.microsoft.com" + path = "/v1.0/users/user/messages/foo" + url = uri + path + ) + + tests := []struct { + name string + status int + expectRetryCount int + mw testMW + expectErr assert.ErrorAssertionFunc + }{ + { + name: "200, no retries", + status: http.StatusOK, + expectRetryCount: 0, + expectErr: assert.NoError, + }, + { + name: "400, no retries", + status: http.StatusBadRequest, + expectRetryCount: 0, + expectErr: assert.Error, + }, + { + // don't test 504: gets intercepted by graph client for long waits. + name: "502", + status: http.StatusBadGateway, + expectRetryCount: defaultMaxRetries, + expectErr: assert.Error, + }, + } + + for _, test := range tests { + suite.Run(test.name, func() { + ctx, flush := tester.NewContext() + defer flush() + + t := suite.T() + called := 0 + mw := newBodylessTestMW(func() { called++ }, test.status, nil) + + adpt, err := mockAdapter(suite.creds, mw) + require.NoError(t, err, clues.ToCore(err)) + + // url doesn't fit the builder, but that shouldn't matter + _, err = users.NewCountRequestBuilder(url, adpt).Get(ctx, nil) + test.expectErr(t, err, clues.ToCore(err)) + + // -1 because the non-retried call always counts for one, then + // we increment based on the number of retry attempts. + assert.Equal(t, test.expectRetryCount, called-1) + }) + } +} diff --git a/src/internal/connector/graph/service.go b/src/internal/connector/graph/service.go index 42ef4440c..9db7fb825 100644 --- a/src/internal/connector/graph/service.go +++ b/src/internal/connector/graph/service.go @@ -243,7 +243,7 @@ func kiotaMiddlewares( mw = append(mw, []khttp.Middleware{ msgraphgocore.NewGraphTelemetryHandler(options), - &RetryHandler{ + &RetryMiddleware{ MaxRetries: cc.maxRetries, Delay: cc.minDelay, },