diff --git a/src/internal/m365/graph/errors_test.go b/src/internal/m365/graph/errors_test.go index d419d5641..81d9787bb 100644 --- a/src/internal/m365/graph/errors_test.go +++ b/src/internal/m365/graph/errors_test.go @@ -2,14 +2,18 @@ package graph import ( "context" + "encoding/json" "net/http" "syscall" "testing" "github.com/alcionai/clues" + "github.com/microsoft/kiota-abstractions-go/serialization" + kjson "github.com/microsoft/kiota-serialization-json-go" "github.com/microsoftgraph/msgraph-sdk-go/models" "github.com/microsoftgraph/msgraph-sdk-go/models/odataerrors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/alcionai/corso/src/internal/common/ptr" @@ -44,6 +48,22 @@ func odErrMsg(code, message string) *odataerrors.ODataError { return odErr } +func parseableToMap(t *testing.T, thing serialization.Parsable) map[string]any { + sw := kjson.NewJsonSerializationWriter() + + err := sw.WriteObjectValue("", thing) + require.NoError(t, err, "serialize") + + content, err := sw.GetSerializedContent() + require.NoError(t, err, "deserialize") + + var out map[string]any + err = json.Unmarshal([]byte(content), &out) + require.NoError(t, err, "unmarshall") + + return out +} + func (suite *GraphErrorsUnitSuite) TestIsErrConnectionReset() { table := []struct { name string diff --git a/src/internal/m365/graph/middleware.go b/src/internal/m365/graph/middleware.go index 80d74144d..456fb402b 100644 --- a/src/internal/m365/graph/middleware.go +++ b/src/internal/m365/graph/middleware.go @@ -206,18 +206,14 @@ func (mw RetryMiddleware) Intercept( req *http.Request, ) (*http.Response, error) { ctx := req.Context() - resp, err := pipeline.Next(req, middlewareIndex) - retriable := IsErrTimeout(err) || IsErrConnectionReset(err) || - (resp != nil && (resp.StatusCode/100 == 4 || resp.StatusCode/100 == 5)) + retriable := IsErrTimeout(err) || + IsErrConnectionReset(err) || + mw.isRetriableRespCode(ctx, resp) if !retriable { - if err != nil { - return resp, stackReq(ctx, req, resp, err) - } - - return resp, nil + return resp, stackReq(ctx, req, resp, err).OrNil() } exponentialBackOff := backoff.NewExponentialBackOff() @@ -234,11 +230,8 @@ func (mw RetryMiddleware) Intercept( 0, exponentialBackOff, err) - if err != nil { - return nil, stackReq(ctx, req, resp, err) - } - return resp, nil + return resp, stackReq(ctx, req, resp, err).OrNil() } func (mw RetryMiddleware) retryRequest( @@ -252,78 +245,70 @@ func (mw RetryMiddleware) retryRequest( exponentialBackoff *backoff.ExponentialBackOff, priorErr error, ) (*http.Response, error) { - status := "unknown_resp_status" - statusCode := -1 + ctx = clues.Add(ctx, "retry_count", executionCount) if resp != nil { - status = resp.Status - statusCode = resp.StatusCode + ctx = clues.Add(ctx, "prev_resp_status", resp.Status) } - ctx = clues.Add( - ctx, - "prev_resp_status", status, - "retry_count", executionCount) - - // only retry under certain conditions: - // 1, there was an error. 2, the resp and/or status code match retriable conditions. - // 3, the request is retriable. - // 4, we haven't hit our max retries already. - if (priorErr != nil || mw.isRetriableRespCode(ctx, resp, statusCode)) && + // only retry if all the following conditions are met: + // 1, there was a prior error OR the status code match retriable conditions. + // 3, the request method is retriable. + // 4, we haven't already hit maximum retries. + shouldRetry := (priorErr != nil || mw.isRetriableRespCode(ctx, resp)) && mw.isRetriableRequest(req) && - executionCount < mw.MaxRetries { - executionCount++ + executionCount < mw.MaxRetries - delay := mw.getRetryDelay(req, resp, exponentialBackoff) - cumulativeDelay += delay + if !shouldRetry { + return resp, stackReq(ctx, req, resp, priorErr).OrNil() + } - req.Header.Set(retryAttemptHeader, strconv.Itoa(executionCount)) + executionCount++ - timer := time.NewTimer(delay) + delay := mw.getRetryDelay(req, resp, exponentialBackoff) + cumulativeDelay += delay - select { - case <-ctx.Done(): - // Don't retry if the context is marked as done, it will just error out - // when we attempt to send the retry anyway. - return resp, clues.Stack(ctx.Err()).WithClues(ctx) + req.Header.Set(retryAttemptHeader, strconv.Itoa(executionCount)) - case <-timer.C: - } + timer := time.NewTimer(delay) - // we have to reset the original body reader for each retry, or else the graph - // compressor will produce a 0 length body following an error response such - // as a 500. - if req.Body != nil { - if s, ok := req.Body.(io.Seeker); ok { - _, err := s.Seek(0, io.SeekStart) - if err != nil { - return nil, Wrap(ctx, err, "resetting request body reader") - } + select { + case <-ctx.Done(): + // Don't retry if the context is marked as done, it will just error out + // when we attempt to send the retry anyway. + return resp, clues.Stack(ctx.Err()).WithClues(ctx) + + case <-timer.C: + } + + // we have to reset the original body reader for each retry, or else the graph + // compressor will produce a 0 length body following an error response such + // as a 500. + if req.Body != nil { + if s, ok := req.Body.(io.Seeker); ok { + if _, err := s.Seek(0, io.SeekStart); err != nil { + return resp, Wrap(ctx, err, "resetting request body reader") } + } else { + logger.Ctx(ctx).Error("body is not an io.Seeker: unable to reset request body") } - - nextResp, err := pipeline.Next(req, middlewareIndex) - if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) { - return nextResp, stackReq(ctx, req, nextResp, err) - } - - return mw.retryRequest( - ctx, - pipeline, - middlewareIndex, - req, - nextResp, - executionCount, - cumulativeDelay, - exponentialBackoff, - err) } - if priorErr != nil { - return nil, stackReq(ctx, req, nil, priorErr) + nextResp, err := pipeline.Next(req, middlewareIndex) + if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) { + return nextResp, stackReq(ctx, req, nextResp, err) } - return resp, nil + return mw.retryRequest( + ctx, + pipeline, + middlewareIndex, + req, + nextResp, + executionCount, + cumulativeDelay, + exponentialBackoff, + err) } var retryableRespCodes = []int{ @@ -331,14 +316,18 @@ var retryableRespCodes = []int{ http.StatusBadGateway, } -func (mw RetryMiddleware) isRetriableRespCode(ctx context.Context, resp *http.Response, code int) bool { - if slices.Contains(retryableRespCodes, code) { +func (mw RetryMiddleware) isRetriableRespCode(ctx context.Context, resp *http.Response) bool { + if resp == nil { + return false + } + + if slices.Contains(retryableRespCodes, resp.StatusCode) { return true } // prevent the body dump below in case of a 2xx response. // There's no reason to check the body on a healthy status. - if code/100 != 4 && code/100 != 5 { + if resp.StatusCode/100 != 4 && resp.StatusCode/100 != 5 { return false } diff --git a/src/internal/m365/graph/middleware_test.go b/src/internal/m365/graph/middleware_test.go index 921f3b4df..0a5dc8eb5 100644 --- a/src/internal/m365/graph/middleware_test.go +++ b/src/internal/m365/graph/middleware_test.go @@ -2,6 +2,7 @@ package graph import ( "bytes" + "encoding/json" "io" "net/http" "syscall" @@ -80,7 +81,10 @@ func (mw *testMW) Intercept( i = 0 } - // panic on out-of-bounds intentionally not protected + if i >= len(mw.toReturn) { + panic(clues.New("middleware test had more calls than responses")) + } + tr := mw.toReturn[i] mw.iter++ @@ -89,7 +93,11 @@ func (mw *testMW) Intercept( } // can't use graph/mock.CreateAdapter() due to circular references. -func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.GraphRequestAdapter, error) { +func mockAdapter( + creds account.M365Config, + mw khttp.Middleware, + timeout time.Duration, +) (*msgraphsdkgo.GraphRequestAdapter, error) { auth, err := GetAuth( creds.AzureTenantID, creds.AzureClientID, @@ -105,7 +113,7 @@ func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.G httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...) ) - httpClient.Timeout = 15 * time.Second + httpClient.Timeout = timeout cc.apply(httpClient) @@ -229,7 +237,7 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() { newMWReturns(test.status, nil, test.providedErr)) mw.repeatReturn0 = true - adpt, err := mockAdapter(suite.creds, mw) + adpt, err := mockAdapter(suite.creds, mw, 15*time.Second) require.NoError(t, err, clues.ToCore(err)) // url doesn't fit the builder, but that shouldn't matter @@ -273,7 +281,7 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter50 newMWReturns(http.StatusInternalServerError, nil, nil), newMWReturns(http.StatusOK, nil, nil)) - adpt, err := mockAdapter(suite.creds, mw) + adpt, err := mockAdapter(suite.creds, mw, 15*time.Second) require.NoError(t, err, clues.ToCore(err)) // no api package needed here, this is a mocked request that works @@ -287,6 +295,45 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter50 require.NoError(t, err, clues.ToCore(err)) } +func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryResponse_maintainBodyAfter503() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + InitializeConcurrencyLimiter(ctx, false, -1) + + odem := odErrMsg("SystemDown", "The System, Is Down, bah-dup-da-woo-woo!") + m := parseableToMap(t, odem) + + body, err := json.Marshal(m) + require.NoError(t, err, clues.ToCore(err)) + + mw := newTestMW( + // intentional no-op, just need to conrol the response code + func(*http.Request) {}, + newMWReturns(http.StatusServiceUnavailable, body, nil), + newMWReturns(http.StatusServiceUnavailable, body, nil), + newMWReturns(http.StatusServiceUnavailable, body, nil), + newMWReturns(http.StatusServiceUnavailable, body, nil)) + + adpt, err := mockAdapter(suite.creds, mw, 55*time.Second) + require.NoError(t, err, clues.ToCore(err)) + + // no api package needed here, + // this is a mocked request that works + // independent of the query. + _, err = NewService(adpt). + Client(). + Users(). + ByUserId("user"). + MailFolders(). + Post(ctx, models.NewMailFolder(), nil) + require.Error(t, err, clues.ToCore(err)) + require.NotContains(t, err.Error(), "content is empty", clues.ToCore(err)) + require.Contains(t, err.Error(), "503", clues.ToCore(err)) +} + type MiddlewareUnitSuite struct { tester.Suite } diff --git a/src/internal/m365/graph/service.go b/src/internal/m365/graph/service.go index 379b969a0..5c94431c9 100644 --- a/src/internal/m365/graph/service.go +++ b/src/internal/m365/graph/service.go @@ -244,9 +244,7 @@ func kiotaMiddlewares( options *msgraphgocore.GraphClientOptions, cc *clientConfig, ) []khttp.Middleware { - mw := []khttp.Middleware{} - - mw = append(mw, []khttp.Middleware{ + mw := []khttp.Middleware{ msgraphgocore.NewGraphTelemetryHandler(options), &RetryMiddleware{ MaxRetries: cc.maxRetries, @@ -258,7 +256,7 @@ func kiotaMiddlewares( khttp.NewParametersNameDecodingHandler(), khttp.NewUserAgentHandler(), &LoggingMiddleware{}, - }...) + } // Optionally add concurrency limiter middleware if it has been initialized. if concurrencyLimitMiddlewareSingleton != nil { diff --git a/src/pkg/services/m365/api/helper_test.go b/src/pkg/services/m365/api/helper_test.go index ff25677c8..05e16b00e 100644 --- a/src/pkg/services/m365/api/helper_test.go +++ b/src/pkg/services/m365/api/helper_test.go @@ -60,7 +60,7 @@ func parseableToMap(t *testing.T, thing serialization.Parsable) map[string]any { require.NoError(t, err, "serialize") content, err := sw.GetSerializedContent() - require.NoError(t, err, "serialize") + require.NoError(t, err, "deserialize") var out map[string]any err = json.Unmarshal([]byte(content), &out)