custom http wrapper for item downloads (#3132)
the http client customized by the graph client for use in downloading files doesn't include our middleware like expected. This introduces a new http client wrapper that populates the roundtripper with a middleware wrapper that can utilize our kiota middleware in the same way as the graph client does. --- #### Does this PR need a docs update or release note? - [x] ⛔ No #### Type of change - [x] 🌻 Feature - [x] 🐛 Bugfix - [x] 🧹 Tech Debt/Cleanup #### Issue(s) * #3129 #### Test Plan - [x] 💪 Manual - [x] ⚡ Unit test - [x] 💚 E2E
This commit is contained in:
parent
315c0cc5f3
commit
d5fac8a480
@ -77,7 +77,10 @@ func handleOneDriveCmd(cmd *cobra.Command, args []string) error {
|
||||
return Only(ctx, clues.Wrap(err, "creating graph adapter"))
|
||||
}
|
||||
|
||||
err = runDisplayM365JSON(ctx, graph.NewService(adpt), creds, user, m365ID)
|
||||
svc := graph.NewService(adpt)
|
||||
gr := graph.NewNoTimeoutHTTPWrapper()
|
||||
|
||||
err = runDisplayM365JSON(ctx, svc, gr, creds, user, m365ID)
|
||||
if err != nil {
|
||||
cmd.SilenceUsage = true
|
||||
cmd.SilenceErrors = true
|
||||
@ -105,6 +108,7 @@ func (i itemPrintable) MinimumPrintable() any {
|
||||
func runDisplayM365JSON(
|
||||
ctx context.Context,
|
||||
srv graph.Servicer,
|
||||
gr graph.Requester,
|
||||
creds account.M365Config,
|
||||
user, itemID string,
|
||||
) error {
|
||||
@ -123,7 +127,7 @@ func runDisplayM365JSON(
|
||||
}
|
||||
|
||||
if item != nil {
|
||||
content, err := getDriveItemContent(item)
|
||||
content, err := getDriveItemContent(ctx, gr, item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -180,22 +184,19 @@ func serializeObject(data serialization.Parsable) (string, error) {
|
||||
return string(content), err
|
||||
}
|
||||
|
||||
func getDriveItemContent(item models.DriveItemable) ([]byte, error) {
|
||||
func getDriveItemContent(
|
||||
ctx context.Context,
|
||||
gr graph.Requester,
|
||||
item models.DriveItemable,
|
||||
) ([]byte, error) {
|
||||
url, ok := item.GetAdditionalData()[downloadURLKey].(*string)
|
||||
if !ok {
|
||||
return nil, clues.New("get download url")
|
||||
return nil, clues.New("retrieving download url")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, *url, nil)
|
||||
resp, err := gr.Request(ctx, http.MethodGet, *url, nil, nil)
|
||||
if err != nil {
|
||||
return nil, clues.New("create download request").With("error", err)
|
||||
}
|
||||
|
||||
hc := graph.HTTPClient(graph.NoTimeout())
|
||||
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, clues.New("download item").With("error", err)
|
||||
return nil, clues.New("downloading item").With("error", err)
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(resp.Body)
|
||||
|
||||
@ -258,7 +258,7 @@ func (suite *DataCollectionIntgSuite) TestSharePointDataCollection() {
|
||||
|
||||
collections, excludes, err := sharepoint.DataCollections(
|
||||
ctx,
|
||||
graph.HTTPClient(graph.NoTimeout()),
|
||||
graph.NewNoTimeoutHTTPWrapper(),
|
||||
sel,
|
||||
connector.credentials,
|
||||
connector.Service,
|
||||
|
||||
@ -1,36 +1,21 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"github.com/alcionai/clues"
|
||||
|
||||
"github.com/alcionai/corso/src/internal/connector/exchange/api"
|
||||
"github.com/alcionai/corso/src/internal/connector/graph"
|
||||
"github.com/alcionai/corso/src/internal/connector/graph/mock"
|
||||
"github.com/alcionai/corso/src/pkg/account"
|
||||
)
|
||||
|
||||
func NewService(creds account.M365Config, opts ...graph.Option) (*graph.Service, error) {
|
||||
a, err := mock.CreateAdapter(
|
||||
creds.AzureTenantID,
|
||||
creds.AzureClientID,
|
||||
creds.AzureClientSecret,
|
||||
opts...)
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "generating graph adapter")
|
||||
}
|
||||
|
||||
return graph.NewService(a), nil
|
||||
}
|
||||
|
||||
// NewClient produces a new exchange api client that can be
|
||||
// mocked using gock.
|
||||
func NewClient(creds account.M365Config) (api.Client, error) {
|
||||
s, err := NewService(creds)
|
||||
s, err := mock.NewService(creds)
|
||||
if err != nil {
|
||||
return api.Client{}, err
|
||||
}
|
||||
|
||||
li, err := NewService(creds, graph.NoTimeout())
|
||||
li, err := mock.NewService(creds, graph.NoTimeout())
|
||||
if err != nil {
|
||||
return api.Client{}, err
|
||||
}
|
||||
|
||||
@ -234,6 +234,9 @@ func Stack(ctx context.Context, e error) *clues.Err {
|
||||
return setLabels(clues.Stack(e).WithClues(ctx).With(data...), innerMsg)
|
||||
}
|
||||
|
||||
// Checks for the following conditions and labels the error accordingly:
|
||||
// * mysiteNotFound | mysiteURLNotFound
|
||||
// * malware
|
||||
func setLabels(err *clues.Err, msg string) *clues.Err {
|
||||
if err == nil {
|
||||
return nil
|
||||
@ -244,6 +247,10 @@ func setLabels(err *clues.Err, msg string) *clues.Err {
|
||||
err = err.Label(LabelsMysiteNotFound)
|
||||
}
|
||||
|
||||
if IsMalware(err) {
|
||||
err = err.Label(LabelsMalware)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
152
src/internal/connector/graph/http_wrapper.go
Normal file
152
src/internal/connector/graph/http_wrapper.go
Normal file
@ -0,0 +1,152 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/alcionai/clues"
|
||||
khttp "github.com/microsoft/kiota-http-go"
|
||||
|
||||
"github.com/alcionai/corso/src/internal/version"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// constructors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type Requester interface {
|
||||
Request(
|
||||
ctx context.Context,
|
||||
method, url string,
|
||||
body io.Reader,
|
||||
headers map[string]string,
|
||||
) (*http.Response, error)
|
||||
}
|
||||
|
||||
// NewHTTPWrapper produces a http.Client wrapper that ensures
|
||||
// calls use all the middleware we expect from the graph api client.
|
||||
//
|
||||
// Re-use of http clients is critical, or else we leak OS resources
|
||||
// and consume relatively unbound socket connections. It is important
|
||||
// to centralize this client to be passed downstream where api calls
|
||||
// can utilize it on a per-download basis.
|
||||
func NewHTTPWrapper(opts ...Option) *httpWrapper {
|
||||
var (
|
||||
cc = populateConfig(opts...)
|
||||
rt = customTransport{
|
||||
n: pipeline{
|
||||
middlewares: internalMiddleware(cc),
|
||||
transport: defaultTransport(),
|
||||
},
|
||||
}
|
||||
redirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
hc = &http.Client{
|
||||
CheckRedirect: redirect,
|
||||
Timeout: defaultHTTPClientTimeout,
|
||||
Transport: rt,
|
||||
}
|
||||
)
|
||||
|
||||
cc.apply(hc)
|
||||
|
||||
return &httpWrapper{hc}
|
||||
}
|
||||
|
||||
// NewNoTimeoutHTTPWrapper constructs a http wrapper with no context timeout.
|
||||
//
|
||||
// Re-use of http clients is critical, or else we leak OS resources
|
||||
// and consume relatively unbound socket connections. It is important
|
||||
// to centralize this client to be passed downstream where api calls
|
||||
// can utilize it on a per-download basis.
|
||||
func NewNoTimeoutHTTPWrapper(opts ...Option) *httpWrapper {
|
||||
opts = append(opts, NoTimeout())
|
||||
return NewHTTPWrapper(opts...)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// requests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Request does the provided request.
|
||||
func (hw httpWrapper) Request(
|
||||
ctx context.Context,
|
||||
method, url string,
|
||||
body io.Reader,
|
||||
headers map[string]string,
|
||||
) (*http.Response, error) {
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "new http request")
|
||||
}
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
//nolint:lll
|
||||
// Decorate the traffic
|
||||
// See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic
|
||||
req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version)
|
||||
|
||||
resp, err := hw.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, Stack(ctx, err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// constructor internals
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type (
|
||||
httpWrapper struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
customTransport struct {
|
||||
n nexter
|
||||
}
|
||||
|
||||
pipeline struct {
|
||||
transport http.RoundTripper
|
||||
middlewares []khttp.Middleware
|
||||
}
|
||||
)
|
||||
|
||||
// RoundTrip kicks off the middleware chain and returns a response
|
||||
func (ct customTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return ct.n.Next(req, 0)
|
||||
}
|
||||
|
||||
// Next moves the request object through middlewares in the pipeline
|
||||
func (pl pipeline) Next(req *http.Request, idx int) (*http.Response, error) {
|
||||
if idx < len(pl.middlewares) {
|
||||
return pl.middlewares[idx].Intercept(pl, idx+1, req)
|
||||
}
|
||||
|
||||
return pl.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func defaultTransport() http.RoundTripper {
|
||||
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
defaultTransport.ForceAttemptHTTP2 = true
|
||||
|
||||
return defaultTransport
|
||||
}
|
||||
|
||||
func internalMiddleware(cc *clientConfig) []khttp.Middleware {
|
||||
return []khttp.Middleware{
|
||||
&RetryHandler{
|
||||
MaxRetries: cc.maxRetries,
|
||||
Delay: cc.minDelay,
|
||||
},
|
||||
&LoggingMiddleware{},
|
||||
&ThrottleControlMiddleware{},
|
||||
&MetricsMiddleware{},
|
||||
}
|
||||
}
|
||||
45
src/internal/connector/graph/http_wrapper_test.go
Normal file
45
src/internal/connector/graph/http_wrapper_test.go
Normal file
@ -0,0 +1,45 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/alcionai/clues"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/alcionai/corso/src/internal/tester"
|
||||
)
|
||||
|
||||
type HTTPWrapperIntgSuite struct {
|
||||
tester.Suite
|
||||
}
|
||||
|
||||
func TestHTTPWrapperIntgSuite(t *testing.T) {
|
||||
suite.Run(t, &HTTPWrapperIntgSuite{
|
||||
Suite: tester.NewIntegrationSuite(
|
||||
t,
|
||||
[][]string{tester.M365AcctCredEnvs}),
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *HTTPWrapperIntgSuite) TestNewHTTPWrapper() {
|
||||
ctx, flush := tester.NewContext()
|
||||
defer flush()
|
||||
|
||||
var (
|
||||
t = suite.T()
|
||||
hw = NewHTTPWrapper()
|
||||
)
|
||||
|
||||
resp, err := hw.Request(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
"https://www.corsobackup.io",
|
||||
nil,
|
||||
nil)
|
||||
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
@ -20,6 +20,10 @@ import (
|
||||
"github.com/alcionai/corso/src/pkg/logger"
|
||||
)
|
||||
|
||||
type nexter interface {
|
||||
Next(req *http.Request, middlewareIndex int) (*http.Response, error)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Logging
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@ -1,12 +1,27 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"github.com/alcionai/clues"
|
||||
"github.com/h2non/gock"
|
||||
msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go"
|
||||
|
||||
"github.com/alcionai/corso/src/internal/connector/graph"
|
||||
"github.com/alcionai/corso/src/pkg/account"
|
||||
)
|
||||
|
||||
func NewService(creds account.M365Config, opts ...graph.Option) (*graph.Service, error) {
|
||||
a, err := CreateAdapter(
|
||||
creds.AzureTenantID,
|
||||
creds.AzureClientID,
|
||||
creds.AzureClientSecret,
|
||||
opts...)
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "generating graph adapter")
|
||||
}
|
||||
|
||||
return graph.NewService(a), nil
|
||||
}
|
||||
|
||||
// CreateAdapter is similar to graph.CreateAdapter, but with option to
|
||||
// enable interceptions via gock to make it mockable.
|
||||
func CreateAdapter(
|
||||
@ -18,7 +33,7 @@ func CreateAdapter(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpClient := graph.HTTPClient(opts...)
|
||||
httpClient := graph.KiotaHTTPClient(opts...)
|
||||
|
||||
// This makes sure that we are able to intercept any requests via
|
||||
// gock. Only necessary for testing.
|
||||
|
||||
@ -114,7 +114,7 @@ func CreateAdapter(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpClient := HTTPClient(opts...)
|
||||
httpClient := KiotaHTTPClient(opts...)
|
||||
|
||||
return msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(
|
||||
auth,
|
||||
@ -140,21 +140,24 @@ func GetAuth(tenant string, client string, secret string) (*kauth.AzureIdentityA
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// HTTPClient creates the httpClient with middlewares and timeout configured
|
||||
// KiotaHTTPClient creates a httpClient with middlewares and timeout configured
|
||||
// for use in the graph adapter.
|
||||
//
|
||||
// Re-use of http clients is critical, or else we leak OS resources
|
||||
// and consume relatively unbound socket connections. It is important
|
||||
// to centralize this client to be passed downstream where api calls
|
||||
// can utilize it on a per-download basis.
|
||||
func HTTPClient(opts ...Option) *http.Client {
|
||||
clientOptions := msgraphsdkgo.GetDefaultClientOptions()
|
||||
clientconfig := (&clientConfig{}).populate(opts...)
|
||||
noOfRetries, minRetryDelay := clientconfig.applyMiddlewareConfig()
|
||||
middlewares := GetKiotaMiddlewares(&clientOptions, noOfRetries, minRetryDelay)
|
||||
httpClient := msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
|
||||
func KiotaHTTPClient(opts ...Option) *http.Client {
|
||||
var (
|
||||
clientOptions = msgraphsdkgo.GetDefaultClientOptions()
|
||||
cc = populateConfig(opts...)
|
||||
middlewares = kiotaMiddlewares(&clientOptions, cc)
|
||||
httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
|
||||
)
|
||||
|
||||
httpClient.Timeout = defaultHTTPClientTimeout
|
||||
|
||||
clientconfig.apply(httpClient)
|
||||
cc.apply(httpClient)
|
||||
|
||||
return httpClient
|
||||
}
|
||||
@ -175,27 +178,17 @@ type clientConfig struct {
|
||||
type Option func(*clientConfig)
|
||||
|
||||
// populate constructs a clientConfig according to the provided options.
|
||||
func (c *clientConfig) populate(opts ...Option) *clientConfig {
|
||||
func populateConfig(opts ...Option) *clientConfig {
|
||||
cc := clientConfig{
|
||||
maxRetries: defaultMaxRetries,
|
||||
minDelay: defaultDelay,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
opt(&cc)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// apply updates the http.Client with the expected options.
|
||||
func (c *clientConfig) applyMiddlewareConfig() (retry int, delay time.Duration) {
|
||||
retry = defaultMaxRetries
|
||||
if c.overrideRetryCount {
|
||||
retry = c.maxRetries
|
||||
}
|
||||
|
||||
delay = defaultDelay
|
||||
if c.minDelay > 0 {
|
||||
delay = c.minDelay
|
||||
}
|
||||
|
||||
return
|
||||
return &cc
|
||||
}
|
||||
|
||||
// apply updates the http.Client with the expected options.
|
||||
@ -236,14 +229,16 @@ func MinimumBackoff(dur time.Duration) Option {
|
||||
// Middleware Control
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// GetDefaultMiddlewares creates a new default set of middlewares for the Kiota request adapter
|
||||
func GetMiddlewares(maxRetry int, delay time.Duration) []khttp.Middleware {
|
||||
// kiotaMiddlewares creates a default slice of middleware for the Graph Client.
|
||||
func kiotaMiddlewares(
|
||||
options *msgraphgocore.GraphClientOptions,
|
||||
cc *clientConfig,
|
||||
) []khttp.Middleware {
|
||||
return []khttp.Middleware{
|
||||
msgraphgocore.NewGraphTelemetryHandler(options),
|
||||
&RetryHandler{
|
||||
// The maximum number of times a request can be retried
|
||||
MaxRetries: maxRetry,
|
||||
// The delay in seconds between retries
|
||||
Delay: delay,
|
||||
MaxRetries: cc.maxRetries,
|
||||
Delay: cc.minDelay,
|
||||
},
|
||||
khttp.NewRetryHandler(),
|
||||
khttp.NewRedirectHandler(),
|
||||
@ -255,21 +250,3 @@ func GetMiddlewares(maxRetry int, delay time.Duration) []khttp.Middleware {
|
||||
&MetricsMiddleware{},
|
||||
}
|
||||
}
|
||||
|
||||
// GetKiotaMiddlewares creates a default slice of middleware for the Graph Client.
|
||||
func GetKiotaMiddlewares(
|
||||
options *msgraphgocore.GraphClientOptions,
|
||||
maxRetry int,
|
||||
minDelay time.Duration,
|
||||
) []khttp.Middleware {
|
||||
kiotaMiddlewares := GetMiddlewares(maxRetry, minDelay)
|
||||
graphMiddlewares := []khttp.Middleware{
|
||||
msgraphgocore.NewGraphTelemetryHandler(options),
|
||||
}
|
||||
graphMiddlewaresLen := len(graphMiddlewares)
|
||||
resultMiddlewares := make([]khttp.Middleware, len(kiotaMiddlewares)+graphMiddlewaresLen)
|
||||
copy(resultMiddlewares, graphMiddlewares)
|
||||
copy(resultMiddlewares[graphMiddlewaresLen:], kiotaMiddlewares)
|
||||
|
||||
return resultMiddlewares
|
||||
}
|
||||
|
||||
@ -70,7 +70,7 @@ func (suite *GraphUnitSuite) TestHTTPClient() {
|
||||
suite.Run(test.name, func() {
|
||||
t := suite.T()
|
||||
|
||||
cli := HTTPClient(test.opts...)
|
||||
cli := KiotaHTTPClient(test.opts...)
|
||||
assert.NotNil(t, cli)
|
||||
test.check(t, cli)
|
||||
})
|
||||
|
||||
@ -4,7 +4,6 @@ package connector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"runtime/trace"
|
||||
"sync"
|
||||
|
||||
@ -36,7 +35,7 @@ var (
|
||||
type GraphConnector struct {
|
||||
Service graph.Servicer
|
||||
Discovery api.Client
|
||||
itemClient *http.Client // configured to handle large item downloads
|
||||
itemClient graph.Requester // configured to handle large item downloads
|
||||
|
||||
tenant string
|
||||
credentials account.M365Config
|
||||
@ -88,7 +87,7 @@ func NewGraphConnector(
|
||||
Service: service,
|
||||
|
||||
credentials: creds,
|
||||
itemClient: graph.HTTPClient(graph.NoTimeout()),
|
||||
itemClient: graph.NewNoTimeoutHTTPWrapper(),
|
||||
ownerLookup: rc,
|
||||
tenant: acct.ID(),
|
||||
wg: &sync.WaitGroup{},
|
||||
|
||||
@ -35,10 +35,6 @@ const (
|
||||
// TODO: Tune this later along with collectionChannelBufferSize
|
||||
urlPrefetchChannelBufferSize = 5
|
||||
|
||||
// maxDownloadRetires specifies the number of times a file download should
|
||||
// be retried
|
||||
maxDownloadRetires = 3
|
||||
|
||||
// Used to compare in case of OneNote files
|
||||
MaxOneNoteFileSize = 2 * 1024 * 1024 * 1024
|
||||
)
|
||||
@ -62,7 +58,7 @@ const (
|
||||
// Collection represents a set of OneDrive objects retrieved from M365
|
||||
type Collection struct {
|
||||
// configured to handle large item downloads
|
||||
itemClient *http.Client
|
||||
itemClient graph.Requester
|
||||
|
||||
// data is used to share data streams with the collection consumer
|
||||
data chan data.Stream
|
||||
@ -110,7 +106,7 @@ type Collection struct {
|
||||
doNotMergeItems bool
|
||||
}
|
||||
|
||||
// itemGetterFunc gets an specified item
|
||||
// itemGetterFunc gets a specified item
|
||||
type itemGetterFunc func(
|
||||
ctx context.Context,
|
||||
srv graph.Servicer,
|
||||
@ -120,7 +116,7 @@ type itemGetterFunc func(
|
||||
// itemReadFunc returns a reader for the specified item
|
||||
type itemReaderFunc func(
|
||||
ctx context.Context,
|
||||
hc *http.Client,
|
||||
client graph.Requester,
|
||||
item models.DriveItemable,
|
||||
) (details.ItemInfo, io.ReadCloser, error)
|
||||
|
||||
@ -148,7 +144,7 @@ func pathToLocation(p path.Path) (*path.Builder, error) {
|
||||
|
||||
// NewCollection creates a Collection
|
||||
func NewCollection(
|
||||
itemClient *http.Client,
|
||||
itemClient graph.Requester,
|
||||
folderPath path.Path,
|
||||
prevPath path.Path,
|
||||
driveID string,
|
||||
@ -372,45 +368,29 @@ func (oc *Collection) getDriveItemContent(
|
||||
itemID = ptr.Val(item.GetId())
|
||||
itemName = ptr.Val(item.GetName())
|
||||
el = errs.Local()
|
||||
|
||||
itemData io.ReadCloser
|
||||
err error
|
||||
)
|
||||
|
||||
// Initial try with url from delta + 2 retries
|
||||
for i := 1; i <= maxDownloadRetires; i++ {
|
||||
_, itemData, err = oc.itemReader(ctx, oc.itemClient, item)
|
||||
if err == nil || !graph.IsErrUnauthorized(err) {
|
||||
break
|
||||
}
|
||||
|
||||
// Assume unauthorized requests are a sign of an expired jwt
|
||||
// token, and that we've overrun the available window to
|
||||
// download the actual file. Re-downloading the item will
|
||||
// refresh that download url.
|
||||
di, diErr := oc.itemGetter(ctx, oc.service, oc.driveID, itemID)
|
||||
if diErr != nil {
|
||||
err = clues.Wrap(diErr, "retrieving expired item")
|
||||
break
|
||||
}
|
||||
|
||||
item = di
|
||||
}
|
||||
|
||||
// check for errors following retries
|
||||
itemData, err := downloadContent(
|
||||
ctx,
|
||||
oc.service,
|
||||
oc.itemGetter,
|
||||
oc.itemReader,
|
||||
oc.itemClient,
|
||||
item,
|
||||
oc.driveID)
|
||||
if err != nil {
|
||||
if clues.HasLabel(err, graph.LabelsMalware) || (item != nil && item.GetMalware() != nil) {
|
||||
logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipMalware).Info("item flagged as malware")
|
||||
el.AddSkip(fault.FileSkip(fault.SkipMalware, itemID, itemName, graph.ItemInfo(item)))
|
||||
|
||||
return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable)
|
||||
return nil, clues.Wrap(err, "malware item").Label(graph.LabelsSkippable)
|
||||
}
|
||||
|
||||
if clues.HasLabel(err, graph.LabelStatus(http.StatusNotFound)) || graph.IsErrDeletedInFlight(err) {
|
||||
logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipNotFound).Info("item not found")
|
||||
el.AddSkip(fault.FileSkip(fault.SkipNotFound, itemID, itemName, graph.ItemInfo(item)))
|
||||
|
||||
return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable)
|
||||
return nil, clues.Wrap(err, "deleted item").Label(graph.LabelsSkippable)
|
||||
}
|
||||
|
||||
// Skip big OneNote files as they can't be downloaded
|
||||
@ -425,7 +405,7 @@ func (oc *Collection) getDriveItemContent(
|
||||
logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipBigOneNote).Info("max OneNote file size exceeded")
|
||||
el.AddSkip(fault.FileSkip(fault.SkipBigOneNote, itemID, itemName, graph.ItemInfo(item)))
|
||||
|
||||
return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable)
|
||||
return nil, clues.Wrap(err, "max oneNote item").Label(graph.LabelsSkippable)
|
||||
}
|
||||
|
||||
logger.CtxErr(ctx, err).Error("downloading item")
|
||||
@ -433,12 +413,48 @@ func (oc *Collection) getDriveItemContent(
|
||||
|
||||
// return err, not el.Err(), because the lazy reader needs to communicate to
|
||||
// the data consumer that this item is unreadable, regardless of the fault state.
|
||||
return nil, clues.Wrap(err, "downloading item")
|
||||
return nil, clues.Wrap(err, "fetching item content")
|
||||
}
|
||||
|
||||
return itemData, nil
|
||||
}
|
||||
|
||||
// downloadContent attempts to fetch the item content. If the content url
|
||||
// is expired (ie, returns a 401), it re-fetches the item to get a new download
|
||||
// url and tries again.
|
||||
func downloadContent(
|
||||
ctx context.Context,
|
||||
svc graph.Servicer,
|
||||
igf itemGetterFunc,
|
||||
irf itemReaderFunc,
|
||||
gr graph.Requester,
|
||||
item models.DriveItemable,
|
||||
driveID string,
|
||||
) (io.ReadCloser, error) {
|
||||
_, content, err := irf(ctx, gr, item)
|
||||
if err == nil {
|
||||
return content, nil
|
||||
} else if !graph.IsErrUnauthorized(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Assume unauthorized requests are a sign of an expired jwt
|
||||
// token, and that we've overrun the available window to
|
||||
// download the actual file. Re-downloading the item will
|
||||
// refresh that download url.
|
||||
di, err := igf(ctx, svc, driveID, ptr.Val(item.GetId()))
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "retrieving expired item")
|
||||
}
|
||||
|
||||
_, content, err = irf(ctx, gr, di)
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "content download retry")
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// populateItems iterates through items added to the collection
|
||||
// and uses the collection `itemReader` to read the item
|
||||
func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
|
||||
@ -504,9 +520,9 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
|
||||
|
||||
ctx = clues.Add(
|
||||
ctx,
|
||||
"backup_item_id", itemID,
|
||||
"backup_item_name", itemName,
|
||||
"backup_item_size", itemSize)
|
||||
"item_id", itemID,
|
||||
"item_name", itemName,
|
||||
"item_size", itemSize)
|
||||
|
||||
item.SetParentReference(setName(item.GetParentReference(), oc.driveName))
|
||||
|
||||
@ -545,7 +561,7 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) {
|
||||
itemInfo.OneDrive.ParentPath = parentPathString
|
||||
}
|
||||
|
||||
ctx = clues.Add(ctx, "backup_item_info", itemInfo)
|
||||
ctx = clues.Add(ctx, "item_info", itemInfo)
|
||||
|
||||
if isFile {
|
||||
dataSuffix := metadata.DataFileSuffix
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/alcionai/corso/src/internal/common/ptr"
|
||||
"github.com/alcionai/corso/src/internal/connector/graph"
|
||||
"github.com/alcionai/corso/src/internal/connector/onedrive/metadata"
|
||||
"github.com/alcionai/corso/src/internal/connector/support"
|
||||
@ -98,7 +99,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
|
||||
numInstances: 1,
|
||||
source: OneDriveSource,
|
||||
itemDeets: nst{testItemName, 42, now},
|
||||
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: testItemName, Modified: now}},
|
||||
io.NopCloser(bytes.NewReader(testItemData)),
|
||||
nil
|
||||
@ -114,7 +115,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
|
||||
numInstances: 3,
|
||||
source: OneDriveSource,
|
||||
itemDeets: nst{testItemName, 42, now},
|
||||
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: testItemName, Modified: now}},
|
||||
io.NopCloser(bytes.NewReader(testItemData)),
|
||||
nil
|
||||
@ -130,7 +131,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
|
||||
numInstances: 3,
|
||||
source: OneDriveSource,
|
||||
itemDeets: nst{testItemName, 42, now},
|
||||
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{}, nil, clues.New("test malware").Label(graph.LabelsMalware)
|
||||
},
|
||||
infoFrom: func(t *testing.T, dii details.ItemInfo) (string, string) {
|
||||
@ -146,7 +147,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
|
||||
source: OneDriveSource,
|
||||
itemDeets: nst{testItemName, 42, now},
|
||||
// Usually `Not Found` is returned from itemGetter and not itemReader
|
||||
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{}, nil, clues.New("test not found").Label(graph.LabelStatus(http.StatusNotFound))
|
||||
},
|
||||
infoFrom: func(t *testing.T, dii details.ItemInfo) (string, string) {
|
||||
@ -161,7 +162,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
|
||||
numInstances: 1,
|
||||
source: SharePointSource,
|
||||
itemDeets: nst{testItemName, 42, now},
|
||||
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{SharePoint: &details.SharePointInfo{ItemName: testItemName, Modified: now}},
|
||||
io.NopCloser(bytes.NewReader(testItemData)),
|
||||
nil
|
||||
@ -177,7 +178,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
|
||||
numInstances: 3,
|
||||
source: SharePointSource,
|
||||
itemDeets: nst{testItemName, 42, now},
|
||||
itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{SharePoint: &details.SharePointInfo{ItemName: testItemName, Modified: now}},
|
||||
io.NopCloser(bytes.NewReader(testItemData)),
|
||||
nil
|
||||
@ -207,7 +208,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
coll, err := NewCollection(
|
||||
graph.HTTPClient(graph.NoTimeout()),
|
||||
graph.NewNoTimeoutHTTPWrapper(),
|
||||
folderPath,
|
||||
nil,
|
||||
"drive-id",
|
||||
@ -278,7 +279,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() {
|
||||
|
||||
if err != nil {
|
||||
for _, label := range test.expectLabels {
|
||||
assert.True(t, clues.HasLabel(err, label), "has clues label:", label)
|
||||
assert.Truef(t, clues.HasLabel(err, label), "has clues label: %s", label)
|
||||
}
|
||||
|
||||
return
|
||||
@ -347,7 +348,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadError() {
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
coll, err := NewCollection(
|
||||
graph.HTTPClient(graph.NoTimeout()),
|
||||
graph.NewNoTimeoutHTTPWrapper(),
|
||||
folderPath,
|
||||
nil,
|
||||
"fakeDriveID",
|
||||
@ -370,7 +371,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadError() {
|
||||
|
||||
coll.itemReader = func(
|
||||
context.Context,
|
||||
*http.Client,
|
||||
graph.Requester,
|
||||
models.DriveItemable,
|
||||
) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{}, nil, assert.AnError
|
||||
@ -437,7 +438,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry()
|
||||
require.NoError(t, err)
|
||||
|
||||
coll, err := NewCollection(
|
||||
graph.HTTPClient(graph.NoTimeout()),
|
||||
graph.NewNoTimeoutHTTPWrapper(),
|
||||
folderPath,
|
||||
nil,
|
||||
"fakeDriveID",
|
||||
@ -470,10 +471,10 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry()
|
||||
|
||||
coll.itemReader = func(
|
||||
context.Context,
|
||||
*http.Client,
|
||||
graph.Requester,
|
||||
models.DriveItemable,
|
||||
) (details.ItemInfo, io.ReadCloser, error) {
|
||||
if count < 2 {
|
||||
if count < 1 {
|
||||
count++
|
||||
return details.ItemInfo{}, nil, clues.Stack(assert.AnError).
|
||||
Label(graph.LabelStatus(http.StatusUnauthorized))
|
||||
@ -494,13 +495,13 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry()
|
||||
assert.True(t, ok)
|
||||
|
||||
_, err = io.ReadAll(collItem.ToReader())
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
wg.Wait()
|
||||
|
||||
require.Equal(t, 1, collStatus.Metrics.Objects, "only one object should be counted")
|
||||
require.Equal(t, 1, collStatus.Metrics.Successes, "read object successfully")
|
||||
require.Equal(t, 2, count, "retry count")
|
||||
require.Equal(t, 1, count, "retry count")
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -537,7 +538,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionPermissionBackupLatestModTim
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
coll, err := NewCollection(
|
||||
graph.HTTPClient(graph.NoTimeout()),
|
||||
graph.NewNoTimeoutHTTPWrapper(),
|
||||
folderPath,
|
||||
nil,
|
||||
"drive-id",
|
||||
@ -561,7 +562,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionPermissionBackupLatestModTim
|
||||
|
||||
coll.itemReader = func(
|
||||
context.Context,
|
||||
*http.Client,
|
||||
graph.Requester,
|
||||
models.DriveItemable,
|
||||
) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: "fakeName", Modified: time.Now()}},
|
||||
@ -611,7 +612,7 @@ func TestGetDriveItemUnitTestSuite(t *testing.T) {
|
||||
suite.Run(t, &GetDriveItemUnitTestSuite{Suite: tester.NewUnitSuite(t)})
|
||||
}
|
||||
|
||||
func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
|
||||
func (suite *GetDriveItemUnitTestSuite) TestGetDriveItem_error() {
|
||||
strval := "not-important"
|
||||
|
||||
table := []struct {
|
||||
@ -637,14 +638,14 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
|
||||
name: "malware error",
|
||||
colScope: CollectionScopeFolder,
|
||||
itemSize: 10,
|
||||
err: clues.New("test error").Label(graph.LabelsMalware),
|
||||
err: clues.New("malware error").Label(graph.LabelsMalware),
|
||||
labels: []string{graph.LabelsMalware, graph.LabelsSkippable},
|
||||
},
|
||||
{
|
||||
name: "file not found error",
|
||||
colScope: CollectionScopeFolder,
|
||||
itemSize: 10,
|
||||
err: clues.New("test error").Label(graph.LabelStatus(http.StatusNotFound)),
|
||||
err: clues.New("not found error").Label(graph.LabelStatus(http.StatusNotFound)),
|
||||
labels: []string{graph.LabelStatus(http.StatusNotFound), graph.LabelsSkippable},
|
||||
},
|
||||
{
|
||||
@ -652,14 +653,14 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
|
||||
name: "small OneNote file",
|
||||
colScope: CollectionScopePackage,
|
||||
itemSize: 10,
|
||||
err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
|
||||
err: clues.New("small onenote error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
|
||||
labels: []string{graph.LabelStatus(http.StatusServiceUnavailable)},
|
||||
},
|
||||
{
|
||||
name: "big OneNote file",
|
||||
colScope: CollectionScopePackage,
|
||||
itemSize: MaxOneNoteFileSize,
|
||||
err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
|
||||
err: clues.New("big onenote error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
|
||||
labels: []string{graph.LabelStatus(http.StatusServiceUnavailable), graph.LabelsSkippable},
|
||||
},
|
||||
{
|
||||
@ -667,7 +668,7 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
|
||||
name: "big file",
|
||||
colScope: CollectionScopeFolder,
|
||||
itemSize: MaxOneNoteFileSize,
|
||||
err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
|
||||
err: clues.New("big file error").Label(graph.LabelStatus(http.StatusServiceUnavailable)),
|
||||
labels: []string{graph.LabelStatus(http.StatusServiceUnavailable)},
|
||||
},
|
||||
}
|
||||
@ -689,9 +690,9 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
|
||||
item.SetSize(&test.itemSize)
|
||||
|
||||
col.itemReader = func(
|
||||
ctx context.Context,
|
||||
hc *http.Client,
|
||||
item models.DriveItemable,
|
||||
_ context.Context,
|
||||
_ graph.Requester,
|
||||
_ models.DriveItemable,
|
||||
) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{}, nil, test.err
|
||||
}
|
||||
@ -707,11 +708,11 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
|
||||
|
||||
_, err := col.getDriveItemContent(ctx, item, errs)
|
||||
if test.err == nil {
|
||||
assert.NoError(t, err, "no error")
|
||||
assert.NoError(t, err, clues.ToCore(err))
|
||||
return
|
||||
}
|
||||
|
||||
assert.EqualError(t, err, clues.Wrap(test.err, "downloading item").Error(), "error")
|
||||
assert.ErrorIs(t, err, test.err, clues.ToCore(err))
|
||||
|
||||
labelsMap := map[string]struct{}{}
|
||||
for _, l := range test.labels {
|
||||
@ -722,3 +723,103 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() {
|
||||
var (
|
||||
svc graph.Servicer
|
||||
gr graph.Requester
|
||||
driveID string
|
||||
iorc = io.NopCloser(bytes.NewReader([]byte("fnords")))
|
||||
item = &models.DriveItem{}
|
||||
itemWID = &models.DriveItem{}
|
||||
)
|
||||
|
||||
itemWID.SetId(ptr.To("brainhooldy"))
|
||||
|
||||
table := []struct {
|
||||
name string
|
||||
igf itemGetterFunc
|
||||
irf itemReaderFunc
|
||||
expectErr require.ErrorAssertionFunc
|
||||
expect require.ValueAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "good",
|
||||
irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{}, iorc, nil
|
||||
},
|
||||
expectErr: require.NoError,
|
||||
expect: require.NotNil,
|
||||
},
|
||||
{
|
||||
name: "expired url redownloads",
|
||||
igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) {
|
||||
return itemWID, nil
|
||||
},
|
||||
irf: func(c context.Context, g graph.Requester, m models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
// a bit hacky: assume only igf returns an item with a non-zero id.
|
||||
if len(ptr.Val(m.GetId())) == 0 {
|
||||
return details.ItemInfo{},
|
||||
nil,
|
||||
clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized))
|
||||
}
|
||||
|
||||
return details.ItemInfo{}, iorc, nil
|
||||
},
|
||||
expectErr: require.NoError,
|
||||
expect: require.NotNil,
|
||||
},
|
||||
{
|
||||
name: "immediate error",
|
||||
irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{}, nil, assert.AnError
|
||||
},
|
||||
expectErr: require.Error,
|
||||
expect: require.Nil,
|
||||
},
|
||||
{
|
||||
name: "re-fetching the item fails",
|
||||
igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) {
|
||||
return nil, assert.AnError
|
||||
},
|
||||
irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
return details.ItemInfo{},
|
||||
nil,
|
||||
clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized))
|
||||
},
|
||||
expectErr: require.Error,
|
||||
expect: require.Nil,
|
||||
},
|
||||
{
|
||||
name: "expired url fails redownload",
|
||||
igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) {
|
||||
return itemWID, nil
|
||||
},
|
||||
irf: func(c context.Context, g graph.Requester, m models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) {
|
||||
// a bit hacky: assume only igf returns an item with a non-zero id.
|
||||
if len(ptr.Val(m.GetId())) == 0 {
|
||||
return details.ItemInfo{},
|
||||
nil,
|
||||
clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized))
|
||||
}
|
||||
|
||||
return details.ItemInfo{}, iorc, assert.AnError
|
||||
},
|
||||
expectErr: require.Error,
|
||||
expect: require.Nil,
|
||||
},
|
||||
}
|
||||
for _, test := range table {
|
||||
suite.Run(test.name, func() {
|
||||
ctx, flush := tester.NewContext()
|
||||
defer flush()
|
||||
|
||||
t := suite.T()
|
||||
|
||||
r, err := downloadContent(ctx, svc, test.igf, test.irf, gr, item, driveID)
|
||||
|
||||
test.expect(t, r)
|
||||
test.expectErr(t, err, clues.ToCore(err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alcionai/clues"
|
||||
@ -73,7 +72,7 @@ type folderMatcher interface {
|
||||
// resource owner, which can be either a user or a sharepoint site.
|
||||
type Collections struct {
|
||||
// configured to handle large item downloads
|
||||
itemClient *http.Client
|
||||
itemClient graph.Requester
|
||||
|
||||
tenant string
|
||||
resourceOwner string
|
||||
@ -109,7 +108,7 @@ type Collections struct {
|
||||
}
|
||||
|
||||
func NewCollections(
|
||||
itemClient *http.Client,
|
||||
itemClient graph.Requester,
|
||||
tenant string,
|
||||
resourceOwner string,
|
||||
source driveSource,
|
||||
|
||||
@ -780,7 +780,7 @@ func (suite *OneDriveCollectionsUnitSuite) TestUpdateCollections() {
|
||||
maps.Copy(outputFolderMap, tt.inputFolderMap)
|
||||
|
||||
c := NewCollections(
|
||||
graph.HTTPClient(graph.NoTimeout()),
|
||||
graph.NewNoTimeoutHTTPWrapper(),
|
||||
tenant,
|
||||
user,
|
||||
OneDriveSource,
|
||||
@ -2231,7 +2231,7 @@ func (suite *OneDriveCollectionsUnitSuite) TestGet() {
|
||||
}
|
||||
|
||||
c := NewCollections(
|
||||
graph.HTTPClient(graph.NoTimeout()),
|
||||
graph.NewNoTimeoutHTTPWrapper(),
|
||||
tenant,
|
||||
user,
|
||||
OneDriveSource,
|
||||
|
||||
@ -2,7 +2,6 @@ package onedrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/alcionai/clues"
|
||||
"golang.org/x/exp/maps"
|
||||
@ -38,7 +37,7 @@ func DataCollections(
|
||||
user common.IDNamer,
|
||||
metadata []data.RestoreCollection,
|
||||
tenant string,
|
||||
itemClient *http.Client,
|
||||
itemClient graph.Requester,
|
||||
service graph.Servicer,
|
||||
su support.StatusUpdater,
|
||||
ctrlOpts control.Options,
|
||||
|
||||
@ -426,7 +426,7 @@ func (suite *OneDriveSuite) TestOneDriveNewCollections() {
|
||||
)
|
||||
|
||||
colls := NewCollections(
|
||||
graph.HTTPClient(graph.NoTimeout()),
|
||||
graph.NewNoTimeoutHTTPWrapper(),
|
||||
creds.AzureTenantID,
|
||||
test.user,
|
||||
OneDriveSource,
|
||||
|
||||
@ -16,7 +16,6 @@ import (
|
||||
"github.com/alcionai/corso/src/internal/connector/graph"
|
||||
"github.com/alcionai/corso/src/internal/connector/onedrive/api"
|
||||
"github.com/alcionai/corso/src/internal/connector/uploadsession"
|
||||
"github.com/alcionai/corso/src/internal/version"
|
||||
"github.com/alcionai/corso/src/pkg/backup/details"
|
||||
"github.com/alcionai/corso/src/pkg/logger"
|
||||
)
|
||||
@ -33,12 +32,12 @@ const (
|
||||
// TODO: Add metadata fetching to SharePoint
|
||||
func sharePointItemReader(
|
||||
ctx context.Context,
|
||||
hc *http.Client,
|
||||
client graph.Requester,
|
||||
item models.DriveItemable,
|
||||
) (details.ItemInfo, io.ReadCloser, error) {
|
||||
resp, err := downloadItem(ctx, hc, item)
|
||||
resp, err := downloadItem(ctx, client, item)
|
||||
if err != nil {
|
||||
return details.ItemInfo{}, nil, clues.Wrap(err, "downloading item")
|
||||
return details.ItemInfo{}, nil, clues.Wrap(err, "sharepoint reader")
|
||||
}
|
||||
|
||||
dii := details.ItemInfo{
|
||||
@ -107,7 +106,7 @@ func baseItemMetaReader(
|
||||
// and using a http client to initialize a reader
|
||||
func oneDriveItemReader(
|
||||
ctx context.Context,
|
||||
hc *http.Client,
|
||||
client graph.Requester,
|
||||
item models.DriveItemable,
|
||||
) (details.ItemInfo, io.ReadCloser, error) {
|
||||
var (
|
||||
@ -116,9 +115,9 @@ func oneDriveItemReader(
|
||||
)
|
||||
|
||||
if isFile {
|
||||
resp, err := downloadItem(ctx, hc, item)
|
||||
resp, err := downloadItem(ctx, client, item)
|
||||
if err != nil {
|
||||
return details.ItemInfo{}, nil, clues.Wrap(err, "downloading item")
|
||||
return details.ItemInfo{}, nil, clues.Wrap(err, "onedrive reader")
|
||||
}
|
||||
|
||||
rc = resp.Body
|
||||
@ -131,38 +130,26 @@ func oneDriveItemReader(
|
||||
return dii, rc, nil
|
||||
}
|
||||
|
||||
func downloadItem(ctx context.Context, hc *http.Client, item models.DriveItemable) (*http.Response, error) {
|
||||
func downloadItem(
|
||||
ctx context.Context,
|
||||
client graph.Requester,
|
||||
item models.DriveItemable,
|
||||
) (*http.Response, error) {
|
||||
url, ok := item.GetAdditionalData()[downloadURLKey].(*string)
|
||||
if !ok {
|
||||
return nil, clues.New("extracting file url").With("item_id", ptr.Val(item.GetId()))
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, *url, nil)
|
||||
resp, err := client.Request(ctx, http.MethodGet, ptr.Val(url), nil, nil)
|
||||
if err != nil {
|
||||
return nil, graph.Wrap(ctx, err, "new item download request")
|
||||
}
|
||||
|
||||
//nolint:lll
|
||||
// Decorate the traffic
|
||||
// See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic
|
||||
req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version)
|
||||
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
cerr := graph.Wrap(ctx, err, "downloading item")
|
||||
|
||||
if graph.IsMalware(err) {
|
||||
cerr = cerr.Label(graph.LabelsMalware)
|
||||
}
|
||||
|
||||
return nil, cerr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if (resp.StatusCode / 100) == 2 {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if graph.IsMalwareResp(context.Background(), resp) {
|
||||
if graph.IsMalwareResp(ctx, resp) {
|
||||
return nil, clues.New("malware detected").Label(graph.LabelsMalware)
|
||||
}
|
||||
|
||||
|
||||
@ -112,7 +112,7 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() {
|
||||
)
|
||||
|
||||
// Read data for the file
|
||||
itemInfo, itemData, err := oneDriveItemReader(ctx, graph.HTTPClient(graph.NoTimeout()), driveItem)
|
||||
itemInfo, itemData, err := oneDriveItemReader(ctx, graph.NewNoTimeoutHTTPWrapper(), driveItem)
|
||||
|
||||
require.NoError(suite.T(), err, clues.ToCore(err))
|
||||
require.NotNil(suite.T(), itemInfo.OneDrive)
|
||||
|
||||
@ -2,7 +2,6 @@ package sharepoint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/alcionai/clues"
|
||||
|
||||
@ -29,7 +28,7 @@ type statusUpdater interface {
|
||||
// for the specified user
|
||||
func DataCollections(
|
||||
ctx context.Context,
|
||||
itemClient *http.Client,
|
||||
itemClient graph.Requester,
|
||||
selector selectors.Selector,
|
||||
creds account.M365Config,
|
||||
serv graph.Servicer,
|
||||
@ -182,7 +181,7 @@ func collectLists(
|
||||
// all the drives associated with the site.
|
||||
func collectLibraries(
|
||||
ctx context.Context,
|
||||
itemClient *http.Client,
|
||||
itemClient graph.Requester,
|
||||
serv graph.Servicer,
|
||||
tenantID, siteID string,
|
||||
scope selectors.SharePointScope,
|
||||
|
||||
@ -109,7 +109,7 @@ func (suite *SharePointLibrariesUnitSuite) TestUpdateCollections() {
|
||||
)
|
||||
|
||||
c := onedrive.NewCollections(
|
||||
graph.HTTPClient(graph.NoTimeout()),
|
||||
graph.NewNoTimeoutHTTPWrapper(),
|
||||
tenant,
|
||||
site,
|
||||
onedrive.SharePointSource,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user