corso/src/internal/m365/graph/middleware_test.go
Keepers 8ba79709a6
split tester into separate files (#3762)
This is primarily an exercise in reducing the number of circular imports we get from adding the tester package to other packages.

No logic changes.  Purely movement/renaming.

---

#### Does this PR need a docs update or release note?

- [x]  No

#### Type of change

- [x] 🤖 Supportability/Tests

#### Test Plan

- [x]  Unit test
- [x] 💚 E2E
2023-07-06 15:43:57 +00:00

397 lines
9.2 KiB
Go

package graph
import (
"bytes"
"io"
"net/http"
"syscall"
"testing"
"time"
"github.com/alcionai/clues"
"github.com/google/uuid"
khttp "github.com/microsoft/kiota-http-go"
msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go"
msgraphgocore "github.com/microsoftgraph/msgraph-sdk-go-core"
"github.com/microsoftgraph/msgraph-sdk-go/models"
"github.com/microsoftgraph/msgraph-sdk-go/users"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"golang.org/x/time/rate"
"github.com/alcionai/corso/src/internal/common/ptr"
"github.com/alcionai/corso/src/internal/tester"
"github.com/alcionai/corso/src/internal/tester/tconfig"
"github.com/alcionai/corso/src/pkg/account"
"github.com/alcionai/corso/src/pkg/path"
)
type mwReturns struct {
err error
resp *http.Response
}
func newMWReturns(code int, body []byte, err error) mwReturns {
var brc io.ReadCloser
if len(body) > 0 {
brc = io.NopCloser(bytes.NewBuffer(body))
}
resp := &http.Response{
StatusCode: code,
Body: brc,
}
if code == 0 {
resp = nil
}
return mwReturns{
err: err,
resp: resp,
}
}
func newTestMW(onIntercept func(*http.Request), mrs ...mwReturns) *testMW {
return &testMW{
onIntercept: onIntercept,
toReturn: mrs,
}
}
type testMW struct {
repeatReturn0 bool
iter int
toReturn []mwReturns
onIntercept func(*http.Request)
}
func (mw *testMW) Intercept(
pipeline khttp.Pipeline,
middlewareIndex int,
req *http.Request,
) (*http.Response, error) {
mw.onIntercept(req)
i := mw.iter
if mw.repeatReturn0 {
i = 0
}
// panic on out-of-bounds intentionally not protected
tr := mw.toReturn[i]
mw.iter++
return tr.resp, tr.err
}
// can't use graph/mock.CreateAdapter() due to circular references.
func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.GraphRequestAdapter, error) {
auth, err := GetAuth(
creds.AzureTenantID,
creds.AzureClientID,
creds.AzureClientSecret)
if err != nil {
return nil, err
}
var (
clientOptions = msgraphsdkgo.GetDefaultClientOptions()
cc = populateConfig(MinimumBackoff(10 * time.Millisecond))
middlewares = append(kiotaMiddlewares(&clientOptions, cc), mw)
httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
)
httpClient.Timeout = 15 * time.Second
cc.apply(httpClient)
return msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(
auth,
nil, nil,
httpClient)
}
type RetryMWIntgSuite struct {
tester.Suite
creds account.M365Config
}
// We do end up mocking the actual request, but creating the rest
// similar to E2E suite
func TestRetryMWIntgSuite(t *testing.T) {
suite.Run(t, &RetryMWIntgSuite{
Suite: tester.NewIntegrationSuite(
t,
[][]string{tconfig.M365AcctCredEnvs}),
})
}
func (suite *RetryMWIntgSuite) SetupSuite() {
var (
a = tconfig.NewM365Account(suite.T())
err error
)
suite.creds, err = a.M365Config()
require.NoError(suite.T(), err, clues.ToCore(err))
}
func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() {
var (
uri = "https://graph.microsoft.com"
urlPath = "/v1.0/users/user/messages/foo"
url = uri + urlPath
)
tests := []struct {
name string
status int
providedErr error
expectRetryCount int
mw testMW
expectErr assert.ErrorAssertionFunc
}{
{
name: "200, no retries",
status: http.StatusOK,
providedErr: nil,
expectRetryCount: 0,
expectErr: assert.NoError,
},
{
name: "400, no retries",
status: http.StatusBadRequest,
providedErr: nil,
expectRetryCount: 0,
expectErr: assert.Error,
},
{
// don't test 504: gets intercepted by graph client for long waits.
name: "502",
status: http.StatusBadGateway,
providedErr: nil,
expectRetryCount: defaultMaxRetries,
expectErr: assert.Error,
},
{
name: "conn reset with 5xx",
status: http.StatusBadGateway,
providedErr: syscall.ECONNRESET,
expectRetryCount: defaultMaxRetries,
expectErr: assert.Error,
},
{
name: "conn reset with 2xx",
status: http.StatusOK,
providedErr: syscall.ECONNRESET,
expectRetryCount: defaultMaxRetries,
expectErr: assert.Error,
},
{
name: "conn reset with nil resp",
providedErr: syscall.ECONNRESET,
// Use 0 to denote nil http response
status: 0,
expectRetryCount: 3,
expectErr: assert.Error,
},
{
// Unlikely but check if connection reset error takes precedence
name: "conn reset with 400 resp",
providedErr: syscall.ECONNRESET,
status: http.StatusBadRequest,
expectRetryCount: 3,
expectErr: assert.Error,
},
{
name: "http timeout",
providedErr: http.ErrHandlerTimeout,
status: 0,
expectRetryCount: 3,
expectErr: assert.Error,
},
}
for _, test := range tests {
suite.Run(test.name, func() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
called := 0
mw := newTestMW(
func(*http.Request) { called++ },
newMWReturns(test.status, nil, test.providedErr))
mw.repeatReturn0 = true
adpt, err := mockAdapter(suite.creds, mw)
require.NoError(t, err, clues.ToCore(err))
// url doesn't fit the builder, but that shouldn't matter
_, err = users.NewCountRequestBuilder(url, adpt).Get(ctx, nil)
test.expectErr(t, err, clues.ToCore(err))
// -1 because the non-retried call always counts for one, then
// we increment based on the number of retry attempts.
assert.Equal(t, test.expectRetryCount, called-1)
})
}
}
func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter500() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
var (
body = models.NewMailFolder()
checkOnIntercept = func(req *http.Request) {
bs, err := io.ReadAll(req.Body)
require.NoError(t, err, clues.ToCore(err))
// an expired body, after graph compression, will
// normally contain 25 bytes. So we should see more
// than that at least.
require.Less(
t,
25,
len(bs),
"body should be longer than 25 bytes; shorter indicates the body was sliced on a retry")
}
)
body.SetDisplayName(ptr.To(uuid.NewString()))
mw := newTestMW(
checkOnIntercept,
newMWReturns(http.StatusInternalServerError, nil, nil),
newMWReturns(http.StatusOK, nil, nil))
adpt, err := mockAdapter(suite.creds, mw)
require.NoError(t, err, clues.ToCore(err))
// no api package needed here, this is a mocked request that works
// independent of the query.
_, err = NewService(adpt).
Client().
Users().
ByUserId("user").
MailFolders().
Post(ctx, body, nil)
require.NoError(t, err, clues.ToCore(err))
}
type MiddlewareUnitSuite struct {
tester.Suite
}
func TestMiddlewareUnitSuite(t *testing.T) {
suite.Run(t, &MiddlewareUnitSuite{Suite: tester.NewUnitSuite(t)})
}
func (suite *MiddlewareUnitSuite) TestBindExtractLimiterConfig() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
// an unpopulated ctx should produce the default limiter
assert.Equal(t, defaultLimiter, ctxLimiter(ctx))
table := []struct {
name string
service path.ServiceType
expectOK require.BoolAssertionFunc
expectLimiter *rate.Limiter
}{
{
name: "exchange",
service: path.ExchangeService,
expectLimiter: defaultLimiter,
},
{
name: "oneDrive",
service: path.OneDriveService,
expectLimiter: driveLimiter,
},
{
name: "sharePoint",
service: path.SharePointService,
expectLimiter: driveLimiter,
},
{
name: "unknownService",
service: path.UnknownService,
expectLimiter: defaultLimiter,
},
{
name: "badService",
service: path.ServiceType(-1),
expectLimiter: defaultLimiter,
},
}
for _, test := range table {
suite.Run(test.name, func() {
t := suite.T()
tctx := BindRateLimiterConfig(ctx, LimiterCfg{Service: test.service})
lc, ok := extractRateLimiterConfig(tctx)
require.True(t, ok, "found rate limiter in ctx")
assert.Equal(t, test.service, lc.Service)
assert.Equal(t, test.expectLimiter, ctxLimiter(tctx))
})
}
}
func (suite *MiddlewareUnitSuite) TestLimiterConsumption() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
// an unpopulated ctx should produce the default consumption
assert.Equal(t, defaultLC, ctxLimiterConsumption(ctx, defaultLC))
table := []struct {
name string
n int
expect int
}{
{
name: "matches default",
n: defaultLC,
expect: defaultLC,
},
{
name: "default+1",
n: defaultLC + 1,
expect: defaultLC + 1,
},
{
name: "zero",
n: 0,
expect: defaultLC,
},
{
name: "negative",
n: -1,
expect: defaultLC,
},
}
for _, test := range table {
suite.Run(test.name, func() {
t := suite.T()
tctx := ConsumeNTokens(ctx, test.n)
lc := ctxLimiterConsumption(tctx, defaultLC)
assert.Equal(t, test.expect, lc)
})
}
}