add test for 503 resp body (#3812)

#### Type of change

- [x] 🤖 Supportability/Tests

#### Issue(s)

* closes #3811

#### Test Plan

- [x]  Unit test
This commit is contained in:
Keepers 2023-07-14 17:40:30 -06:00 committed by GitHub
parent fd6dff3270
commit 09e5e9464a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 134 additions and 80 deletions

View File

@ -2,14 +2,18 @@ package graph
import (
"context"
"encoding/json"
"net/http"
"syscall"
"testing"
"github.com/alcionai/clues"
"github.com/microsoft/kiota-abstractions-go/serialization"
kjson "github.com/microsoft/kiota-serialization-json-go"
"github.com/microsoftgraph/msgraph-sdk-go/models"
"github.com/microsoftgraph/msgraph-sdk-go/models/odataerrors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/alcionai/corso/src/internal/common/ptr"
@ -44,6 +48,22 @@ func odErrMsg(code, message string) *odataerrors.ODataError {
return odErr
}
func parseableToMap(t *testing.T, thing serialization.Parsable) map[string]any {
sw := kjson.NewJsonSerializationWriter()
err := sw.WriteObjectValue("", thing)
require.NoError(t, err, "serialize")
content, err := sw.GetSerializedContent()
require.NoError(t, err, "deserialize")
var out map[string]any
err = json.Unmarshal([]byte(content), &out)
require.NoError(t, err, "unmarshall")
return out
}
func (suite *GraphErrorsUnitSuite) TestIsErrConnectionReset() {
table := []struct {
name string

View File

@ -206,18 +206,14 @@ func (mw RetryMiddleware) Intercept(
req *http.Request,
) (*http.Response, error) {
ctx := req.Context()
resp, err := pipeline.Next(req, middlewareIndex)
retriable := IsErrTimeout(err) || IsErrConnectionReset(err) ||
(resp != nil && (resp.StatusCode/100 == 4 || resp.StatusCode/100 == 5))
retriable := IsErrTimeout(err) ||
IsErrConnectionReset(err) ||
mw.isRetriableRespCode(ctx, resp)
if !retriable {
if err != nil {
return resp, stackReq(ctx, req, resp, err)
}
return resp, nil
return resp, stackReq(ctx, req, resp, err).OrNil()
}
exponentialBackOff := backoff.NewExponentialBackOff()
@ -234,11 +230,8 @@ func (mw RetryMiddleware) Intercept(
0,
exponentialBackOff,
err)
if err != nil {
return nil, stackReq(ctx, req, resp, err)
}
return resp, nil
return resp, stackReq(ctx, req, resp, err).OrNil()
}
func (mw RetryMiddleware) retryRequest(
@ -252,78 +245,70 @@ func (mw RetryMiddleware) retryRequest(
exponentialBackoff *backoff.ExponentialBackOff,
priorErr error,
) (*http.Response, error) {
status := "unknown_resp_status"
statusCode := -1
ctx = clues.Add(ctx, "retry_count", executionCount)
if resp != nil {
status = resp.Status
statusCode = resp.StatusCode
ctx = clues.Add(ctx, "prev_resp_status", resp.Status)
}
ctx = clues.Add(
ctx,
"prev_resp_status", status,
"retry_count", executionCount)
// only retry under certain conditions:
// 1, there was an error. 2, the resp and/or status code match retriable conditions.
// 3, the request is retriable.
// 4, we haven't hit our max retries already.
if (priorErr != nil || mw.isRetriableRespCode(ctx, resp, statusCode)) &&
// only retry if all the following conditions are met:
// 1, there was a prior error OR the status code match retriable conditions.
// 3, the request method is retriable.
// 4, we haven't already hit maximum retries.
shouldRetry := (priorErr != nil || mw.isRetriableRespCode(ctx, resp)) &&
mw.isRetriableRequest(req) &&
executionCount < mw.MaxRetries {
executionCount++
executionCount < mw.MaxRetries
delay := mw.getRetryDelay(req, resp, exponentialBackoff)
cumulativeDelay += delay
if !shouldRetry {
return resp, stackReq(ctx, req, resp, priorErr).OrNil()
}
req.Header.Set(retryAttemptHeader, strconv.Itoa(executionCount))
executionCount++
timer := time.NewTimer(delay)
delay := mw.getRetryDelay(req, resp, exponentialBackoff)
cumulativeDelay += delay
select {
case <-ctx.Done():
// Don't retry if the context is marked as done, it will just error out
// when we attempt to send the retry anyway.
return resp, clues.Stack(ctx.Err()).WithClues(ctx)
req.Header.Set(retryAttemptHeader, strconv.Itoa(executionCount))
case <-timer.C:
}
timer := time.NewTimer(delay)
// we have to reset the original body reader for each retry, or else the graph
// compressor will produce a 0 length body following an error response such
// as a 500.
if req.Body != nil {
if s, ok := req.Body.(io.Seeker); ok {
_, err := s.Seek(0, io.SeekStart)
if err != nil {
return nil, Wrap(ctx, err, "resetting request body reader")
}
select {
case <-ctx.Done():
// Don't retry if the context is marked as done, it will just error out
// when we attempt to send the retry anyway.
return resp, clues.Stack(ctx.Err()).WithClues(ctx)
case <-timer.C:
}
// we have to reset the original body reader for each retry, or else the graph
// compressor will produce a 0 length body following an error response such
// as a 500.
if req.Body != nil {
if s, ok := req.Body.(io.Seeker); ok {
if _, err := s.Seek(0, io.SeekStart); err != nil {
return resp, Wrap(ctx, err, "resetting request body reader")
}
} else {
logger.Ctx(ctx).Error("body is not an io.Seeker: unable to reset request body")
}
nextResp, err := pipeline.Next(req, middlewareIndex)
if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) {
return nextResp, stackReq(ctx, req, nextResp, err)
}
return mw.retryRequest(
ctx,
pipeline,
middlewareIndex,
req,
nextResp,
executionCount,
cumulativeDelay,
exponentialBackoff,
err)
}
if priorErr != nil {
return nil, stackReq(ctx, req, nil, priorErr)
nextResp, err := pipeline.Next(req, middlewareIndex)
if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) {
return nextResp, stackReq(ctx, req, nextResp, err)
}
return resp, nil
return mw.retryRequest(
ctx,
pipeline,
middlewareIndex,
req,
nextResp,
executionCount,
cumulativeDelay,
exponentialBackoff,
err)
}
var retryableRespCodes = []int{
@ -331,14 +316,18 @@ var retryableRespCodes = []int{
http.StatusBadGateway,
}
func (mw RetryMiddleware) isRetriableRespCode(ctx context.Context, resp *http.Response, code int) bool {
if slices.Contains(retryableRespCodes, code) {
func (mw RetryMiddleware) isRetriableRespCode(ctx context.Context, resp *http.Response) bool {
if resp == nil {
return false
}
if slices.Contains(retryableRespCodes, resp.StatusCode) {
return true
}
// prevent the body dump below in case of a 2xx response.
// There's no reason to check the body on a healthy status.
if code/100 != 4 && code/100 != 5 {
if resp.StatusCode/100 != 4 && resp.StatusCode/100 != 5 {
return false
}

View File

@ -2,6 +2,7 @@ package graph
import (
"bytes"
"encoding/json"
"io"
"net/http"
"syscall"
@ -80,7 +81,10 @@ func (mw *testMW) Intercept(
i = 0
}
// panic on out-of-bounds intentionally not protected
if i >= len(mw.toReturn) {
panic(clues.New("middleware test had more calls than responses"))
}
tr := mw.toReturn[i]
mw.iter++
@ -89,7 +93,11 @@ func (mw *testMW) Intercept(
}
// can't use graph/mock.CreateAdapter() due to circular references.
func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.GraphRequestAdapter, error) {
func mockAdapter(
creds account.M365Config,
mw khttp.Middleware,
timeout time.Duration,
) (*msgraphsdkgo.GraphRequestAdapter, error) {
auth, err := GetAuth(
creds.AzureTenantID,
creds.AzureClientID,
@ -105,7 +113,7 @@ func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.G
httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
)
httpClient.Timeout = 15 * time.Second
httpClient.Timeout = timeout
cc.apply(httpClient)
@ -229,7 +237,7 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() {
newMWReturns(test.status, nil, test.providedErr))
mw.repeatReturn0 = true
adpt, err := mockAdapter(suite.creds, mw)
adpt, err := mockAdapter(suite.creds, mw, 15*time.Second)
require.NoError(t, err, clues.ToCore(err))
// url doesn't fit the builder, but that shouldn't matter
@ -273,7 +281,7 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter50
newMWReturns(http.StatusInternalServerError, nil, nil),
newMWReturns(http.StatusOK, nil, nil))
adpt, err := mockAdapter(suite.creds, mw)
adpt, err := mockAdapter(suite.creds, mw, 15*time.Second)
require.NoError(t, err, clues.ToCore(err))
// no api package needed here, this is a mocked request that works
@ -287,6 +295,45 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter50
require.NoError(t, err, clues.ToCore(err))
}
func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryResponse_maintainBodyAfter503() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
InitializeConcurrencyLimiter(ctx, false, -1)
odem := odErrMsg("SystemDown", "The System, Is Down, bah-dup-da-woo-woo!")
m := parseableToMap(t, odem)
body, err := json.Marshal(m)
require.NoError(t, err, clues.ToCore(err))
mw := newTestMW(
// intentional no-op, just need to conrol the response code
func(*http.Request) {},
newMWReturns(http.StatusServiceUnavailable, body, nil),
newMWReturns(http.StatusServiceUnavailable, body, nil),
newMWReturns(http.StatusServiceUnavailable, body, nil),
newMWReturns(http.StatusServiceUnavailable, body, nil))
adpt, err := mockAdapter(suite.creds, mw, 55*time.Second)
require.NoError(t, err, clues.ToCore(err))
// no api package needed here,
// this is a mocked request that works
// independent of the query.
_, err = NewService(adpt).
Client().
Users().
ByUserId("user").
MailFolders().
Post(ctx, models.NewMailFolder(), nil)
require.Error(t, err, clues.ToCore(err))
require.NotContains(t, err.Error(), "content is empty", clues.ToCore(err))
require.Contains(t, err.Error(), "503", clues.ToCore(err))
}
type MiddlewareUnitSuite struct {
tester.Suite
}

View File

@ -244,9 +244,7 @@ func kiotaMiddlewares(
options *msgraphgocore.GraphClientOptions,
cc *clientConfig,
) []khttp.Middleware {
mw := []khttp.Middleware{}
mw = append(mw, []khttp.Middleware{
mw := []khttp.Middleware{
msgraphgocore.NewGraphTelemetryHandler(options),
&RetryMiddleware{
MaxRetries: cc.maxRetries,
@ -258,7 +256,7 @@ func kiotaMiddlewares(
khttp.NewParametersNameDecodingHandler(),
khttp.NewUserAgentHandler(),
&LoggingMiddleware{},
}...)
}
// Optionally add concurrency limiter middleware if it has been initialized.
if concurrencyLimitMiddlewareSingleton != nil {

View File

@ -60,7 +60,7 @@ func parseableToMap(t *testing.T, thing serialization.Parsable) map[string]any {
require.NoError(t, err, "serialize")
content, err := sw.GetSerializedContent()
require.NoError(t, err, "serialize")
require.NoError(t, err, "deserialize")
var out map[string]any
err = json.Unmarshal([]byte(content), &out)