Plug connection reset wrapper into OneDrive code (#3947)
Also add basic test to ensure everything is wired up as expected. --- #### Does this PR need a docs update or release note? - [ ] ✅ Yes, it's included - [ ] 🕐 Yes, but in a later PR - [x] ⛔ No #### Type of change - [ ] 🌻 Feature - [x] 🐛 Bugfix - [ ] 🗺️ Documentation - [ ] 🤖 Supportability/Tests - [ ] 💻 CI/Deployment - [ ] 🧹 Tech Debt/Cleanup #### Test Plan - [ ] 💪 Manual - [x] ⚡ Unit test - [ ] 💚 E2E
This commit is contained in:
parent
4ebb2d3bfb
commit
3b73b61c90
@ -167,7 +167,9 @@ func (rrh *resetRetryHandler) reconnect(maxRetries int) (int, error) {
|
|||||||
err = retryErrs[0]
|
err = retryErrs[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
if rrh.getter.SupportsRange() {
|
// Only set the range header if we've already read data. Otherwise we could
|
||||||
|
// get 416 (range not satisfiable) if the file is empty.
|
||||||
|
if rrh.getter.SupportsRange() && rrh.offset > 0 {
|
||||||
headers[rangeHeaderKey] = fmt.Sprintf(
|
headers[rangeHeaderKey] = fmt.Sprintf(
|
||||||
rangeHeaderOneSidedValueTmpl,
|
rangeHeaderOneSidedValueTmpl,
|
||||||
rrh.offset)
|
rrh.offset)
|
||||||
|
|||||||
@ -152,10 +152,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
|||||||
{
|
{
|
||||||
name: "OnlyFirstReadErrors RangeSupport",
|
name: "OnlyFirstReadErrors RangeSupport",
|
||||||
supportsRange: true,
|
supportsRange: true,
|
||||||
getterExpectHeaders: map[int]map[string]string{
|
|
||||||
0: {"Range": "bytes=0-"},
|
|
||||||
1: {"Range": "bytes=0-"},
|
|
||||||
},
|
|
||||||
getterResps: map[int]getterResp{
|
getterResps: map[int]getterResp{
|
||||||
0: {
|
0: {
|
||||||
err: syscall.ECONNRESET,
|
err: syscall.ECONNRESET,
|
||||||
@ -180,7 +176,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
|||||||
1: {offset: 12},
|
1: {offset: 12},
|
||||||
},
|
},
|
||||||
getterExpectHeaders: map[int]map[string]string{
|
getterExpectHeaders: map[int]map[string]string{
|
||||||
0: {"Range": "bytes=0-"},
|
|
||||||
1: {"Range": "bytes=12-"},
|
1: {"Range": "bytes=12-"},
|
||||||
},
|
},
|
||||||
readerResps: map[int]readResp{
|
readerResps: map[int]readResp{
|
||||||
@ -213,7 +208,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
|||||||
2: {offset: 20},
|
2: {offset: 20},
|
||||||
},
|
},
|
||||||
getterExpectHeaders: map[int]map[string]string{
|
getterExpectHeaders: map[int]map[string]string{
|
||||||
0: {"Range": "bytes=0-"},
|
|
||||||
1: {"Range": "bytes=12-"},
|
1: {"Range": "bytes=12-"},
|
||||||
2: {"Range": "bytes=20-"},
|
2: {"Range": "bytes=20-"},
|
||||||
},
|
},
|
||||||
@ -246,7 +240,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
|||||||
1: {offset: 14},
|
1: {offset: 14},
|
||||||
},
|
},
|
||||||
getterExpectHeaders: map[int]map[string]string{
|
getterExpectHeaders: map[int]map[string]string{
|
||||||
0: {"Range": "bytes=0-"},
|
|
||||||
1: {"Range": "bytes=14-"},
|
1: {"Range": "bytes=14-"},
|
||||||
},
|
},
|
||||||
readerResps: map[int]readResp{
|
readerResps: map[int]readResp{
|
||||||
@ -275,7 +268,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
|||||||
1: {offset: 16},
|
1: {offset: 16},
|
||||||
},
|
},
|
||||||
getterExpectHeaders: map[int]map[string]string{
|
getterExpectHeaders: map[int]map[string]string{
|
||||||
0: {"Range": "bytes=0-"},
|
|
||||||
1: {"Range": "bytes=16-"},
|
1: {"Range": "bytes=16-"},
|
||||||
},
|
},
|
||||||
readerResps: map[int]readResp{
|
readerResps: map[int]readResp{
|
||||||
@ -305,7 +297,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
|||||||
1: {offset: 12},
|
1: {offset: 12},
|
||||||
},
|
},
|
||||||
getterExpectHeaders: map[int]map[string]string{
|
getterExpectHeaders: map[int]map[string]string{
|
||||||
0: {"Range": "bytes=0-"},
|
|
||||||
1: {"Range": "bytes=12-"},
|
1: {"Range": "bytes=12-"},
|
||||||
},
|
},
|
||||||
readerResps: map[int]readResp{
|
readerResps: map[int]readResp{
|
||||||
@ -347,7 +338,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
|||||||
1: {offset: 14},
|
1: {offset: 14},
|
||||||
},
|
},
|
||||||
getterExpectHeaders: map[int]map[string]string{
|
getterExpectHeaders: map[int]map[string]string{
|
||||||
0: {"Range": "bytes=0-"},
|
|
||||||
1: {"Range": "bytes=14-"},
|
1: {"Range": "bytes=14-"},
|
||||||
},
|
},
|
||||||
readerResps: map[int]readResp{
|
readerResps: map[int]readResp{
|
||||||
@ -391,7 +381,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
|||||||
3: {err: syscall.ECONNRESET},
|
3: {err: syscall.ECONNRESET},
|
||||||
},
|
},
|
||||||
getterExpectHeaders: map[int]map[string]string{
|
getterExpectHeaders: map[int]map[string]string{
|
||||||
0: {"Range": "bytes=0-"},
|
|
||||||
1: {"Range": "bytes=12-"},
|
1: {"Range": "bytes=12-"},
|
||||||
2: {"Range": "bytes=13-"},
|
2: {"Range": "bytes=13-"},
|
||||||
3: {"Range": "bytes=14-"},
|
3: {"Range": "bytes=14-"},
|
||||||
@ -423,14 +412,6 @@ func (suite *ResetRetryHandlerUnitSuite) TestResetRetryHandler() {
|
|||||||
4: {offset: -1},
|
4: {offset: -1},
|
||||||
5: {offset: -1},
|
5: {offset: -1},
|
||||||
},
|
},
|
||||||
getterExpectHeaders: map[int]map[string]string{
|
|
||||||
0: {"Range": "bytes=0-"},
|
|
||||||
1: {"Range": "bytes=0-"},
|
|
||||||
2: {"Range": "bytes=0-"},
|
|
||||||
3: {"Range": "bytes=0-"},
|
|
||||||
4: {"Range": "bytes=0-"},
|
|
||||||
5: {"Range": "bytes=0-"},
|
|
||||||
},
|
|
||||||
readerResps: map[int]readResp{
|
readerResps: map[int]readResp{
|
||||||
0: {
|
0: {
|
||||||
sticky: true,
|
sticky: true,
|
||||||
|
|||||||
@ -8,14 +8,21 @@ import (
|
|||||||
|
|
||||||
"github.com/alcionai/clues"
|
"github.com/alcionai/clues"
|
||||||
"github.com/microsoftgraph/msgraph-sdk-go/models"
|
"github.com/microsoftgraph/msgraph-sdk-go/models"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/alcionai/corso/src/internal/common/ptr"
|
"github.com/alcionai/corso/src/internal/common/ptr"
|
||||||
|
"github.com/alcionai/corso/src/internal/common/readers"
|
||||||
"github.com/alcionai/corso/src/internal/common/str"
|
"github.com/alcionai/corso/src/internal/common/str"
|
||||||
"github.com/alcionai/corso/src/internal/m365/graph"
|
"github.com/alcionai/corso/src/internal/m365/graph"
|
||||||
"github.com/alcionai/corso/src/internal/m365/onedrive/metadata"
|
"github.com/alcionai/corso/src/internal/m365/onedrive/metadata"
|
||||||
"github.com/alcionai/corso/src/pkg/services/m365/api"
|
"github.com/alcionai/corso/src/pkg/services/m365/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
acceptHeaderKey = "Accept"
|
||||||
|
acceptHeaderValue = "*/*"
|
||||||
|
)
|
||||||
|
|
||||||
// downloadUrlKeys is used to find the download URL in a DriveItem response.
|
// downloadUrlKeys is used to find the download URL in a DriveItem response.
|
||||||
var downloadURLKeys = []string{
|
var downloadURLKeys = []string{
|
||||||
"@microsoft.graph.downloadUrl",
|
"@microsoft.graph.downloadUrl",
|
||||||
@ -59,25 +66,42 @@ func downloadItem(
|
|||||||
return rc, nil
|
return rc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func downloadFile(
|
type downloadWithRetries struct {
|
||||||
ctx context.Context,
|
getter api.Getter
|
||||||
ag api.Getter,
|
url string
|
||||||
url string,
|
|
||||||
) (io.ReadCloser, error) {
|
|
||||||
if len(url) == 0 {
|
|
||||||
return nil, clues.New("empty file url")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := ag.Get(ctx, url, nil)
|
func (dg *downloadWithRetries) SupportsRange() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dg *downloadWithRetries) Get(
|
||||||
|
ctx context.Context,
|
||||||
|
additionalHeaders map[string]string,
|
||||||
|
) (io.ReadCloser, error) {
|
||||||
|
headers := maps.Clone(additionalHeaders)
|
||||||
|
// Set the accept header like curl does. Local testing showed range headers
|
||||||
|
// wouldn't work without it (get 416 responses instead of 206).
|
||||||
|
headers[acceptHeaderKey] = acceptHeaderValue
|
||||||
|
|
||||||
|
resp, err := dg.getter.Get(ctx, dg.url, headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, clues.Wrap(err, "getting file")
|
return nil, clues.Wrap(err, "getting file")
|
||||||
}
|
}
|
||||||
|
|
||||||
if graph.IsMalwareResp(ctx, resp) {
|
if graph.IsMalwareResp(ctx, resp) {
|
||||||
|
if resp != nil && resp.Body != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
return nil, clues.New("malware detected").Label(graph.LabelsMalware)
|
return nil, clues.New("malware detected").Label(graph.LabelsMalware)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp != nil && (resp.StatusCode/100) != 2 {
|
if resp != nil && (resp.StatusCode/100) != 2 {
|
||||||
|
if resp.Body != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// upstream error checks can compare the status with
|
// upstream error checks can compare the status with
|
||||||
// clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode))
|
// clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode))
|
||||||
return nil, clues.
|
return nil, clues.
|
||||||
@ -88,6 +112,25 @@ func downloadFile(
|
|||||||
return resp.Body, nil
|
return resp.Body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func downloadFile(
|
||||||
|
ctx context.Context,
|
||||||
|
ag api.Getter,
|
||||||
|
url string,
|
||||||
|
) (io.ReadCloser, error) {
|
||||||
|
if len(url) == 0 {
|
||||||
|
return nil, clues.New("empty file url").WithClues(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
rc, err := readers.NewResetRetryHandler(
|
||||||
|
ctx,
|
||||||
|
&downloadWithRetries{
|
||||||
|
getter: ag,
|
||||||
|
url: url,
|
||||||
|
})
|
||||||
|
|
||||||
|
return rc, clues.Stack(err).OrNil()
|
||||||
|
}
|
||||||
|
|
||||||
func downloadItemMeta(
|
func downloadItemMeta(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
gip GetItemPermissioner,
|
gip GetItemPermissioner,
|
||||||
|
|||||||
@ -5,10 +5,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/alcionai/clues"
|
"github.com/alcionai/clues"
|
||||||
"github.com/microsoftgraph/msgraph-sdk-go/models"
|
"github.com/microsoftgraph/msgraph-sdk-go/models"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
|
||||||
@ -438,3 +440,64 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type errReader struct{}
|
||||||
|
|
||||||
|
func (r errReader) Read(p []byte) (int, error) {
|
||||||
|
return 0, syscall.ECONNRESET
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *ItemUnitTestSuite) TestDownloadItem_ConnectionResetErrorOnFirstRead() {
|
||||||
|
var (
|
||||||
|
callCount int
|
||||||
|
|
||||||
|
testData = []byte("test")
|
||||||
|
testRc = io.NopCloser(bytes.NewReader(testData))
|
||||||
|
url = "https://example.com"
|
||||||
|
|
||||||
|
itemFunc = func() models.DriveItemable {
|
||||||
|
di := newItem("test", false)
|
||||||
|
di.SetAdditionalData(map[string]any{
|
||||||
|
"@microsoft.graph.downloadUrl": url,
|
||||||
|
})
|
||||||
|
|
||||||
|
return di
|
||||||
|
}
|
||||||
|
|
||||||
|
GetFunc = func(ctx context.Context, url string) (*http.Response, error) {
|
||||||
|
defer func() {
|
||||||
|
callCount++
|
||||||
|
}()
|
||||||
|
|
||||||
|
if callCount == 0 {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(errReader{}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: testRc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
errorExpected = require.NoError
|
||||||
|
rcExpected = require.NotNil
|
||||||
|
)
|
||||||
|
|
||||||
|
t := suite.T()
|
||||||
|
|
||||||
|
ctx, flush := tester.NewContext(t)
|
||||||
|
defer flush()
|
||||||
|
|
||||||
|
mg := mockGetter{
|
||||||
|
GetFunc: GetFunc,
|
||||||
|
}
|
||||||
|
rc, err := downloadItem(ctx, mg, itemFunc())
|
||||||
|
errorExpected(t, err, clues.ToCore(err))
|
||||||
|
rcExpected(t, rc)
|
||||||
|
|
||||||
|
data, err := io.ReadAll(rc)
|
||||||
|
require.NoError(t, err, clues.ToCore(err))
|
||||||
|
assert.Equal(t, testData, data)
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user