custom http wrapper for item downloads (#3132)

the http client customized by the graph client for use in downloading files doesn't include our middleware like expected. This introduces a new http client wrapper that populates the roundtripper with a middleware wrapper that can utilize our kiota middleware in the same way as the graph client does.

---

#### Does this PR need a docs update or release note?

- [x]  No

#### Type of change

- [x] 🌻 Feature
- [x] 🐛 Bugfix
- [x] 🧹 Tech Debt/Cleanup

#### Issue(s)

* #3129

#### Test Plan

- [x] 💪 Manual
- [x]  Unit test
- [x] 💚 E2E
This commit is contained in:
Keepers 2023-04-20 17:20:23 -06:00 committed by GitHub
parent 315c0cc5f3
commit d5fac8a480
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 481 additions and 195 deletions

View File

@ -77,7 +77,10 @@ func handleOneDriveCmd(cmd *cobra.Command, args []string) error {
return Only(ctx, clues.Wrap(err, "creating graph adapter")) return Only(ctx, clues.Wrap(err, "creating graph adapter"))
} }
err = runDisplayM365JSON(ctx, graph.NewService(adpt), creds, user, m365ID) svc := graph.NewService(adpt)
gr := graph.NewNoTimeoutHTTPWrapper()
err = runDisplayM365JSON(ctx, svc, gr, creds, user, m365ID)
if err != nil { if err != nil {
cmd.SilenceUsage = true cmd.SilenceUsage = true
cmd.SilenceErrors = true cmd.SilenceErrors = true
@ -105,6 +108,7 @@ func (i itemPrintable) MinimumPrintable() any {
func runDisplayM365JSON( func runDisplayM365JSON(
ctx context.Context, ctx context.Context,
srv graph.Servicer, srv graph.Servicer,
gr graph.Requester,
creds account.M365Config, creds account.M365Config,
user, itemID string, user, itemID string,
) error { ) error {
@ -123,7 +127,7 @@ func runDisplayM365JSON(
} }
if item != nil { if item != nil {
content, err := getDriveItemContent(item) content, err := getDriveItemContent(ctx, gr, item)
if err != nil { if err != nil {
return err return err
} }
@ -180,22 +184,19 @@ func serializeObject(data serialization.Parsable) (string, error) {
return string(content), err return string(content), err
} }
func getDriveItemContent(item models.DriveItemable) ([]byte, error) { func getDriveItemContent(
ctx context.Context,
gr graph.Requester,
item models.DriveItemable,
) ([]byte, error) {
url, ok := item.GetAdditionalData()[downloadURLKey].(*string) url, ok := item.GetAdditionalData()[downloadURLKey].(*string)
if !ok { if !ok {
return nil, clues.New("get download url") return nil, clues.New("retrieving download url")
} }
req, err := http.NewRequest(http.MethodGet, *url, nil) resp, err := gr.Request(ctx, http.MethodGet, *url, nil, nil)
if err != nil { if err != nil {
return nil, clues.New("create download request").With("error", err) return nil, clues.New("downloading item").With("error", err)
}
hc := graph.HTTPClient(graph.NoTimeout())
resp, err := hc.Do(req)
if err != nil {
return nil, clues.New("download item").With("error", err)
} }
content, err := io.ReadAll(resp.Body) content, err := io.ReadAll(resp.Body)

View File

@ -258,7 +258,7 @@ func (suite *DataCollectionIntgSuite) TestSharePointDataCollection() {
collections, excludes, err := sharepoint.DataCollections( collections, excludes, err := sharepoint.DataCollections(
ctx, ctx,
graph.HTTPClient(graph.NoTimeout()), graph.NewNoTimeoutHTTPWrapper(),
sel, sel,
connector.credentials, connector.credentials,
connector.Service, connector.Service,

View File

@ -1,36 +1,21 @@
package mock package mock
import ( import (
"github.com/alcionai/clues"
"github.com/alcionai/corso/src/internal/connector/exchange/api" "github.com/alcionai/corso/src/internal/connector/exchange/api"
"github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/graph"
"github.com/alcionai/corso/src/internal/connector/graph/mock" "github.com/alcionai/corso/src/internal/connector/graph/mock"
"github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/account"
) )
func NewService(creds account.M365Config, opts ...graph.Option) (*graph.Service, error) {
a, err := mock.CreateAdapter(
creds.AzureTenantID,
creds.AzureClientID,
creds.AzureClientSecret,
opts...)
if err != nil {
return nil, clues.Wrap(err, "generating graph adapter")
}
return graph.NewService(a), nil
}
// NewClient produces a new exchange api client that can be // NewClient produces a new exchange api client that can be
// mocked using gock. // mocked using gock.
func NewClient(creds account.M365Config) (api.Client, error) { func NewClient(creds account.M365Config) (api.Client, error) {
s, err := NewService(creds) s, err := mock.NewService(creds)
if err != nil { if err != nil {
return api.Client{}, err return api.Client{}, err
} }
li, err := NewService(creds, graph.NoTimeout()) li, err := mock.NewService(creds, graph.NoTimeout())
if err != nil { if err != nil {
return api.Client{}, err return api.Client{}, err
} }

View File

@ -234,6 +234,9 @@ func Stack(ctx context.Context, e error) *clues.Err {
return setLabels(clues.Stack(e).WithClues(ctx).With(data...), innerMsg) return setLabels(clues.Stack(e).WithClues(ctx).With(data...), innerMsg)
} }
// Checks for the following conditions and labels the error accordingly:
// * mysiteNotFound | mysiteURLNotFound
// * malware
func setLabels(err *clues.Err, msg string) *clues.Err { func setLabels(err *clues.Err, msg string) *clues.Err {
if err == nil { if err == nil {
return nil return nil
@ -244,6 +247,10 @@ func setLabels(err *clues.Err, msg string) *clues.Err {
err = err.Label(LabelsMysiteNotFound) err = err.Label(LabelsMysiteNotFound)
} }
if IsMalware(err) {
err = err.Label(LabelsMalware)
}
return err return err
} }

View File

@ -0,0 +1,152 @@
package graph
import (
"context"
"io"
"net/http"
"github.com/alcionai/clues"
khttp "github.com/microsoft/kiota-http-go"
"github.com/alcionai/corso/src/internal/version"
)
// ---------------------------------------------------------------------------
// constructors
// ---------------------------------------------------------------------------
type Requester interface {
Request(
ctx context.Context,
method, url string,
body io.Reader,
headers map[string]string,
) (*http.Response, error)
}
// NewHTTPWrapper produces a http.Client wrapper that ensures
// calls use all the middleware we expect from the graph api client.
//
// Re-use of http clients is critical, or else we leak OS resources
// and consume relatively unbound socket connections. It is important
// to centralize this client to be passed downstream where api calls
// can utilize it on a per-download basis.
func NewHTTPWrapper(opts ...Option) *httpWrapper {
var (
cc = populateConfig(opts...)
rt = customTransport{
n: pipeline{
middlewares: internalMiddleware(cc),
transport: defaultTransport(),
},
}
redirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
hc = &http.Client{
CheckRedirect: redirect,
Timeout: defaultHTTPClientTimeout,
Transport: rt,
}
)
cc.apply(hc)
return &httpWrapper{hc}
}
// NewNoTimeoutHTTPWrapper constructs a http wrapper with no context timeout.
//
// Re-use of http clients is critical, or else we leak OS resources
// and consume relatively unbound socket connections. It is important
// to centralize this client to be passed downstream where api calls
// can utilize it on a per-download basis.
func NewNoTimeoutHTTPWrapper(opts ...Option) *httpWrapper {
opts = append(opts, NoTimeout())
return NewHTTPWrapper(opts...)
}
// ---------------------------------------------------------------------------
// requests
// ---------------------------------------------------------------------------
// Request does the provided request.
func (hw httpWrapper) Request(
ctx context.Context,
method, url string,
body io.Reader,
headers map[string]string,
) (*http.Response, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, clues.Wrap(err, "new http request")
}
for k, v := range headers {
req.Header.Set(k, v)
}
//nolint:lll
// Decorate the traffic
// See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic
req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version)
resp, err := hw.client.Do(req)
if err != nil {
return nil, Stack(ctx, err)
}
return resp, nil
}
// ---------------------------------------------------------------------------
// constructor internals
// ---------------------------------------------------------------------------
type (
httpWrapper struct {
client *http.Client
}
customTransport struct {
n nexter
}
pipeline struct {
transport http.RoundTripper
middlewares []khttp.Middleware
}
)
// RoundTrip kicks off the middleware chain and returns a response
func (ct customTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return ct.n.Next(req, 0)
}
// Next moves the request object through middlewares in the pipeline
func (pl pipeline) Next(req *http.Request, idx int) (*http.Response, error) {
if idx < len(pl.middlewares) {
return pl.middlewares[idx].Intercept(pl, idx+1, req)
}
return pl.transport.RoundTrip(req)
}
func defaultTransport() http.RoundTripper {
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
defaultTransport.ForceAttemptHTTP2 = true
return defaultTransport
}
func internalMiddleware(cc *clientConfig) []khttp.Middleware {
return []khttp.Middleware{
&RetryHandler{
MaxRetries: cc.maxRetries,
Delay: cc.minDelay,
},
&LoggingMiddleware{},
&ThrottleControlMiddleware{},
&MetricsMiddleware{},
}
}

View File

@ -0,0 +1,45 @@
package graph
import (
"net/http"
"testing"
"github.com/alcionai/clues"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/alcionai/corso/src/internal/tester"
)
type HTTPWrapperIntgSuite struct {
tester.Suite
}
func TestHTTPWrapperIntgSuite(t *testing.T) {
suite.Run(t, &HTTPWrapperIntgSuite{
Suite: tester.NewIntegrationSuite(
t,
[][]string{tester.M365AcctCredEnvs}),
})
}
func (suite *HTTPWrapperIntgSuite) TestNewHTTPWrapper() {
ctx, flush := tester.NewContext()
defer flush()
var (
t = suite.T()
hw = NewHTTPWrapper()
)
resp, err := hw.Request(
ctx,
http.MethodGet,
"https://www.corsobackup.io",
nil,
nil)
require.NoError(t, err, clues.ToCore(err))
require.NotNil(t, resp)
require.Equal(t, http.StatusOK, resp.StatusCode)
}

View File

@ -20,6 +20,10 @@ import (
"github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/logger"
) )
type nexter interface {
Next(req *http.Request, middlewareIndex int) (*http.Response, error)
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Logging // Logging
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------

View File

@ -1,12 +1,27 @@
package mock package mock
import ( import (
"github.com/alcionai/clues"
"github.com/h2non/gock" "github.com/h2non/gock"
msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go" msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go"
"github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/graph"
"github.com/alcionai/corso/src/pkg/account"
) )
func NewService(creds account.M365Config, opts ...graph.Option) (*graph.Service, error) {
a, err := CreateAdapter(
creds.AzureTenantID,
creds.AzureClientID,
creds.AzureClientSecret,
opts...)
if err != nil {
return nil, clues.Wrap(err, "generating graph adapter")
}
return graph.NewService(a), nil
}
// CreateAdapter is similar to graph.CreateAdapter, but with option to // CreateAdapter is similar to graph.CreateAdapter, but with option to
// enable interceptions via gock to make it mockable. // enable interceptions via gock to make it mockable.
func CreateAdapter( func CreateAdapter(
@ -18,7 +33,7 @@ func CreateAdapter(
return nil, err return nil, err
} }
httpClient := graph.HTTPClient(opts...) httpClient := graph.KiotaHTTPClient(opts...)
// This makes sure that we are able to intercept any requests via // This makes sure that we are able to intercept any requests via
// gock. Only necessary for testing. // gock. Only necessary for testing.

View File

@ -114,7 +114,7 @@ func CreateAdapter(
return nil, err return nil, err
} }
httpClient := HTTPClient(opts...) httpClient := KiotaHTTPClient(opts...)
return msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient( return msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(
auth, auth,
@ -140,21 +140,24 @@ func GetAuth(tenant string, client string, secret string) (*kauth.AzureIdentityA
return auth, nil return auth, nil
} }
// HTTPClient creates the httpClient with middlewares and timeout configured // KiotaHTTPClient creates a httpClient with middlewares and timeout configured
// for use in the graph adapter.
// //
// Re-use of http clients is critical, or else we leak OS resources // Re-use of http clients is critical, or else we leak OS resources
// and consume relatively unbound socket connections. It is important // and consume relatively unbound socket connections. It is important
// to centralize this client to be passed downstream where api calls // to centralize this client to be passed downstream where api calls
// can utilize it on a per-download basis. // can utilize it on a per-download basis.
func HTTPClient(opts ...Option) *http.Client { func KiotaHTTPClient(opts ...Option) *http.Client {
clientOptions := msgraphsdkgo.GetDefaultClientOptions() var (
clientconfig := (&clientConfig{}).populate(opts...) clientOptions = msgraphsdkgo.GetDefaultClientOptions()
noOfRetries, minRetryDelay := clientconfig.applyMiddlewareConfig() cc = populateConfig(opts...)
middlewares := GetKiotaMiddlewares(&clientOptions, noOfRetries, minRetryDelay) middlewares = kiotaMiddlewares(&clientOptions, cc)
httpClient := msgraphgocore.GetDefaultClient(&clientOptions, middlewares...) httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
)
httpClient.Timeout = defaultHTTPClientTimeout httpClient.Timeout = defaultHTTPClientTimeout
clientconfig.apply(httpClient) cc.apply(httpClient)
return httpClient return httpClient
} }
@ -175,27 +178,17 @@ type clientConfig struct {
type Option func(*clientConfig) type Option func(*clientConfig)
// populate constructs a clientConfig according to the provided options. // populate constructs a clientConfig according to the provided options.
func (c *clientConfig) populate(opts ...Option) *clientConfig { func populateConfig(opts ...Option) *clientConfig {
cc := clientConfig{
maxRetries: defaultMaxRetries,
minDelay: defaultDelay,
}
for _, opt := range opts { for _, opt := range opts {
opt(c) opt(&cc)
} }
return c return &cc
}
// apply updates the http.Client with the expected options.
func (c *clientConfig) applyMiddlewareConfig() (retry int, delay time.Duration) {
retry = defaultMaxRetries
if c.overrideRetryCount {
retry = c.maxRetries
}
delay = defaultDelay
if c.minDelay > 0 {
delay = c.minDelay
}
return
} }
// apply updates the http.Client with the expected options. // apply updates the http.Client with the expected options.
@ -236,14 +229,16 @@ func MinimumBackoff(dur time.Duration) Option {
// Middleware Control // Middleware Control
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// GetDefaultMiddlewares creates a new default set of middlewares for the Kiota request adapter // kiotaMiddlewares creates a default slice of middleware for the Graph Client.
func GetMiddlewares(maxRetry int, delay time.Duration) []khttp.Middleware { func kiotaMiddlewares(
options *msgraphgocore.GraphClientOptions,
cc *clientConfig,
) []khttp.Middleware {
return []khttp.Middleware{ return []khttp.Middleware{
msgraphgocore.NewGraphTelemetryHandler(options),
&RetryHandler{ &RetryHandler{
// The maximum number of times a request can be retried MaxRetries: cc.maxRetries,
MaxRetries: maxRetry, Delay: cc.minDelay,
// The delay in seconds between retries
Delay: delay,
}, },
khttp.NewRetryHandler(), khttp.NewRetryHandler(),
khttp.NewRedirectHandler(), khttp.NewRedirectHandler(),
@ -255,21 +250,3 @@ func GetMiddlewares(maxRetry int, delay time.Duration) []khttp.Middleware {
&MetricsMiddleware{}, &MetricsMiddleware{},
} }
} }
// GetKiotaMiddlewares creates a default slice of middleware for the Graph Client.
func GetKiotaMiddlewares(
options *msgraphgocore.GraphClientOptions,
maxRetry int,
minDelay time.Duration,
) []khttp.Middleware {
kiotaMiddlewares := GetMiddlewares(maxRetry, minDelay)
graphMiddlewares := []khttp.Middleware{
msgraphgocore.NewGraphTelemetryHandler(options),
}
graphMiddlewaresLen := len(graphMiddlewares)
resultMiddlewares := make([]khttp.Middleware, len(kiotaMiddlewares)+graphMiddlewaresLen)
copy(resultMiddlewares, graphMiddlewares)
copy(resultMiddlewares[graphMiddlewaresLen:], kiotaMiddlewares)
return resultMiddlewares
}

View File

@ -70,7 +70,7 @@ func (suite *GraphUnitSuite) TestHTTPClient() {
suite.Run(test.name, func() { suite.Run(test.name, func() {
t := suite.T() t := suite.T()
cli := HTTPClient(test.opts...) cli := KiotaHTTPClient(test.opts...)
assert.NotNil(t, cli) assert.NotNil(t, cli)
test.check(t, cli) test.check(t, cli)
}) })

View File

@ -4,7 +4,6 @@ package connector
import ( import (
"context" "context"
"net/http"
"runtime/trace" "runtime/trace"
"sync" "sync"
@ -36,7 +35,7 @@ var (
type GraphConnector struct { type GraphConnector struct {
Service graph.Servicer Service graph.Servicer
Discovery api.Client Discovery api.Client
itemClient *http.Client // configured to handle large item downloads itemClient graph.Requester // configured to handle large item downloads
tenant string tenant string
credentials account.M365Config credentials account.M365Config
@ -88,7 +87,7 @@ func NewGraphConnector(
Service: service, Service: service,
credentials: creds, credentials: creds,
itemClient: graph.HTTPClient(graph.NoTimeout()), itemClient: graph.NewNoTimeoutHTTPWrapper(),
ownerLookup: rc, ownerLookup: rc,
tenant: acct.ID(), tenant: acct.ID(),
wg: &sync.WaitGroup{}, wg: &sync.WaitGroup{},

View File

@ -35,10 +35,6 @@ const (
// TODO: Tune this later along with collectionChannelBufferSize // TODO: Tune this later along with collectionChannelBufferSize
urlPrefetchChannelBufferSize = 5 urlPrefetchChannelBufferSize = 5
// maxDownloadRetires specifies the number of times a file download should
// be retried
maxDownloadRetires = 3
// Used to compare in case of OneNote files // Used to compare in case of OneNote files
MaxOneNoteFileSize = 2 * 1024 * 1024 * 1024 MaxOneNoteFileSize = 2 * 1024 * 1024 * 1024
) )
@ -62,7 +58,7 @@ const (
// Collection represents a set of OneDrive objects retrieved from M365 // Collection represents a set of OneDrive objects retrieved from M365
type Collection struct { type Collection struct {
// configured to handle large item downloads // configured to handle large item downloads
itemClient *http.Client itemClient graph.Requester
// data is used to share data streams with the collection consumer // data is used to share data streams with the collection consumer
data chan data.Stream data chan data.Stream
@ -110,7 +106,7 @@ type Collection struct {
doNotMergeItems bool doNotMergeItems bool
} }
// itemGetterFunc gets an specified item // itemGetterFunc gets a specified item
type itemGetterFunc func( type itemGetterFunc func(
ctx context.Context, ctx context.Context,
srv graph.Servicer, srv graph.Servicer,
@ -120,7 +116,7 @@ type itemGetterFunc func(
// itemReadFunc returns a reader for the specified item // itemReadFunc returns a reader for the specified item
type itemReaderFunc func( type itemReaderFunc func(
ctx context.Context, ctx context.Context,
hc *http.Client, client graph.Requester,
item models.DriveItemable, item models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) ) (details.ItemInfo, io.ReadCloser, error)
@ -148,7 +144,7 @@ func pathToLocation(p path.Path) (*path.Builder, error) {
// NewCollection creates a Collection // NewCollection creates a Collection
func NewCollection( func NewCollection(
itemClient *http.Client, itemClient graph.Requester,
folderPath path.Path, folderPath path.Path,
prevPath path.Path, prevPath path.Path,
driveID string, driveID string,
@ -372,45 +368,29 @@ func (oc *Collection) getDriveItemContent(
itemID = ptr.Val(item.GetId()) itemID = ptr.Val(item.GetId())
itemName = ptr.Val(item.GetName()) itemName = ptr.Val(item.GetName())
el = errs.Local() el = errs.Local()
itemData io.ReadCloser
err error
) )
// Initial try with url from delta + 2 retries itemData, err := downloadContent(
for i := 1; i <= maxDownloadRetires; i++ { ctx,
_, itemData, err = oc.itemReader(ctx, oc.itemClient, item) oc.service,
if err == nil || !graph.IsErrUnauthorized(err) { oc.itemGetter,
break oc.itemReader,
} oc.itemClient,
item,
// Assume unauthorized requests are a sign of an expired jwt oc.driveID)
// token, and that we've overrun the available window to
// download the actual file. Re-downloading the item will
// refresh that download url.
di, diErr := oc.itemGetter(ctx, oc.service, oc.driveID, itemID)
if diErr != nil {
err = clues.Wrap(diErr, "retrieving expired item")
break
}
item = di
}
// check for errors following retries
if err != nil { if err != nil {
if clues.HasLabel(err, graph.LabelsMalware) || (item != nil && item.GetMalware() != nil) { if clues.HasLabel(err, graph.LabelsMalware) || (item != nil && item.GetMalware() != nil) {
logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipMalware).Info("item flagged as malware") logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipMalware).Info("item flagged as malware")
el.AddSkip(fault.FileSkip(fault.SkipMalware, itemID, itemName, graph.ItemInfo(item))) el.AddSkip(fault.FileSkip(fault.SkipMalware, itemID, itemName, graph.ItemInfo(item)))
return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable) return nil, clues.Wrap(err, "malware item").Label(graph.LabelsSkippable)
} }
if clues.HasLabel(err, graph.LabelStatus(http.StatusNotFound)) || graph.IsErrDeletedInFlight(err) { if clues.HasLabel(err, graph.LabelStatus(http.StatusNotFound)) || graph.IsErrDeletedInFlight(err) {
logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipNotFound).Info("item not found") logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipNotFound).Info("item not found")
el.AddSkip(fault.FileSkip(fault.SkipNotFound, itemID, itemName, graph.ItemInfo(item))) el.AddSkip(fault.FileSkip(fault.SkipNotFound, itemID, itemName, graph.ItemInfo(item)))
return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable) return nil, clues.Wrap(err, "deleted item").Label(graph.LabelsSkippable)
} }
// Skip big OneNote files as they can't be downloaded // Skip big OneNote files as they can't be downloaded
@ -425,7 +405,7 @@ func (oc *Collection) getDriveItemContent(
logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipBigOneNote).Info("max OneNote file size exceeded") logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipBigOneNote).Info("max OneNote file size exceeded")
el.AddSkip(fault.FileSkip(fault.SkipBigOneNote, itemID, itemName, graph.ItemInfo(item))) el.AddSkip(fault.FileSkip(fault.SkipBigOneNote, itemID, itemName, graph.ItemInfo(item)))
return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable) return nil, clues.Wrap(err, "max oneNote item").Label(graph.LabelsSkippable)
} }
logger.CtxErr(ctx, err).Error("downloading item") logger.CtxErr(ctx, err).Error("downloading item")
@ -433,12 +413,48 @@ func (oc *Collection) getDriveItemContent(
// return err, not el.Err(), because the lazy reader needs to communicate to // return err, not el.Err(), because the lazy reader needs to communicate to
// the data consumer that this item is unreadable, regardless of the fault state. // the data consumer that this item is unreadable, regardless of the fault state.
return nil, clues.Wrap(err, "downloading item") return nil, clues.Wrap(err, "fetching item content")
} }
return itemData, nil return itemData, nil
} }
// downloadContent attempts to fetch the item content. If the content url
// is expired (ie, returns a 401), it re-fetches the item to get a new download
// url and tries again.
func downloadContent(
ctx context.Context,
svc graph.Servicer,
igf itemGetterFunc,
irf itemReaderFunc,
gr graph.Requester,
item models.DriveItemable,
driveID string,
) (io.ReadCloser, error) {
_, content, err := irf(ctx, gr, item)
if err == nil {
return content, nil
} else if !graph.IsErrUnauthorized(err) {
return nil, err
}
// Assume unauthorized requests are a sign of an expired jwt
// token, and that we've overrun the available window to
// download the actual file. Re-downloading the item will
// refresh that download url.
di, err := igf(ctx, svc, driveID, ptr.Val(item.GetId()))
if err != nil {
return nil, clues.Wrap(err, "retrieving expired item")
}
_, content, err = irf(ctx, gr, di)
if err != nil {
return nil, clues.Wrap(err, "content download retry")
}
return content, nil
}
// populateItems iterates through items added to the collection // populateItems iterates through items added to the collection
// and uses the collection `itemReader` to read the item // and uses the collection `itemReader` to read the item
func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) { func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
@ -504,9 +520,9 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
ctx = clues.Add( ctx = clues.Add(
ctx, ctx,
"backup_item_id", itemID, "item_id", itemID,
"backup_item_name", itemName, "item_name", itemName,
"backup_item_size", itemSize) "item_size", itemSize)
item.SetParentReference(setName(item.GetParentReference(), oc.driveName)) item.SetParentReference(setName(item.GetParentReference(), oc.driveName))
@ -545,7 +561,7 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
itemInfo.OneDrive.ParentPath = parentPathString itemInfo.OneDrive.ParentPath = parentPathString
} }
ctx = clues.Add(ctx, "backup_item_info", itemInfo) ctx = clues.Add(ctx, "item_info", itemInfo)
if isFile { if isFile {
dataSuffix := metadata.DataFileSuffix dataSuffix := metadata.DataFileSuffix

View File

@ -18,6 +18,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/alcionai/corso/src/internal/common/ptr"
"github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/graph"
"github.com/alcionai/corso/src/internal/connector/onedrive/metadata" "github.com/alcionai/corso/src/internal/connector/onedrive/metadata"
"github.com/alcionai/corso/src/internal/connector/support" "github.com/alcionai/corso/src/internal/connector/support"
@ -98,7 +99,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
numInstances: 1, numInstances: 1,
source: OneDriveSource, source: OneDriveSource,
itemDeets: nst{testItemName, 42, now}, itemDeets: nst{testItemName, 42, now},
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: testItemName, Modified: now}}, return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: testItemName, Modified: now}},
io.NopCloser(bytes.NewReader(testItemData)), io.NopCloser(bytes.NewReader(testItemData)),
nil nil
@ -114,7 +115,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
numInstances: 3, numInstances: 3,
source: OneDriveSource, source: OneDriveSource,
itemDeets: nst{testItemName, 42, now}, itemDeets: nst{testItemName, 42, now},
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: testItemName, Modified: now}}, return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: testItemName, Modified: now}},
io.NopCloser(bytes.NewReader(testItemData)), io.NopCloser(bytes.NewReader(testItemData)),
nil nil
@ -130,7 +131,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
numInstances: 3, numInstances: 3,
source: OneDriveSource, source: OneDriveSource,
itemDeets: nst{testItemName, 42, now}, itemDeets: nst{testItemName, 42, now},
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, nil, clues.New("test malware").Label(graph.LabelsMalware) return details.ItemInfo{}, nil, clues.New("test malware").Label(graph.LabelsMalware)
}, },
infoFrom: func(t *testing.T, dii details.ItemInfo) (string, string) { infoFrom: func(t *testing.T, dii details.ItemInfo) (string, string) {
@ -146,7 +147,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
source: OneDriveSource, source: OneDriveSource,
itemDeets: nst{testItemName, 42, now}, itemDeets: nst{testItemName, 42, now},
// Usually `Not Found` is returned from itemGetter and not itemReader // Usually `Not Found` is returned from itemGetter and not itemReader
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, nil, clues.New("test not found").Label(graph.LabelStatus(http.StatusNotFound)) return details.ItemInfo{}, nil, clues.New("test not found").Label(graph.LabelStatus(http.StatusNotFound))
}, },
infoFrom: func(t *testing.T, dii details.ItemInfo) (string, string) { infoFrom: func(t *testing.T, dii details.ItemInfo) (string, string) {
@ -161,7 +162,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
numInstances: 1, numInstances: 1,
source: SharePointSource, source: SharePointSource,
itemDeets: nst{testItemName, 42, now}, itemDeets: nst{testItemName, 42, now},
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{SharePoint: &details.SharePointInfo{ItemName: testItemName, Modified: now}}, return details.ItemInfo{SharePoint: &details.SharePointInfo{ItemName: testItemName, Modified: now}},
io.NopCloser(bytes.NewReader(testItemData)), io.NopCloser(bytes.NewReader(testItemData)),
nil nil
@ -177,7 +178,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
numInstances: 3, numInstances: 3,
source: SharePointSource, source: SharePointSource,
itemDeets: nst{testItemName, 42, now}, itemDeets: nst{testItemName, 42, now},
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{SharePoint: &details.SharePointInfo{ItemName: testItemName, Modified: now}}, return details.ItemInfo{SharePoint: &details.SharePointInfo{ItemName: testItemName, Modified: now}},
io.NopCloser(bytes.NewReader(testItemData)), io.NopCloser(bytes.NewReader(testItemData)),
nil nil
@ -207,7 +208,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
coll, err := NewCollection( coll, err := NewCollection(
graph.HTTPClient(graph.NoTimeout()), graph.NewNoTimeoutHTTPWrapper(),
folderPath, folderPath,
nil, nil,
"drive-id", "drive-id",
@ -278,7 +279,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
if err != nil { if err != nil {
for _, label := range test.expectLabels { for _, label := range test.expectLabels {
assert.True(t, clues.HasLabel(err, label), "has clues label:", label) assert.Truef(t, clues.HasLabel(err, label), "has clues label: %s", label)
} }
return return
@ -347,7 +348,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadError() {
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
coll, err := NewCollection( coll, err := NewCollection(
graph.HTTPClient(graph.NoTimeout()), graph.NewNoTimeoutHTTPWrapper(),
folderPath, folderPath,
nil, nil,
"fakeDriveID", "fakeDriveID",
@ -370,7 +371,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadError() {
coll.itemReader = func( coll.itemReader = func(
context.Context, context.Context,
*http.Client, graph.Requester,
models.DriveItemable, models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) { ) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, nil, assert.AnError return details.ItemInfo{}, nil, assert.AnError
@ -437,7 +438,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry()
require.NoError(t, err) require.NoError(t, err)
coll, err := NewCollection( coll, err := NewCollection(
graph.HTTPClient(graph.NoTimeout()), graph.NewNoTimeoutHTTPWrapper(),
folderPath, folderPath,
nil, nil,
"fakeDriveID", "fakeDriveID",
@ -470,10 +471,10 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry()
coll.itemReader = func( coll.itemReader = func(
context.Context, context.Context,
*http.Client, graph.Requester,
models.DriveItemable, models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) { ) (details.ItemInfo, io.ReadCloser, error) {
if count < 2 { if count < 1 {
count++ count++
return details.ItemInfo{}, nil, clues.Stack(assert.AnError). return details.ItemInfo{}, nil, clues.Stack(assert.AnError).
Label(graph.LabelStatus(http.StatusUnauthorized)) Label(graph.LabelStatus(http.StatusUnauthorized))
@ -494,13 +495,13 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry()
assert.True(t, ok) assert.True(t, ok)
_, err = io.ReadAll(collItem.ToReader()) _, err = io.ReadAll(collItem.ToReader())
assert.NoError(t, err) assert.NoError(t, err, clues.ToCore(err))
wg.Wait() wg.Wait()
require.Equal(t, 1, collStatus.Metrics.Objects, "only one object should be counted") require.Equal(t, 1, collStatus.Metrics.Objects, "only one object should be counted")
require.Equal(t, 1, collStatus.Metrics.Successes, "read object successfully") require.Equal(t, 1, collStatus.Metrics.Successes, "read object successfully")
require.Equal(t, 2, count, "retry count") require.Equal(t, 1, count, "retry count")
}) })
} }
} }
@ -537,7 +538,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionPermissionBackupLatestModTim
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
coll, err := NewCollection( coll, err := NewCollection(
graph.HTTPClient(graph.NoTimeout()), graph.NewNoTimeoutHTTPWrapper(),
folderPath, folderPath,
nil, nil,
"drive-id", "drive-id",
@ -561,7 +562,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionPermissionBackupLatestModTim
coll.itemReader = func( coll.itemReader = func(
context.Context, context.Context,
*http.Client, graph.Requester,
models.DriveItemable, models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) { ) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: "fakeName", Modified: time.Now()}}, return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: "fakeName", Modified: time.Now()}},
@ -611,7 +612,7 @@ func TestGetDriveItemUnitTestSuite(t *testing.T) {
suite.Run(t, &GetDriveItemUnitTestSuite{Suite: tester.NewUnitSuite(t)}) suite.Run(t, &GetDriveItemUnitTestSuite{Suite: tester.NewUnitSuite(t)})
} }
func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() { func (suite *GetDriveItemUnitTestSuite) TestGetDriveItem_error() {
strval := "not-important" strval := "not-important"
table := []struct { table := []struct {
@ -637,14 +638,14 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
name: "malware error", name: "malware error",
colScope: CollectionScopeFolder, colScope: CollectionScopeFolder,
itemSize: 10, itemSize: 10,
err: clues.New("test error").Label(graph.LabelsMalware), err: clues.New("malware error").Label(graph.LabelsMalware),
labels: []string{graph.LabelsMalware, graph.LabelsSkippable}, labels: []string{graph.LabelsMalware, graph.LabelsSkippable},
}, },
{ {
name: "file not found error", name: "file not found error",
colScope: CollectionScopeFolder, colScope: CollectionScopeFolder,
itemSize: 10, itemSize: 10,
err: clues.New("test error").Label(graph.LabelStatus(http.StatusNotFound)), err: clues.New("not found error").Label(graph.LabelStatus(http.StatusNotFound)),
labels: []string{graph.LabelStatus(http.StatusNotFound), graph.LabelsSkippable}, labels: []string{graph.LabelStatus(http.StatusNotFound), graph.LabelsSkippable},
}, },
{ {
@ -652,14 +653,14 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
name: "small OneNote file", name: "small OneNote file",
colScope: CollectionScopePackage, colScope: CollectionScopePackage,
itemSize: 10, itemSize: 10,
err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)), err: clues.New("small onenote error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
labels: []string{graph.LabelStatus(http.StatusServiceUnavailable)}, labels: []string{graph.LabelStatus(http.StatusServiceUnavailable)},
}, },
{ {
name: "big OneNote file", name: "big OneNote file",
colScope: CollectionScopePackage, colScope: CollectionScopePackage,
itemSize: MaxOneNoteFileSize, itemSize: MaxOneNoteFileSize,
err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)), err: clues.New("big onenote error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
labels: []string{graph.LabelStatus(http.StatusServiceUnavailable), graph.LabelsSkippable}, labels: []string{graph.LabelStatus(http.StatusServiceUnavailable), graph.LabelsSkippable},
}, },
{ {
@ -667,7 +668,7 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
name: "big file", name: "big file",
colScope: CollectionScopeFolder, colScope: CollectionScopeFolder,
itemSize: MaxOneNoteFileSize, itemSize: MaxOneNoteFileSize,
err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)), err: clues.New("big file error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
labels: []string{graph.LabelStatus(http.StatusServiceUnavailable)}, labels: []string{graph.LabelStatus(http.StatusServiceUnavailable)},
}, },
} }
@ -689,9 +690,9 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
item.SetSize(&test.itemSize) item.SetSize(&test.itemSize)
col.itemReader = func( col.itemReader = func(
ctx context.Context, _ context.Context,
hc *http.Client, _ graph.Requester,
item models.DriveItemable, _ models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) { ) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, nil, test.err return details.ItemInfo{}, nil, test.err
} }
@ -707,11 +708,11 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
_, err := col.getDriveItemContent(ctx, item, errs) _, err := col.getDriveItemContent(ctx, item, errs)
if test.err == nil { if test.err == nil {
assert.NoError(t, err, "no error") assert.NoError(t, err, clues.ToCore(err))
return return
} }
assert.EqualError(t, err, clues.Wrap(test.err, "downloading item").Error(), "error") assert.ErrorIs(t, err, test.err, clues.ToCore(err))
labelsMap := map[string]struct{}{} labelsMap := map[string]struct{}{}
for _, l := range test.labels { for _, l := range test.labels {
@ -722,3 +723,103 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
}) })
} }
} }
func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() {
var (
svc graph.Servicer
gr graph.Requester
driveID string
iorc = io.NopCloser(bytes.NewReader([]byte("fnords")))
item = &models.DriveItem{}
itemWID = &models.DriveItem{}
)
itemWID.SetId(ptr.To("brainhooldy"))
table := []struct {
name string
igf itemGetterFunc
irf itemReaderFunc
expectErr require.ErrorAssertionFunc
expect require.ValueAssertionFunc
}{
{
name: "good",
irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, iorc, nil
},
expectErr: require.NoError,
expect: require.NotNil,
},
{
name: "expired url redownloads",
igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) {
return itemWID, nil
},
irf: func(c context.Context, g graph.Requester, m models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
// a bit hacky: assume only igf returns an item with a non-zero id.
if len(ptr.Val(m.GetId())) == 0 {
return details.ItemInfo{},
nil,
clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized))
}
return details.ItemInfo{}, iorc, nil
},
expectErr: require.NoError,
expect: require.NotNil,
},
{
name: "immediate error",
irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{}, nil, assert.AnError
},
expectErr: require.Error,
expect: require.Nil,
},
{
name: "re-fetching the item fails",
igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) {
return nil, assert.AnError
},
irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
return details.ItemInfo{},
nil,
clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized))
},
expectErr: require.Error,
expect: require.Nil,
},
{
name: "expired url fails redownload",
igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) {
return itemWID, nil
},
irf: func(c context.Context, g graph.Requester, m models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
// a bit hacky: assume only igf returns an item with a non-zero id.
if len(ptr.Val(m.GetId())) == 0 {
return details.ItemInfo{},
nil,
clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized))
}
return details.ItemInfo{}, iorc, assert.AnError
},
expectErr: require.Error,
expect: require.Nil,
},
}
for _, test := range table {
suite.Run(test.name, func() {
ctx, flush := tester.NewContext()
defer flush()
t := suite.T()
r, err := downloadContent(ctx, svc, test.igf, test.irf, gr, item, driveID)
test.expect(t, r)
test.expectErr(t, err, clues.ToCore(err))
})
}
}

