From 3cd82de23ada51bb16a405cb23eaaf33233bd8c7 Mon Sep 17 00:00:00 2001 From: Keepers Date: Thu, 26 Jan 2023 18:20:41 -0700 Subject: [PATCH] re-fetch file download url after expiration (#2283) ## Description If a drive item goes over its 1 hour jwt expiration to download the backing file, re-fetch the item and use the new download url to get the file. ## Does this PR need a docs update or release note? - [x] :no_entry: No ## Type of change - [x] :sunflower: Feature ## Issue(s) * #2267 ## Test Plan - [x] :muscle: Manual --- src/internal/connector/graph/errors.go | 109 +++++++++++++++--- .../connector/graph/service_helper.go | 41 +++++-- src/internal/connector/onedrive/collection.go | 30 ++++- src/internal/connector/onedrive/item.go | 77 +++++++++---- 4 files changed, 202 insertions(+), 55 deletions(-) diff --git a/src/internal/connector/graph/errors.go b/src/internal/connector/graph/errors.go index 86cec64bd..049660425 100644 --- a/src/internal/connector/graph/errors.go +++ b/src/internal/connector/graph/errors.go @@ -26,6 +26,15 @@ const ( errCodeMailboxNotEnabledForRESTAPI = "MailboxNotEnabledForRESTAPI" ) +var ( + Err401Unauthorized = errors.New("401 unauthorized") + // normally the graph client will catch this for us, but in case we + // run our own client Do(), we need to translate it to a timeout type + // failure locally. + Err429TooManyRequests = errors.New("429 too many requests") + Err503ServiceUnavailable = errors.New("503 Service Unavailable") +) + // The folder or item was deleted between the time we identified // it and when we tried to fetch data for it. type ErrDeletedInFlight struct { @@ -102,6 +111,89 @@ func asTimeout(err error) bool { return errors.As(err, &e) } +// isTimeoutErr is used to determine if the Graph error returned is +// because of Timeout. This is used to restrict retries to just +// timeouts as other errors are handled within a middleware in the +// client. +func isTimeoutErr(err error) bool { + if errors.Is(err, context.DeadlineExceeded) || os.IsTimeout(err) { + return true + } + + switch err := err.(type) { + case *url.Error: + return err.Timeout() + default: + return false + } +} + +type ErrThrottled struct { + common.Err +} + +func IsErrThrottled(err error) error { + if errors.Is(err, Err429TooManyRequests) { + return err + } + + if asThrottled(err) { + return err + } + + return nil +} + +func asThrottled(err error) bool { + e := ErrThrottled{} + return errors.As(err, &e) +} + +type ErrUnauthorized struct { + common.Err +} + +func IsErrUnauthorized(err error) error { + // TODO: refine this investigation. We don't currently know if + // a specific item download url expired, or if the full connection + // auth expired. + if errors.Is(err, Err401Unauthorized) { + return err + } + + if asUnauthorized(err) { + return err + } + + return nil +} + +func asUnauthorized(err error) bool { + e := ErrUnauthorized{} + return errors.As(err, &e) +} + +type ErrServiceUnavailable struct { + common.Err +} + +func IsSericeUnavailable(err error) error { + if errors.Is(err, Err503ServiceUnavailable) { + return err + } + + if asServiceUnavailable(err) { + return err + } + + return nil +} + +func asServiceUnavailable(err error) bool { + e := ErrUnauthorized{} + return errors.As(err, &e) +} + // --------------------------------------------------------------------------- // error parsers // --------------------------------------------------------------------------- @@ -122,20 +214,3 @@ func hasErrorCode(err error, codes ...string) bool { return slices.Contains(codes, *oDataError.GetError().GetCode()) } - -// isTimeoutErr is used to determine if the Graph error returned is -// because of Timeout. This is used to restrict retries to just -// timeouts as other errors are handled within a middleware in the -// client. -func isTimeoutErr(err error) bool { - if errors.Is(err, context.DeadlineExceeded) || os.IsTimeout(err) { - return true - } - - switch err := err.(type) { - case *url.Error: - return err.Timeout() - default: - return false - } -} diff --git a/src/internal/connector/graph/service_helper.go b/src/internal/connector/graph/service_helper.go index 900919406..bf39fe194 100644 --- a/src/internal/connector/graph/service_helper.go +++ b/src/internal/connector/graph/service_helper.go @@ -94,31 +94,48 @@ func (handler *LoggingMiddleware) Intercept( } if (resp.StatusCode / 100) == 2 { + if logger.DebugAPI || os.Getenv(logGraphRequestsEnvKey) != "" { + respDump, _ := httputil.DumpResponse(resp, false) + + metadata := []any{ + "idx", middlewareIndex, + "method", req.Method, + "status", resp.Status, + "statusCode", resp.StatusCode, + "requestLen", req.ContentLength, + "url", req.URL, + "response", respDump, + } + + logger.Ctx(ctx).Debugw("2xx graph api resp", metadata...) + } + return resp, err } - // special case for supportability: log all throttling cases. - if resp.StatusCode == http.StatusTooManyRequests { - logger.Ctx(ctx).Infow("graph api throttling", "method", req.Method, "url", req.URL) - } - - if resp.StatusCode != http.StatusTooManyRequests && (resp.StatusCode/100) != 2 { - logger.Ctx(ctx).Infow("graph api error", "method", req.Method, "url", req.URL) - } - if logger.DebugAPI || os.Getenv(logGraphRequestsEnvKey) != "" { respDump, _ := httputil.DumpResponse(resp, true) metadata := []any{ + "idx", middlewareIndex, "method", req.Method, - "url", req.URL, - "requestLen", req.ContentLength, "status", resp.Status, "statusCode", resp.StatusCode, - "request", string(respDump), + "requestLen", req.ContentLength, + "url", req.URL, + "response", string(respDump), } logger.Ctx(ctx).Errorw("non-2xx graph api response", metadata...) + } else { + // special case for supportability: log all throttling cases. + if resp.StatusCode == http.StatusTooManyRequests { + logger.Ctx(ctx).Infow("graph api throttling", "method", req.Method, "url", req.URL) + } + + if resp.StatusCode != http.StatusTooManyRequests && (resp.StatusCode/100) != 2 { + logger.Ctx(ctx).Infow("graph api error", "status", resp.Status, "method", req.Method, "url", req.URL) + } } return resp, err diff --git a/src/internal/connector/onedrive/collection.go b/src/internal/connector/onedrive/collection.go index 68ad71a89..77ac88c63 100644 --- a/src/internal/connector/onedrive/collection.go +++ b/src/internal/connector/onedrive/collection.go @@ -224,6 +224,7 @@ func (oc *Collection) populateItems(ctx context.Context) { defer func() { <-semaphoreCh }() var ( + itemID = *item.GetId() itemName = *item.GetName() itemSize = *item.GetSize() itemInfo details.ItemInfo @@ -251,7 +252,32 @@ func (oc *Collection) populateItems(ctx context.Context) { for i := 1; i <= maxRetries; i++ { _, itemData, err = oc.itemReader(oc.itemClient, item) - if err == nil || graph.IsErrTimeout(err) == nil { + if err == nil { + break + } + + if graph.IsErrUnauthorized(err) != nil { + // assume unauthorized requests are a sign of an expired + // jwt token, and that we've overrun the available window + // to download the actual file. Re-downloading the item + // will refresh that download url. + di, diErr := getDriveItem(ctx, oc.service, oc.driveID, itemID) + if diErr != nil { + err = errors.Wrap(diErr, "retrieving expired item") + break + } + + item = di + + continue + + } else if graph.IsErrTimeout(err) == nil && + graph.IsErrThrottled(err) == nil && + graph.IsSericeUnavailable(err) == nil { + // TODO: graphAPI will provides headers that state the duration to wait + // in order to succeed again. The one second sleep won't cut it here. + // + // for all non-timeout, non-unauth, non-throttling errors, do not retry break } @@ -262,7 +288,7 @@ func (oc *Collection) populateItems(ctx context.Context) { // check for errors following retries if err != nil { - errUpdater(*item.GetId(), err) + errUpdater(itemID, err) return nil, err } diff --git a/src/internal/connector/onedrive/item.go b/src/internal/connector/onedrive/item.go index 3e4e9e516..c4fd1b380 100644 --- a/src/internal/connector/onedrive/item.go +++ b/src/internal/connector/onedrive/item.go @@ -25,6 +25,15 @@ const ( downloadURLKey = "@microsoft.graph.downloadUrl" ) +// generic drive item getter +func getDriveItem( + ctx context.Context, + srv graph.Servicer, + driveID, itemID string, +) (models.DriveItemable, error) { + return srv.Client().DrivesById(driveID).ItemsById(itemID).Get(ctx, nil) +} + // sharePointItemReader will return a io.ReadCloser for the specified item // It crafts this by querying M365 for a download URL for the item // and using a http client to initialize a reader @@ -32,14 +41,9 @@ func sharePointItemReader( hc *http.Client, item models.DriveItemable, ) (details.ItemInfo, io.ReadCloser, error) { - url, ok := item.GetAdditionalData()[downloadURLKey].(*string) - if !ok { - return details.ItemInfo{}, nil, fmt.Errorf("failed to get url for %s", *item.GetName()) - } - - resp, err := hc.Get(*url) + resp, err := downloadItem(hc, item) if err != nil { - return details.ItemInfo{}, nil, err + return details.ItemInfo{}, nil, errors.Wrap(err, "downloading item") } dii := details.ItemInfo{ @@ -56,24 +60,9 @@ func oneDriveItemReader( hc *http.Client, item models.DriveItemable, ) (details.ItemInfo, io.ReadCloser, error) { - url, ok := item.GetAdditionalData()[downloadURLKey].(*string) - if !ok { - return details.ItemInfo{}, nil, fmt.Errorf("failed to get url for %s", *item.GetName()) - } - - req, err := http.NewRequest(http.MethodGet, *url, nil) + resp, err := downloadItem(hc, item) if err != nil { - return details.ItemInfo{}, nil, err - } - - // Decorate the traffic - //nolint:lll - // See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic - req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version) - - resp, err := hc.Do(req) - if err != nil { - return details.ItemInfo{}, nil, err + return details.ItemInfo{}, nil, errors.Wrap(err, "downloading item") } dii := details.ItemInfo{ @@ -83,6 +72,46 @@ func oneDriveItemReader( return dii, resp.Body, nil } +func downloadItem(hc *http.Client, item models.DriveItemable) (*http.Response, error) { + url, ok := item.GetAdditionalData()[downloadURLKey].(*string) + if !ok { + return nil, fmt.Errorf("extracting file url: file %s", *item.GetId()) + } + + req, err := http.NewRequest(http.MethodGet, *url, nil) + if err != nil { + return nil, errors.Wrap(err, "new request") + } + + //nolint:lll + // Decorate the traffic + // See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic + req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version) + + resp, err := hc.Do(req) + if err != nil { + return nil, err + } + + if (resp.StatusCode / 100) == 2 { + return resp, nil + } + + if resp.StatusCode == http.StatusTooManyRequests { + return resp, graph.Err429TooManyRequests + } + + if resp.StatusCode == http.StatusUnauthorized { + return resp, graph.Err401Unauthorized + } + + if resp.StatusCode == http.StatusServiceUnavailable { + return resp, graph.Err503ServiceUnavailable + } + + return resp, errors.New("non-2xx http response: " + resp.Status) +} + // oneDriveItemInfo will populate a details.OneDriveInfo struct // with properties from the drive item. ItemSize is specified // separately for restore processes because the local itemable