Skip graph call if the download url has expired (#4419)

<!-- PR description-->

Builds on top of earlier PR #4417 to skip graph API call if the token has already expired. This is a performance optimization.

---

#### Does this PR need a docs update or release note?

- [x]  Yes, it's included
- [ ] 🕐 Yes, but in a later PR
- [ ]  No

#### Type of change

<!--- Please check the type of change your PR introduces: --->
- [ ] 🌻 Feature
- [ ] 🐛 Bugfix
- [ ] 🗺️ Documentation
- [ ] 🤖 Supportability/Tests
- [ ] 💻 CI/Deployment
- [ ] 🧹 Tech Debt/Cleanup
- [x] Performance Opt

#### Issue(s)

<!-- Can reference multiple issues. Use one of the following "magic words" - "closes, fixes" to auto-close the Github issue. -->
* internal

#### Test Plan

<!-- How will this be tested prior to merging.-->
- [ ] 💪 Manual
- [x]  Unit test
- [x] 💚 E2E
This commit is contained in:
Abhishek Pandey 2023-10-09 19:12:54 +05:30 committed by GitHub
parent 2eae5b9f13
commit 757007e027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 120 additions and 19 deletions

View File

@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] (beta)
### Added
- Skips graph calls for expired item download URLs.
## [v0.14.0] (beta) - 2023-10-09
### Added
@ -16,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `--backups` flag to delete multiple backups in `corso backup delete` command.
- Backup now includes all sites that belongs to a team, not just the root site.
## Fixed
### Fixed
- Teams Channels that cannot support delta tokens (those without messages) fall back to non-delta enumeration and no longer fail a backup.
### Known issues

View File

