add test for 503 resp body (#3812)
#### Type of change - [x] 🤖 Supportability/Tests #### Issue(s) * closes #3811 #### Test Plan - [x] ⚡ Unit test
This commit is contained in:
parent
fd6dff3270
commit
09e5e9464a
@ -2,14 +2,18 @@ package graph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/alcionai/clues"
|
||||
"github.com/microsoft/kiota-abstractions-go/serialization"
|
||||
kjson "github.com/microsoft/kiota-serialization-json-go"
|
||||
"github.com/microsoftgraph/msgraph-sdk-go/models"
|
||||
"github.com/microsoftgraph/msgraph-sdk-go/models/odataerrors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/alcionai/corso/src/internal/common/ptr"
|
||||
@ -44,6 +48,22 @@ func odErrMsg(code, message string) *odataerrors.ODataError {
|
||||
return odErr
|
||||
}
|
||||
|
||||
func parseableToMap(t *testing.T, thing serialization.Parsable) map[string]any {
|
||||
sw := kjson.NewJsonSerializationWriter()
|
||||
|
||||
err := sw.WriteObjectValue("", thing)
|
||||
require.NoError(t, err, "serialize")
|
||||
|
||||
content, err := sw.GetSerializedContent()
|
||||
require.NoError(t, err, "deserialize")
|
||||
|
||||
var out map[string]any
|
||||
err = json.Unmarshal([]byte(content), &out)
|
||||
require.NoError(t, err, "unmarshall")
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (suite *GraphErrorsUnitSuite) TestIsErrConnectionReset() {
|
||||
table := []struct {
|
||||
name string
|
||||
|
||||
@ -206,18 +206,14 @@ func (mw RetryMiddleware) Intercept(
|
||||
req *http.Request,
|
||||
) (*http.Response, error) {
|
||||
ctx := req.Context()
|
||||
|
||||
resp, err := pipeline.Next(req, middlewareIndex)
|
||||
|
||||
retriable := IsErrTimeout(err) || IsErrConnectionReset(err) ||
|
||||
(resp != nil && (resp.StatusCode/100 == 4 || resp.StatusCode/100 == 5))
|
||||
retriable := IsErrTimeout(err) ||
|
||||
IsErrConnectionReset(err) ||
|
||||
mw.isRetriableRespCode(ctx, resp)
|
||||
|
||||
if !retriable {
|
||||
if err != nil {
|
||||
return resp, stackReq(ctx, req, resp, err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return resp, stackReq(ctx, req, resp, err).OrNil()
|
||||
}
|
||||
|
||||
exponentialBackOff := backoff.NewExponentialBackOff()
|
||||
@ -234,11 +230,8 @@ func (mw RetryMiddleware) Intercept(
|
||||
0,
|
||||
exponentialBackOff,
|
||||
err)
|
||||
if err != nil {
|
||||
return nil, stackReq(ctx, req, resp, err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return resp, stackReq(ctx, req, resp, err).OrNil()
|
||||
}
|
||||
|
||||
func (mw RetryMiddleware) retryRequest(
|
||||
@ -252,78 +245,70 @@ func (mw RetryMiddleware) retryRequest(
|
||||
exponentialBackoff *backoff.ExponentialBackOff,
|
||||
priorErr error,
|
||||
) (*http.Response, error) {
|
||||
status := "unknown_resp_status"
|
||||
statusCode := -1
|
||||
ctx = clues.Add(ctx, "retry_count", executionCount)
|
||||
|
||||
if resp != nil {
|
||||
status = resp.Status
|
||||
statusCode = resp.StatusCode
|
||||
ctx = clues.Add(ctx, "prev_resp_status", resp.Status)
|
||||
}
|
||||
|
||||
ctx = clues.Add(
|
||||
ctx,
|
||||
"prev_resp_status", status,
|
||||
"retry_count", executionCount)
|
||||
|
||||
// only retry under certain conditions:
|
||||
// 1, there was an error. 2, the resp and/or status code match retriable conditions.
|
||||
// 3, the request is retriable.
|
||||
// 4, we haven't hit our max retries already.
|
||||
if (priorErr != nil || mw.isRetriableRespCode(ctx, resp, statusCode)) &&
|
||||
// only retry if all the following conditions are met:
|
||||
// 1, there was a prior error OR the status code match retriable conditions.
|
||||
// 3, the request method is retriable.
|
||||
// 4, we haven't already hit maximum retries.
|
||||
shouldRetry := (priorErr != nil || mw.isRetriableRespCode(ctx, resp)) &&
|
||||
mw.isRetriableRequest(req) &&
|
||||
executionCount < mw.MaxRetries {
|
||||
executionCount++
|
||||
executionCount < mw.MaxRetries
|
||||
|
||||
delay := mw.getRetryDelay(req, resp, exponentialBackoff)
|
||||
cumulativeDelay += delay
|
||||
if !shouldRetry {
|
||||
return resp, stackReq(ctx, req, resp, priorErr).OrNil()
|
||||
}
|
||||
|
||||
req.Header.Set(retryAttemptHeader, strconv.Itoa(executionCount))
|
||||
executionCount++
|
||||
|
||||
timer := time.NewTimer(delay)
|
||||
delay := mw.getRetryDelay(req, resp, exponentialBackoff)
|
||||
cumulativeDelay += delay
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Don't retry if the context is marked as done, it will just error out
|
||||
// when we attempt to send the retry anyway.
|
||||
return resp, clues.Stack(ctx.Err()).WithClues(ctx)
|
||||
req.Header.Set(retryAttemptHeader, strconv.Itoa(executionCount))
|
||||
|
||||
case <-timer.C:
|
||||
}
|
||||
timer := time.NewTimer(delay)
|
||||
|
||||
// we have to reset the original body reader for each retry, or else the graph
|
||||
// compressor will produce a 0 length body following an error response such
|
||||
// as a 500.
|
||||
if req.Body != nil {
|
||||
if s, ok := req.Body.(io.Seeker); ok {
|
||||
_, err := s.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return nil, Wrap(ctx, err, "resetting request body reader")
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Don't retry if the context is marked as done, it will just error out
|
||||
// when we attempt to send the retry anyway.
|
||||
return resp, clues.Stack(ctx.Err()).WithClues(ctx)
|
||||
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
// we have to reset the original body reader for each retry, or else the graph
|
||||
// compressor will produce a 0 length body following an error response such
|
||||
// as a 500.
|
||||
if req.Body != nil {
|
||||
if s, ok := req.Body.(io.Seeker); ok {
|
||||
if _, err := s.Seek(0, io.SeekStart); err != nil {
|
||||
return resp, Wrap(ctx, err, "resetting request body reader")
|
||||
}
|
||||
} else {
|
||||
logger.Ctx(ctx).Error("body is not an io.Seeker: unable to reset request body")
|
||||
}
|
||||
|
||||
nextResp, err := pipeline.Next(req, middlewareIndex)
|
||||
if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) {
|
||||
return nextResp, stackReq(ctx, req, nextResp, err)
|
||||
}
|
||||
|
||||
return mw.retryRequest(
|
||||
ctx,
|
||||
pipeline,
|
||||
middlewareIndex,
|
||||
req,
|
||||
nextResp,
|
||||
executionCount,
|
||||
cumulativeDelay,
|
||||
exponentialBackoff,
|
||||
err)
|
||||
}
|
||||
|
||||
if priorErr != nil {
|
||||
return nil, stackReq(ctx, req, nil, priorErr)
|
||||
nextResp, err := pipeline.Next(req, middlewareIndex)
|
||||
if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) {
|
||||
return nextResp, stackReq(ctx, req, nextResp, err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return mw.retryRequest(
|
||||
ctx,
|
||||
pipeline,
|
||||
middlewareIndex,
|
||||
req,
|
||||
nextResp,
|
||||
executionCount,
|
||||
cumulativeDelay,
|
||||
exponentialBackoff,
|
||||
err)
|
||||
}
|
||||
|
||||
var retryableRespCodes = []int{
|
||||
@ -331,14 +316,18 @@ var retryableRespCodes = []int{
|
||||
http.StatusBadGateway,
|
||||
}
|
||||
|
||||
func (mw RetryMiddleware) isRetriableRespCode(ctx context.Context, resp *http.Response, code int) bool {
|
||||
if slices.Contains(retryableRespCodes, code) {
|
||||
func (mw RetryMiddleware) isRetriableRespCode(ctx context.Context, resp *http.Response) bool {
|
||||
if resp == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if slices.Contains(retryableRespCodes, resp.StatusCode) {
|
||||
return true
|
||||
}
|
||||
|
||||
// prevent the body dump below in case of a 2xx response.
|
||||
// There's no reason to check the body on a healthy status.
|
||||
if code/100 != 4 && code/100 != 5 {
|
||||
if resp.StatusCode/100 != 4 && resp.StatusCode/100 != 5 {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ package graph
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"syscall"
|
||||
@ -80,7 +81,10 @@ func (mw *testMW) Intercept(
|
||||
i = 0
|
||||
}
|
||||
|
||||
// panic on out-of-bounds intentionally not protected
|
||||
if i >= len(mw.toReturn) {
|
||||
panic(clues.New("middleware test had more calls than responses"))
|
||||
}
|
||||
|
||||
tr := mw.toReturn[i]
|
||||
|
||||
mw.iter++
|
||||
@ -89,7 +93,11 @@ func (mw *testMW) Intercept(
|
||||
}
|
||||
|
||||
// can't use graph/mock.CreateAdapter() due to circular references.
|
||||
func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.GraphRequestAdapter, error) {
|
||||
func mockAdapter(
|
||||
creds account.M365Config,
|
||||
mw khttp.Middleware,
|
||||
timeout time.Duration,
|
||||
) (*msgraphsdkgo.GraphRequestAdapter, error) {
|
||||
auth, err := GetAuth(
|
||||
creds.AzureTenantID,
|
||||
creds.AzureClientID,
|
||||
@ -105,7 +113,7 @@ func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.G
|
||||
httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
|
||||
)
|
||||
|
||||
httpClient.Timeout = 15 * time.Second
|
||||
httpClient.Timeout = timeout
|
||||
|
||||
cc.apply(httpClient)
|
||||
|
||||
@ -229,7 +237,7 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() {
|
||||
newMWReturns(test.status, nil, test.providedErr))
|
||||
mw.repeatReturn0 = true
|
||||
|
||||
adpt, err := mockAdapter(suite.creds, mw)
|
||||
adpt, err := mockAdapter(suite.creds, mw, 15*time.Second)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
// url doesn't fit the builder, but that shouldn't matter
|
||||
@ -273,7 +281,7 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter50
|
||||
newMWReturns(http.StatusInternalServerError, nil, nil),
|
||||
newMWReturns(http.StatusOK, nil, nil))
|
||||
|
||||
adpt, err := mockAdapter(suite.creds, mw)
|
||||
adpt, err := mockAdapter(suite.creds, mw, 15*time.Second)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
// no api package needed here, this is a mocked request that works
|
||||
@ -287,6 +295,45 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter50
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
}
|
||||
|
||||
func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryResponse_maintainBodyAfter503() {
|
||||
t := suite.T()
|
||||
|
||||
ctx, flush := tester.NewContext(t)
|
||||
defer flush()
|
||||
|
||||
InitializeConcurrencyLimiter(ctx, false, -1)
|
||||
|
||||
odem := odErrMsg("SystemDown", "The System, Is Down, bah-dup-da-woo-woo!")
|
||||
m := parseableToMap(t, odem)
|
||||
|
||||
body, err := json.Marshal(m)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
mw := newTestMW(
|
||||
// intentional no-op, just need to conrol the response code
|
||||
func(*http.Request) {},
|
||||
newMWReturns(http.StatusServiceUnavailable, body, nil),
|
||||
newMWReturns(http.StatusServiceUnavailable, body, nil),
|
||||
newMWReturns(http.StatusServiceUnavailable, body, nil),
|
||||
newMWReturns(http.StatusServiceUnavailable, body, nil))
|
||||
|
||||
adpt, err := mockAdapter(suite.creds, mw, 55*time.Second)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
// no api package needed here,
|
||||
// this is a mocked request that works
|
||||
// independent of the query.
|
||||
_, err = NewService(adpt).
|
||||
Client().
|
||||
Users().
|
||||
ByUserId("user").
|
||||
MailFolders().
|
||||
Post(ctx, models.NewMailFolder(), nil)
|
||||
require.Error(t, err, clues.ToCore(err))
|
||||
require.NotContains(t, err.Error(), "content is empty", clues.ToCore(err))
|
||||
require.Contains(t, err.Error(), "503", clues.ToCore(err))
|
||||
}
|
||||
|
||||
type MiddlewareUnitSuite struct {
|
||||
tester.Suite
|
||||
}
|
||||
|
||||
@ -244,9 +244,7 @@ func kiotaMiddlewares(
|
||||
options *msgraphgocore.GraphClientOptions,
|
||||
cc *clientConfig,
|
||||
) []khttp.Middleware {
|
||||
mw := []khttp.Middleware{}
|
||||
|
||||
mw = append(mw, []khttp.Middleware{
|
||||
mw := []khttp.Middleware{
|
||||
msgraphgocore.NewGraphTelemetryHandler(options),
|
||||
&RetryMiddleware{
|
||||
MaxRetries: cc.maxRetries,
|
||||
@ -258,7 +256,7 @@ func kiotaMiddlewares(
|
||||
khttp.NewParametersNameDecodingHandler(),
|
||||
khttp.NewUserAgentHandler(),
|
||||
&LoggingMiddleware{},
|
||||
}...)
|
||||
}
|
||||
|
||||
// Optionally add concurrency limiter middleware if it has been initialized.
|
||||
if concurrencyLimitMiddlewareSingleton != nil {
|
||||
|
||||
@ -60,7 +60,7 @@ func parseableToMap(t *testing.T, thing serialization.Parsable) map[string]any {
|
||||
require.NoError(t, err, "serialize")
|
||||
|
||||
content, err := sw.GetSerializedContent()
|
||||
require.NoError(t, err, "serialize")
|
||||
require.NoError(t, err, "deserialize")
|
||||
|
||||
var out map[string]any
|
||||
err = json.Unmarshal([]byte(content), &out)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user