Plug connection reset wrapper into OneDrive code (#3947)
Also add basic test to ensure everything is wired up as expected. --- #### Does this PR need a docs update or release note? - [ ] ✅ Yes, it's included - [ ] 🕐 Yes, but in a later PR - [x] ⛔ No #### Type of change - [ ] 🌻 Feature - [x] 🐛 Bugfix - [ ] 🗺️ Documentation - [ ] 🤖 Supportability/Tests - [ ] 💻 CI/Deployment - [ ] 🧹 Tech Debt/Cleanup #### Test Plan - [ ] 💪 Manual - [x] ⚡ Unit test - [ ] 💚 E2E
This commit is contained in:
parent
4ebb2d3bfb
commit
3b73b61c90
@ -167,7 +167,9 @@ func (rrh *resetRetryHandler) reconnect(maxRetries int) (int, error) {
|
||||
err = retryErrs[0]
|
||||
)
|
||||
|
||||
if rrh.getter.SupportsRange() {
|
||||
// Only set the range header if we've already read data. Otherwise we could
|
||||
// get 416 (range not satisfiable) if the file is empty.
|
||||
if rrh.getter.SupportsRange() && rrh.offset > 0 {
|
||||
headers[rangeHeaderKey] = fmt.Sprintf(
|
||||
rangeHeaderOneSidedValueTmpl,
|
||||
rrh.offset)
|
||||
|
||||
@ -152,10 +152,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
||||
{
|
||||
name: "OnlyFirstReadErrors RangeSupport",
|
||||
supportsRange: true,
|
||||
getterExpectHeaders: map[int]map[string]string{
|
||||
0: {"Range": "bytes=0-"},
|
||||
1: {"Range": "bytes=0-"},
|
||||
},
|
||||
getterResps: map[int]getterResp{
|
||||
0: {
|
||||
err: syscall.ECONNRESET,
|
||||
@ -180,7 +176,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
||||
1: {offset: 12},
|
||||
},
|
||||
getterExpectHeaders: map[int]map[string]string{
|
||||
0: {"Range": "bytes=0-"},
|
||||
1: {"Range": "bytes=12-"},
|
||||
},
|
||||
readerResps: map[int]readResp{
|
||||
@ -213,7 +208,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
||||
2: {offset: 20},
|
||||
},
|
||||
getterExpectHeaders: map[int]map[string]string{
|
||||
0: {"Range": "bytes=0-"},
|
||||
1: {"Range": "bytes=12-"},
|
||||
2: {"Range": "bytes=20-"},
|
||||
},
|
||||
@ -246,7 +240,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
||||
1: {offset: 14},
|
||||
},
|
||||
getterExpectHeaders: map[int]map[string]string{
|
||||
0: {"Range": "bytes=0-"},
|
||||
1: {"Range": "bytes=14-"},
|
||||
},
|
||||
readerResps: map[int]readResp{
|
||||
@ -275,7 +268,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
||||
1: {offset: 16},
|
||||
},
|
||||
getterExpectHeaders: map[int]map[string]string{
|
||||
0: {"Range": "bytes=0-"},
|
||||
1: {"Range": "bytes=16-"},
|
||||
},
|
||||
readerResps: map[int]readResp{
|
||||
@ -305,7 +297,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
||||
1: {offset: 12},
|
||||
},
|
||||
getterExpectHeaders: map[int]map[string]string{
|
||||
0: {"Range": "bytes=0-"},
|
||||
1: {"Range": "bytes=12-"},
|
||||
},
|
||||
readerResps: map[int]readResp{
|
||||
@ -347,7 +338,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
||||
1: {offset: 14},
|
||||
},
|
||||
getterExpectHeaders: map[int]map[string]string{
|
||||
0: {"Range": "bytes=0-"},
|
||||
1: {"Range": "bytes=14-"},
|
||||
},
|
||||
readerResps: map[int]readResp{
|
||||
@ -391,7 +381,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
||||
3: {err: syscall.ECONNRESET},
|
||||
},
|
||||
getterExpectHeaders: map[int]map[string]string{
|
||||
0: {"Range": "bytes=0-"},
|
||||
1: {"Range": "bytes=12-"},
|
||||
2: {"Range": "bytes=13-"},
|
||||
3: {"Range": "bytes=14-"},
|
||||
@ -423,14 +412,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
||||
4: {offset: -1},
|
||||
5: {offset: -1},
|
||||
},
|
||||
getterExpectHeaders: map[int]map[string]string{
|
||||
0: {"Range": "bytes=0-"},
|
||||
1: {"Range": "bytes=0-"},
|
||||
2: {"Range": "bytes=0-"},
|
||||
3: {"Range": "bytes=0-"},
|
||||
4: {"Range": "bytes=0-"},
|
||||
5: {"Range": "bytes=0-"},
|
||||
},
|
||||
readerResps: map[int]readResp{
|
||||
0: {
|
||||
sticky: true,
|
||||
|
||||
@ -8,14 +8,21 @@ import (
|
||||
|
||||
"github.com/alcionai/clues"
|
||||
"github.com/microsoftgraph/msgraph-sdk-go/models"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/alcionai/corso/src/internal/common/ptr"
|
||||
"github.com/alcionai/corso/src/internal/common/readers"
|
||||
"github.com/alcionai/corso/src/internal/common/str"
|
||||
"github.com/alcionai/corso/src/internal/m365/graph"
|
||||
"github.com/alcionai/corso/src/internal/m365/onedrive/metadata"
|
||||
"github.com/alcionai/corso/src/pkg/services/m365/api"
|
||||
)
|
||||
|
||||
const (
|
||||
acceptHeaderKey = "Accept"
|
||||
acceptHeaderValue = "*/*"
|
||||
)
|
||||
|
||||
// downloadUrlKeys is used to find the download URL in a DriveItem response.
|
||||
var downloadURLKeys = []string{
|
||||
"@microsoft.graph.downloadUrl",
|
||||
@ -59,25 +66,42 @@ func downloadItem(
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
func downloadFile(
|
||||
ctx context.Context,
|
||||
ag api.Getter,
|
||||
url string,
|
||||
) (io.ReadCloser, error) {
|
||||
if len(url) == 0 {
|
||||
return nil, clues.New("empty file url")
|
||||
type downloadWithRetries struct {
|
||||
getter api.Getter
|
||||
url string
|
||||
}
|
||||
|
||||
resp, err := ag.Get(ctx, url, nil)
|
||||
func (dg *downloadWithRetries) SupportsRange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (dg *downloadWithRetries) Get(
|
||||
ctx context.Context,
|
||||
additionalHeaders map[string]string,
|
||||
) (io.ReadCloser, error) {
|
||||
headers := maps.Clone(additionalHeaders)
|
||||
// Set the accept header like curl does. Local testing showed range headers
|
||||
// wouldn't work without it (get 416 responses instead of 206).
|
||||
headers[acceptHeaderKey] = acceptHeaderValue
|
||||
|
||||
resp, err := dg.getter.Get(ctx, dg.url, headers)
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "getting file")
|
||||
}
|
||||
|
||||
if graph.IsMalwareResp(ctx, resp) {
|
||||
if resp != nil && resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil, clues.New("malware detected").Label(graph.LabelsMalware)
|
||||
}
|
||||
|
||||
if resp != nil && (resp.StatusCode/100) != 2 {
|
||||
if resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// upstream error checks can compare the status with
|
||||
// clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode))
|
||||
return nil, clues.
|
||||
@ -88,6 +112,25 @@ func downloadFile(
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func downloadFile(
|
||||
ctx context.Context,
|
||||
ag api.Getter,
|
||||
url string,
|
||||
) (io.ReadCloser, error) {
|
||||
if len(url) == 0 {
|
||||
return nil, clues.New("empty file url").WithClues(ctx)
|
||||
}
|
||||
|
||||
rc, err := readers.NewResetRetryHandler(
|
||||
ctx,
|
||||
&downloadWithRetries{
|
||||
getter: ag,
|
||||
url: url,
|
||||
})
|
||||
|
||||
return rc, clues.Stack(err).OrNil()
|
||||
}
|
||||
|
||||
func downloadItemMeta(
|
||||
ctx context.Context,
|
||||
gip GetItemPermissioner,
|
||||
|
||||
@ -5,10 +5,12 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/alcionai/clues"
|
||||
"github.com/microsoftgraph/msgraph-sdk-go/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
@ -438,3 +440,64 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type errReader struct{}
|
||||
|
||||
func (r errReader) Read(p []byte) (int, error) {
|
||||
return 0, syscall.ECONNRESET
|
||||
}
|
||||
|
||||
func (suite *ItemUnitTestSuite) TestDownloadItem_ConnectionResetErrorOnFirstRead() {
|
||||
var (
|
||||
callCount int
|
||||
|
||||
testData = []byte("test")
|
||||
testRc = io.NopCloser(bytes.NewReader(testData))
|
||||
url = "https://example.com"
|
||||
|
||||
itemFunc = func() models.DriveItemable {
|
||||
di := newItem("test", false)
|
||||
di.SetAdditionalData(map[string]any{
|
||||
"@microsoft.graph.downloadUrl": url,
|
||||
})
|
||||
|
||||
return di
|
||||
}
|
||||
|
||||
GetFunc = func(ctx context.Context, url string) (*http.Response, error) {
|
||||
defer func() {
|
||||
callCount++
|
||||
}()
|
||||
|
||||
if callCount == 0 {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(errReader{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: testRc,
|
||||
}, nil
|
||||
}
|
||||
errorExpected = require.NoError
|
||||
rcExpected = require.NotNil
|
||||
)
|
||||
|
||||
t := suite.T()
|
||||
|
||||
ctx, flush := tester.NewContext(t)
|
||||
defer flush()
|
||||
|
||||
mg := mockGetter{
|
||||
GetFunc: GetFunc,
|
||||
}
|
||||
rc, err := downloadItem(ctx, mg, itemFunc())
|
||||
errorExpected(t, err, clues.ToCore(err))
|
||||
rcExpected(t, rc)
|
||||
|
||||
data, err := io.ReadAll(rc)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
assert.Equal(t, testData, data)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user