diff --git a/src/cmd/getM365/onedrive/get_item.go b/src/cmd/getM365/onedrive/get_item.go index 8794fbb03..4868ab343 100644 --- a/src/cmd/getM365/onedrive/get_item.go +++ b/src/cmd/getM365/onedrive/get_item.go @@ -77,7 +77,10 @@ func handleOneDriveCmd(cmd *cobra.Command, args []string) error { return Only(ctx, clues.Wrap(err, "creating graph adapter")) } - err = runDisplayM365JSON(ctx, graph.NewService(adpt), creds, user, m365ID) + svc := graph.NewService(adpt) + gr := graph.NewNoTimeoutHTTPWrapper() + + err = runDisplayM365JSON(ctx, svc, gr, creds, user, m365ID) if err != nil { cmd.SilenceUsage = true cmd.SilenceErrors = true @@ -105,6 +108,7 @@ func (i itemPrintable) MinimumPrintable() any { func runDisplayM365JSON( ctx context.Context, srv graph.Servicer, + gr graph.Requester, creds account.M365Config, user, itemID string, ) error { @@ -123,7 +127,7 @@ func runDisplayM365JSON( } if item != nil { - content, err := getDriveItemContent(item) + content, err := getDriveItemContent(ctx, gr, item) if err != nil { return err } @@ -180,22 +184,19 @@ func serializeObject(data serialization.Parsable) (string, error) { return string(content), err } -func getDriveItemContent(item models.DriveItemable) ([]byte, error) { +func getDriveItemContent( + ctx context.Context, + gr graph.Requester, + item models.DriveItemable, +) ([]byte, error) { url, ok := item.GetAdditionalData()[downloadURLKey].(*string) if !ok { - return nil, clues.New("get download url") + return nil, clues.New("retrieving download url") } - req, err := http.NewRequest(http.MethodGet, *url, nil) + resp, err := gr.Request(ctx, http.MethodGet, *url, nil, nil) if err != nil { - return nil, clues.New("create download request").With("error", err) - } - - hc := graph.HTTPClient(graph.NoTimeout()) - - resp, err := hc.Do(req) - if err != nil { - return nil, clues.New("download item").With("error", err) + return nil, clues.New("downloading item").With("error", err) } content, err := io.ReadAll(resp.Body) diff --git a/src/internal/connector/data_collections_test.go b/src/internal/connector/data_collections_test.go index cfa4e171a..9bfd88dc0 100644 --- a/src/internal/connector/data_collections_test.go +++ b/src/internal/connector/data_collections_test.go @@ -258,7 +258,7 @@ func (suite *DataCollectionIntgSuite) TestSharePointDataCollection() { collections, excludes, err := sharepoint.DataCollections( ctx, - graph.HTTPClient(graph.NoTimeout()), + graph.NewNoTimeoutHTTPWrapper(), sel, connector.credentials, connector.Service, diff --git a/src/internal/connector/exchange/api/mock/mail.go b/src/internal/connector/exchange/api/mock/mail.go index 43f6f8d5c..6caf47f88 100644 --- a/src/internal/connector/exchange/api/mock/mail.go +++ b/src/internal/connector/exchange/api/mock/mail.go @@ -1,36 +1,21 @@ package mock import ( - "github.com/alcionai/clues" - "github.com/alcionai/corso/src/internal/connector/exchange/api" "github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/graph/mock" "github.com/alcionai/corso/src/pkg/account" ) -func NewService(creds account.M365Config, opts ...graph.Option) (*graph.Service, error) { - a, err := mock.CreateAdapter( - creds.AzureTenantID, - creds.AzureClientID, - creds.AzureClientSecret, - opts...) - if err != nil { - return nil, clues.Wrap(err, "generating graph adapter") - } - - return graph.NewService(a), nil -} - // NewClient produces a new exchange api client that can be // mocked using gock. func NewClient(creds account.M365Config) (api.Client, error) { - s, err := NewService(creds) + s, err := mock.NewService(creds) if err != nil { return api.Client{}, err } - li, err := NewService(creds, graph.NoTimeout()) + li, err := mock.NewService(creds, graph.NoTimeout()) if err != nil { return api.Client{}, err } diff --git a/src/internal/connector/graph/errors.go b/src/internal/connector/graph/errors.go index 70348762d..d5dca985a 100644 --- a/src/internal/connector/graph/errors.go +++ b/src/internal/connector/graph/errors.go @@ -234,6 +234,9 @@ func Stack(ctx context.Context, e error) *clues.Err { return setLabels(clues.Stack(e).WithClues(ctx).With(data...), innerMsg) } +// Checks for the following conditions and labels the error accordingly: +// * mysiteNotFound | mysiteURLNotFound +// * malware func setLabels(err *clues.Err, msg string) *clues.Err { if err == nil { return nil @@ -244,6 +247,10 @@ func setLabels(err *clues.Err, msg string) *clues.Err { err = err.Label(LabelsMysiteNotFound) } + if IsMalware(err) { + err = err.Label(LabelsMalware) + } + return err } diff --git a/src/internal/connector/graph/http_wrapper.go b/src/internal/connector/graph/http_wrapper.go new file mode 100644 index 000000000..1410fb194 --- /dev/null +++ b/src/internal/connector/graph/http_wrapper.go @@ -0,0 +1,152 @@ +package graph + +import ( + "context" + "io" + "net/http" + + "github.com/alcionai/clues" + khttp "github.com/microsoft/kiota-http-go" + + "github.com/alcionai/corso/src/internal/version" +) + +// --------------------------------------------------------------------------- +// constructors +// --------------------------------------------------------------------------- + +type Requester interface { + Request( + ctx context.Context, + method, url string, + body io.Reader, + headers map[string]string, + ) (*http.Response, error) +} + +// NewHTTPWrapper produces a http.Client wrapper that ensures +// calls use all the middleware we expect from the graph api client. +// +// Re-use of http clients is critical, or else we leak OS resources +// and consume relatively unbound socket connections. It is important +// to centralize this client to be passed downstream where api calls +// can utilize it on a per-download basis. +func NewHTTPWrapper(opts ...Option) *httpWrapper { + var ( + cc = populateConfig(opts...) + rt = customTransport{ + n: pipeline{ + middlewares: internalMiddleware(cc), + transport: defaultTransport(), + }, + } + redirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + hc = &http.Client{ + CheckRedirect: redirect, + Timeout: defaultHTTPClientTimeout, + Transport: rt, + } + ) + + cc.apply(hc) + + return &httpWrapper{hc} +} + +// NewNoTimeoutHTTPWrapper constructs a http wrapper with no context timeout. +// +// Re-use of http clients is critical, or else we leak OS resources +// and consume relatively unbound socket connections. It is important +// to centralize this client to be passed downstream where api calls +// can utilize it on a per-download basis. +func NewNoTimeoutHTTPWrapper(opts ...Option) *httpWrapper { + opts = append(opts, NoTimeout()) + return NewHTTPWrapper(opts...) +} + +// --------------------------------------------------------------------------- +// requests +// --------------------------------------------------------------------------- + +// Request does the provided request. +func (hw httpWrapper) Request( + ctx context.Context, + method, url string, + body io.Reader, + headers map[string]string, +) (*http.Response, error) { + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, clues.Wrap(err, "new http request") + } + + for k, v := range headers { + req.Header.Set(k, v) + } + + //nolint:lll + // Decorate the traffic + // See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic + req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version) + + resp, err := hw.client.Do(req) + if err != nil { + return nil, Stack(ctx, err) + } + + return resp, nil +} + +// --------------------------------------------------------------------------- +// constructor internals +// --------------------------------------------------------------------------- + +type ( + httpWrapper struct { + client *http.Client + } + + customTransport struct { + n nexter + } + + pipeline struct { + transport http.RoundTripper + middlewares []khttp.Middleware + } +) + +// RoundTrip kicks off the middleware chain and returns a response +func (ct customTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return ct.n.Next(req, 0) +} + +// Next moves the request object through middlewares in the pipeline +func (pl pipeline) Next(req *http.Request, idx int) (*http.Response, error) { + if idx < len(pl.middlewares) { + return pl.middlewares[idx].Intercept(pl, idx+1, req) + } + + return pl.transport.RoundTrip(req) +} + +func defaultTransport() http.RoundTripper { + defaultTransport := http.DefaultTransport.(*http.Transport).Clone() + defaultTransport.ForceAttemptHTTP2 = true + + return defaultTransport +} + +func internalMiddleware(cc *clientConfig) []khttp.Middleware { + return []khttp.Middleware{ + &RetryHandler{ + MaxRetries: cc.maxRetries, + Delay: cc.minDelay, + }, + &LoggingMiddleware{}, + &ThrottleControlMiddleware{}, + &MetricsMiddleware{}, + } +} diff --git a/src/internal/connector/graph/http_wrapper_test.go b/src/internal/connector/graph/http_wrapper_test.go new file mode 100644 index 000000000..483a5f0ba --- /dev/null +++ b/src/internal/connector/graph/http_wrapper_test.go @@ -0,0 +1,45 @@ +package graph + +import ( + "net/http" + "testing" + + "github.com/alcionai/clues" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/tester" +) + +type HTTPWrapperIntgSuite struct { + tester.Suite +} + +func TestHTTPWrapperIntgSuite(t *testing.T) { + suite.Run(t, &HTTPWrapperIntgSuite{ + Suite: tester.NewIntegrationSuite( + t, + [][]string{tester.M365AcctCredEnvs}), + }) +} + +func (suite *HTTPWrapperIntgSuite) TestNewHTTPWrapper() { + ctx, flush := tester.NewContext() + defer flush() + + var ( + t = suite.T() + hw = NewHTTPWrapper() + ) + + resp, err := hw.Request( + ctx, + http.MethodGet, + "https://www.corsobackup.io", + nil, + nil) + + require.NoError(t, err, clues.ToCore(err)) + require.NotNil(t, resp) + require.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/src/internal/connector/graph/middleware.go b/src/internal/connector/graph/middleware.go index bedfbd932..57825c38f 100644 --- a/src/internal/connector/graph/middleware.go +++ b/src/internal/connector/graph/middleware.go @@ -20,6 +20,10 @@ import ( "github.com/alcionai/corso/src/pkg/logger" ) +type nexter interface { + Next(req *http.Request, middlewareIndex int) (*http.Response, error) +} + // --------------------------------------------------------------------------- // Logging // --------------------------------------------------------------------------- diff --git a/src/internal/connector/graph/mock/service.go b/src/internal/connector/graph/mock/service.go index 9a2a9b292..a44d9f1ca 100644 --- a/src/internal/connector/graph/mock/service.go +++ b/src/internal/connector/graph/mock/service.go @@ -1,12 +1,27 @@ package mock import ( + "github.com/alcionai/clues" "github.com/h2non/gock" msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go" "github.com/alcionai/corso/src/internal/connector/graph" + "github.com/alcionai/corso/src/pkg/account" ) +func NewService(creds account.M365Config, opts ...graph.Option) (*graph.Service, error) { + a, err := CreateAdapter( + creds.AzureTenantID, + creds.AzureClientID, + creds.AzureClientSecret, + opts...) + if err != nil { + return nil, clues.Wrap(err, "generating graph adapter") + } + + return graph.NewService(a), nil +} + // CreateAdapter is similar to graph.CreateAdapter, but with option to // enable interceptions via gock to make it mockable. func CreateAdapter( @@ -18,7 +33,7 @@ func CreateAdapter( return nil, err } - httpClient := graph.HTTPClient(opts...) + httpClient := graph.KiotaHTTPClient(opts...) // This makes sure that we are able to intercept any requests via // gock. Only necessary for testing. diff --git a/src/internal/connector/graph/service.go b/src/internal/connector/graph/service.go index ff8b3a85d..96e7b0a52 100644 --- a/src/internal/connector/graph/service.go +++ b/src/internal/connector/graph/service.go @@ -114,7 +114,7 @@ func CreateAdapter( return nil, err } - httpClient := HTTPClient(opts...) + httpClient := KiotaHTTPClient(opts...) return msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient( auth, @@ -140,21 +140,24 @@ func GetAuth(tenant string, client string, secret string) (*kauth.AzureIdentityA return auth, nil } -// HTTPClient creates the httpClient with middlewares and timeout configured +// KiotaHTTPClient creates a httpClient with middlewares and timeout configured +// for use in the graph adapter. // // Re-use of http clients is critical, or else we leak OS resources // and consume relatively unbound socket connections. It is important // to centralize this client to be passed downstream where api calls // can utilize it on a per-download basis. -func HTTPClient(opts ...Option) *http.Client { - clientOptions := msgraphsdkgo.GetDefaultClientOptions() - clientconfig := (&clientConfig{}).populate(opts...) - noOfRetries, minRetryDelay := clientconfig.applyMiddlewareConfig() - middlewares := GetKiotaMiddlewares(&clientOptions, noOfRetries, minRetryDelay) - httpClient := msgraphgocore.GetDefaultClient(&clientOptions, middlewares...) +func KiotaHTTPClient(opts ...Option) *http.Client { + var ( + clientOptions = msgraphsdkgo.GetDefaultClientOptions() + cc = populateConfig(opts...) + middlewares = kiotaMiddlewares(&clientOptions, cc) + httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...) + ) + httpClient.Timeout = defaultHTTPClientTimeout - clientconfig.apply(httpClient) + cc.apply(httpClient) return httpClient } @@ -175,27 +178,17 @@ type clientConfig struct { type Option func(*clientConfig) // populate constructs a clientConfig according to the provided options. -func (c *clientConfig) populate(opts ...Option) *clientConfig { +func populateConfig(opts ...Option) *clientConfig { + cc := clientConfig{ + maxRetries: defaultMaxRetries, + minDelay: defaultDelay, + } + for _, opt := range opts { - opt(c) + opt(&cc) } - return c -} - -// apply updates the http.Client with the expected options. -func (c *clientConfig) applyMiddlewareConfig() (retry int, delay time.Duration) { - retry = defaultMaxRetries - if c.overrideRetryCount { - retry = c.maxRetries - } - - delay = defaultDelay - if c.minDelay > 0 { - delay = c.minDelay - } - - return + return &cc } // apply updates the http.Client with the expected options. @@ -236,14 +229,16 @@ func MinimumBackoff(dur time.Duration) Option { // Middleware Control // --------------------------------------------------------------------------- -// GetDefaultMiddlewares creates a new default set of middlewares for the Kiota request adapter -func GetMiddlewares(maxRetry int, delay time.Duration) []khttp.Middleware { +// kiotaMiddlewares creates a default slice of middleware for the Graph Client. +func kiotaMiddlewares( + options *msgraphgocore.GraphClientOptions, + cc *clientConfig, +) []khttp.Middleware { return []khttp.Middleware{ + msgraphgocore.NewGraphTelemetryHandler(options), &RetryHandler{ - // The maximum number of times a request can be retried - MaxRetries: maxRetry, - // The delay in seconds between retries - Delay: delay, + MaxRetries: cc.maxRetries, + Delay: cc.minDelay, }, khttp.NewRetryHandler(), khttp.NewRedirectHandler(), @@ -255,21 +250,3 @@ func GetMiddlewares(maxRetry int, delay time.Duration) []khttp.Middleware { &MetricsMiddleware{}, } } - -// GetKiotaMiddlewares creates a default slice of middleware for the Graph Client. -func GetKiotaMiddlewares( - options *msgraphgocore.GraphClientOptions, - maxRetry int, - minDelay time.Duration, -) []khttp.Middleware { - kiotaMiddlewares := GetMiddlewares(maxRetry, minDelay) - graphMiddlewares := []khttp.Middleware{ - msgraphgocore.NewGraphTelemetryHandler(options), - } - graphMiddlewaresLen := len(graphMiddlewares) - resultMiddlewares := make([]khttp.Middleware, len(kiotaMiddlewares)+graphMiddlewaresLen) - copy(resultMiddlewares, graphMiddlewares) - copy(resultMiddlewares[graphMiddlewaresLen:], kiotaMiddlewares) - - return resultMiddlewares -} diff --git a/src/internal/connector/graph/service_test.go b/src/internal/connector/graph/service_test.go index 4565efca1..9d4aad624 100644 --- a/src/internal/connector/graph/service_test.go +++ b/src/internal/connector/graph/service_test.go @@ -70,7 +70,7 @@ func (suite *GraphUnitSuite) TestHTTPClient() { suite.Run(test.name, func() { t := suite.T() - cli := HTTPClient(test.opts...) + cli := KiotaHTTPClient(test.opts...) assert.NotNil(t, cli) test.check(t, cli) }) diff --git a/src/internal/connector/graph_connector.go b/src/internal/connector/graph_connector.go index 94e9e1634..22126b82f 100644 --- a/src/internal/connector/graph_connector.go +++ b/src/internal/connector/graph_connector.go @@ -4,7 +4,6 @@ package connector import ( "context" - "net/http" "runtime/trace" "sync" @@ -36,7 +35,7 @@ var ( type GraphConnector struct { Service graph.Servicer Discovery api.Client - itemClient *http.Client // configured to handle large item downloads + itemClient graph.Requester // configured to handle large item downloads tenant string credentials account.M365Config @@ -88,7 +87,7 @@ func NewGraphConnector( Service: service, credentials: creds, - itemClient: graph.HTTPClient(graph.NoTimeout()), + itemClient: graph.NewNoTimeoutHTTPWrapper(), ownerLookup: rc, tenant: acct.ID(), wg: &sync.WaitGroup{}, diff --git a/src/internal/connector/onedrive/collection.go b/src/internal/connector/onedrive/collection.go index 39624f3a6..aef3dd7ab 100644 --- a/src/internal/connector/onedrive/collection.go +++ b/src/internal/connector/onedrive/collection.go @@ -35,10 +35,6 @@ const ( // TODO: Tune this later along with collectionChannelBufferSize urlPrefetchChannelBufferSize = 5 - // maxDownloadRetires specifies the number of times a file download should - // be retried - maxDownloadRetires = 3 - // Used to compare in case of OneNote files MaxOneNoteFileSize = 2 * 1024 * 1024 * 1024 ) @@ -62,7 +58,7 @@ const ( // Collection represents a set of OneDrive objects retrieved from M365 type Collection struct { // configured to handle large item downloads - itemClient *http.Client + itemClient graph.Requester // data is used to share data streams with the collection consumer data chan data.Stream @@ -110,7 +106,7 @@ type Collection struct { doNotMergeItems bool } -// itemGetterFunc gets an specified item +// itemGetterFunc gets a specified item type itemGetterFunc func( ctx context.Context, srv graph.Servicer, @@ -120,7 +116,7 @@ type itemGetterFunc func( // itemReadFunc returns a reader for the specified item type itemReaderFunc func( ctx context.Context, - hc *http.Client, + client graph.Requester, item models.DriveItemable, ) (details.ItemInfo, io.ReadCloser, error) @@ -148,7 +144,7 @@ func pathToLocation(p path.Path) (*path.Builder, error) { // NewCollection creates a Collection func NewCollection( - itemClient *http.Client, + itemClient graph.Requester, folderPath path.Path, prevPath path.Path, driveID string, @@ -372,45 +368,29 @@ func (oc *Collection) getDriveItemContent( itemID = ptr.Val(item.GetId()) itemName = ptr.Val(item.GetName()) el = errs.Local() - - itemData io.ReadCloser - err error ) - // Initial try with url from delta + 2 retries - for i := 1; i <= maxDownloadRetires; i++ { - _, itemData, err = oc.itemReader(ctx, oc.itemClient, item) - if err == nil || !graph.IsErrUnauthorized(err) { - break - } - - // Assume unauthorized requests are a sign of an expired jwt - // token, and that we've overrun the available window to - // download the actual file. Re-downloading the item will - // refresh that download url. - di, diErr := oc.itemGetter(ctx, oc.service, oc.driveID, itemID) - if diErr != nil { - err = clues.Wrap(diErr, "retrieving expired item") - break - } - - item = di - } - - // check for errors following retries + itemData, err := downloadContent( + ctx, + oc.service, + oc.itemGetter, + oc.itemReader, + oc.itemClient, + item, + oc.driveID) if err != nil { if clues.HasLabel(err, graph.LabelsMalware) || (item != nil && item.GetMalware() != nil) { logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipMalware).Info("item flagged as malware") el.AddSkip(fault.FileSkip(fault.SkipMalware, itemID, itemName, graph.ItemInfo(item))) - return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable) + return nil, clues.Wrap(err, "malware item").Label(graph.LabelsSkippable) } if clues.HasLabel(err, graph.LabelStatus(http.StatusNotFound)) || graph.IsErrDeletedInFlight(err) { logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipNotFound).Info("item not found") el.AddSkip(fault.FileSkip(fault.SkipNotFound, itemID, itemName, graph.ItemInfo(item))) - return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable) + return nil, clues.Wrap(err, "deleted item").Label(graph.LabelsSkippable) } // Skip big OneNote files as they can't be downloaded @@ -425,7 +405,7 @@ func (oc *Collection) getDriveItemContent( logger.CtxErr(ctx, err).With("skipped_reason", fault.SkipBigOneNote).Info("max OneNote file size exceeded") el.AddSkip(fault.FileSkip(fault.SkipBigOneNote, itemID, itemName, graph.ItemInfo(item))) - return nil, clues.Wrap(err, "downloading item").Label(graph.LabelsSkippable) + return nil, clues.Wrap(err, "max oneNote item").Label(graph.LabelsSkippable) } logger.CtxErr(ctx, err).Error("downloading item") @@ -433,12 +413,48 @@ func (oc *Collection) getDriveItemContent( // return err, not el.Err(), because the lazy reader needs to communicate to // the data consumer that this item is unreadable, regardless of the fault state. - return nil, clues.Wrap(err, "downloading item") + return nil, clues.Wrap(err, "fetching item content") } return itemData, nil } +// downloadContent attempts to fetch the item content. If the content url +// is expired (ie, returns a 401), it re-fetches the item to get a new download +// url and tries again. +func downloadContent( + ctx context.Context, + svc graph.Servicer, + igf itemGetterFunc, + irf itemReaderFunc, + gr graph.Requester, + item models.DriveItemable, + driveID string, +) (io.ReadCloser, error) { + _, content, err := irf(ctx, gr, item) + if err == nil { + return content, nil + } else if !graph.IsErrUnauthorized(err) { + return nil, err + } + + // Assume unauthorized requests are a sign of an expired jwt + // token, and that we've overrun the available window to + // download the actual file. Re-downloading the item will + // refresh that download url. + di, err := igf(ctx, svc, driveID, ptr.Val(item.GetId())) + if err != nil { + return nil, clues.Wrap(err, "retrieving expired item") + } + + _, content, err = irf(ctx, gr, di) + if err != nil { + return nil, clues.Wrap(err, "content download retry") + } + + return content, nil +} + // populateItems iterates through items added to the collection // and uses the collection `itemReader` to read the item func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) { @@ -504,9 +520,9 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) { ctx = clues.Add( ctx, - "backup_item_id", itemID, - "backup_item_name", itemName, - "backup_item_size", itemSize) + "item_id", itemID, + "item_name", itemName, + "item_size", itemSize) item.SetParentReference(setName(item.GetParentReference(), oc.driveName)) @@ -545,7 +561,7 @@ func (oc *Collection) populateItems(ctx context.Context, errs *fault.Bus) { itemInfo.OneDrive.ParentPath = parentPathString } - ctx = clues.Add(ctx, "backup_item_info", itemInfo) + ctx = clues.Add(ctx, "item_info", itemInfo) if isFile { dataSuffix := metadata.DataFileSuffix diff --git a/src/internal/connector/onedrive/collection_test.go b/src/internal/connector/onedrive/collection_test.go index b4328fe9b..682033f07 100644 --- a/src/internal/connector/onedrive/collection_test.go +++ b/src/internal/connector/onedrive/collection_test.go @@ -18,6 +18,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/onedrive/metadata" "github.com/alcionai/corso/src/internal/connector/support" @@ -98,7 +99,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() { numInstances: 1, source: OneDriveSource, itemDeets: nst{testItemName, 42, now}, - itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: testItemName, Modified: now}}, io.NopCloser(bytes.NewReader(testItemData)), nil @@ -114,7 +115,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() { numInstances: 3, source: OneDriveSource, itemDeets: nst{testItemName, 42, now}, - itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: testItemName, Modified: now}}, io.NopCloser(bytes.NewReader(testItemData)), nil @@ -130,7 +131,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() { numInstances: 3, source: OneDriveSource, itemDeets: nst{testItemName, 42, now}, - itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { return details.ItemInfo{}, nil, clues.New("test malware").Label(graph.LabelsMalware) }, infoFrom: func(t *testing.T, dii details.ItemInfo) (string, string) { @@ -146,7 +147,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() { source: OneDriveSource, itemDeets: nst{testItemName, 42, now}, // Usually `Not Found` is returned from itemGetter and not itemReader - itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { return details.ItemInfo{}, nil, clues.New("test not found").Label(graph.LabelStatus(http.StatusNotFound)) }, infoFrom: func(t *testing.T, dii details.ItemInfo) (string, string) { @@ -161,7 +162,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() { numInstances: 1, source: SharePointSource, itemDeets: nst{testItemName, 42, now}, - itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { return details.ItemInfo{SharePoint: &details.SharePointInfo{ItemName: testItemName, Modified: now}}, io.NopCloser(bytes.NewReader(testItemData)), nil @@ -177,7 +178,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() { numInstances: 3, source: SharePointSource, itemDeets: nst{testItemName, 42, now}, - itemReader: func(context.Context, *http.Client, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + itemReader: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { return details.ItemInfo{SharePoint: &details.SharePointInfo{ItemName: testItemName, Modified: now}}, io.NopCloser(bytes.NewReader(testItemData)), nil @@ -207,7 +208,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() { require.NoError(t, err, clues.ToCore(err)) coll, err := NewCollection( - graph.HTTPClient(graph.NoTimeout()), + graph.NewNoTimeoutHTTPWrapper(), folderPath, nil, "drive-id", @@ -278,7 +279,7 @@ func (suite *CollectionUnitTestSuite) TestCollection() { if err != nil { for _, label := range test.expectLabels { - assert.True(t, clues.HasLabel(err, label), "has clues label:", label) + assert.Truef(t, clues.HasLabel(err, label), "has clues label: %s", label) } return @@ -347,7 +348,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadError() { require.NoError(t, err, clues.ToCore(err)) coll, err := NewCollection( - graph.HTTPClient(graph.NoTimeout()), + graph.NewNoTimeoutHTTPWrapper(), folderPath, nil, "fakeDriveID", @@ -370,7 +371,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadError() { coll.itemReader = func( context.Context, - *http.Client, + graph.Requester, models.DriveItemable, ) (details.ItemInfo, io.ReadCloser, error) { return details.ItemInfo{}, nil, assert.AnError @@ -437,7 +438,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry() require.NoError(t, err) coll, err := NewCollection( - graph.HTTPClient(graph.NoTimeout()), + graph.NewNoTimeoutHTTPWrapper(), folderPath, nil, "fakeDriveID", @@ -470,10 +471,10 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry() coll.itemReader = func( context.Context, - *http.Client, + graph.Requester, models.DriveItemable, ) (details.ItemInfo, io.ReadCloser, error) { - if count < 2 { + if count < 1 { count++ return details.ItemInfo{}, nil, clues.Stack(assert.AnError). Label(graph.LabelStatus(http.StatusUnauthorized)) @@ -494,13 +495,13 @@ func (suite *CollectionUnitTestSuite) TestCollectionReadUnauthorizedErrorRetry() assert.True(t, ok) _, err = io.ReadAll(collItem.ToReader()) - assert.NoError(t, err) + assert.NoError(t, err, clues.ToCore(err)) wg.Wait() require.Equal(t, 1, collStatus.Metrics.Objects, "only one object should be counted") require.Equal(t, 1, collStatus.Metrics.Successes, "read object successfully") - require.Equal(t, 2, count, "retry count") + require.Equal(t, 1, count, "retry count") }) } } @@ -537,7 +538,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionPermissionBackupLatestModTim require.NoError(t, err, clues.ToCore(err)) coll, err := NewCollection( - graph.HTTPClient(graph.NoTimeout()), + graph.NewNoTimeoutHTTPWrapper(), folderPath, nil, "drive-id", @@ -561,7 +562,7 @@ func (suite *CollectionUnitTestSuite) TestCollectionPermissionBackupLatestModTim coll.itemReader = func( context.Context, - *http.Client, + graph.Requester, models.DriveItemable, ) (details.ItemInfo, io.ReadCloser, error) { return details.ItemInfo{OneDrive: &details.OneDriveInfo{ItemName: "fakeName", Modified: time.Now()}}, @@ -611,7 +612,7 @@ func TestGetDriveItemUnitTestSuite(t *testing.T) { suite.Run(t, &GetDriveItemUnitTestSuite{Suite: tester.NewUnitSuite(t)}) } -func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() { +func (suite *GetDriveItemUnitTestSuite) TestGetDriveItem_error() { strval := "not-important" table := []struct { @@ -637,14 +638,14 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() { name: "malware error", colScope: CollectionScopeFolder, itemSize: 10, - err: clues.New("test error").Label(graph.LabelsMalware), + err: clues.New("malware error").Label(graph.LabelsMalware), labels: []string{graph.LabelsMalware, graph.LabelsSkippable}, }, { name: "file not found error", colScope: CollectionScopeFolder, itemSize: 10, - err: clues.New("test error").Label(graph.LabelStatus(http.StatusNotFound)), + err: clues.New("not found error").Label(graph.LabelStatus(http.StatusNotFound)), labels: []string{graph.LabelStatus(http.StatusNotFound), graph.LabelsSkippable}, }, { @@ -652,14 +653,14 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() { name: "small OneNote file", colScope: CollectionScopePackage, itemSize: 10, - err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)), + err: clues.New("small onenote error").Label(graph.LabelStatus(http.StatusServiceUnavailable)), labels: []string{graph.LabelStatus(http.StatusServiceUnavailable)}, }, { name: "big OneNote file", colScope: CollectionScopePackage, itemSize: MaxOneNoteFileSize, - err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)), + err: clues.New("big onenote error").Label(graph.LabelStatus(http.StatusServiceUnavailable)), labels: []string{graph.LabelStatus(http.StatusServiceUnavailable), graph.LabelsSkippable}, }, { @@ -667,7 +668,7 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() { name: "big file", colScope: CollectionScopeFolder, itemSize: MaxOneNoteFileSize, - err: clues.New("test error").Label(graph.LabelStatus(http.StatusServiceUnavailable)), + err: clues.New("big file error").Label(graph.LabelStatus(http.StatusServiceUnavailable)), labels: []string{graph.LabelStatus(http.StatusServiceUnavailable)}, }, } @@ -689,9 +690,9 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() { item.SetSize(&test.itemSize) col.itemReader = func( - ctx context.Context, - hc *http.Client, - item models.DriveItemable, + _ context.Context, + _ graph.Requester, + _ models.DriveItemable, ) (details.ItemInfo, io.ReadCloser, error) { return details.ItemInfo{}, nil, test.err } @@ -707,11 +708,11 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() { _, err := col.getDriveItemContent(ctx, item, errs) if test.err == nil { - assert.NoError(t, err, "no error") + assert.NoError(t, err, clues.ToCore(err)) return } - assert.EqualError(t, err, clues.Wrap(test.err, "downloading item").Error(), "error") + assert.ErrorIs(t, err, test.err, clues.ToCore(err)) labelsMap := map[string]struct{}{} for _, l := range test.labels { @@ -722,3 +723,103 @@ func (suite *GetDriveItemUnitTestSuite) TestGetDriveItemError() { }) } } + +func (suite *GetDriveItemUnitTestSuite) TestDownloadContent() { + var ( + svc graph.Servicer + gr graph.Requester + driveID string + iorc = io.NopCloser(bytes.NewReader([]byte("fnords"))) + item = &models.DriveItem{} + itemWID = &models.DriveItem{} + ) + + itemWID.SetId(ptr.To("brainhooldy")) + + table := []struct { + name string + igf itemGetterFunc + irf itemReaderFunc + expectErr require.ErrorAssertionFunc + expect require.ValueAssertionFunc + }{ + { + name: "good", + irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + return details.ItemInfo{}, iorc, nil + }, + expectErr: require.NoError, + expect: require.NotNil, + }, + { + name: "expired url redownloads", + igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) { + return itemWID, nil + }, + irf: func(c context.Context, g graph.Requester, m models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + // a bit hacky: assume only igf returns an item with a non-zero id. + if len(ptr.Val(m.GetId())) == 0 { + return details.ItemInfo{}, + nil, + clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized)) + } + + return details.ItemInfo{}, iorc, nil + }, + expectErr: require.NoError, + expect: require.NotNil, + }, + { + name: "immediate error", + irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + return details.ItemInfo{}, nil, assert.AnError + }, + expectErr: require.Error, + expect: require.Nil, + }, + { + name: "re-fetching the item fails", + igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) { + return nil, assert.AnError + }, + irf: func(context.Context, graph.Requester, models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + return details.ItemInfo{}, + nil, + clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized)) + }, + expectErr: require.Error, + expect: require.Nil, + }, + { + name: "expired url fails redownload", + igf: func(context.Context, graph.Servicer, string, string) (models.DriveItemable, error) { + return itemWID, nil + }, + irf: func(c context.Context, g graph.Requester, m models.DriveItemable) (details.ItemInfo, io.ReadCloser, error) { + // a bit hacky: assume only igf returns an item with a non-zero id. + if len(ptr.Val(m.GetId())) == 0 { + return details.ItemInfo{}, + nil, + clues.Stack(assert.AnError).Label(graph.LabelStatus(http.StatusUnauthorized)) + } + + return details.ItemInfo{}, iorc, assert.AnError + }, + expectErr: require.Error, + expect: require.Nil, + }, + } + for _, test := range table { + suite.Run(test.name, func() { + ctx, flush := tester.NewContext() + defer flush() + + t := suite.T() + + r, err := downloadContent(ctx, svc, test.igf, test.irf, gr, item, driveID) + + test.expect(t, r) + test.expectErr(t, err, clues.ToCore(err)) + }) + } +} diff --git a/src/internal/connector/onedrive/collections.go b/src/internal/connector/onedrive/collections.go index fdac083c8..aca636b94 100644 --- a/src/internal/connector/onedrive/collections.go +++ b/src/internal/connector/onedrive/collections.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "net/http" "strings" "github.com/alcionai/clues" @@ -73,7 +72,7 @@ type folderMatcher interface { // resource owner, which can be either a user or a sharepoint site. type Collections struct { // configured to handle large item downloads - itemClient *http.Client + itemClient graph.Requester tenant string resourceOwner string @@ -109,7 +108,7 @@ type Collections struct { } func NewCollections( - itemClient *http.Client, + itemClient graph.Requester, tenant string, resourceOwner string, source driveSource, diff --git a/src/internal/connector/onedrive/collections_test.go b/src/internal/connector/onedrive/collections_test.go index 5598d701e..d9e6fde6c 100644 --- a/src/internal/connector/onedrive/collections_test.go +++ b/src/internal/connector/onedrive/collections_test.go @@ -780,7 +780,7 @@ func (suite *OneDriveCollectionsUnitSuite) TestUpdateCollections() { maps.Copy(outputFolderMap, tt.inputFolderMap) c := NewCollections( - graph.HTTPClient(graph.NoTimeout()), + graph.NewNoTimeoutHTTPWrapper(), tenant, user, OneDriveSource, @@ -2231,7 +2231,7 @@ func (suite *OneDriveCollectionsUnitSuite) TestGet() { } c := NewCollections( - graph.HTTPClient(graph.NoTimeout()), + graph.NewNoTimeoutHTTPWrapper(), tenant, user, OneDriveSource, diff --git a/src/internal/connector/onedrive/data_collections.go b/src/internal/connector/onedrive/data_collections.go index a0c3e648f..90c7bf782 100644 --- a/src/internal/connector/onedrive/data_collections.go +++ b/src/internal/connector/onedrive/data_collections.go @@ -2,7 +2,6 @@ package onedrive import ( "context" - "net/http" "github.com/alcionai/clues" "golang.org/x/exp/maps" @@ -38,7 +37,7 @@ func DataCollections( user common.IDNamer, metadata []data.RestoreCollection, tenant string, - itemClient *http.Client, + itemClient graph.Requester, service graph.Servicer, su support.StatusUpdater, ctrlOpts control.Options, diff --git a/src/internal/connector/onedrive/drive_test.go b/src/internal/connector/onedrive/drive_test.go index 26f8c5c85..06b460cff 100644 --- a/src/internal/connector/onedrive/drive_test.go +++ b/src/internal/connector/onedrive/drive_test.go @@ -426,7 +426,7 @@ func (suite *OneDriveSuite) TestOneDriveNewCollections() { ) colls := NewCollections( - graph.HTTPClient(graph.NoTimeout()), + graph.NewNoTimeoutHTTPWrapper(), creds.AzureTenantID, test.user, OneDriveSource, diff --git a/src/internal/connector/onedrive/item.go b/src/internal/connector/onedrive/item.go index 209cdce15..340746436 100644 --- a/src/internal/connector/onedrive/item.go +++ b/src/internal/connector/onedrive/item.go @@ -16,7 +16,6 @@ import ( "github.com/alcionai/corso/src/internal/connector/graph" "github.com/alcionai/corso/src/internal/connector/onedrive/api" "github.com/alcionai/corso/src/internal/connector/uploadsession" - "github.com/alcionai/corso/src/internal/version" "github.com/alcionai/corso/src/pkg/backup/details" "github.com/alcionai/corso/src/pkg/logger" ) @@ -33,12 +32,12 @@ const ( // TODO: Add metadata fetching to SharePoint func sharePointItemReader( ctx context.Context, - hc *http.Client, + client graph.Requester, item models.DriveItemable, ) (details.ItemInfo, io.ReadCloser, error) { - resp, err := downloadItem(ctx, hc, item) + resp, err := downloadItem(ctx, client, item) if err != nil { - return details.ItemInfo{}, nil, clues.Wrap(err, "downloading item") + return details.ItemInfo{}, nil, clues.Wrap(err, "sharepoint reader") } dii := details.ItemInfo{ @@ -107,7 +106,7 @@ func baseItemMetaReader( // and using a http client to initialize a reader func oneDriveItemReader( ctx context.Context, - hc *http.Client, + client graph.Requester, item models.DriveItemable, ) (details.ItemInfo, io.ReadCloser, error) { var ( @@ -116,9 +115,9 @@ func oneDriveItemReader( ) if isFile { - resp, err := downloadItem(ctx, hc, item) + resp, err := downloadItem(ctx, client, item) if err != nil { - return details.ItemInfo{}, nil, clues.Wrap(err, "downloading item") + return details.ItemInfo{}, nil, clues.Wrap(err, "onedrive reader") } rc = resp.Body @@ -131,38 +130,26 @@ func oneDriveItemReader( return dii, rc, nil } -func downloadItem(ctx context.Context, hc *http.Client, item models.DriveItemable) (*http.Response, error) { +func downloadItem( + ctx context.Context, + client graph.Requester, + item models.DriveItemable, +) (*http.Response, error) { url, ok := item.GetAdditionalData()[downloadURLKey].(*string) if !ok { return nil, clues.New("extracting file url").With("item_id", ptr.Val(item.GetId())) } - req, err := http.NewRequest(http.MethodGet, *url, nil) + resp, err := client.Request(ctx, http.MethodGet, ptr.Val(url), nil, nil) if err != nil { - return nil, graph.Wrap(ctx, err, "new item download request") - } - - //nolint:lll - // Decorate the traffic - // See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic - req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version) - - resp, err := hc.Do(req) - if err != nil { - cerr := graph.Wrap(ctx, err, "downloading item") - - if graph.IsMalware(err) { - cerr = cerr.Label(graph.LabelsMalware) - } - - return nil, cerr + return nil, err } if (resp.StatusCode / 100) == 2 { return resp, nil } - if graph.IsMalwareResp(context.Background(), resp) { + if graph.IsMalwareResp(ctx, resp) { return nil, clues.New("malware detected").Label(graph.LabelsMalware) } diff --git a/src/internal/connector/onedrive/item_test.go b/src/internal/connector/onedrive/item_test.go index 89dbd4036..992b446d1 100644 --- a/src/internal/connector/onedrive/item_test.go +++ b/src/internal/connector/onedrive/item_test.go @@ -112,7 +112,7 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() { ) // Read data for the file - itemInfo, itemData, err := oneDriveItemReader(ctx, graph.HTTPClient(graph.NoTimeout()), driveItem) + itemInfo, itemData, err := oneDriveItemReader(ctx, graph.NewNoTimeoutHTTPWrapper(), driveItem) require.NoError(suite.T(), err, clues.ToCore(err)) require.NotNil(suite.T(), itemInfo.OneDrive) diff --git a/src/internal/connector/sharepoint/data_collections.go b/src/internal/connector/sharepoint/data_collections.go index 815f3b1bb..51364373f 100644 --- a/src/internal/connector/sharepoint/data_collections.go +++ b/src/internal/connector/sharepoint/data_collections.go @@ -2,7 +2,6 @@ package sharepoint import ( "context" - "net/http" "github.com/alcionai/clues" @@ -29,7 +28,7 @@ type statusUpdater interface { // for the specified user func DataCollections( ctx context.Context, - itemClient *http.Client, + itemClient graph.Requester, selector selectors.Selector, creds account.M365Config, serv graph.Servicer, @@ -182,7 +181,7 @@ func collectLists( // all the drives associated with the site. func collectLibraries( ctx context.Context, - itemClient *http.Client, + itemClient graph.Requester, serv graph.Servicer, tenantID, siteID string, scope selectors.SharePointScope, diff --git a/src/internal/connector/sharepoint/data_collections_test.go b/src/internal/connector/sharepoint/data_collections_test.go index b7411e059..e787aea41 100644 --- a/src/internal/connector/sharepoint/data_collections_test.go +++ b/src/internal/connector/sharepoint/data_collections_test.go @@ -109,7 +109,7 @@ func (suite *SharePointLibrariesUnitSuite) TestUpdateCollections() { ) c := onedrive.NewCollections( - graph.HTTPClient(graph.NoTimeout()), + graph.NewNoTimeoutHTTPWrapper(), tenant, site, onedrive.SharePointSource,