From 3b73b61c905cca6e4cf0749341509676f821145a Mon Sep 17 00:00:00 2001 From: ashmrtn <3891298+ashmrtn@users.noreply.github.com> Date: Tue, 8 Aug 2023 13:53:56 -0700 Subject: [PATCH] Plug connection reset wrapper into OneDrive code (#3947) Also add basic test to ensure everything is wired up as expected. --- #### Does this PR need a docs update or release note? - [ ] :white_check_mark: Yes, it's included - [ ] :clock1: Yes, but in a later PR - [x] :no_entry: No #### Type of change - [ ] :sunflower: Feature - [x] :bug: Bugfix - [ ] :world_map: Documentation - [ ] :robot: Supportability/Tests - [ ] :computer: CI/Deployment - [ ] :broom: Tech Debt/Cleanup #### Test Plan - [ ] :muscle: Manual - [x] :zap: Unit test - [ ] :green_heart: E2E --- src/internal/common/readers/retry_handler.go | 4 +- .../common/readers/retry_handler_test.go | 19 ------ src/internal/m365/onedrive/item.go | 61 +++++++++++++++--- src/internal/m365/onedrive/item_test.go | 63 +++++++++++++++++++ 4 files changed, 118 insertions(+), 29 deletions(-) diff --git a/src/internal/common/readers/retry_handler.go b/src/internal/common/readers/retry_handler.go index ea6ece185..b52389f83 100644 --- a/src/internal/common/readers/retry_handler.go +++ b/src/internal/common/readers/retry_handler.go @@ -167,7 +167,9 @@ func (rrh *resetRetryHandler) reconnect(maxRetries int) (int, error) { err = retryErrs[0] ) - if rrh.getter.SupportsRange() { + // Only set the range header if we've already read data. Otherwise we could + // get 416 (range not satisfiable) if the file is empty. + if rrh.getter.SupportsRange() && rrh.offset > 0 { headers[rangeHeaderKey] = fmt.Sprintf( rangeHeaderOneSidedValueTmpl, rrh.offset) diff --git a/src/internal/common/readers/retry_handler_test.go b/src/internal/common/readers/retry_handler_test.go index f5842e6fa..e6bca2585 100644 --- a/src/internal/common/readers/retry_handler_test.go +++ b/src/internal/common/readers/retry_handler_test.go @@ -152,10 +152,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() { { name: "OnlyFirstReadErrors RangeSupport", supportsRange: true, - getterExpectHeaders: map[int]map[string]string{ - 0: {"Range": "bytes=0-"}, - 1: {"Range": "bytes=0-"}, - }, getterResps: map[int]getterResp{ 0: { err: syscall.ECONNRESET, @@ -180,7 +176,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() { 1: {offset: 12}, }, getterExpectHeaders: map[int]map[string]string{ - 0: {"Range": "bytes=0-"}, 1: {"Range": "bytes=12-"}, }, readerResps: map[int]readResp{ @@ -213,7 +208,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() { 2: {offset: 20}, }, getterExpectHeaders: map[int]map[string]string{ - 0: {"Range": "bytes=0-"}, 1: {"Range": "bytes=12-"}, 2: {"Range": "bytes=20-"}, }, @@ -246,7 +240,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() { 1: {offset: 14}, }, getterExpectHeaders: map[int]map[string]string{ - 0: {"Range": "bytes=0-"}, 1: {"Range": "bytes=14-"}, }, readerResps: map[int]readResp{ @@ -275,7 +268,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() { 1: {offset: 16}, }, getterExpectHeaders: map[int]map[string]string{ - 0: {"Range": "bytes=0-"}, 1: {"Range": "bytes=16-"}, }, readerResps: map[int]readResp{ @@ -305,7 +297,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() { 1: {offset: 12}, }, getterExpectHeaders: map[int]map[string]string{ - 0: {"Range": "bytes=0-"}, 1: {"Range": "bytes=12-"}, }, readerResps: map[int]readResp{ @@ -347,7 +338,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() { 1: {offset: 14}, }, getterExpectHeaders: map[int]map[string]string{ - 0: {"Range": "bytes=0-"}, 1: {"Range": "bytes=14-"}, }, readerResps: map[int]readResp{ @@ -391,7 +381,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() { 3: {err: syscall.ECONNRESET}, }, getterExpectHeaders: map[int]map[string]string{ - 0: {"Range": "bytes=0-"}, 1: {"Range": "bytes=12-"}, 2: {"Range": "bytes=13-"}, 3: {"Range": "bytes=14-"}, @@ -423,14 +412,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() { 4: {offset: -1}, 5: {offset: -1}, }, - getterExpectHeaders: map[int]map[string]string{ - 0: {"Range": "bytes=0-"}, - 1: {"Range": "bytes=0-"}, - 2: {"Range": "bytes=0-"}, - 3: {"Range": "bytes=0-"}, - 4: {"Range": "bytes=0-"}, - 5: {"Range": "bytes=0-"}, - }, readerResps: map[int]readResp{ 0: { sticky: true, diff --git a/src/internal/m365/onedrive/item.go b/src/internal/m365/onedrive/item.go index a149efd12..3bf35000e 100644 --- a/src/internal/m365/onedrive/item.go +++ b/src/internal/m365/onedrive/item.go @@ -8,14 +8,21 @@ import ( "github.com/alcionai/clues" "github.com/microsoftgraph/msgraph-sdk-go/models" + "golang.org/x/exp/maps" "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/internal/common/readers" "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/internal/m365/onedrive/metadata" "github.com/alcionai/corso/src/pkg/services/m365/api" ) +const ( + acceptHeaderKey = "Accept" + acceptHeaderValue = "*/*" +) + // downloadUrlKeys is used to find the download URL in a DriveItem response. var downloadURLKeys = []string{ "@microsoft.graph.downloadUrl", @@ -59,25 +66,42 @@ func downloadItem( return rc, nil } -func downloadFile( - ctx context.Context, - ag api.Getter, - url string, -) (io.ReadCloser, error) { - if len(url) == 0 { - return nil, clues.New("empty file url") - } +type downloadWithRetries struct { + getter api.Getter + url string +} - resp, err := ag.Get(ctx, url, nil) +func (dg *downloadWithRetries) SupportsRange() bool { + return true +} + +func (dg *downloadWithRetries) Get( + ctx context.Context, + additionalHeaders map[string]string, +) (io.ReadCloser, error) { + headers := maps.Clone(additionalHeaders) + // Set the accept header like curl does. Local testing showed range headers + // wouldn't work without it (get 416 responses instead of 206). + headers[acceptHeaderKey] = acceptHeaderValue + + resp, err := dg.getter.Get(ctx, dg.url, headers) if err != nil { return nil, clues.Wrap(err, "getting file") } if graph.IsMalwareResp(ctx, resp) { + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + return nil, clues.New("malware detected").Label(graph.LabelsMalware) } if resp != nil && (resp.StatusCode/100) != 2 { + if resp.Body != nil { + resp.Body.Close() + } + // upstream error checks can compare the status with // clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode)) return nil, clues. @@ -88,6 +112,25 @@ func downloadFile( return resp.Body, nil } +func downloadFile( + ctx context.Context, + ag api.Getter, + url string, +) (io.ReadCloser, error) { + if len(url) == 0 { + return nil, clues.New("empty file url").WithClues(ctx) + } + + rc, err := readers.NewResetRetryHandler( + ctx, + &downloadWithRetries{ + getter: ag, + url: url, + }) + + return rc, clues.Stack(err).OrNil() +} + func downloadItemMeta( ctx context.Context, gip GetItemPermissioner, diff --git a/src/internal/m365/onedrive/item_test.go b/src/internal/m365/onedrive/item_test.go index b3f352bbf..f8f11992f 100644 --- a/src/internal/m365/onedrive/item_test.go +++ b/src/internal/m365/onedrive/item_test.go @@ -5,10 +5,12 @@ import ( "context" "io" "net/http" + "syscall" "testing" "github.com/alcionai/clues" "github.com/microsoftgraph/msgraph-sdk-go/models" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -438,3 +440,64 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() { }) } } + +type errReader struct{} + +func (r errReader) Read(p []byte) (int, error) { + return 0, syscall.ECONNRESET +} + +func (suite *ItemUnitTestSuite) TestDownloadItem_ConnectionResetErrorOnFirstRead() { + var ( + callCount int + + testData = []byte("test") + testRc = io.NopCloser(bytes.NewReader(testData)) + url = "https://example.com" + + itemFunc = func() models.DriveItemable { + di := newItem("test", false) + di.SetAdditionalData(map[string]any{ + "@microsoft.graph.downloadUrl": url, + }) + + return di + } + + GetFunc = func(ctx context.Context, url string) (*http.Response, error) { + defer func() { + callCount++ + }() + + if callCount == 0 { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(errReader{}), + }, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: testRc, + }, nil + } + errorExpected = require.NoError + rcExpected = require.NotNil + ) + + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + mg := mockGetter{ + GetFunc: GetFunc, + } + rc, err := downloadItem(ctx, mg, itemFunc()) + errorExpected(t, err, clues.ToCore(err)) + rcExpected(t, rc) + + data, err := io.ReadAll(rc) + require.NoError(t, err, clues.ToCore(err)) + assert.Equal(t, testData, data) +}