From bb2bd6df3fd6c8ae7ac72414dcf22650dac428b7 Mon Sep 17 00:00:00 2001 From: Keepers Date: Wed, 14 Feb 2024 10:50:36 -0700 Subject: [PATCH] add authentication to requester (#5198) the graph requester for large item downloads now includes the option to authenticate requests. The option is configured at the time of creating the requester, therefore all requests using that servier are either authenticatd or not. In our case, we're opting to authenticate all requests, since we do not use this requester for non-graph api calls, and even if we did the addition of auth headers is likely benign. --- #### Does this PR need a docs update or release note? - [x] :no_entry: No #### Type of change - [x] :sunflower: Feature #### Test Plan - [x] :green_heart: E2E --- src/internal/common/str/str.go | 13 +++ src/internal/common/str/str_test.go | 93 ++++++++++++++++++ .../m365/collection/drive/collection.go | 6 +- .../m365/collection/drive/helper_test.go | 7 +- src/internal/m365/collection/drive/item.go | 68 ++++++++------ .../m365/collection/drive/item_test.go | 92 ++++++++++++++++-- .../m365/collection/drive/site_handler.go | 3 +- .../m365/collection/drive/url_cache_test.go | 3 +- .../collection/drive/user_drive_handler.go | 3 +- .../m365/service/onedrive/mock/handlers.go | 7 +- src/pkg/services/m365/api/access.go | 2 +- src/pkg/services/m365/api/client.go | 22 +++-- src/pkg/services/m365/api/graph/auth.go | 94 +++++++++++++++++++ .../m365/api/graph/concurrency_middleware.go | 4 +- .../services/m365/api/graph/http_wrapper.go | 19 +++- .../m365/api/graph/http_wrapper_test.go | 79 +++++++++++++--- src/pkg/services/m365/api/graph/logging.go | 22 +++++ src/pkg/services/m365/api/graph/middleware.go | 11 +-- src/pkg/services/m365/api/graph/service.go | 30 +++--- .../services/m365/api/graph/uploadsession.go | 3 +- 20 files changed, 481 insertions(+), 100 deletions(-) create mode 100644 src/pkg/services/m365/api/graph/auth.go diff --git a/src/internal/common/str/str.go b/src/internal/common/str/str.go index 41919b2b7..201d241e0 100644 --- a/src/internal/common/str/str.go +++ b/src/internal/common/str/str.go @@ -59,6 +59,19 @@ func First(vs ...string) string { return "" } +// FirstIn returns the first entry in the map with a non-zero value +// when iterating the provided list of keys. +func FirstIn(m map[string]any, keys ...string) string { + for _, key := range keys { + v, err := AnyValueToString(key, m) + if err == nil && len(v) > 0 { + return v + } + } + + return "" +} + // Preview reduces the string to the specified size. // If the string is longer than the size, the last three // characters are replaced with an ellipsis. Size < 4 diff --git a/src/internal/common/str/str_test.go b/src/internal/common/str/str_test.go index 11af84e93..879c16aab 100644 --- a/src/internal/common/str/str_test.go +++ b/src/internal/common/str/str_test.go @@ -118,3 +118,96 @@ func TestGenerateHash(t *testing.T) { } } } + +func TestFirstIn(t *testing.T) { + table := []struct { + name string + m map[string]any + keys []string + expect string + }{ + { + name: "nil map", + keys: []string{"foo", "bar"}, + expect: "", + }, + { + name: "empty map", + m: map[string]any{}, + keys: []string{"foo", "bar"}, + expect: "", + }, + { + name: "no match", + m: map[string]any{ + "baz": "baz", + }, + keys: []string{"foo", "bar"}, + expect: "", + }, + { + name: "no keys", + m: map[string]any{ + "baz": "baz", + }, + keys: []string{}, + expect: "", + }, + { + name: "nil match", + m: map[string]any{ + "foo": nil, + }, + keys: []string{"foo", "bar"}, + expect: "", + }, + { + name: "empty match", + m: map[string]any{ + "foo": "", + }, + keys: []string{"foo", "bar"}, + expect: "", + }, + { + name: "matches first key", + m: map[string]any{ + "foo": "fnords", + }, + keys: []string{"foo", "bar"}, + expect: "fnords", + }, + { + name: "matches second key", + m: map[string]any{ + "bar": "smarf", + }, + keys: []string{"foo", "bar"}, + expect: "smarf", + }, + { + name: "matches second key with nil first match", + m: map[string]any{ + "foo": nil, + "bar": "smarf", + }, + keys: []string{"foo", "bar"}, + expect: "smarf", + }, + { + name: "matches second key with empty first match", + m: map[string]any{ + "foo": "", + "bar": "smarf", + }, + keys: []string{"foo", "bar"}, + expect: "smarf", + }, + } + for _, test := range table { + t.Run(test.name, func(t *testing.T) { + result := FirstIn(test.m, test.keys...) + assert.Equal(t, test.expect, result) + }) + } +} diff --git a/src/internal/m365/collection/drive/collection.go b/src/internal/m365/collection/drive/collection.go index 7ef897fed..2ba53de42 100644 --- a/src/internal/m365/collection/drive/collection.go +++ b/src/internal/m365/collection/drive/collection.go @@ -366,7 +366,7 @@ func downloadContent( itemID := ptr.Val(item.GetId()) ctx = clues.Add(ctx, "item_id", itemID) - content, err := downloadItem(ctx, iaag, item) + content, err := downloadItem(ctx, iaag, driveID, item) if err == nil { return content, nil } else if !graph.IsErrUnauthorizedOrBadToken(err) { @@ -395,7 +395,7 @@ func downloadContent( cdi := custom.ToCustomDriveItem(di) - content, err = downloadItem(ctx, iaag, cdi) + content, err = downloadItem(ctx, iaag, driveID, cdi) if err != nil { return nil, clues.Wrap(err, "content download retry") } @@ -426,7 +426,7 @@ func readItemContents( return nil, core.ErrNotFound } - rc, err := downloadFile(ctx, iaag, props.downloadURL) + rc, err := downloadFile(ctx, iaag, props.downloadURL, false) if graph.IsErrUnauthorizedOrBadToken(err) { logger.CtxErr(ctx, err).Debug("stale item in cache") } diff --git a/src/internal/m365/collection/drive/helper_test.go b/src/internal/m365/collection/drive/helper_test.go index 8220f5ed0..a26c6aa01 100644 --- a/src/internal/m365/collection/drive/helper_test.go +++ b/src/internal/m365/collection/drive/helper_test.go @@ -795,7 +795,12 @@ func (h mockBackupHandler[T]) AugmentItemInfo( return h.ItemInfo } -func (h *mockBackupHandler[T]) Get(context.Context, string, map[string]string) (*http.Response, error) { +func (h *mockBackupHandler[T]) Get( + context.Context, + string, + map[string]string, + bool, +) (*http.Response, error) { c := h.getCall h.getCall++ diff --git a/src/internal/m365/collection/drive/item.go b/src/internal/m365/collection/drive/item.go index 02ac6010a..be5c255bc 100644 --- a/src/internal/m365/collection/drive/item.go +++ b/src/internal/m365/collection/drive/item.go @@ -21,8 +21,10 @@ import ( ) const ( - acceptHeaderKey = "Accept" - acceptHeaderValue = "*/*" + acceptHeaderKey = "Accept" + acceptHeaderValue = "*/*" + gigabyte = 1024 * 1024 * 1024 + largeFileDownloadLimit = 15 * gigabyte ) // downloadUrlKeys is used to find the download URL in a DriveItem response. @@ -33,7 +35,8 @@ var downloadURLKeys = []string{ func downloadItem( ctx context.Context, - ag api.Getter, + getter api.Getter, + driveID string, item *custom.DriveItem, ) (io.ReadCloser, error) { if item == nil { @@ -41,36 +44,37 @@ func downloadItem( } var ( - rc io.ReadCloser - isFile = item.GetFile() != nil - err error + // very large file content needs to be downloaded through a different endpoint, or else + // the download could take longer than the lifespan of the download token in the cached + // url, which will cause us to timeout on every download request, even if we refresh the + // download url right before the query. + url = "https://graph.microsoft.com/v1.0/drives/" + driveID + "/items/" + ptr.Val(item.GetId()) + "/content" + reader io.ReadCloser + err error + isLargeFile = ptr.Val(item.GetSize()) > largeFileDownloadLimit ) - if isFile { - var ( - url string - ad = item.GetAdditionalData() - ) - - for _, key := range downloadURLKeys { - if v, err := str.AnyValueToString(key, ad); err == nil { - url = v - break - } - } - - rc, err = downloadFile(ctx, ag, url) - if err != nil { - return nil, clues.Stack(err) - } + // if this isn't a file, no content is available for download + if item.GetFile() == nil { + return reader, nil } - return rc, nil + // smaller files will maintain our current behavior (prefetching the download url with the + // url cache). That pattern works for us in general, and we only need to deviate for very + // large file sizes. + if !isLargeFile { + url = str.FirstIn(item.GetAdditionalData(), downloadURLKeys...) + } + + reader, err = downloadFile(ctx, getter, url, isLargeFile) + + return reader, clues.StackWC(ctx, err).OrNil() } type downloadWithRetries struct { - getter api.Getter - url string + getter api.Getter + requireAuth bool + url string } func (dg *downloadWithRetries) SupportsRange() bool { @@ -86,7 +90,7 @@ func (dg *downloadWithRetries) Get( // wouldn't work without it (get 416 responses instead of 206). headers[acceptHeaderKey] = acceptHeaderValue - resp, err := dg.getter.Get(ctx, dg.url, headers) + resp, err := dg.getter.Get(ctx, dg.url, headers, dg.requireAuth) if err != nil { return nil, clues.Wrap(err, "getting file") } @@ -96,7 +100,7 @@ func (dg *downloadWithRetries) Get( resp.Body.Close() } - return nil, clues.New("malware detected").Label(graph.LabelsMalware) + return nil, clues.NewWC(ctx, "malware detected").Label(graph.LabelsMalware) } if resp != nil && (resp.StatusCode/100) != 2 { @@ -107,7 +111,7 @@ func (dg *downloadWithRetries) Get( // upstream error checks can compare the status with // clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode)) return nil, clues. - Wrap(clues.New(resp.Status), "non-2xx http response"). + Wrap(clues.NewWC(ctx, resp.Status), "non-2xx http response"). Label(graph.LabelStatus(resp.StatusCode)) } @@ -118,6 +122,7 @@ func downloadFile( ctx context.Context, ag api.Getter, url string, + requireAuth bool, ) (io.ReadCloser, error) { if len(url) == 0 { return nil, clues.NewWC(ctx, "empty file url") @@ -141,8 +146,9 @@ func downloadFile( rc, err := readers.NewResetRetryHandler( ctx, &downloadWithRetries{ - getter: ag, - url: url, + getter: ag, + requireAuth: requireAuth, + url: url, }) return rc, clues.Stack(err).OrNil() diff --git a/src/internal/m365/collection/drive/item_test.go b/src/internal/m365/collection/drive/item_test.go index 33dbc9ae0..23de93deb 100644 --- a/src/internal/m365/collection/drive/item_test.go +++ b/src/internal/m365/collection/drive/item_test.go @@ -109,7 +109,11 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() { } // Read data for the file - itemData, err := downloadItem(ctx, bh, custom.ToCustomDriveItem(driveItem)) + itemData, err := downloadItem( + ctx, + bh, + suite.m365.User.DriveID, + custom.ToCustomDriveItem(driveItem)) require.NoError(t, err, clues.ToCore(err)) size, err := io.Copy(io.Discard, itemData) @@ -292,6 +296,7 @@ func (m mockGetter) Get( ctx context.Context, url string, headers map[string]string, + requireAuth bool, ) (*http.Response, error) { return m.GetFunc(ctx, url) } @@ -379,7 +384,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() { return nil, clues.New("test error") }, errorExpected: require.Error, - rcExpected: require.Nil, + rcExpected: require.NotNil, }, { name: "download url is empty", @@ -416,7 +421,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() { }, nil }, errorExpected: require.Error, - rcExpected: require.Nil, + rcExpected: require.NotNil, }, { name: "non-2xx http response", @@ -435,7 +440,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() { }, nil }, errorExpected: require.Error, - rcExpected: require.Nil, + rcExpected: require.NotNil, }, } @@ -448,9 +453,78 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() { mg := mockGetter{ GetFunc: test.GetFunc, } - rc, err := downloadItem(ctx, mg, custom.ToCustomDriveItem(test.itemFunc())) + rc, err := downloadItem( + ctx, + mg, + "driveID", + custom.ToCustomDriveItem(test.itemFunc())) test.errorExpected(t, err, clues.ToCore(err)) - test.rcExpected(t, rc) + test.rcExpected(t, rc, "reader should only be nil if item is nil") + }) + } +} + +func (suite *ItemUnitTestSuite) TestDownloadItem_urlByFileSize() { + var ( + testRc = io.NopCloser(bytes.NewReader([]byte("test"))) + url = "https://example.com" + okResp = &http.Response{ + StatusCode: http.StatusOK, + Body: testRc, + } + ) + + table := []struct { + name string + itemFunc func() models.DriveItemable + GetFunc func(ctx context.Context, url string) (*http.Response, error) + errorExpected require.ErrorAssertionFunc + rcExpected require.ValueAssertionFunc + label string + }{ + { + name: "big file", + itemFunc: func() models.DriveItemable { + di := api.NewDriveItem("test", false) + di.SetAdditionalData(map[string]any{"@microsoft.graph.downloadUrl": url}) + di.SetSize(ptr.To[int64](20 * gigabyte)) + + return di + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + assert.Contains(suite.T(), url, "/content") + return okResp, nil + }, + }, + { + name: "small file", + itemFunc: func() models.DriveItemable { + di := api.NewDriveItem("test", false) + di.SetAdditionalData(map[string]any{"@microsoft.graph.downloadUrl": url}) + di.SetSize(ptr.To[int64](2 * gigabyte)) + + return di + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + assert.NotContains(suite.T(), url, "/content") + return okResp, nil + }, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + _, err := downloadItem( + ctx, + mockGetter{GetFunc: test.GetFunc}, + "driveID", + custom.ToCustomDriveItem(test.itemFunc())) + require.NoError(t, err, clues.ToCore(err)) }) } } @@ -507,7 +581,11 @@ func (suite *ItemUnitTestSuite) TestDownloadItem_ConnectionResetErrorOnFirstRead mg := mockGetter{ GetFunc: GetFunc, } - rc, err := downloadItem(ctx, mg, custom.ToCustomDriveItem(itemFunc())) + rc, err := downloadItem( + ctx, + mg, + "driveID", + custom.ToCustomDriveItem(itemFunc())) errorExpected(t, err, clues.ToCore(err)) rcExpected(t, rc) diff --git a/src/internal/m365/collection/drive/site_handler.go b/src/internal/m365/collection/drive/site_handler.go index 189131e04..e268be921 100644 --- a/src/internal/m365/collection/drive/site_handler.go +++ b/src/internal/m365/collection/drive/site_handler.go @@ -93,8 +93,9 @@ func (h siteBackupHandler) Get( ctx context.Context, url string, headers map[string]string, + requireAuth bool, ) (*http.Response, error) { - return h.ac.Get(ctx, url, headers) + return h.ac.Get(ctx, url, headers, requireAuth) } func (h siteBackupHandler) PathPrefix( diff --git a/src/internal/m365/collection/drive/url_cache_test.go b/src/internal/m365/collection/drive/url_cache_test.go index 2c47e67c3..ac8eabb92 100644 --- a/src/internal/m365/collection/drive/url_cache_test.go +++ b/src/internal/m365/collection/drive/url_cache_test.go @@ -154,7 +154,8 @@ func (suite *URLCacheIntegrationSuite) TestURLCacheBasic() { http.MethodGet, props.downloadURL, nil, - nil) + nil, + false) require.NoError(t, err, clues.ToCore(err)) require.NotNil(t, resp) diff --git a/src/internal/m365/collection/drive/user_drive_handler.go b/src/internal/m365/collection/drive/user_drive_handler.go index fcf3943d0..42e802f1b 100644 --- a/src/internal/m365/collection/drive/user_drive_handler.go +++ b/src/internal/m365/collection/drive/user_drive_handler.go @@ -93,8 +93,9 @@ func (h userDriveBackupHandler) Get( ctx context.Context, url string, headers map[string]string, + requireAuth bool, ) (*http.Response, error) { - return h.ac.Get(ctx, url, headers) + return h.ac.Get(ctx, url, headers, requireAuth) } func (h userDriveBackupHandler) PathPrefix( diff --git a/src/internal/m365/service/onedrive/mock/handlers.go b/src/internal/m365/service/onedrive/mock/handlers.go index e56c7bab5..2a54749f3 100644 --- a/src/internal/m365/service/onedrive/mock/handlers.go +++ b/src/internal/m365/service/onedrive/mock/handlers.go @@ -197,7 +197,12 @@ func (h BackupHandler[T]) AugmentItemInfo( return h.ItemInfo } -func (h *BackupHandler[T]) Get(context.Context, string, map[string]string) (*http.Response, error) { +func (h *BackupHandler[T]) Get( + context.Context, + string, + map[string]string, + bool, +) (*http.Response, error) { c := h.getCall h.getCall++ diff --git a/src/pkg/services/m365/api/access.go b/src/pkg/services/m365/api/access.go index 710430490..f6b3ee1d7 100644 --- a/src/pkg/services/m365/api/access.go +++ b/src/pkg/services/m365/api/access.go @@ -47,7 +47,7 @@ func (c Access) GetToken( c.Credentials.AzureClientSecret)) ) - resp, err := c.Post(ctx, rawURL, headers, body) + resp, err := c.Post(ctx, rawURL, headers, body, false) if err != nil { return clues.Stack(err) } diff --git a/src/pkg/services/m365/api/client.go b/src/pkg/services/m365/api/client.go index 972439b75..f454a5648 100644 --- a/src/pkg/services/m365/api/client.go +++ b/src/pkg/services/m365/api/client.go @@ -63,7 +63,14 @@ func NewClient( return Client{}, err } - rqr := graph.NewNoTimeoutHTTPWrapper(counter) + azureAuth, err := graph.NewAzureAuth(creds) + if err != nil { + return Client{}, clues.Wrap(err, "generating azure authorizer") + } + + rqr := graph.NewNoTimeoutHTTPWrapper( + counter, + graph.AuthorizeRequester(azureAuth)) if co.DeltaPageSize < 1 || co.DeltaPageSize > maxDeltaPageSize { co.DeltaPageSize = maxDeltaPageSize @@ -124,11 +131,7 @@ func newLargeItemService( counter *count.Bus, ) (*graph.Service, error) { a, err := NewService(creds, counter, graph.NoTimeout()) - if err != nil { - return nil, clues.Wrap(err, "generating no-timeout graph adapter") - } - - return a, nil + return a, clues.Wrap(err, "generating no-timeout graph adapter").OrNil() } type Getter interface { @@ -136,6 +139,7 @@ type Getter interface { ctx context.Context, url string, headers map[string]string, + requireAuth bool, ) (*http.Response, error) } @@ -144,8 +148,9 @@ func (c Client) Get( ctx context.Context, url string, headers map[string]string, + requireAuth bool, ) (*http.Response, error) { - return c.Requester.Request(ctx, http.MethodGet, url, nil, headers) + return c.Requester.Request(ctx, http.MethodGet, url, nil, headers, requireAuth) } // Get performs an ad-hoc get request using its graph.Requester @@ -154,8 +159,9 @@ func (c Client) Post( url string, headers map[string]string, body io.Reader, + requireAuth bool, ) (*http.Response, error) { - return c.Requester.Request(ctx, http.MethodGet, url, body, headers) + return c.Requester.Request(ctx, http.MethodGet, url, body, headers, requireAuth) } // --------------------------------------------------------------------------- diff --git a/src/pkg/services/m365/api/graph/auth.go b/src/pkg/services/m365/api/graph/auth.go new file mode 100644 index 000000000..da4cb43ee --- /dev/null +++ b/src/pkg/services/m365/api/graph/auth.go @@ -0,0 +1,94 @@ +package graph + +import ( + "context" + "net/http" + "net/url" + + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/alcionai/clues" + abstractions "github.com/microsoft/kiota-abstractions-go" + kauth "github.com/microsoft/kiota-authentication-azure-go" + + "github.com/alcionai/corso/src/pkg/account" +) + +func GetAuth(tenant, client, secret string) (*kauth.AzureIdentityAuthenticationProvider, error) { + // Client Provider: Uses Secret for access to tenant-level data + cred, err := azidentity.NewClientSecretCredential(tenant, client, secret, nil) + if err != nil { + return nil, clues.Wrap(err, "creating m365 client identity") + } + + auth, err := kauth.NewAzureIdentityAuthenticationProviderWithScopes( + cred, + []string{"https://graph.microsoft.com/.default"}) + if err != nil { + return nil, clues.Wrap(err, "creating azure authentication") + } + + return auth, nil +} + +// --------------------------------------------------------------------------- +// requester authorization +// --------------------------------------------------------------------------- + +type authorizer interface { + addAuthToHeaders( + ctx context.Context, + urlStr string, + headers http.Header, + ) error +} + +// consumed by kiota +type authenticateRequester interface { + AuthenticateRequest( + ctx context.Context, + request *abstractions.RequestInformation, + additionalAuthenticationContext map[string]any, + ) error +} + +// --------------------------------------------------------------------------- +// Azure Authorizer +// --------------------------------------------------------------------------- + +type azureAuth struct { + auth authenticateRequester +} + +func NewAzureAuth(creds account.M365Config) (*azureAuth, error) { + auth, err := GetAuth( + creds.AzureTenantID, + creds.AzureClientID, + creds.AzureClientSecret) + + return &azureAuth{auth}, clues.Stack(err).OrNil() +} + +func (aa azureAuth) addAuthToHeaders( + ctx context.Context, + urlStr string, + headers http.Header, +) error { + requestInfo := abstractions.NewRequestInformation() + + uri, err := url.Parse(urlStr) + if err != nil { + return clues.WrapWC(ctx, err, "parsing url").OrNil() + } + + requestInfo.SetUri(*uri) + + err = aa.auth.AuthenticateRequest(ctx, requestInfo, nil) + + for _, k := range requestInfo.Headers.ListKeys() { + for _, v := range requestInfo.Headers.Get(k) { + headers.Add(k, v) + } + } + + return clues.WrapWC(ctx, err, "authorizing request").OrNil() +} diff --git a/src/pkg/services/m365/api/graph/concurrency_middleware.go b/src/pkg/services/m365/api/graph/concurrency_middleware.go index ee9d62f73..3694a9f2e 100644 --- a/src/pkg/services/m365/api/graph/concurrency_middleware.go +++ b/src/pkg/services/m365/api/graph/concurrency_middleware.go @@ -240,7 +240,7 @@ func (mw *RateLimiterMiddleware) Intercept( middlewareIndex int, req *http.Request, ) (*http.Response, error) { - QueueRequest(req.Context()) + QueueRequest(getReqCtx(req)) return pipeline.Next(req, middlewareIndex) } @@ -339,7 +339,7 @@ func (mw *throttlingMiddleware) Intercept( middlewareIndex int, req *http.Request, ) (*http.Response, error) { - err := mw.tf.Block(req.Context()) + err := mw.tf.Block(getReqCtx(req)) if err != nil { return nil, err } diff --git a/src/pkg/services/m365/api/graph/http_wrapper.go b/src/pkg/services/m365/api/graph/http_wrapper.go index 17cf51d3a..77f5766bf 100644 --- a/src/pkg/services/m365/api/graph/http_wrapper.go +++ b/src/pkg/services/m365/api/graph/http_wrapper.go @@ -36,6 +36,7 @@ type Requester interface { method, url string, body io.Reader, headers map[string]string, + requireAuth bool, ) (*http.Response, error) } @@ -58,12 +59,8 @@ func NewHTTPWrapper( transport: defaultTransport(), }, } - redirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } hc = &http.Client{ - CheckRedirect: redirect, - Transport: rt, + Transport: rt, } ) @@ -100,6 +97,7 @@ func (hw httpWrapper) Request( method, url string, body io.Reader, headers map[string]string, + requireAuth bool, ) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { @@ -115,6 +113,17 @@ func (hw httpWrapper) Request( // See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version) + if requireAuth { + if hw.config.requesterAuth == nil { + return nil, clues.Wrap(err, "http wrapper misconfigured: missing required authorization") + } + + err := hw.config.requesterAuth.addAuthToHeaders(ctx, url, req.Header) + if err != nil { + return nil, clues.Wrap(err, "setting request auth headers") + } + } + retriedErrors := []string{} var e error diff --git a/src/pkg/services/m365/api/graph/http_wrapper_test.go b/src/pkg/services/m365/api/graph/http_wrapper_test.go index 555af8ffd..12fbaa7af 100644 --- a/src/pkg/services/m365/api/graph/http_wrapper_test.go +++ b/src/pkg/services/m365/api/graph/http_wrapper_test.go @@ -40,9 +40,10 @@ func (suite *HTTPWrapperIntgSuite) TestNewHTTPWrapper() { resp, err := hw.Request( ctx, http.MethodGet, - "https://www.corsobackup.io", + "https://www.google.com", nil, - nil) + nil, + false) require.NoError(t, err, clues.ToCore(err)) defer resp.Body.Close() @@ -76,6 +77,56 @@ func (mw *mwForceResp) Intercept( return mw.resp, mw.err } +func (suite *HTTPWrapperIntgSuite) TestHTTPWrapper_Request_withAuth() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + a := tconfig.NewM365Account(t) + m365, err := a.M365Config() + require.NoError(t, err, clues.ToCore(err)) + + azureAuth, err := NewAzureAuth(m365) + require.NoError(t, err, clues.ToCore(err)) + + hw := NewHTTPWrapper(count.New(), AuthorizeRequester(azureAuth)) + + // any request that requires authorization will do + resp, err := hw.Request( + ctx, + http.MethodGet, + "https://graph.microsoft.com/v1.0/users", + nil, + nil, + true) + require.NoError(t, err, clues.ToCore(err)) + + defer resp.Body.Close() + + require.NotNil(t, resp) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // also validate that non-auth'd endpoints succeed + resp, err = hw.Request( + ctx, + http.MethodGet, + "https://www.google.com", + nil, + nil, + true) + require.NoError(t, err, clues.ToCore(err)) + + defer resp.Body.Close() + + require.NotNil(t, resp) + require.Equal(t, http.StatusOK, resp.StatusCode) +} + +// --------------------------------------------------------------------------- +// unit +// --------------------------------------------------------------------------- + type HTTPWrapperUnitSuite struct { tester.Suite } @@ -84,26 +135,25 @@ func TestHTTPWrapperUnitSuite(t *testing.T) { suite.Run(t, &HTTPWrapperUnitSuite{Suite: tester.NewUnitSuite(t)}) } -func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_redirectMiddleware() { +func (suite *HTTPWrapperUnitSuite) TestHTTPWrapper_Request_redirect() { t := suite.T() ctx, flush := tester.NewContext(t) defer flush() - url := "https://graph.microsoft.com/fnords/beaux/regard" - - hdr := http.Header{} - hdr.Set("Location", "localhost:99999999/smarfs") + respHdr := http.Header{} + respHdr.Set("Location", "localhost:99999999/smarfs") toResp := &http.Response{ StatusCode: http.StatusFound, - Header: hdr, + Header: respHdr, } mwResp := mwForceResp{ resp: toResp, alternate: func(req *http.Request) (bool, *http.Response, error) { if strings.HasSuffix(req.URL.String(), "smarfs") { + assert.Equal(t, req.Header.Get("X-Test-Val"), "should-be-copied-to-redirect") return true, &http.Response{StatusCode: http.StatusOK}, nil } @@ -113,17 +163,22 @@ func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_redirectMiddleware() { hw := NewHTTPWrapper(count.New(), appendMiddleware(&mwResp)) - resp, err := hw.Request(ctx, http.MethodGet, url, nil, nil) + resp, err := hw.Request( + ctx, + http.MethodGet, + "https://graph.microsoft.com/fnords/beaux/regard", + nil, + map[string]string{"X-Test-Val": "should-be-copied-to-redirect"}, + false) require.NoError(t, err, clues.ToCore(err)) defer resp.Body.Close() require.NotNil(t, resp) - // require.Equal(t, 1, calledCorrectly, "test server was called with expected path") require.Equal(t, http.StatusOK, resp.StatusCode) } -func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_http2StreamErrorRetries() { +func (suite *HTTPWrapperUnitSuite) TestHTTPWrapper_Request_http2StreamErrorRetries() { var ( url = "https://graph.microsoft.com/fnords/beaux/regard" streamErr = http2.StreamError{ @@ -188,7 +243,7 @@ func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_http2StreamErrorRetries() // the test middleware. hw.retryDelay = 0 - _, err := hw.Request(ctx, http.MethodGet, url, nil, nil) + _, err := hw.Request(ctx, http.MethodGet, url, nil, nil, false) require.ErrorAs(t, err, &http2.StreamError{}, clues.ToCore(err)) require.Equal(t, test.expectRetries, tries, "count of retries") }) diff --git a/src/pkg/services/m365/api/graph/logging.go b/src/pkg/services/m365/api/graph/logging.go index 2b6e536c9..09283ec2b 100644 --- a/src/pkg/services/m365/api/graph/logging.go +++ b/src/pkg/services/m365/api/graph/logging.go @@ -6,6 +6,9 @@ import ( "net/http/httputil" "os" + "github.com/alcionai/clues" + + "github.com/alcionai/corso/src/internal/common/pii" "github.com/alcionai/corso/src/pkg/logger" ) @@ -69,3 +72,22 @@ func getRespDump(ctx context.Context, resp *http.Response, getBody bool) string return string(respDump) } + +func getReqCtx(req *http.Request) context.Context { + if req == nil { + return context.Background() + } + + var logURL pii.SafeURL + + if req.URL != nil { + logURL = LoggableURL(req.URL.String()) + } + + return clues.AddTraceName( + req.Context(), + "graph-http-middleware", + "method", req.Method, + "url", logURL, + "request_content_len", req.ContentLength) +} diff --git a/src/pkg/services/m365/api/graph/middleware.go b/src/pkg/services/m365/api/graph/middleware.go index 0c5156ed0..2a56b1e07 100644 --- a/src/pkg/services/m365/api/graph/middleware.go +++ b/src/pkg/services/m365/api/graph/middleware.go @@ -125,10 +125,7 @@ func (mw *LoggingMiddleware) Intercept( } ctx := clues.Add( - req.Context(), - "method", req.Method, - "url", LoggableURL(req.URL.String()), - "request_content_len", req.ContentLength, + getReqCtx(req), "resp_status", resp.Status, "resp_status_code", resp.StatusCode, "resp_content_len", resp.ContentLength) @@ -156,7 +153,7 @@ func (mw RetryMiddleware) Intercept( middlewareIndex int, req *http.Request, ) (*http.Response, error) { - ctx := req.Context() + ctx := getReqCtx(req) resp, err := pipeline.Next(req, middlewareIndex) retriable := IsErrTimeout(err) || @@ -249,7 +246,9 @@ func (mw RetryMiddleware) retryRequest( return resp, Wrap(ctx, err, "resetting request body reader") } } else { - logger.Ctx(ctx).Error("body is not an io.Seeker: unable to reset request body") + logger. + Ctx(getReqCtx(req)). + Error("body is not an io.Seeker: unable to reset request body") } } diff --git a/src/pkg/services/m365/api/graph/service.go b/src/pkg/services/m365/api/graph/service.go index 748e82f1a..ebb8a854d 100644 --- a/src/pkg/services/m365/api/graph/service.go +++ b/src/pkg/services/m365/api/graph/service.go @@ -6,11 +6,9 @@ import ( "net/http" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/alcionai/clues" abstractions "github.com/microsoft/kiota-abstractions-go" "github.com/microsoft/kiota-abstractions-go/serialization" - kauth "github.com/microsoft/kiota-authentication-azure-go" khttp "github.com/microsoft/kiota-http-go" msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go" msgraphgocore "github.com/microsoftgraph/msgraph-sdk-go-core" @@ -127,23 +125,6 @@ func CreateAdapter( return wrapAdapter(adpt, cc), nil } -func GetAuth(tenant string, client string, secret string) (*kauth.AzureIdentityAuthenticationProvider, error) { - // Client Provider: Uses Secret for access to tenant-level data - cred, err := azidentity.NewClientSecretCredential(tenant, client, secret, nil) - if err != nil { - return nil, clues.Wrap(err, "creating m365 client identity") - } - - auth, err := kauth.NewAzureIdentityAuthenticationProviderWithScopes( - cred, - []string{"https://graph.microsoft.com/.default"}) - if err != nil { - return nil, clues.Wrap(err, "creating azure authentication") - } - - return auth, nil -} - // KiotaHTTPClient creates a httpClient with middlewares and timeout configured // for use in the graph adapter. // @@ -200,6 +181,11 @@ type clientConfig struct { maxRetries int // The minimum delay in seconds between retries minDelay time.Duration + // requesterAuth sets the authorization step for requester-compliant clients. + // if non-nil, it will ensure calls are authorized before querying. + // does not get consumed by the standard graph client, which already comes + // packaged with an auth protocol. + requesterAuth authorizer appendMiddleware []khttp.Middleware } @@ -287,6 +273,12 @@ func MaxConnectionRetries(max int) Option { } } +func AuthorizeRequester(a authorizer) Option { + return func(c *clientConfig) { + c.requesterAuth = a + } +} + // --------------------------------------------------------------------------- // Middleware Control // --------------------------------------------------------------------------- diff --git a/src/pkg/services/m365/api/graph/uploadsession.go b/src/pkg/services/m365/api/graph/uploadsession.go index a511d81b2..8a439f3c3 100644 --- a/src/pkg/services/m365/api/graph/uploadsession.go +++ b/src/pkg/services/m365/api/graph/uploadsession.go @@ -77,7 +77,8 @@ func (iw *largeItemWriter) Write(p []byte) (int, error) { http.MethodPut, iw.url, bytes.NewReader(p), - headers) + headers, + false) if err != nil { return 0, clues.Wrap(err, "uploading item").With( "upload_id", iw.parentID,