add test for 503 resp body (#3812)

#### Type of change

- [x] 🤖 Supportability/Tests

#### Issue(s)

* closes #3811

#### Test Plan

- [x]  Unit test
This commit is contained in:
Keepers 2023-07-14 17:40:30 -06:00 committed by GitHub
parent fd6dff3270
commit 09e5e9464a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 134 additions and 80 deletions

View File

@ -2,14 +2,18 @@ package graph
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"syscall" "syscall"
"testing" "testing"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/microsoft/kiota-abstractions-go/serialization"
kjson "github.com/microsoft/kiota-serialization-json-go"
"github.com/microsoftgraph/msgraph-sdk-go/models" "github.com/microsoftgraph/msgraph-sdk-go/models"
"github.com/microsoftgraph/msgraph-sdk-go/models/odataerrors" "github.com/microsoftgraph/msgraph-sdk-go/models/odataerrors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/common/ptr"
@ -44,6 +48,22 @@ func odErrMsg(code, message string) *odataerrors.ODataError {
return odErr return odErr
} }
func parseableToMap(t *testing.T, thing serialization.Parsable) map[string]any {
sw := kjson.NewJsonSerializationWriter()
err := sw.WriteObjectValue("", thing)
require.NoError(t, err, "serialize")
content, err := sw.GetSerializedContent()
require.NoError(t, err, "deserialize")
var out map[string]any
err = json.Unmarshal([]byte(content), &out)
require.NoError(t, err, "unmarshall")
return out
}
func (suite *GraphErrorsUnitSuite) TestIsErrConnectionReset() { func (suite *GraphErrorsUnitSuite) TestIsErrConnectionReset() {
table := []struct { table := []struct {
name string name string

View File

@ -206,18 +206,14 @@ func (mw RetryMiddleware) Intercept(
req *http.Request, req *http.Request,
) (*http.Response, error) { ) (*http.Response, error) {
ctx := req.Context() ctx := req.Context()
resp, err := pipeline.Next(req, middlewareIndex) resp, err := pipeline.Next(req, middlewareIndex)
retriable := IsErrTimeout(err) || IsErrConnectionReset(err) || retriable := IsErrTimeout(err) ||
(resp != nil && (resp.StatusCode/100 == 4 || resp.StatusCode/100 == 5)) IsErrConnectionReset(err) ||
mw.isRetriableRespCode(ctx, resp)
if !retriable { if !retriable {
if err != nil { return resp, stackReq(ctx, req, resp, err).OrNil()
return resp, stackReq(ctx, req, resp, err)
}
return resp, nil
} }
exponentialBackOff := backoff.NewExponentialBackOff() exponentialBackOff := backoff.NewExponentialBackOff()
@ -234,11 +230,8 @@ func (mw RetryMiddleware) Intercept(
0, 0,
exponentialBackOff, exponentialBackOff,
err) err)
if err != nil {
return nil, stackReq(ctx, req, resp, err)
}
return resp, nil return resp, stackReq(ctx, req, resp, err).OrNil()
} }
func (mw RetryMiddleware) retryRequest( func (mw RetryMiddleware) retryRequest(
@ -252,26 +245,24 @@ func (mw RetryMiddleware) retryRequest(
exponentialBackoff *backoff.ExponentialBackOff, exponentialBackoff *backoff.ExponentialBackOff,
priorErr error, priorErr error,
) (*http.Response, error) { ) (*http.Response, error) {
status := "unknown_resp_status" ctx = clues.Add(ctx, "retry_count", executionCount)
statusCode := -1
if resp != nil { if resp != nil {
status = resp.Status ctx = clues.Add(ctx, "prev_resp_status", resp.Status)
statusCode = resp.StatusCode
} }
ctx = clues.Add( // only retry if all the following conditions are met:
ctx, // 1, there was a prior error OR the status code match retriable conditions.
"prev_resp_status", status, // 3, the request method is retriable.
"retry_count", executionCount) // 4, we haven't already hit maximum retries.
shouldRetry := (priorErr != nil || mw.isRetriableRespCode(ctx, resp)) &&
// only retry under certain conditions:
// 1, there was an error. 2, the resp and/or status code match retriable conditions.
// 3, the request is retriable.
// 4, we haven't hit our max retries already.
if (priorErr != nil || mw.isRetriableRespCode(ctx, resp, statusCode)) &&
mw.isRetriableRequest(req) && mw.isRetriableRequest(req) &&
executionCount < mw.MaxRetries { executionCount < mw.MaxRetries
if !shouldRetry {
return resp, stackReq(ctx, req, resp, priorErr).OrNil()
}
executionCount++ executionCount++
delay := mw.getRetryDelay(req, resp, exponentialBackoff) delay := mw.getRetryDelay(req, resp, exponentialBackoff)
@ -295,10 +286,11 @@ func (mw RetryMiddleware) retryRequest(
// as a 500. // as a 500.
if req.Body != nil { if req.Body != nil {
if s, ok := req.Body.(io.Seeker); ok { if s, ok := req.Body.(io.Seeker); ok {
_, err := s.Seek(0, io.SeekStart) if _, err := s.Seek(0, io.SeekStart); err != nil {
if err != nil { return resp, Wrap(ctx, err, "resetting request body reader")
return nil, Wrap(ctx, err, "resetting request body reader")
} }
} else {
logger.Ctx(ctx).Error("body is not an io.Seeker: unable to reset request body")
} }
} }
@ -319,26 +311,23 @@ func (mw RetryMiddleware) retryRequest(
err) err)
} }
if priorErr != nil {
return nil, stackReq(ctx, req, nil, priorErr)
}
return resp, nil
}
var retryableRespCodes = []int{ var retryableRespCodes = []int{
http.StatusInternalServerError, http.StatusInternalServerError,
http.StatusBadGateway, http.StatusBadGateway,
} }
func (mw RetryMiddleware) isRetriableRespCode(ctx context.Context, resp *http.Response, code int) bool { func (mw RetryMiddleware) isRetriableRespCode(ctx context.Context, resp *http.Response) bool {
if slices.Contains(retryableRespCodes, code) { if resp == nil {
return false
}
if slices.Contains(retryableRespCodes, resp.StatusCode) {
return true return true
} }
// prevent the body dump below in case of a 2xx response. // prevent the body dump below in case of a 2xx response.
// There's no reason to check the body on a healthy status. // There's no reason to check the body on a healthy status.
if code/100 != 4 && code/100 != 5 { if resp.StatusCode/100 != 4 && resp.StatusCode/100 != 5 {
return false return false
} }

View File

@ -2,6 +2,7 @@ package graph
import ( import (
"bytes" "bytes"
"encoding/json"
"io" "io"
"net/http" "net/http"
"syscall" "syscall"
@ -80,7 +81,10 @@ func (mw *testMW) Intercept(
i = 0 i = 0
} }
// panic on out-of-bounds intentionally not protected if i >= len(mw.toReturn) {
panic(clues.New("middleware test had more calls than responses"))
}
tr := mw.toReturn[i] tr := mw.toReturn[i]
mw.iter++ mw.iter++
@ -89,7 +93,11 @@ func (mw *testMW) Intercept(
} }
// can't use graph/mock.CreateAdapter() due to circular references. // can't use graph/mock.CreateAdapter() due to circular references.
func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.GraphRequestAdapter, error) { func mockAdapter(
creds account.M365Config,
mw khttp.Middleware,
timeout time.Duration,
) (*msgraphsdkgo.GraphRequestAdapter, error) {
auth, err := GetAuth( auth, err := GetAuth(
creds.AzureTenantID, creds.AzureTenantID,
creds.AzureClientID, creds.AzureClientID,
@ -105,7 +113,7 @@ func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.G
httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...) httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
) )
httpClient.Timeout = 15 * time.Second httpClient.Timeout = timeout
cc.apply(httpClient) cc.apply(httpClient)
@ -229,7 +237,7 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() {
newMWReturns(test.status, nil, test.providedErr)) newMWReturns(test.status, nil, test.providedErr))
mw.repeatReturn0 = true mw.repeatReturn0 = true
adpt, err := mockAdapter(suite.creds, mw) adpt, err := mockAdapter(suite.creds, mw, 15*time.Second)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
// url doesn't fit the builder, but that shouldn't matter // url doesn't fit the builder, but that shouldn't matter
@ -273,7 +281,7 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter50
newMWReturns(http.StatusInternalServerError, nil, nil), newMWReturns(http.StatusInternalServerError, nil, nil),
newMWReturns(http.StatusOK, nil, nil)) newMWReturns(http.StatusOK, nil, nil))
adpt, err := mockAdapter(suite.creds, mw) adpt, err := mockAdapter(suite.creds, mw, 15*time.Second)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
// no api package needed here, this is a mocked request that works // no api package needed here, this is a mocked request that works
@ -287,6 +295,45 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter50
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
} }
func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryResponse_maintainBodyAfter503() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
InitializeConcurrencyLimiter(ctx, false, -1)
odem := odErrMsg("SystemDown", "The System, Is Down, bah-dup-da-woo-woo!")
m := parseableToMap(t, odem)
body, err := json.Marshal(m)
require.NoError(t, err, clues.ToCore(err))
mw := newTestMW(
// intentional no-op, just need to conrol the response code
func(*http.Request) {},
newMWReturns(http.StatusServiceUnavailable, body, nil),
newMWReturns(http.StatusServiceUnavailable, body, nil),
newMWReturns(http.StatusServiceUnavailable, body, nil),
newMWReturns(http.StatusServiceUnavailable, body, nil))
adpt, err := mockAdapter(suite.creds, mw, 55*time.Second)
require.NoError(t, err, clues.ToCore(err))
// no api package needed here,
// this is a mocked request that works
// independent of the query.
_, err = NewService(adpt).
Client().
Users().
ByUserId("user").
MailFolders().
Post(ctx, models.NewMailFolder(), nil)
require.Error(t, err, clues.ToCore(err))
require.NotContains(t, err.Error(), "content is empty", clues.ToCore(err))
require.Contains(t, err.Error(), "503", clues.ToCore(err))
}
type MiddlewareUnitSuite struct { type MiddlewareUnitSuite struct {
tester.Suite tester.Suite
} }