View File

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http"
"strings" "strings"
"github.com/alcionai/clues" "github.com/alcionai/clues"
@ -73,7 +72,7 @@ type folderMatcher interface {
// resource owner, which can be either a user or a sharepoint site. // resource owner, which can be either a user or a sharepoint site.
type Collections struct { type Collections struct {
// configured to handle large item downloads // configured to handle large item downloads
itemClient *http.Client itemClient graph.Requester
tenant string tenant string
resourceOwner string resourceOwner string
@ -109,7 +108,7 @@ type Collections struct {
} }
func NewCollections( func NewCollections(
itemClient *http.Client, itemClient graph.Requester,
tenant string, tenant string,
resourceOwner string, resourceOwner string,
source driveSource, source driveSource,

View File

@ -780,7 +780,7 @@ func (suite *OneDriveCollectionsUnitSuite) TestUpdateCollections() {
maps.Copy(outputFolderMap, tt.inputFolderMap) maps.Copy(outputFolderMap, tt.inputFolderMap)
c := NewCollections( c := NewCollections(
graph.HTTPClient(graph.NoTimeout()), graph.NewNoTimeoutHTTPWrapper(),
tenant, tenant,
user, user,
OneDriveSource, OneDriveSource,
@ -2231,7 +2231,7 @@ func (suite *OneDriveCollectionsUnitSuite) TestGet() {
} }
c := NewCollections( c := NewCollections(
graph.HTTPClient(graph.NoTimeout()), graph.NewNoTimeoutHTTPWrapper(),
tenant, tenant,
user, user,
OneDriveSource, OneDriveSource,

View File

@ -2,7 +2,6 @@ package onedrive
import ( import (
"context" "context"
"net/http"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@ -38,7 +37,7 @@ func DataCollections(
user common.IDNamer, user common.IDNamer,
metadata []data.RestoreCollection, metadata []data.RestoreCollection,
tenant string, tenant string,
itemClient *http.Client, itemClient graph.Requester,
service graph.Servicer, service graph.Servicer,
su support.StatusUpdater, su support.StatusUpdater,
ctrlOpts control.Options, ctrlOpts control.Options,

View File

@ -426,7 +426,7 @@ func (suite *OneDriveSuite) TestOneDriveNewCollections() {
) )
colls := NewCollections( colls := NewCollections(
graph.HTTPClient(graph.NoTimeout()), graph.NewNoTimeoutHTTPWrapper(),
creds.AzureTenantID, creds.AzureTenantID,
test.user, test.user,
OneDriveSource, OneDriveSource,

View File

@ -16,7 +16,6 @@ import (
"github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/graph"
"github.com/alcionai/corso/src/internal/connector/onedrive/api" "github.com/alcionai/corso/src/internal/connector/onedrive/api"
"github.com/alcionai/corso/src/internal/connector/uploadsession" "github.com/alcionai/corso/src/internal/connector/uploadsession"
"github.com/alcionai/corso/src/internal/version"
"github.com/alcionai/corso/src/pkg/backup/details" "github.com/alcionai/corso/src/pkg/backup/details"
"github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/logger"
) )
@ -33,12 +32,12 @@ const (
// TODO: Add metadata fetching to SharePoint // TODO: Add metadata fetching to SharePoint
func sharePointItemReader( func sharePointItemReader(
ctx context.Context, ctx context.Context,
hc *http.Client, client graph.Requester,
item models.DriveItemable, item models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) { ) (details.ItemInfo, io.ReadCloser, error) {
resp, err := downloadItem(ctx, hc, item) resp, err := downloadItem(ctx, client, item)
if err != nil { if err != nil {
return details.ItemInfo{}, nil, clues.Wrap(err, "downloading item") return details.ItemInfo{}, nil, clues.Wrap(err, "sharepoint reader")
} }
dii := details.ItemInfo{ dii := details.ItemInfo{
@ -107,7 +106,7 @@ func baseItemMetaReader(
// and using a http client to initialize a reader // and using a http client to initialize a reader
func oneDriveItemReader( func oneDriveItemReader(
ctx context.Context, ctx context.Context,
hc *http.Client, client graph.Requester,
item models.DriveItemable, item models.DriveItemable,
) (details.ItemInfo, io.ReadCloser, error) { ) (details.ItemInfo, io.ReadCloser, error) {
var ( var (
@ -116,9 +115,9 @@ func oneDriveItemReader(
) )
if isFile { if isFile {
resp, err := downloadItem(ctx, hc, item) resp, err := downloadItem(ctx, client, item)
if err != nil { if err != nil {
return details.ItemInfo{}, nil, clues.Wrap(err, "downloading item") return details.ItemInfo{}, nil, clues.Wrap(err, "onedrive reader")
} }
rc = resp.Body rc = resp.Body
@ -131,38 +130,26 @@ func oneDriveItemReader(
return dii, rc, nil return dii, rc, nil
} }
func downloadItem(ctx context.Context, hc *http.Client, item models.DriveItemable) (*http.Response, error) { func downloadItem(
ctx context.Context,
client graph.Requester,
item models.DriveItemable,
) (*http.Response, error) {
url, ok := item.GetAdditionalData()[downloadURLKey].(*string) url, ok := item.GetAdditionalData()[downloadURLKey].(*string)
if !ok { if !ok {
return nil, clues.New("extracting file url").With("item_id", ptr.Val(item.GetId())) return nil, clues.New("extracting file url").With("item_id", ptr.Val(item.GetId()))
} }
req, err := http.NewRequest(http.MethodGet, *url, nil) resp, err := client.Request(ctx, http.MethodGet, ptr.Val(url), nil, nil)
if err != nil { if err != nil {
return nil, graph.Wrap(ctx, err, "new item download request") return nil, err
}
//nolint:lll
// Decorate the traffic
// See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic
req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version)
resp, err := hc.Do(req)
if err != nil {
cerr := graph.Wrap(ctx, err, "downloading item")
if graph.IsMalware(err) {
cerr = cerr.Label(graph.LabelsMalware)
}
return nil, cerr
} }
if (resp.StatusCode / 100) == 2 { if (resp.StatusCode / 100) == 2 {
return resp, nil return resp, nil
} }
if graph.IsMalwareResp(context.Background(), resp) { if graph.IsMalwareResp(ctx, resp) {
return nil, clues.New("malware detected").Label(graph.LabelsMalware) return nil, clues.New("malware detected").Label(graph.LabelsMalware)
} }

View File

@ -112,7 +112,7 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() {
) )
// Read data for the file // Read data for the file
itemInfo, itemData, err := oneDriveItemReader(ctx, graph.HTTPClient(graph.NoTimeout()), driveItem) itemInfo, itemData, err := oneDriveItemReader(ctx, graph.NewNoTimeoutHTTPWrapper(), driveItem)
require.NoError(suite.T(), err, clues.ToCore(err)) require.NoError(suite.T(), err, clues.ToCore(err))
require.NotNil(suite.T(), itemInfo.OneDrive) require.NotNil(suite.T(), itemInfo.OneDrive)

View File

@ -2,7 +2,6 @@ package sharepoint
import ( import (
"context" "context"
"net/http"
"github.com/alcionai/clues" "github.com/alcionai/clues"
@ -29,7 +28,7 @@ type statusUpdater interface {
// for the specified user // for the specified user
func DataCollections( func DataCollections(
ctx context.Context, ctx context.Context,
itemClient *http.Client, itemClient graph.Requester,
selector selectors.Selector, selector selectors.Selector,
creds account.M365Config, creds account.M365Config,
serv graph.Servicer, serv graph.Servicer,
@ -182,7 +181,7 @@ func collectLists(
// all the drives associated with the site. // all the drives associated with the site.
func collectLibraries( func collectLibraries(
ctx context.Context, ctx context.Context,
itemClient *http.Client, itemClient graph.Requester,
serv graph.Servicer, serv graph.Servicer,
tenantID, siteID string, tenantID, siteID string,
scope selectors.SharePointScope, scope selectors.SharePointScope,

View File

@ -109,7 +109,7 @@ func (suite *SharePointLibrariesUnitSuite) TestUpdateCollections() {
) )
c := onedrive.NewCollections( c := onedrive.NewCollections(
graph.HTTPClient(graph.NoTimeout()), graph.NewNoTimeoutHTTPWrapper(),
tenant, tenant,
site, site,
onedrive.SharePointSource, onedrive.SharePointSource,