reset body reader on corso retry handling (#3280)

The kiota compressor middleware will attempt to compress the request body.  In the event that we have a corso-middleware- retriable response (eg: status 500), we need to reset the seek position of the req.Body, similar to how graph api does in their retry handler, or else the re-run of the compressor will already have read the full req.Body and the retried call will have a zero len body.

---

#### Does this PR need a docs update or release note?

- [y]  Yes, it's included

#### Type of change

- [x] 🐛 Bugfix

#### Test Plan

- [x] 💪 Manual
This commit is contained in:
Keepers 2023-05-04 12:31:12 -06:00 committed by GitHub
parent 67d5c53420
commit 9b21699b6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 13 deletions

View File

@ -7,7 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] (beta)
### Added
### Fixed
- Graph requests now automatically retry in case of a Bad Gateway or Gateway Timeout.
- POST Retries following certain status codes (500, 502, 504) will re-use the post body instead of retrying with a no-content request.
- Fix nil pointer exception when running an incremental backup on SharePoint where the base backup used an older index data format.
## [v0.7.0] (beta) - 2023-05-02

View File

@ -3,6 +3,7 @@ package graph
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httputil"
"os"
@ -250,7 +251,6 @@ func (mw RetryMiddleware) retryRequest(
executionCount++
delay := mw.getRetryDelay(req, resp, exponentialBackoff)
cumulativeDelay += delay
req.Header.Set(retryAttemptHeader, strconv.Itoa(executionCount))
@ -266,6 +266,18 @@ func (mw RetryMiddleware) retryRequest(
case <-timer.C:
}
// we have to reset the original body reader for each retry, or else the graph
// compressor will produce a 0 length body following an error response such
// as a 500.
if req.Body != nil {
if s, ok := req.Body.(io.Seeker); ok {
_, err := s.Seek(0, io.SeekStart)
if err != nil {
return nil, Wrap(ctx, err, "resetting request body reader")
}
}
}
nextResp, err := pipeline.Next(req, middlewareIndex)
if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) {
return nextResp, stackReq(ctx, req, nextResp, err)
@ -381,6 +393,10 @@ func (mw *ThrottleControlMiddleware) Intercept(
return pipeline.Next(req, middlewareIndex)
}
// ---------------------------------------------------------------------------
// Metrics
// ---------------------------------------------------------------------------
// MetricsMiddleware aggregates per-request metrics on the events bus
type MetricsMiddleware struct{}

View File

@ -1,44 +1,81 @@
package graph
import (
"bytes"
"io"
"net/http"
"testing"
"time"
"github.com/alcionai/clues"
"github.com/google/uuid"
khttp "github.com/microsoft/kiota-http-go"
msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go"
msgraphgocore "github.com/microsoftgraph/msgraph-sdk-go-core"
"github.com/microsoftgraph/msgraph-sdk-go/models"
"github.com/microsoftgraph/msgraph-sdk-go/users"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/alcionai/corso/src/internal/common/ptr"
"github.com/alcionai/corso/src/internal/tester"
"github.com/alcionai/corso/src/pkg/account"
)
func newBodylessTestMW(onIntercept func(), code int, err error) testMW {
return testMW{
type mwReturns struct {
err error
resp *http.Response
}
func newMWReturns(code int, body []byte, err error) mwReturns {
var brc io.ReadCloser
if len(body) > 0 {
brc = io.NopCloser(bytes.NewBuffer(body))
}
return mwReturns{
err: err,
resp: &http.Response{
StatusCode: code,
Body: brc,
},
}
}
func newTestMW(onIntercept func(*http.Request), mrs ...mwReturns) *testMW {
return &testMW{
onIntercept: onIntercept,
resp: &http.Response{StatusCode: code},
toReturn: mrs,
}
}
type testMW struct {
err error
onIntercept func()
resp *http.Response
repeatReturn0 bool
iter int
toReturn []mwReturns
onIntercept func(*http.Request)
}
func (mw testMW) Intercept(
func (mw *testMW) Intercept(
pipeline khttp.Pipeline,
middlewareIndex int,
req *http.Request,
) (*http.Response, error) {
mw.onIntercept()
return mw.resp, mw.err
mw.onIntercept(req)
i := mw.iter
if mw.repeatReturn0 {
i = 0
}
// panic on out-of-bounds intentionally not protected
tr := mw.toReturn[i]
mw.iter++
return tr.resp, tr.err
}
// can't use graph/mock.CreateAdapter() due to circular references.
@ -58,7 +95,7 @@ func mockAdapter(creds account.M365Config, mw khttp.Middleware) (*msgraphsdkgo.G
httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
)
httpClient.Timeout = 5 * time.Second
httpClient.Timeout = 15 * time.Second
cc.apply(httpClient)
@ -135,7 +172,10 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() {
t := suite.T()
called := 0
mw := newBodylessTestMW(func() { called++ }, test.status, nil)
mw := newTestMW(
func(*http.Request) { called++ },
newMWReturns(test.status, nil, nil))
mw.repeatReturn0 = true
adpt, err := mockAdapter(suite.creds, mw)
require.NoError(t, err, clues.ToCore(err))
@ -150,3 +190,43 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() {
})
}
}
func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter500() {
ctx, flush := tester.NewContext()
defer flush()
var (
t = suite.T()
body = models.NewMailFolder()
checkOnIntercept = func(req *http.Request) {
bs, err := io.ReadAll(req.Body)
require.NoError(t, err, clues.ToCore(err))
// an expired body, after graph compression, will
// normally contain 25 bytes. So we should see more
// than that at least.
require.Less(
t,
25,
len(bs),
"body should be longer than 25 bytes; shorter indicates the body was sliced on a retry")
}
)
body.SetDisplayName(ptr.To(uuid.NewString()))
mw := newTestMW(
checkOnIntercept,
newMWReturns(http.StatusInternalServerError, nil, nil),
newMWReturns(http.StatusOK, nil, nil))
adpt, err := mockAdapter(suite.creds, mw)
require.NoError(t, err, clues.ToCore(err))
_, err = NewService(adpt).
Client().
UsersById("user").
MailFolders().
Post(ctx, body, nil)
require.NoError(t, err, clues.ToCore(err))
}