@ -7,7 +7,7 @@ import (
)
// GetQueryParamFromURL parses an URL and returns value of the specified
// query parameter.
// query parameter. In case of multiple occurrences, first one is returned.
func GetQueryParamFromURL(
rawURL, queryParam string,
) (string, error) {

View File

@ -10,17 +10,24 @@ import (
"github.com/microsoftgraph/msgraph-sdk-go/models"
"golang.org/x/exp/maps"
"github.com/alcionai/corso/src/internal/common"
jwt "github.com/alcionai/corso/src/internal/common/jwt"
"github.com/alcionai/corso/src/internal/common/ptr"
"github.com/alcionai/corso/src/internal/common/readers"
"github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/internal/m365/collection/drive/metadata"
"github.com/alcionai/corso/src/internal/m365/graph"
"github.com/alcionai/corso/src/pkg/logger"
"github.com/alcionai/corso/src/pkg/services/m365/api"
)
const (
acceptHeaderKey = "Accept"
acceptHeaderValue = "*/*"
// JWTQueryParam is a query param embed in graph download URLs which holds
// JWT token.
JWTQueryParam = "tempauth"
)
// downloadUrlKeys is used to find the download URL in a DriveItem response.
@ -121,6 +128,19 @@ func downloadFile(
return nil, clues.New("empty file url").WithClues(ctx)
}
// Precheck for url expiry before we make a call to graph to download the
// file. If the url is expired, we can return early and save a call to graph.
//
// Ignore all errors encountered during the check. We can rely on graph to
// return errors on malformed urls. Ignoring errors also future proofs against
// any sudden graph changes, for e.g. if graph decides to embed the token in a
// new query param.
expired, err := isURLExpired(ctx, url)
if err == nil && expired {
logger.Ctx(ctx).Debug("expired item download url")
return nil, graph.ErrTokenExpired
}
rc, err := readers.NewResetRetryHandler(
ctx,
&downloadWithRetries{
@ -193,3 +213,27 @@ func setName(orig models.ItemReferenceable, driveName string) models.ItemReferen
return orig
}
// isURLExpired inspects the jwt token embed in the item download url
// and returns true if it is expired.
func isURLExpired(
ctx context.Context,
url string,
) (bool, error) {
// Extract the raw JWT string from the download url.
rawJWT, err := common.GetQueryParamFromURL(url, JWTQueryParam)
if err != nil {
logger.CtxErr(ctx, err).Info("query param not found")
return false, clues.Stack(err).WithClues(ctx)
}
expired, err := jwt.IsJWTExpired(rawJWT)
if err != nil {
logger.CtxErr(ctx, err).Info("checking jwt expiry")
return false, clues.Stack(err).WithClues(ctx)
}
return expired, nil
}

View File

@ -16,6 +16,8 @@ import (
"github.com/alcionai/corso/src/internal/common/dttm"
"github.com/alcionai/corso/src/internal/common/ptr"
"github.com/alcionai/corso/src/internal/common/str"
"github.com/alcionai/corso/src/internal/m365/graph"
"github.com/alcionai/corso/src/internal/tester"
"github.com/alcionai/corso/src/internal/tester/tconfig"
"github.com/alcionai/corso/src/pkg/control"
@ -49,6 +51,8 @@ func (suite *ItemIntegrationSuite) SetupSuite() {
suite.service = loadTestService(t)
suite.user = tconfig.SecondaryM365UserID(t)
graph.InitializeConcurrencyLimiter(ctx, true, 4)
pager := suite.service.ac.Drives().NewUserDrivePager(suite.user, nil)
odDrives, err := api.GetAllDrives(ctx, pager)
@ -60,19 +64,13 @@ func (suite *ItemIntegrationSuite) SetupSuite() {
suite.userDriveID = ptr.Val(odDrives[0].GetId())
}
// TestItemReader is an integration test that makes a few assumptions
// about the test environment
// 1) It assumes the test user has a drive
// 2) It assumes the drive has a file it can use to test `driveItemReader`
// The test checks these in below
func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
func getOneDriveItem(
ctx context.Context,
t *testing.T,
ac api.Client,
driveID string,
) models.DriveItemable {
var driveItem models.DriveItemable
// This item collector tries to find "a" drive item that is a non-empty
// file to test the reader function
itemCollector := func(
_ context.Context,
@ -99,14 +97,14 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() {
return nil
}
ip := suite.service.ac.
ip := ac.
Drives().
NewDriveItemDeltaPager(suite.userDriveID, "", api.DriveItemSelectDefault())
NewDriveItemDeltaPager(driveID, "", api.DriveItemSelectDefault())
_, _, _, err := collectItems(
ctx,
ip,
suite.userDriveID,
driveID,
"General",
itemCollector,
map[string]string{},
@ -114,6 +112,21 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() {
fault.New(true))
require.NoError(t, err, clues.ToCore(err))
return driveItem
}
// TestItemReader is an integration test that makes a few assumptions
// about the test environment
// 1) It assumes the test user has a drive
// 2) It assumes the drive has a file it can use to test `driveItemReader`
// The test checks these in below
func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
driveItem := getOneDriveItem(ctx, t, suite.service.ac, suite.userDriveID)
// Test Requirement 2: Need a file
require.NotEmpty(
t,
@ -137,6 +150,39 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() {
require.NotZero(t, size)
}
// In prod we consider any errors in isURLExpired as non-fatal and carry on
// with the download. This is a regression test to make sure we keep track
// of any graph changes to the download url scheme, including how graph
// embeds the jwt token.
func (suite *ItemIntegrationSuite) TestIsURLExpired() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
driveItem := getOneDriveItem(ctx, t, suite.service.ac, suite.userDriveID)
require.NotEmpty(
t,
driveItem,
"no file item found for user %s drive %s",
suite.user,
suite.userDriveID)
var url string
for _, key := range downloadURLKeys {
if v, err := str.AnyValueToString(key, driveItem.GetAdditionalData()); err == nil {
url = v
break
}
}
expired, err := isURLExpired(ctx, url)
require.NoError(t, err, clues.ToCore(err))
require.False(t, expired)
}
// TestItemWriter is an integration test for uploading data to OneDrive
// It creates a new folder with a new item and writes data to it
func (suite *ItemIntegrationSuite) TestItemWriter() {

View File

@ -124,6 +124,8 @@ var (
ErrTimeout = clues.New("communication timeout")
ErrResourceOwnerNotFound = clues.New("resource owner not found in tenant")
ErrTokenExpired = clues.New("jwt token expired")
)
func IsErrApplicationThrottled(err error) bool {
@ -224,7 +226,8 @@ func IsErrUnauthorized(err error) bool {
// TODO: refine this investigation. We don't currently know if
// a specific item download url expired, or if the full connection
// auth expired.
return clues.HasLabel(err, LabelStatus(http.StatusUnauthorized))
return clues.HasLabel(err, LabelStatus(http.StatusUnauthorized)) ||
errors.Is(err, ErrTokenExpired)
}
func IsErrItemAlreadyExistsConflict(err error) bool {

View File

@ -478,11 +478,16 @@ func (suite *GraphErrorsUnitSuite) TestIsErrUnauthorized() {
expect: assert.False,
},
{
name: "as",
name: "graph 401",
err: clues.Stack(assert.AnError).
Label(LabelStatus(http.StatusUnauthorized)),
expect: assert.True,
},
{
name: "token expired",
err: clues.Stack(assert.AnError, ErrTokenExpired),
expect: assert.True,
},
}
for _, test := range table {
suite.Run(test.name, func() {