View File

@ -244,9 +244,7 @@ func kiotaMiddlewares(
options *msgraphgocore.GraphClientOptions, options *msgraphgocore.GraphClientOptions,
cc *clientConfig, cc *clientConfig,
) []khttp.Middleware { ) []khttp.Middleware {
mw := []khttp.Middleware{} mw := []khttp.Middleware{
mw = append(mw, []khttp.Middleware{
msgraphgocore.NewGraphTelemetryHandler(options), msgraphgocore.NewGraphTelemetryHandler(options),
&RetryMiddleware{ &RetryMiddleware{
MaxRetries: cc.maxRetries, MaxRetries: cc.maxRetries,
@ -258,7 +256,7 @@ func kiotaMiddlewares(
khttp.NewParametersNameDecodingHandler(), khttp.NewParametersNameDecodingHandler(),
khttp.NewUserAgentHandler(), khttp.NewUserAgentHandler(),
&LoggingMiddleware{}, &LoggingMiddleware{},
}...) }
// Optionally add concurrency limiter middleware if it has been initialized. // Optionally add concurrency limiter middleware if it has been initialized.
if concurrencyLimitMiddlewareSingleton != nil { if concurrencyLimitMiddlewareSingleton != nil {

View File

@ -60,7 +60,7 @@ func parseableToMap(t *testing.T, thing serialization.Parsable) map[string]any {
require.NoError(t, err, "serialize") require.NoError(t, err, "serialize")
content, err := sw.GetSerializedContent() content, err := sw.GetSerializedContent()
require.NoError(t, err, "serialize") require.NoError(t, err, "deserialize")
var out map[string]any var out map[string]any
err = json.Unmarshal([]byte(content), &out) err = json.Unmarshal([]byte(content), &out)