custom http wrapper for item downloads (#3132)

the http client customized by the graph client for use in downloading files doesn't include our middleware like expected. This introduces a new http client wrapper that populates the roundtripper with a middleware wrapper that can utilize our kiota middleware in the same way as the graph client does.

---

#### Does this PR need a docs update or release note?

- [x]  No

#### Type of change

- [x] 🌻 Feature
- [x] 🐛 Bugfix
- [x] 🧹 Tech Debt/Cleanup

#### Issue(s)

* #3129

#### Test Plan

- [x] 💪 Manual
- [x]  Unit test
- [x] 💚 E2E
This commit is contained in:
Keepers 2023-04-20 17:20:23 -06:00 committed by GitHub
parent 315c0cc5f3
commit d5fac8a480
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 481 additions and 195 deletions

View File

@ -77,7 +77,10 @@ func handleOneDriveCmd(cmd *cobra.Command, args []string) error {
return Only(ctx, clues.Wrap(err, "creating graph adapter"))
}
err = runDisplayM365JSON(ctx, graph.NewService(adpt), creds, user, m365ID)
svc := graph.NewService(adpt)
gr := graph.NewNoTimeoutHTTPWrapper()
err = runDisplayM365JSON(ctx, svc, gr, creds, user, m365ID)
if err != nil {
cmd.SilenceUsage = true
cmd.SilenceErrors = true
@ -105,6 +108,7 @@ func (i itemPrintable) MinimumPrintable() any {
func runDisplayM365JSON(
ctx context.Context,
srv graph.Servicer,
gr graph.Requester,
creds account.M365Config,
user, itemID string,
) error {
@ -123,7 +127,7 @@ func runDisplayM365JSON(
}
if item != nil {
content, err := getDriveItemContent(item)
content, err := getDriveItemContent(ctx, gr, item)
if err != nil {
return err
}
@ -180,22 +184,19 @@ func serializeObject(data serialization.Parsable) (string, error) {
return string(content), err
}
func getDriveItemContent(item models.DriveItemable) ([]byte, error) {
func getDriveItemContent(
ctx context.Context,
gr graph.Requester,
item models.DriveItemable,
) ([]byte, error) {
url, ok := item.GetAdditionalData()[downloadURLKey].(*string)
if !ok {
return nil, clues.New("get download url")
return nil, clues.New("retrieving download url")
}
req, err := http.NewRequest(http.MethodGet, *url, nil)
resp, err := gr.Request(ctx, http.MethodGet, *url, nil, nil)
if err != nil {
return nil, clues.New("create download request").With("error", err)
}
hc := graph.HTTPClient(graph.NoTimeout())
resp, err := hc.Do(req)
if err != nil {
return nil, clues.New("download item").With("error", err)
return nil, clues.New("downloading item").With("error", err)
}
content, err := io.ReadAll(resp.Body)

View File

@ -258,7 +258,7 @@ func (suite *DataCollectionIntgSuite) TestSharePointDataCollection() {
collections, excludes, err := sharepoint.DataCollections(
ctx,
graph.HTTPClient(graph.NoTimeout()),
graph.NewNoTimeoutHTTPWrapper(),
sel,
connector.credentials,
connector.Service,

View File

@ -1,36 +1,21 @@
package mock
import (
"github.com/alcionai/clues"
"github.com/alcionai/corso/src/internal/connector/exchange/api"
"github.com/alcionai/corso/src/internal/connector/graph"
"github.com/alcionai/corso/src/internal/connector/graph/mock"
"github.com/alcionai/corso/src/pkg/account"
)
func NewService(creds account.M365Config, opts ...graph.Option) (*graph.Service, error) {
a, err := mock.CreateAdapter(
creds.AzureTenantID,
creds.AzureClientID,
creds.AzureClientSecret,
opts...)
if err != nil {
return nil, clues.Wrap(err, "generating graph adapter")
}
return graph.NewService(a), nil
}
// NewClient produces a new exchange api client that can be
// mocked using gock.
func NewClient(creds account.M365Config) (api.Client, error) {
s, err := NewService(creds)
s, err := mock.NewService(creds)
if err != nil {
return api.Client{}, err
}
li, err := NewService(creds, graph.NoTimeout())
li, err := mock.NewService(creds, graph.NoTimeout())
if err != nil {
return api.Client{}, err
}

View File

@ -234,6 +234,9 @@ func Stack(ctx context.Context, e error) *clues.Err {
return setLabels(clues.Stack(e).WithClues(ctx).With(data...), innerMsg)
}
// Checks for the following conditions and labels the error accordingly:
// * mysiteNotFound | mysiteURLNotFound
// * malware
func setLabels(err *clues.Err, msg string) *clues.Err {
if err == nil {
return nil
@ -244,6 +247,10 @@ func setLabels(err *clues.Err, msg string) *clues.Err {
err = err.Label(LabelsMysiteNotFound)
}
if IsMalware(err) {
err = err.Label(LabelsMalware)
}
return err
}

View File

@ -0,0 +1,152 @@
package graph
import (
"context"
"io"
"net/http"
"github.com/alcionai/clues"
khttp "github.com/microsoft/kiota-http-go"
"github.com/alcionai/corso/src/internal/version"
)
// ---------------------------------------------------------------------------
// constructors
// ---------------------------------------------------------------------------
type Requester interface {
Request(
ctx context.Context,
method, url string,
body io.Reader,
headers map[string]string,
) (*http.Response, error)
}
// NewHTTPWrapper produces a http.Client wrapper that ensures
// calls use all the middleware we expect from the graph api client.
//
// Re-use of http clients is critical, or else we leak OS resources
// and consume relatively unbound socket connections. It is important
// to centralize this client to be passed downstream where api calls
// can utilize it on a per-download basis.
func NewHTTPWrapper(opts ...Option) *httpWrapper {
var (
cc = populateConfig(opts...)
rt = customTransport{
n: pipeline{
middlewares: internalMiddleware(cc),
transport: defaultTransport(),
},
}
redirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
hc = &http.Client{
CheckRedirect: redirect,
Timeout: defaultHTTPClientTimeout,
Transport: rt,
}
)
cc.apply(hc)
return &httpWrapper{hc}
}
// NewNoTimeoutHTTPWrapper constructs a http wrapper with no context timeout.
//
// Re-use of http clients is critical, or else we leak OS resources
// and consume relatively unbound socket connections. It is important
// to centralize this client to be passed downstream where api calls
// can utilize it on a per-download basis.
func NewNoTimeoutHTTPWrapper(opts ...Option) *httpWrapper {
opts = append(opts, NoTimeout())
return NewHTTPWrapper(opts...)
}
// ---------------------------------------------------------------------------
// requests
// ---------------------------------------------------------------------------
// Request does the provided request.
func (hw httpWrapper) Request(
ctx context.Context,
method, url string,
body io.Reader,
headers map[string]string,
) (*http.Response, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, clues.Wrap(err, "new http request")
}
for k, v := range headers {
req.Header.Set(k, v)
}
//nolint:lll
// Decorate the traffic
// See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic
req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version)
resp, err := hw.client.Do(req)
if err != nil {
return nil, Stack(ctx, err)
}
return resp, nil
}
// ---------------------------------------------------------------------------
// constructor internals
// ---------------------------------------------------------------------------
type (
httpWrapper struct {
client *http.Client
}
customTransport struct {
n nexter
}
pipeline struct {
transport http.RoundTripper
middlewares []khttp.Middleware
}
)
// RoundTrip kicks off the middleware chain and returns a response
func (ct customTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return ct.n.Next(req, 0)
}
// Next moves the request object through middlewares in the pipeline
func (pl pipeline) Next(req *http.Request, idx int) (*http.Response, error) {
if idx < len(pl.middlewares) {
return pl.middlewares[idx].Intercept(pl, idx+1, req)
}
return pl.transport.RoundTrip(req)
}
func defaultTransport() http.RoundTripper {
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
defaultTransport.ForceAttemptHTTP2 = true
return defaultTransport
}
func internalMiddleware(cc *clientConfig) []khttp.Middleware {
return []khttp.Middleware{
&RetryHandler{
MaxRetries: cc.maxRetries,
Delay: cc.minDelay,
},
&LoggingMiddleware{},
&ThrottleControlMiddleware{},
&MetricsMiddleware{},
}
}

View File

@ -0,0 +1,45 @@
package graph
import (
"net/http"
"testing"
"github.com/alcionai/clues"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/alcionai/corso/src/internal/tester"
)
type HTTPWrapperIntgSuite struct {
tester.Suite
}
func TestHTTPWrapperIntgSuite(t *testing.T) {
suite.Run(t, &HTTPWrapperIntgSuite{
Suite: tester.NewIntegrationSuite(
t,
[][]string{tester.M365AcctCredEnvs}),
})
}
func (suite *HTTPWrapperIntgSuite) TestNewHTTPWrapper() {
ctx, flush := tester.NewContext()
defer flush()
var (
t = suite.T()
hw = NewHTTPWrapper()
)
resp, err := hw.Request(
ctx,
http.MethodGet,
"https://www.corsobackup.io",
nil,
nil)
require.NoError(t, err, clues.ToCore(err))
require.NotNil(t, resp)
require.Equal(t, http.StatusOK, resp.StatusCode)
}

View File

@ -20,6 +20,10 @@ import (
"github.com/alcionai/corso/src/pkg/logger"
)
type nexter interface {
Next(req *http.Request, middlewareIndex int) (*http.Response, error)
}
// ---------------------------------------------------------------------------
// Logging
// ---------------------------------------------------------------------------

View File

@ -1,12 +1,27 @@
package mock
import (
"github.com/alcionai/clues"
"github.com/h2non/gock"
msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go"
"github.com/alcionai/corso/src/internal/connector/graph"
"github.com/alcionai/corso/src/pkg/account"
)
func NewService(creds account.M365Config, opts ...graph.Option) (*graph.Service, error) {
a, err := CreateAdapter(
creds.AzureTenantID,
creds.AzureClientID,
creds.AzureClientSecret,
opts...)
if err != nil {
return nil, clues.Wrap(err, "generating graph adapter")
}
return graph.NewService(a), nil
}
// CreateAdapter is similar to graph.CreateAdapter, but with option to
// enable interceptions via gock to make it mockable.
func CreateAdapter(
@ -18,7 +33,7 @@ func CreateAdapter(
return nil, err
}
httpClient := graph.HTTPClient(opts...)
httpClient := graph.KiotaHTTPClient(opts...)
// This makes sure that we are able to intercept any requests via
// gock. Only necessary for testing.

View File

@ -114,7 +114,7 @@ func CreateAdapter(
return nil, err
}
httpClient := HTTPClient(opts...)
httpClient := KiotaHTTPClient(opts...)
return msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(
auth,
@ -140,21 +140,24 @@ func GetAuth(tenant string, client string, secret string) (*kauth.AzureIdentityA
return auth, nil
}
// HTTPClient creates the httpClient with middlewares and timeout configured
// KiotaHTTPClient creates a httpClient with middlewares and timeout configured
// for use in the graph adapter.
//
// Re-use of http clients is critical, or else we leak OS resources
// and consume relatively unbound socket connections. It is important
// to centralize this client to be passed downstream where api calls
// can utilize it on a per-download basis.
func HTTPClient(opts ...Option) *http.Client {
clientOptions := msgraphsdkgo.GetDefaultClientOptions()
clientconfig := (&clientConfig{}).populate(opts...)
noOfRetries, minRetryDelay := clientconfig.applyMiddlewareConfig()
middlewares := GetKiotaMiddlewares(&clientOptions, noOfRetries, minRetryDelay)
httpClient := msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
func KiotaHTTPClient(opts ...Option) *http.Client {
var (
clientOptions = msgraphsdkgo.GetDefaultClientOptions()
cc = populateConfig(opts...)
middlewares = kiotaMiddlewares(&clientOptions, cc)
httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
)
httpClient.Timeout = defaultHTTPClientTimeout
clientconfig.apply(httpClient)
cc.apply(httpClient)
return httpClient
}
@ -175,27 +178,17 @@ type clientConfig struct {
type Option func(*clientConfig)
// populate constructs a clientConfig according to the provided options.
func (c *clientConfig) populate(opts ...Option) *clientConfig {
func populateConfig(opts ...Option) *clientConfig {
cc := clientConfig{
maxRetries: defaultMaxRetries,
minDelay: defaultDelay,
}
for _, opt := range opts {
opt(c)
opt(&cc)
}
return c
}
// apply updates the http.Client with the expected options.
func (c *clientConfig) applyMiddlewareConfig() (retry int, delay time.Duration) {
retry = defaultMaxRetries
if c.overrideRetryCount {
retry = c.maxRetries
}
delay = defaultDelay
if c.minDelay > 0 {
delay = c.minDelay
}
return
return &cc
}
// apply updates the http.Client with the expected options.
@ -236,14 +229,16 @@ func MinimumBackoff(dur time.Duration) Option {
// Middleware Control
// ---------------------------------------------------------------------------
// GetDefaultMiddlewares creates a new default set of middlewares for the Kiota request adapter
func GetMiddlewares(maxRetry int, delay time.Duration) []khttp.Middleware {
// kiotaMiddlewares creates a default slice of middleware for the Graph Client.
func kiotaMiddlewares(
options *msgraphgocore.GraphClientOptions,
cc *clientConfig,
) []khttp.Middleware {
return []khttp.Middleware{
msgraphgocore.NewGraphTelemetryHandler(options),
&RetryHandler{
// The maximum number of times a request can be retried
MaxRetries: maxRetry,
// The delay in seconds between retries
Delay: delay,
MaxRetries: cc.maxRetries,
Delay: cc.minDelay,
},
khttp.NewRetryHandler(),
khttp.NewRedirectHandler(),
@ -255,21 +250,3 @@ func GetMiddlewares(maxRetry int, delay time.Duration) []khttp.Middleware {
&MetricsMiddleware{},
}
}
// GetKiotaMiddlewares creates a default slice of middleware for the Graph Client.
func GetKiotaMiddlewares(
options *msgraphgocore.GraphClientOptions,
maxRetry int,
minDelay time.Duration,
) []khttp.Middleware {
kiotaMiddlewares := GetMiddlewares(maxRetry, minDelay)
graphMiddlewares := []khttp.Middleware{
msgraphgocore.NewGraphTelemetryHandler(options),
}
graphMiddlewaresLen := len(graphMiddlewares)
resultMiddlewares := make([]khttp.Middleware, len(kiotaMiddlewares)+graphMiddlewaresLen)
copy(resultMiddlewares, graphMiddlewares)
copy(resultMiddlewares[graphMiddlewaresLen:], kiotaMiddlewares)
return resultMiddlewares
}

View File

@ -70,7 +70,7 @@ func (suite *GraphUnitSuite) TestHTTPClient() {
suite.Run(test.name, func() {
t := suite.T()
cli := HTTPClient(test.opts...)
cli := KiotaHTTPClient(test.opts...)
assert.NotNil(t, cli)
test.check(t, cli)
})

View File

@ -4,7 +4,6 @@ package connector
import (
"context"
"net/http"
"runtime/trace"
"sync"
@ -36,7 +35,7 @@ var (
type GraphConnector struct {
Service graph.Servicer
Discovery api.Client
itemClient *http.Client // configured to handle large item downloads
itemClient graph.Requester // configured to handle large item downloads
tenant string
credentials account.M365Config
@ -88,7 +87,7 @@ func NewGraphConnector(
Service: service,
credentials: creds,
itemClient: graph.HTTPClient(graph.NoTimeout()),
itemClient: graph.NewNoTimeoutHTTPWrapper(),
ownerLookup: rc,
tenant: acct.ID(),
wg: &sync.WaitGroup{},

View File

@ -35,10 +35,6 @@ const (
// TODO: Tune this later along with collectionChannelBufferSize
urlPrefetchChannelBufferSize = 5
// maxDownloadRetires specifies the number of times a file download should
// be retried
maxDownloadRetires = 3
// Used to compare in case of OneNote files
MaxOneNoteFileSize = 2 * 1024 * 1024 * 1024
)
@ -62,7 +58,7 @@ const (
// Collection represents a set of OneDrive objects retrieved from M365
type Collection struct {
// configured to handle large item downloads
itemClient *http.Client
itemClient graph.Requester
// data is used to share data streams with the collection consumer
data chan data.Stream
@ -110,7 +106,7 @@ type Collection struct {
doNotMergeItems bool
}
// itemGetterFunc gets an specified item
// itemGetterFunc gets a specified item
type itemGetterFunc func(
ctx context.Context,
srv graph.Servicer,
@ -120,7 +116,7 @@ type itemGetterFunc func(
// itemReadFunc returns a reader for the specified item
type itemReaderFunc func(
ctx context.Context,
hc *http.Client,
client graph.Requester,
item models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error)
@ -148,7 +144,7 @@ func pathToLocation(p path.Path) (*path.Builder, error) {
// NewCollection creates a Collection
func NewCollection(
itemClient *http.Client,
itemClient graph.Requester,
folderPath path.Path,
prevPath path.Path,
driveID string,
@ -372,45 +368,29 @@ func (oc *Collection) getDriveItemContent(
itemID = ptr.Val(item.GetId())
itemName = ptr.Val(item.GetName())
el = errs.Local()
itemData io.ReadCloser
err error
)
// Initial try with url from delta + 2 retries
for i := 1; i <= maxDownloadRetires; i++ {
_, itemData, err = oc.itemReader(ctx, oc.itemClient, item)
if err == nil || !graph.IsErrUnauthorized(err) {
break
}
// Assume unauthorized requests are a sign of an expired jwt
// token, and that we've overrun the available window to
// download the actual file. Re-downloading the item will
// refresh that download url.
di, diErr := oc.itemGetter(ctx, oc.service, oc.driveID, itemID)
if diErr != nil {
err = clues.Wrap(diErr, "retrieving expired item")
break
}
item = di
}
// check for errors following retries
itemData, err := downloadContent(
ctx,
oc.service,
oc.itemGetter,
oc.itemReader,
oc.itemClient,
item,
oc.driveID)
if err != nil {
if clues.HasLabel(err, graph.LabelsMalware) || (item != nil && item.GetMalware() != nil) {
logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipMalware).Info("item flagged as malware")
el.AddSkip(fault.FileSkip(fault.SkipMalware, itemID, itemName, graph.ItemInfo(item)))
return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable)
return nil, clues.Wrap(err, "malware item").Label(graph.LabelsSkippable)
}
if clues.HasLabel(err, graph.LabelStatus(http.StatusNotFound)) || graph.IsErrDeletedInFlight(err) {
logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipNotFound).Info("item not found")
el.AddSkip(fault.FileSkip(fault.SkipNotFound, itemID, itemName, graph.ItemInfo(item)))
return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable)
return nil, clues.Wrap(err, "deleted item").Label(graph.LabelsSkippable)
}
// Skip big OneNote files as they can't be downloaded
@ -425,7 +405,7 @@ func (oc *Collection) getDriveItemContent(
logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipBigOneNote).Info("max OneNote file size exceeded")
el.AddSkip(fault.FileSkip(fault.SkipBigOneNote, itemID, itemName, graph.ItemInfo(item)))
return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable)
return nil, clues.Wrap(err, "max oneNote item").Label(graph.LabelsSkippable)
}
logger.CtxErr(ctx, err).Error("downloading item")
@ -433,12 +413,48 @@ func (oc *Collection) getDriveItemContent(
// return err, not el.Err(), because the lazy reader needs to communicate to
// the data consumer that this item is unreadable, regardless of the fault state.
return nil, clues.Wrap(err, "downloading item")
return nil, clues.Wrap(err, "fetching item content")
}
return itemData, nil
}
// downloadContent attempts to fetch the item content. If the content url
// is expired (ie, returns a 401), it re-fetches the item to get a new download
// url and tries again.
func downloadContent(
ctx context.Context,
svc graph.Servicer,
igf itemGetterFunc,
irf itemReaderFunc,
gr graph.Requester,
item models.DriveItemable,
driveID string,
) (io.ReadCloser, error) {
_, content, err := irf(ctx, gr, item)
if err == nil {
return content, nil
} else if !graph.IsErrUnauthorized(err) {
return nil, err
}
// Assume unauthorized requests are a sign of an expired jwt
// token, and that we've overrun the available window to
// download the actual file. Re-downloading the item will
// refresh that download url.
di, err := igf(ctx, svc, driveID, ptr.Val(item.GetId()))
if err != nil {
return nil, clues.Wrap(err, "retrieving expired item")
}
_, content, err = irf(ctx, gr, di)
if err != nil {
return nil, clues.Wrap(err, "content download retry")
}
return content, nil
}
// populateItems iterates through items added to the collection
// and uses the collection `itemReader` to read the item
func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
@ -504,9 +520,9 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
ctx = clues.Add(
ctx,
"backup_item_id", itemID,
"backup_item_name", itemName,
"backup_item_size", itemSize)
"item_id", itemID,
"item_name", itemName,
"item_size", itemSize)
item.SetParentReference(setName(item.GetParentReference(), oc.driveName))
@ -545,7 +561,7 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
itemInfo.OneDrive.ParentPath = parentPathString
}
ctx = clues.Add(ctx, "backup_item_info", itemInfo)
ctx = clues.Add(ctx, "item_info", itemInfo)
if isFile {
dataSuffix := metadata.DataFileSuffix

View File

@ -18,6 +18,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/alcionai/corso/src/internal/common/ptr"
"github.com/alcionai/corso/src/internal/connector/graph"
"github.com/alcionai/corso/src/internal/connector/onedrive/metadata"
"github.com/alcionai/corso/src/internal/connector/support"
@ -98,7 +99,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
numInstances: 1,
source: OneDriveSource,
itemDeets: nst{testItemName, 42, now},
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: testItemName, Modified: now}},
io.NopCloser(bytes.NewReader(testItemData)),
nil
@ -114,7 +115,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
numInstances: 3,
source: OneDriveSource,
itemDeets: nst{testItemName, 42, now},
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: testItemName, Modified: now}},
io.NopCloser(bytes.NewReader(testItemData)),
nil
@ -130,7 +131,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
numInstances: 3,
source: OneDriveSource,
itemDeets: nst{testItemName, 42, now},
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, nil, clues.New("test malware").Label(graph.LabelsMalware)
},
infoFrom: func(t *testing.T, dii details.ItemInfo) (string, string) {
@ -146,7 +147,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
source: OneDriveSource,
itemDeets: nst{testItemName, 42, now},
// Usually `Not Found` is returned from itemGetter and not itemReader
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, nil, clues.New("test not found").Label(graph.LabelStatus(http.StatusNotFound))
},
infoFrom: func(t *testing.T, dii details.ItemInfo) (string, string) {
@ -161,7 +162,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
numInstances: 1,
source: SharePointSource,
itemDeets: nst{testItemName, 42, now},
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{SharePoint: &details.SharePointInfo{ItemName: testItemName, Modified: now}},
io.NopCloser(bytes.NewReader(testItemData)),
nil
@ -177,7 +178,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
numInstances: 3,
source: SharePointSource,
itemDeets: nst{testItemName, 42, now},
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{SharePoint: &details.SharePointInfo{ItemName: testItemName, Modified: now}},
io.NopCloser(bytes.NewReader(testItemData)),
nil
@ -207,7 +208,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
require.NoError(t, err, clues.ToCore(err))
coll, err := NewCollection(
graph.HTTPClient(graph.NoTimeout()),
graph.NewNoTimeoutHTTPWrapper(),
folderPath,
nil,
"drive-id",
@ -278,7 +279,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
if err != nil {
for _, label := range test.expectLabels {
assert.True(t, clues.HasLabel(err, label), "has clues label:", label)
assert.Truef(t, clues.HasLabel(err, label), "has clues label: %s", label)
}
return
@ -347,7 +348,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadError() {
require.NoError(t, err, clues.ToCore(err))
coll, err := NewCollection(
graph.HTTPClient(graph.NoTimeout()),
graph.NewNoTimeoutHTTPWrapper(),
folderPath,
nil,
"fakeDriveID",
@ -370,7 +371,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadError() {
coll.itemReader = func(
context.Context,
*http.Client,
graph.Requester,
models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, nil, assert.AnError
@ -437,7 +438,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry()
require.NoError(t, err)
coll, err := NewCollection(
graph.HTTPClient(graph.NoTimeout()),
graph.NewNoTimeoutHTTPWrapper(),
folderPath,
nil,
"fakeDriveID",
@ -470,10 +471,10 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry()
coll.itemReader = func(
context.Context,
*http.Client,
graph.Requester,
models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) {
if count < 2 {
if count < 1 {
count++
return details.ItemInfo{}, nil, clues.Stack(assert.AnError).
Label(graph.LabelStatus(http.StatusUnauthorized))
@ -494,13 +495,13 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry()
assert.True(t, ok)
_, err = io.ReadAll(collItem.ToReader())
assert.NoError(t, err)
assert.NoError(t, err, clues.ToCore(err))
wg.Wait()
require.Equal(t, 1, collStatus.Metrics.Objects, "only one object should be counted")
require.Equal(t, 1, collStatus.Metrics.Successes, "read object successfully")
require.Equal(t, 2, count, "retry count")
require.Equal(t, 1, count, "retry count")
})
}
}
@ -537,7 +538,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionPermissionBackupLatestModTim
require.NoError(t, err, clues.ToCore(err))
coll, err := NewCollection(
graph.HTTPClient(graph.NoTimeout()),
graph.NewNoTimeoutHTTPWrapper(),
folderPath,
nil,
"drive-id",
@ -561,7 +562,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionPermissionBackupLatestModTim
coll.itemReader = func(
context.Context,
*http.Client,
graph.Requester,
models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: "fakeName", Modified: time.Now()}},
@ -611,7 +612,7 @@ func TestGetDriveItemUnitTestSuite(t *testing.T) {
suite.Run(t, &GetDriveItemUnitTestSuite{Suite: tester.NewUnitSuite(t)})
}
func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
func (suite *GetDriveItemUnitTestSuite) TestGetDriveItem_error() {
strval := "not-important"
table := []struct {
@ -637,14 +638,14 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
name: "malware error",
colScope: CollectionScopeFolder,
itemSize: 10,
err: clues.New("test error").Label(graph.LabelsMalware),
err: clues.New("malware error").Label(graph.LabelsMalware),
labels: []string{graph.LabelsMalware, graph.LabelsSkippable},
},
{
name: "file not found error",
colScope: CollectionScopeFolder,
itemSize: 10,
err: clues.New("test error").Label(graph.LabelStatus(http.StatusNotFound)),
err: clues.New("not found error").Label(graph.LabelStatus(http.StatusNotFound)),
labels: []string{graph.LabelStatus(http.StatusNotFound), graph.LabelsSkippable},
},
{
@ -652,14 +653,14 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
name: "small OneNote file",
colScope: CollectionScopePackage,
itemSize: 10,
err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
err: clues.New("small onenote error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
labels: []string{graph.LabelStatus(http.StatusServiceUnavailable)},
},
{
name: "big OneNote file",
colScope: CollectionScopePackage,
itemSize: MaxOneNoteFileSize,
err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
err: clues.New("big onenote error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
labels: []string{graph.LabelStatus(http.StatusServiceUnavailable), graph.LabelsSkippable},
},
{
@ -667,7 +668,7 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
name: "big file",
colScope: CollectionScopeFolder,
itemSize: MaxOneNoteFileSize,
err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
err: clues.New("big file error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
labels: []string{graph.LabelStatus(http.StatusServiceUnavailable)},
},
}
@ -689,9 +690,9 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
item.SetSize(&test.itemSize)
col.itemReader = func(
ctx context.Context,
hc *http.Client,
item models.DriveItemable,
_ context.Context,
_ graph.Requester,
_ models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, nil, test.err
}
@ -707,11 +708,11 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
_, err := col.getDriveItemContent(ctx, item, errs)
if test.err == nil {
assert.NoError(t, err, "no error")
assert.NoError(t, err, clues.ToCore(err))
return
}
assert.EqualError(t, err, clues.Wrap(test.err, "downloading item").Error(), "error")
assert.ErrorIs(t, err, test.err, clues.ToCore(err))
labelsMap := map[string]struct{}{}
for _, l := range test.labels {
@ -722,3 +723,103 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
})
}
}
func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() {
var (
svc graph.Servicer
gr graph.Requester
driveID string
iorc = io.NopCloser(bytes.NewReader([]byte("fnords")))
item = &models.DriveItem{}
itemWID = &models.DriveItem{}
)
itemWID.SetId(ptr.To("brainhooldy"))
table := []struct {
name string
igf itemGetterFunc
irf itemReaderFunc
expectErr require.ErrorAssertionFunc
expect require.ValueAssertionFunc
}{
{
name: "good",
irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, iorc, nil
},
expectErr: require.NoError,
expect: require.NotNil,
},
{
name: "expired url redownloads",
igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) {
return itemWID, nil
},
irf: func(c context.Context, g graph.Requester, m models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
// a bit hacky: assume only igf returns an item with a non-zero id.
if len(ptr.Val(m.GetId())) == 0 {
return details.ItemInfo{},
nil,
clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized))
}
return details.ItemInfo{}, iorc, nil
},
expectErr: require.NoError,
expect: require.NotNil,
},
{
name: "immediate error",
irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, nil, assert.AnError
},
expectErr: require.Error,
expect: require.Nil,
},
{
name: "re-fetching the item fails",
igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) {
return nil, assert.AnError
},
irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{},
nil,
clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized))
},
expectErr: require.Error,
expect: require.Nil,
},
{
name: "expired url fails redownload",
igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) {
return itemWID, nil
},
irf: func(c context.Context, g graph.Requester, m models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
// a bit hacky: assume only igf returns an item with a non-zero id.
if len(ptr.Val(m.GetId())) == 0 {
return details.ItemInfo{},
nil,
clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized))
}
return details.ItemInfo{}, iorc, assert.AnError
},
expectErr: require.Error,
expect: require.Nil,
},
}
for _, test := range table {
suite.Run(test.name, func() {
ctx, flush := tester.NewContext()
defer flush()
t := suite.T()
r, err := downloadContent(ctx, svc, test.igf, test.irf, gr, item, driveID)
test.expect(t, r)
test.expectErr(t, err, clues.ToCore(err))
})
}
}

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/alcionai/clues"
@ -73,7 +72,7 @@ type folderMatcher interface {
// resource owner, which can be either a user or a sharepoint site.
type Collections struct {
// configured to handle large item downloads
itemClient *http.Client
itemClient graph.Requester
tenant string
resourceOwner string
@ -109,7 +108,7 @@ type Collections struct {
}
func NewCollections(
itemClient *http.Client,
itemClient graph.Requester,
tenant string,
resourceOwner string,
source driveSource,

View File

@ -780,7 +780,7 @@ func (suite *OneDriveCollectionsUnitSuite) TestUpdateCollections() {
maps.Copy(outputFolderMap, tt.inputFolderMap)
c := NewCollections(
graph.HTTPClient(graph.NoTimeout()),
graph.NewNoTimeoutHTTPWrapper(),
tenant,
user,
OneDriveSource,
@ -2231,7 +2231,7 @@ func (suite *OneDriveCollectionsUnitSuite) TestGet() {
}
c := NewCollections(
graph.HTTPClient(graph.NoTimeout()),
graph.NewNoTimeoutHTTPWrapper(),
tenant,
user,
OneDriveSource,

View File

@ -2,7 +2,6 @@ package onedrive
import (
"context"
"net/http"
"github.com/alcionai/clues"
"golang.org/x/exp/maps"
@ -38,7 +37,7 @@ func DataCollections(
user common.IDNamer,
metadata []data.RestoreCollection,
tenant string,
itemClient *http.Client,
itemClient graph.Requester,
service graph.Servicer,
su support.StatusUpdater,
ctrlOpts control.Options,

View File

@ -426,7 +426,7 @@ func (suite *OneDriveSuite) TestOneDriveNewCollections() {
)
colls := NewCollections(
graph.HTTPClient(graph.NoTimeout()),
graph.NewNoTimeoutHTTPWrapper(),
creds.AzureTenantID,
test.user,
OneDriveSource,

View File

@ -16,7 +16,6 @@ import (
"github.com/alcionai/corso/src/internal/connector/graph"
"github.com/alcionai/corso/src/internal/connector/onedrive/api"
"github.com/alcionai/corso/src/internal/connector/uploadsession"
"github.com/alcionai/corso/src/internal/version"
"github.com/alcionai/corso/src/pkg/backup/details"
"github.com/alcionai/corso/src/pkg/logger"
)
@ -33,12 +32,12 @@ const (
// TODO: Add metadata fetching to SharePoint
func sharePointItemReader(
ctx context.Context,
hc *http.Client,
client graph.Requester,
item models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) {
resp, err := downloadItem(ctx, hc, item)
resp, err := downloadItem(ctx, client, item)
if err != nil {
return details.ItemInfo{}, nil, clues.Wrap(err, "downloading item")
return details.ItemInfo{}, nil, clues.Wrap(err, "sharepoint reader")
}
dii := details.ItemInfo{
@ -107,7 +106,7 @@ func baseItemMetaReader(
// and using a http client to initialize a reader
func oneDriveItemReader(
ctx context.Context,
hc *http.Client,
client graph.Requester,
item models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) {
var (
@ -116,9 +115,9 @@ func oneDriveItemReader(
)
if isFile {
resp, err := downloadItem(ctx, hc, item)
resp, err := downloadItem(ctx, client, item)
if err != nil {
return details.ItemInfo{}, nil, clues.Wrap(err, "downloading item")
return details.ItemInfo{}, nil, clues.Wrap(err, "onedrive reader")
}
rc = resp.Body
@ -131,38 +130,26 @@ func oneDriveItemReader(
return dii, rc, nil
}
func downloadItem(ctx context.Context, hc *http.Client, item models.DriveItemable) (*http.Response, error) {
func downloadItem(
ctx context.Context,
client graph.Requester,
item models.DriveItemable,
) (*http.Response, error) {
url, ok := item.GetAdditionalData()[downloadURLKey].(*string)
if !ok {
return nil, clues.New("extracting file url").With("item_id", ptr.Val(item.GetId()))
}
req, err := http.NewRequest(http.MethodGet, *url, nil)
resp, err := client.Request(ctx, http.MethodGet, ptr.Val(url), nil, nil)
if err != nil {
return nil, graph.Wrap(ctx, err, "new item download request")
}
//nolint:lll
// Decorate the traffic
// See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic
req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version)
resp, err := hc.Do(req)
if err != nil {
cerr := graph.Wrap(ctx, err, "downloading item")
if graph.IsMalware(err) {
cerr = cerr.Label(graph.LabelsMalware)
}
return nil, cerr
return nil, err
}
if (resp.StatusCode / 100) == 2 {
return resp, nil
}
if graph.IsMalwareResp(context.Background(), resp) {
if graph.IsMalwareResp(ctx, resp) {
return nil, clues.New("malware detected").Label(graph.LabelsMalware)
}

View File

@ -112,7 +112,7 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() {
)
// Read data for the file
itemInfo, itemData, err := oneDriveItemReader(ctx, graph.HTTPClient(graph.NoTimeout()), driveItem)
itemInfo, itemData, err := oneDriveItemReader(ctx, graph.NewNoTimeoutHTTPWrapper(), driveItem)
require.NoError(suite.T(), err, clues.ToCore(err))
require.NotNil(suite.T(), itemInfo.OneDrive)

View File

@ -2,7 +2,6 @@ package sharepoint
import (
"context"
"net/http"
"github.com/alcionai/clues"
@ -29,7 +28,7 @@ type statusUpdater interface {
// for the specified user
func DataCollections(
ctx context.Context,
itemClient *http.Client,
itemClient graph.Requester,
selector selectors.Selector,
creds account.M365Config,
serv graph.Servicer,
@ -182,7 +181,7 @@ func collectLists(
// all the drives associated with the site.
func collectLibraries(
ctx context.Context,
itemClient *http.Client,
itemClient graph.Requester,
serv graph.Servicer,
tenantID, siteID string,
scope selectors.SharePointScope,

View File

@ -109,7 +109,7 @@ func (suite *SharePointLibrariesUnitSuite) TestUpdateCollections() {
)
c := onedrive.NewCollections(
graph.HTTPClient(graph.NoTimeout()),
graph.NewNoTimeoutHTTPWrapper(),
tenant,
site,
onedrive.SharePointSource,