diff --git a/CHANGELOG.md b/CHANGELOG.md index ddf9ecde0..b68e7c43a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Handle OneDrive folders being deleted and recreated midway through a backup - Automatically re-run a full delta query on incrmental if the prior backup is found to have malformed prior-state information. +- Retry drive item permission downloads during long-running backups after the jwt token expires and refreshes. ## [v0.15.0] (beta) - 2023-10-31 diff --git a/src/internal/m365/collection/drive/collection.go b/src/internal/m365/collection/drive/collection.go index 0b2203e47..00f99abe2 100644 --- a/src/internal/m365/collection/drive/collection.go +++ b/src/internal/m365/collection/drive/collection.go @@ -341,7 +341,7 @@ func downloadContent( content, err := downloadItem(ctx, iaag, item) if err == nil { return content, nil - } else if !graph.IsErrUnauthorized(err) { + } else if !graph.IsErrUnauthorizedOrBadToken(err) { return nil, err } @@ -397,7 +397,7 @@ func readItemContents( } rc, err := downloadFile(ctx, iaag, props.downloadURL) - if graph.IsErrUnauthorized(err) { + if graph.IsErrUnauthorizedOrBadToken(err) { logger.CtxErr(ctx, err).Info("stale item in cache") } diff --git a/src/internal/m365/collection/drive/item.go b/src/internal/m365/collection/drive/item.go index f5143d388..9fad7cd8e 100644 --- a/src/internal/m365/collection/drive/item.go +++ b/src/internal/m365/collection/drive/item.go @@ -10,8 +10,6 @@ import ( "github.com/microsoftgraph/msgraph-sdk-go/models" "golang.org/x/exp/maps" - "github.com/alcionai/corso/src/internal/common" - jwt "github.com/alcionai/corso/src/internal/common/jwt" "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/common/readers" "github.com/alcionai/corso/src/internal/common/str" @@ -25,10 +23,6 @@ import ( const ( acceptHeaderKey = "Accept" acceptHeaderValue = "*/*" - - // JWTQueryParam is a query param embed in graph download URLs which holds - // JWT token. - JWTQueryParam = "tempauth" ) // downloadUrlKeys is used to find the download URL in a DriveItem response. @@ -130,16 +124,18 @@ func downloadFile( } // Precheck for url expiry before we make a call to graph to download the - // file. If the url is expired, we can return early and save a call to graph. + // file. If the url is expiredErr, we can return early and save a call to graph. // // Ignore all errors encountered during the check. We can rely on graph to // return errors on malformed urls. Ignoring errors also future proofs against // any sudden graph changes, for e.g. if graph decides to embed the token in a // new query param. - expired, err := isURLExpired(ctx, url) - if err == nil && expired { - logger.Ctx(ctx).Debug("expired item download url") - return nil, graph.ErrTokenExpired + expiredErr, err := graph.IsURLExpired(ctx, url) + if expiredErr != nil { + logger.CtxErr(ctx, expiredErr).Debug("expired item download url") + return nil, clues.Stack(expiredErr) + } else if err != nil { + logger.CtxErr(ctx, err).Info("checking item download url for expiration") } rc, err := readers.NewResetRetryHandler( @@ -154,20 +150,19 @@ func downloadFile( func downloadItemMeta( ctx context.Context, - gip GetItemPermissioner, + getter GetItemPermissioner, driveID string, item models.DriveItemable, ) (io.ReadCloser, int, error) { - meta := metadata.Metadata{FileName: ptr.Val(item.GetName())} - - if item.GetShared() == nil { - meta.SharingMode = metadata.SharingModeInherited - } else { - meta.SharingMode = metadata.SharingModeCustom + meta := metadata.Metadata{ + FileName: ptr.Val(item.GetName()), + SharingMode: metadata.SharingModeInherited, } - if meta.SharingMode == metadata.SharingModeCustom { - perm, err := gip.GetItemPermission(ctx, driveID, ptr.Val(item.GetId())) + if item.GetShared() != nil { + meta.SharingMode = metadata.SharingModeCustom + + perm, err := getter.GetItemPermission(ctx, driveID, ptr.Val(item.GetId())) if err != nil { return nil, 0, err } @@ -219,27 +214,3 @@ func setName(orig models.ItemReferenceable, driveName string) models.ItemReferen return orig } - -// isURLExpired inspects the jwt token embed in the item download url -// and returns true if it is expired. -func isURLExpired( - ctx context.Context, - url string, -) (bool, error) { - // Extract the raw JWT string from the download url. - rawJWT, err := common.GetQueryParamFromURL(url, JWTQueryParam) - if err != nil { - logger.CtxErr(ctx, err).Info("query param not found") - - return false, clues.StackWC(ctx, err) - } - - expired, err := jwt.IsJWTExpired(rawJWT) - if err != nil { - logger.CtxErr(ctx, err).Info("checking jwt expiry") - - return false, clues.StackWC(ctx, err) - } - - return expired, nil -} diff --git a/src/internal/m365/collection/drive/item_test.go b/src/internal/m365/collection/drive/item_test.go index 29f63cd5b..5be66f7c1 100644 --- a/src/internal/m365/collection/drive/item_test.go +++ b/src/internal/m365/collection/drive/item_test.go @@ -158,10 +158,9 @@ func (suite *ItemIntegrationSuite) TestIsURLExpired() { } } - expired, err := isURLExpired(ctx, url) + expired, err := graph.IsURLExpired(ctx, url) require.NoError(t, err, clues.ToCore(err)) - - require.False(t, expired) + require.NoError(t, expired, clues.ToCore(err)) } // TestItemWriter is an integration test for uploading data to OneDrive diff --git a/src/internal/operations/backup.go b/src/internal/operations/backup.go index 3644d259a..e6e5a5490 100644 --- a/src/internal/operations/backup.go +++ b/src/internal/operations/backup.go @@ -198,13 +198,13 @@ func (op *BackupOperation) Run(ctx context.Context) (err error) { }() ctx, end := diagnostics.Span(ctx, "operations:backup:run") - defer func() { - end() - }() + defer end() ctx, flushMetrics := events.NewMetrics(ctx, logger.Writer{Ctx: ctx}) defer flushMetrics() + ctx = clues.AddTrace(ctx) + // Check if the protected resource has the service enabled in order for us // to run a backup. enabled, err := op.bp.IsServiceEnabled( diff --git a/src/internal/operations/restore.go b/src/internal/operations/restore.go index 4fb62d050..de763351d 100644 --- a/src/internal/operations/restore.go +++ b/src/internal/operations/restore.go @@ -129,13 +129,13 @@ func (op *RestoreOperation) Run(ctx context.Context) (restoreDetails *details.De // ----- ctx, end := diagnostics.Span(ctx, "operations:restore:run") - defer func() { - end() - }() + defer end() ctx, flushMetrics := events.NewMetrics(ctx, logger.Writer{Ctx: ctx}) defer flushMetrics() + ctx = clues.AddTrace(ctx) + cats, err := op.Selectors.AllHumanPathCategories() if err != nil { // No need to exit over this, we'll just be missing a bit of info in the diff --git a/src/pkg/fault/fault.go b/src/pkg/fault/fault.go index 97816405b..2beb55f6f 100644 --- a/src/pkg/fault/fault.go +++ b/src/pkg/fault/fault.go @@ -149,7 +149,7 @@ func (e *Bus) logAndAddRecoverable(ctx context.Context, err error, skip int) { isFail := e.addRecoverableErr(err) if isFail { - log.Errorf("recoverable error: %v", err) + log.Errorf("failed on recoverable error: %v", err) } else { log.Infof("recoverable error: %v", err) } diff --git a/src/pkg/services/m365/api/drive.go b/src/pkg/services/m365/api/drive.go index c8c2eb464..934dfcc7c 100644 --- a/src/pkg/services/m365/api/drive.go +++ b/src/pkg/services/m365/api/drive.go @@ -274,11 +274,8 @@ func (c Drives) GetItemPermission( ByDriveItemId(itemID). Permissions(). Get(ctx, nil) - if err != nil { - return nil, graph.Wrap(ctx, err, "getting item permission").With("item_id", itemID) - } - return perm, nil + return perm, graph.Wrap(ctx, err, "getting item permissions").OrNil() } func (c Drives) PostItemPermissionUpdate( diff --git a/src/pkg/services/m365/api/graph/errors.go b/src/pkg/services/m365/api/graph/errors.go index ac87b29a3..d242449e9 100644 --- a/src/pkg/services/m365/api/graph/errors.go +++ b/src/pkg/services/m365/api/graph/errors.go @@ -14,6 +14,8 @@ import ( "github.com/microsoftgraph/msgraph-sdk-go/models/odataerrors" "github.com/pkg/errors" + "github.com/alcionai/corso/src/internal/common" + "github.com/alcionai/corso/src/internal/common/jwt" "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/pkg/fault" @@ -45,6 +47,7 @@ const ( // Some datacenters are returning this when we try to get the inbox of a user // that doesn't exist. invalidUser errorCode = "ErrorInvalidUser" + invalidAuthenticationToken errorCode = "InvalidAuthenticationToken" itemNotFound errorCode = "itemNotFound" MailboxNotEnabledForRESTAPI errorCode = "MailboxNotEnabledForRESTAPI" malwareDetected errorCode = "malwareDetected" @@ -138,6 +141,7 @@ var ( ErrResourceOwnerNotFound = clues.New("resource owner not found in tenant") ErrTokenExpired = clues.New("jwt token expired") + ErrTokenInvalid = clues.New("jwt token invalid") ) func IsErrApplicationThrottled(err error) bool { @@ -234,12 +238,17 @@ func IsErrConnectionReset(err error) bool { return errors.Is(err, syscall.ECONNRESET) } -func IsErrUnauthorized(err error) bool { - // TODO: refine this investigation. We don't currently know if - // a specific item download url expired, or if the full connection - // auth expired. +func IsErrUnauthorizedOrBadToken(err error) bool { return clues.HasLabel(err, LabelStatus(http.StatusUnauthorized)) || - errors.Is(err, ErrTokenExpired) + hasErrorCode(err, invalidAuthenticationToken) || + errors.Is(err, ErrTokenExpired) || + errors.Is(err, ErrTokenInvalid) +} + +func IsErrBadJWTToken(err error) bool { + return hasErrorCode(err, invalidAuthenticationToken) || + errors.Is(err, ErrTokenExpired) || + errors.Is(err, ErrTokenInvalid) } func IsErrItemAlreadyExistsConflict(err error) bool { @@ -558,3 +567,38 @@ func ItemInfo(item models.DriveItemable) map[string]any { return m } + +// --------------------------------------------------------------------------- +// other helpers +// --------------------------------------------------------------------------- + +// JWTQueryParam is a query param embed in graph download URLs which holds +// JWT token. +const JWTQueryParam = "tempauth" + +// IsURLExpired inspects the jwt token embed in the item download url +// and returns true if it is expired. +func IsURLExpired( + ctx context.Context, + urlStr string, +) ( + expiredErr error, + err error, +) { + // Extract the raw JWT string from the download url. + rawJWT, err := common.GetQueryParamFromURL(urlStr, JWTQueryParam) + if err != nil { + return nil, clues.WrapWC(ctx, err, "jwt query param not found") + } + + expired, err := jwt.IsJWTExpired(rawJWT) + if err != nil { + return nil, clues.WrapWC(ctx, err, "checking jwt expiry") + } + + if expired { + return clues.StackWC(ctx, ErrTokenExpired), nil + } + + return nil, nil +} diff --git a/src/pkg/services/m365/api/graph/errors_test.go b/src/pkg/services/m365/api/graph/errors_test.go index e46955035..a585e13fd 100644 --- a/src/pkg/services/m365/api/graph/errors_test.go +++ b/src/pkg/services/m365/api/graph/errors_test.go @@ -461,7 +461,7 @@ func (suite *GraphErrorsUnitSuite) TestIsErrTimeout() { } } -func (suite *GraphErrorsUnitSuite) TestIsErrUnauthorized() { +func (suite *GraphErrorsUnitSuite) TestIsErrUnauthorizedOrBadToken() { table := []struct { name string err error @@ -477,6 +477,11 @@ func (suite *GraphErrorsUnitSuite) TestIsErrUnauthorized() { err: assert.AnError, expect: assert.False, }, + { + name: "non-matching oDataErr", + err: odErr("folder doesn't exist"), + expect: assert.False, + }, { name: "graph 401", err: clues.Stack(assert.AnError). @@ -484,14 +489,74 @@ func (suite *GraphErrorsUnitSuite) TestIsErrUnauthorized() { expect: assert.True, }, { - name: "token expired", + name: "err token expired", err: clues.Stack(assert.AnError, ErrTokenExpired), expect: assert.True, }, + { + name: "oDataErr code invalid auth token ", + err: odErr(string(invalidAuthenticationToken)), + expect: assert.True, + }, + { + name: "err token invalid", + err: clues.Stack(assert.AnError, ErrTokenInvalid), + expect: assert.True, + }, } for _, test := range table { suite.Run(test.name, func() { - test.expect(suite.T(), IsErrUnauthorized(test.err)) + test.expect(suite.T(), IsErrUnauthorizedOrBadToken(test.err)) + }) + } +} + +func (suite *GraphErrorsUnitSuite) TestIsErrIsErrBadJWTToken() { + table := []struct { + name string + err error + expect assert.BoolAssertionFunc + }{ + { + name: "nil", + err: nil, + expect: assert.False, + }, + { + name: "non-matching", + err: assert.AnError, + expect: assert.False, + }, + { + name: "non-matching oDataErr", + err: odErr("folder doesn't exist"), + expect: assert.False, + }, + { + name: "graph 401", + err: clues.Stack(assert.AnError). + Label(LabelStatus(http.StatusUnauthorized)), + expect: assert.False, + }, + { + name: "err token expired", + err: clues.Stack(assert.AnError, ErrTokenExpired), + expect: assert.True, + }, + { + name: "oDataErr code invalid auth token ", + err: odErr(string(invalidAuthenticationToken)), + expect: assert.True, + }, + { + name: "err token invalid", + err: clues.Stack(assert.AnError, ErrTokenInvalid), + expect: assert.True, + }, + } + for _, test := range table { + suite.Run(test.name, func() { + test.expect(suite.T(), IsErrBadJWTToken(test.err)) }) } } diff --git a/src/pkg/services/m365/api/graph/service.go b/src/pkg/services/m365/api/graph/service.go index 169becee4..78d4f84ba 100644 --- a/src/pkg/services/m365/api/graph/service.go +++ b/src/pkg/services/m365/api/graph/service.go @@ -356,12 +356,12 @@ func (aw *adapterWrap) Send( } }() - // stream errors from http/2 will fail before we reach - // client middleware handling, therefore we don't get to - // make use of the retry middleware. This external - // retry wrapper is unsophisticated, but should only - // retry in the event of a `stream error`, which is not - // a common expectation. + // This external retry wrapper is unsophisticated, but should + // only retry under certain circumstances + // 1. stream errors from http/2, which will fail before we reach + // client middleware handling. + // 2. jwt token invalidation, which requires a re-auth that's handled + // in the Send() call, before reaching client middleware. for i := 0; i < aw.config.maxConnectionRetries+1; i++ { ictx := clues.Add(ctx, "request_retry_iter", i) @@ -370,19 +370,27 @@ func (aw *adapterWrap) Send( break } + // force an early exit on throttling issues. + // those retries are well handled in middleware already. We want to ensure + // that the error gets wrapped with the appropriate sentinel here. if IsErrApplicationThrottled(err) { return nil, clues.StackWC(ictx, ErrApplicationThrottled, err).WithTrace(1) } - if !IsErrConnectionReset(err) && !connectionEnded.Compare(err.Error()) { + // exit most errors without retry + switch { + case IsErrConnectionReset(err) || connectionEnded.Compare(err.Error()): + logger.Ctx(ictx).Debug("http connection error") + events.Inc(events.APICall, "connectionerror") + case IsErrBadJWTToken(err): + logger.Ctx(ictx).Debug("bad jwt token") + events.Inc(events.APICall, "badjwttoken") + default: return nil, clues.StackWC(ictx, err).WithTrace(1) } - logger.Ctx(ictx).Debug("http connection error") - events.Inc(events.APICall, "connectionerror") - time.Sleep(3 * time.Second) } - return sp, err + return sp, clues.Stack(err).OrNil() } diff --git a/src/pkg/services/m365/api/graph/service_test.go b/src/pkg/services/m365/api/graph/service_test.go index 3aa800fb1..aa963ea69 100644 --- a/src/pkg/services/m365/api/graph/service_test.go +++ b/src/pkg/services/m365/api/graph/service_test.go @@ -1,12 +1,17 @@ package graph import ( + "bytes" + "io" "net/http" + "strconv" "syscall" "testing" "time" "github.com/alcionai/clues" + "github.com/microsoft/kiota-abstractions-go/serialization" + kjson "github.com/microsoft/kiota-serialization-json-go" "github.com/microsoftgraph/msgraph-sdk-go/models" "github.com/microsoftgraph/msgraph-sdk-go/users" "github.com/stretchr/testify/assert" @@ -245,3 +250,75 @@ func (suite *GraphIntgSuite) TestAdapterWrap_retriesConnectionClose() { require.ErrorIs(t, err, syscall.ECONNRESET, clues.ToCore(err)) require.Equal(t, 16, retryInc, "number of retries") } + +func requireParseableToReader(t *testing.T, thing serialization.Parsable) (int64, io.ReadCloser) { + sw := kjson.NewJsonSerializationWriter() + + err := sw.WriteObjectValue("", thing) + require.NoError(t, err, "serialize") + + content, err := sw.GetSerializedContent() + require.NoError(t, err, "deserialize") + + return int64(len(content)), io.NopCloser(bytes.NewReader(content)) +} + +func (suite *GraphIntgSuite) TestAdapterWrap_retriesBadJWTToken() { + var ( + t = suite.T() + retryInc = 0 + odErr = odErrMsg(string(invalidAuthenticationToken), string(invalidAuthenticationToken)) + ) + + ctx, flush := tester.NewContext(t) + defer flush() + + // the panics should get caught and returned as errors + alwaysBadJWT := mwForceResp{ + alternate: func(req *http.Request) (bool, *http.Response, error) { + retryInc++ + + l, b := requireParseableToReader(t, odErr) + + header := http.Header{} + header.Set("Content-Length", strconv.Itoa(int(l))) + header.Set("Content-Type", "application/json") + + resp := &http.Response{ + Body: b, + ContentLength: l, + Header: header, + Proto: req.Proto, + Request: req, + // avoiding 401 for the test to escape extraneous code paths in graph client + // shouldn't affect the result + StatusCode: http.StatusMethodNotAllowed, + } + + return true, resp, nil + }, + } + + adpt, err := CreateAdapter( + suite.credentials.AzureTenantID, + suite.credentials.AzureClientID, + suite.credentials.AzureClientSecret, + count.New(), + appendMiddleware(&alwaysBadJWT)) + require.NoError(t, err, clues.ToCore(err)) + + // When run locally this may fail. Not sure why it works in github but not locally. + // Pester keepers if it bothers you. + _, err = users. + NewItemCalendarsItemEventsDeltaRequestBuilder("https://graph.microsoft.com/fnords/beaux/regard", adpt). + Get(ctx, nil) + assert.True(t, IsErrBadJWTToken(err), clues.ToCore(err)) + assert.Equal(t, 4, retryInc, "number of retries") + + retryInc = 0 + + // the query doesn't matter + _, err = NewService(adpt).Client().Users().Get(ctx, nil) + assert.True(t, IsErrBadJWTToken(err), clues.ToCore(err)) + assert.Equal(t, 4, retryInc, "number of retries") +}