Plug connection reset wrapper into OneDrive code (#3947)

Also add basic test to ensure everything is wired up as expected.

---

#### Does this PR need a docs update or release note?

- [ ]  Yes, it's included
- [ ] 🕐 Yes, but in a later PR
- [x]  No

#### Type of change

- [ ] 🌻 Feature
- [x] 🐛 Bugfix
- [ ] 🗺️ Documentation
- [ ] 🤖 Supportability/Tests
- [ ] 💻 CI/Deployment
- [ ] 🧹 Tech Debt/Cleanup

#### Test Plan

- [ ] 💪 Manual
- [x]  Unit test
- [ ] 💚 E2E
This commit is contained in:
ashmrtn 2023-08-08 13:53:56 -07:00 committed by GitHub
parent 4ebb2d3bfb
commit 3b73b61c90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 118 additions and 29 deletions

View File

@ -167,7 +167,9 @@ func (rrh *resetRetryHandler) reconnect(maxRetries int) (int, error) {
err = retryErrs[0] err = retryErrs[0]
) )
if rrh.getter.SupportsRange() { // Only set the range header if we've already read data. Otherwise we could
// get 416 (range not satisfiable) if the file is empty.
if rrh.getter.SupportsRange() && rrh.offset > 0 {
headers[rangeHeaderKey] = fmt.Sprintf( headers[rangeHeaderKey] = fmt.Sprintf(
rangeHeaderOneSidedValueTmpl, rangeHeaderOneSidedValueTmpl,
rrh.offset) rrh.offset)

View File

@ -152,10 +152,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
{ {
name: "OnlyFirstReadErrors RangeSupport", name: "OnlyFirstReadErrors RangeSupport",
supportsRange: true, supportsRange: true,
getterExpectHeaders: map[int]map[string]string{
0: {"Range": "bytes=0-"},
1: {"Range": "bytes=0-"},
},
getterResps: map[int]getterResp{ getterResps: map[int]getterResp{
0: { 0: {
err: syscall.ECONNRESET, err: syscall.ECONNRESET,
@ -180,7 +176,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
1: {offset: 12}, 1: {offset: 12},
}, },
getterExpectHeaders: map[int]map[string]string{ getterExpectHeaders: map[int]map[string]string{
0: {"Range": "bytes=0-"},
1: {"Range": "bytes=12-"}, 1: {"Range": "bytes=12-"},
}, },
readerResps: map[int]readResp{ readerResps: map[int]readResp{
@ -213,7 +208,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
2: {offset: 20}, 2: {offset: 20},
}, },
getterExpectHeaders: map[int]map[string]string{ getterExpectHeaders: map[int]map[string]string{
0: {"Range": "bytes=0-"},
1: {"Range": "bytes=12-"}, 1: {"Range": "bytes=12-"},
2: {"Range": "bytes=20-"}, 2: {"Range": "bytes=20-"},
}, },
@ -246,7 +240,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
1: {offset: 14}, 1: {offset: 14},
}, },
getterExpectHeaders: map[int]map[string]string{ getterExpectHeaders: map[int]map[string]string{
0: {"Range": "bytes=0-"},
1: {"Range": "bytes=14-"}, 1: {"Range": "bytes=14-"},
}, },
readerResps: map[int]readResp{ readerResps: map[int]readResp{
@ -275,7 +268,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
1: {offset: 16}, 1: {offset: 16},
}, },
getterExpectHeaders: map[int]map[string]string{ getterExpectHeaders: map[int]map[string]string{
0: {"Range": "bytes=0-"},
1: {"Range": "bytes=16-"}, 1: {"Range": "bytes=16-"},
}, },
readerResps: map[int]readResp{ readerResps: map[int]readResp{
@ -305,7 +297,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
1: {offset: 12}, 1: {offset: 12},
}, },
getterExpectHeaders: map[int]map[string]string{ getterExpectHeaders: map[int]map[string]string{
0: {"Range": "bytes=0-"},
1: {"Range": "bytes=12-"}, 1: {"Range": "bytes=12-"},
}, },
readerResps: map[int]readResp{ readerResps: map[int]readResp{
@ -347,7 +338,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
1: {offset: 14}, 1: {offset: 14},
}, },
getterExpectHeaders: map[int]map[string]string{ getterExpectHeaders: map[int]map[string]string{
0: {"Range": "bytes=0-"},
1: {"Range": "bytes=14-"}, 1: {"Range": "bytes=14-"},
}, },
readerResps: map[int]readResp{ readerResps: map[int]readResp{
@ -391,7 +381,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
3: {err: syscall.ECONNRESET}, 3: {err: syscall.ECONNRESET},
}, },
getterExpectHeaders: map[int]map[string]string{ getterExpectHeaders: map[int]map[string]string{
0: {"Range": "bytes=0-"},
1: {"Range": "bytes=12-"}, 1: {"Range": "bytes=12-"},
2: {"Range": "bytes=13-"}, 2: {"Range": "bytes=13-"},
3: {"Range": "bytes=14-"}, 3: {"Range": "bytes=14-"},
@ -423,14 +412,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
4: {offset: -1}, 4: {offset: -1},
5: {offset: -1}, 5: {offset: -1},
}, },
getterExpectHeaders: map[int]map[string]string{
0: {"Range": "bytes=0-"},
1: {"Range": "bytes=0-"},
2: {"Range": "bytes=0-"},
3: {"Range": "bytes=0-"},
4: {"Range": "bytes=0-"},
5: {"Range": "bytes=0-"},
},
readerResps: map[int]readResp{ readerResps: map[int]readResp{
0: { 0: {
sticky: true, sticky: true,

View File

@ -8,14 +8,21 @@ import (
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/microsoftgraph/msgraph-sdk-go/models" "github.com/microsoftgraph/msgraph-sdk-go/models"
"golang.org/x/exp/maps"
"github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/common/ptr"
"github.com/alcionai/corso/src/internal/common/readers"
"github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/internal/m365/graph"
"github.com/alcionai/corso/src/internal/m365/onedrive/metadata" "github.com/alcionai/corso/src/internal/m365/onedrive/metadata"
"github.com/alcionai/corso/src/pkg/services/m365/api" "github.com/alcionai/corso/src/pkg/services/m365/api"
) )
const (
acceptHeaderKey = "Accept"
acceptHeaderValue = "*/*"
)
// downloadUrlKeys is used to find the download URL in a DriveItem response. // downloadUrlKeys is used to find the download URL in a DriveItem response.
var downloadURLKeys = []string{ var downloadURLKeys = []string{
"@microsoft.graph.downloadUrl", "@microsoft.graph.downloadUrl",
@ -59,25 +66,42 @@ func downloadItem(
return rc, nil return rc, nil
} }
func downloadFile( type downloadWithRetries struct {
ctx context.Context, getter api.Getter
ag api.Getter, url string
url string, }
) (io.ReadCloser, error) {
if len(url) == 0 {
return nil, clues.New("empty file url")
}
resp, err := ag.Get(ctx, url, nil) func (dg *downloadWithRetries) SupportsRange() bool {
return true
}
func (dg *downloadWithRetries) Get(
ctx context.Context,
additionalHeaders map[string]string,
) (io.ReadCloser, error) {
headers := maps.Clone(additionalHeaders)
// Set the accept header like curl does. Local testing showed range headers
// wouldn't work without it (get 416 responses instead of 206).
headers[acceptHeaderKey] = acceptHeaderValue
resp, err := dg.getter.Get(ctx, dg.url, headers)
if err != nil { if err != nil {
return nil, clues.Wrap(err, "getting file") return nil, clues.Wrap(err, "getting file")
} }
if graph.IsMalwareResp(ctx, resp) { if graph.IsMalwareResp(ctx, resp) {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
return nil, clues.New("malware detected").Label(graph.LabelsMalware) return nil, clues.New("malware detected").Label(graph.LabelsMalware)
} }
if resp != nil && (resp.StatusCode/100) != 2 { if resp != nil && (resp.StatusCode/100) != 2 {
if resp.Body != nil {
resp.Body.Close()
}
// upstream error checks can compare the status with // upstream error checks can compare the status with
// clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode)) // clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode))
return nil, clues. return nil, clues.
@ -88,6 +112,25 @@ func downloadFile(
return resp.Body, nil return resp.Body, nil
} }
func downloadFile(
ctx context.Context,
ag api.Getter,
url string,
) (io.ReadCloser, error) {
if len(url) == 0 {
return nil, clues.New("empty file url").WithClues(ctx)
}
rc, err := readers.NewResetRetryHandler(
ctx,
&downloadWithRetries{
getter: ag,
url: url,
})
return rc, clues.Stack(err).OrNil()
}
func downloadItemMeta( func downloadItemMeta(
ctx context.Context, ctx context.Context,
gip GetItemPermissioner, gip GetItemPermissioner,

View File

@ -5,10 +5,12 @@ import (
"context" "context"
"io" "io"
"net/http" "net/http"
"syscall"
"testing" "testing"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/microsoftgraph/msgraph-sdk-go/models" "github.com/microsoftgraph/msgraph-sdk-go/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -438,3 +440,64 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
}) })
} }
} }
type errReader struct{}
func (r errReader) Read(p []byte) (int, error) {
return 0, syscall.ECONNRESET
}
func (suite *ItemUnitTestSuite) TestDownloadItem_ConnectionResetErrorOnFirstRead() {
var (
callCount int
testData = []byte("test")
testRc = io.NopCloser(bytes.NewReader(testData))
url = "https://example.com"
itemFunc = func() models.DriveItemable {
di := newItem("test", false)
di.SetAdditionalData(map[string]any{
"@microsoft.graph.downloadUrl": url,
})
return di
}
GetFunc = func(ctx context.Context, url string) (*http.Response, error) {
defer func() {
callCount++
}()
if callCount == 0 {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(errReader{}),
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Body: testRc,
}, nil
}
errorExpected = require.NoError
rcExpected = require.NotNil
)
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
mg := mockGetter{
GetFunc: GetFunc,
}
rc, err := downloadItem(ctx, mg, itemFunc())
errorExpected(t, err, clues.ToCore(err))
rcExpected(t, rc)
data, err := io.ReadAll(rc)
require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, testData, data)
}