diff --git a/src/internal/common/str/str.go b/src/internal/common/str/str.go index 41919b2b7..201d241e0 100644 --- a/src/internal/common/str/str.go +++ b/src/internal/common/str/str.go @@ -59,6 +59,19 @@ func First(vs ...string) string { return "" } +// FirstIn returns the first entry in the map with a non-zero value +// when iterating the provided list of keys. +func FirstIn(m map[string]any, keys ...string) string { + for _, key := range keys { + v, err := AnyValueToString(key, m) + if err == nil && len(v) > 0 { + return v + } + } + + return "" +} + // Preview reduces the string to the specified size. // If the string is longer than the size, the last three // characters are replaced with an ellipsis. Size < 4 diff --git a/src/internal/common/str/str_test.go b/src/internal/common/str/str_test.go index 11af84e93..879c16aab 100644 --- a/src/internal/common/str/str_test.go +++ b/src/internal/common/str/str_test.go @@ -118,3 +118,96 @@ func TestGenerateHash(t *testing.T) { } } } + +func TestFirstIn(t *testing.T) { + table := []struct { + name string + m map[string]any + keys []string + expect string + }{ + { + name: "nil map", + keys: []string{"foo", "bar"}, + expect: "", + }, + { + name: "empty map", + m: map[string]any{}, + keys: []string{"foo", "bar"}, + expect: "", + }, + { + name: "no match", + m: map[string]any{ + "baz": "baz", + }, + keys: []string{"foo", "bar"}, + expect: "", + }, + { + name: "no keys", + m: map[string]any{ + "baz": "baz", + }, + keys: []string{}, + expect: "", + }, + { + name: "nil match", + m: map[string]any{ + "foo": nil, + }, + keys: []string{"foo", "bar"}, + expect: "", + }, + { + name: "empty match", + m: map[string]any{ + "foo": "", + }, + keys: []string{"foo", "bar"}, + expect: "", + }, + { + name: "matches first key", + m: map[string]any{ + "foo": "fnords", + }, + keys: []string{"foo", "bar"}, + expect: "fnords", + }, + { + name: "matches second key", + m: map[string]any{ + "bar": "smarf", + }, + keys: []string{"foo", "bar"}, + expect: "smarf", + }, + { + name: "matches second key with nil first match", + m: map[string]any{ + "foo": nil, + "bar": "smarf", + }, + keys: []string{"foo", "bar"}, + expect: "smarf", + }, + { + name: "matches second key with empty first match", + m: map[string]any{ + "foo": "", + "bar": "smarf", + }, + keys: []string{"foo", "bar"}, + expect: "smarf", + }, + } + for _, test := range table { + t.Run(test.name, func(t *testing.T) { + result := FirstIn(test.m, test.keys...) + assert.Equal(t, test.expect, result) + }) + } +} diff --git a/src/internal/m365/collection/drive/collection.go b/src/internal/m365/collection/drive/collection.go index 7ef897fed..2ba53de42 100644 --- a/src/internal/m365/collection/drive/collection.go +++ b/src/internal/m365/collection/drive/collection.go @@ -366,7 +366,7 @@ func downloadContent( itemID := ptr.Val(item.GetId()) ctx = clues.Add(ctx, "item_id", itemID) - content, err := downloadItem(ctx, iaag, item) + content, err := downloadItem(ctx, iaag, driveID, item) if err == nil { return content, nil } else if !graph.IsErrUnauthorizedOrBadToken(err) { @@ -395,7 +395,7 @@ func downloadContent( cdi := custom.ToCustomDriveItem(di) - content, err = downloadItem(ctx, iaag, cdi) + content, err = downloadItem(ctx, iaag, driveID, cdi) if err != nil { return nil, clues.Wrap(err, "content download retry") } @@ -426,7 +426,7 @@ func readItemContents( return nil, core.ErrNotFound } - rc, err := downloadFile(ctx, iaag, props.downloadURL) + rc, err := downloadFile(ctx, iaag, props.downloadURL, false) if graph.IsErrUnauthorizedOrBadToken(err) { logger.CtxErr(ctx, err).Debug("stale item in cache") } diff --git a/src/internal/m365/collection/drive/helper_test.go b/src/internal/m365/collection/drive/helper_test.go index 8220f5ed0..a26c6aa01 100644 --- a/src/internal/m365/collection/drive/helper_test.go +++ b/src/internal/m365/collection/drive/helper_test.go @@ -795,7 +795,12 @@ func (h mockBackupHandler[T]) AugmentItemInfo( return h.ItemInfo } -func (h *mockBackupHandler[T]) Get(context.Context, string, map[string]string) (*http.Response, error) { +func (h *mockBackupHandler[T]) Get( + context.Context, + string, + map[string]string, + bool, +) (*http.Response, error) { c := h.getCall h.getCall++ diff --git a/src/internal/m365/collection/drive/item.go b/src/internal/m365/collection/drive/item.go index 02ac6010a..be5c255bc 100644 --- a/src/internal/m365/collection/drive/item.go +++ b/src/internal/m365/collection/drive/item.go @@ -21,8 +21,10 @@ import ( ) const ( - acceptHeaderKey = "Accept" - acceptHeaderValue = "*/*" + acceptHeaderKey = "Accept" + acceptHeaderValue = "*/*" + gigabyte = 1024 * 1024 * 1024 + largeFileDownloadLimit = 15 * gigabyte ) // downloadUrlKeys is used to find the download URL in a DriveItem response. @@ -33,7 +35,8 @@ var downloadURLKeys = []string{ func downloadItem( ctx context.Context, - ag api.Getter, + getter api.Getter, + driveID string, item *custom.DriveItem, ) (io.ReadCloser, error) { if item == nil { @@ -41,36 +44,37 @@ func downloadItem( } var ( - rc io.ReadCloser - isFile = item.GetFile() != nil - err error + // very large file content needs to be downloaded through a different endpoint, or else + // the download could take longer than the lifespan of the download token in the cached + // url, which will cause us to timeout on every download request, even if we refresh the + // download url right before the query. + url = "https://graph.microsoft.com/v1.0/drives/" + driveID + "/items/" + ptr.Val(item.GetId()) + "/content" + reader io.ReadCloser + err error + isLargeFile = ptr.Val(item.GetSize()) > largeFileDownloadLimit ) - if isFile { - var ( - url string - ad = item.GetAdditionalData() - ) - - for _, key := range downloadURLKeys { - if v, err := str.AnyValueToString(key, ad); err == nil { - url = v - break - } - } - - rc, err = downloadFile(ctx, ag, url) - if err != nil { - return nil, clues.Stack(err) - } + // if this isn't a file, no content is available for download + if item.GetFile() == nil { + return reader, nil } - return rc, nil + // smaller files will maintain our current behavior (prefetching the download url with the + // url cache). That pattern works for us in general, and we only need to deviate for very + // large file sizes. + if !isLargeFile { + url = str.FirstIn(item.GetAdditionalData(), downloadURLKeys...) + } + + reader, err = downloadFile(ctx, getter, url, isLargeFile) + + return reader, clues.StackWC(ctx, err).OrNil() } type downloadWithRetries struct { - getter api.Getter - url string + getter api.Getter + requireAuth bool + url string } func (dg *downloadWithRetries) SupportsRange() bool { @@ -86,7 +90,7 @@ func (dg *downloadWithRetries) Get( // wouldn't work without it (get 416 responses instead of 206). headers[acceptHeaderKey] = acceptHeaderValue - resp, err := dg.getter.Get(ctx, dg.url, headers) + resp, err := dg.getter.Get(ctx, dg.url, headers, dg.requireAuth) if err != nil { return nil, clues.Wrap(err, "getting file") } @@ -96,7 +100,7 @@ func (dg *downloadWithRetries) Get( resp.Body.Close() } - return nil, clues.New("malware detected").Label(graph.LabelsMalware) + return nil, clues.NewWC(ctx, "malware detected").Label(graph.LabelsMalware) } if resp != nil && (resp.StatusCode/100) != 2 { @@ -107,7 +111,7 @@ func (dg *downloadWithRetries) Get( // upstream error checks can compare the status with // clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode)) return nil, clues. - Wrap(clues.New(resp.Status), "non-2xx http response"). + Wrap(clues.NewWC(ctx, resp.Status), "non-2xx http response"). Label(graph.LabelStatus(resp.StatusCode)) } @@ -118,6 +122,7 @@ func downloadFile( ctx context.Context, ag api.Getter, url string, + requireAuth bool, ) (io.ReadCloser, error) { if len(url) == 0 { return nil, clues.NewWC(ctx, "empty file url") @@ -141,8 +146,9 @@ func downloadFile( rc, err := readers.NewResetRetryHandler( ctx, &downloadWithRetries{ - getter: ag, - url: url, + getter: ag, + requireAuth: requireAuth, + url: url, }) return rc, clues.Stack(err).OrNil() diff --git a/src/internal/m365/collection/drive/item_test.go b/src/internal/m365/collection/drive/item_test.go index 33dbc9ae0..23de93deb 100644 --- a/src/internal/m365/collection/drive/item_test.go +++ b/src/internal/m365/collection/drive/item_test.go @@ -109,7 +109,11 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() { } // Read data for the file - itemData, err := downloadItem(ctx, bh, custom.ToCustomDriveItem(driveItem)) + itemData, err := downloadItem( + ctx, + bh, + suite.m365.User.DriveID, + custom.ToCustomDriveItem(driveItem)) require.NoError(t, err, clues.ToCore(err)) size, err := io.Copy(io.Discard, itemData) @@ -292,6 +296,7 @@ func (m mockGetter) Get( ctx context.Context, url string, headers map[string]string, + requireAuth bool, ) (*http.Response, error) { return m.GetFunc(ctx, url) } @@ -379,7 +384,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() { return nil, clues.New("test error") }, errorExpected: require.Error, - rcExpected: require.Nil, + rcExpected: require.NotNil, }, { name: "download url is empty", @@ -416,7 +421,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() { }, nil }, errorExpected: require.Error, - rcExpected: require.Nil, + rcExpected: require.NotNil, }, { name: "non-2xx http response", @@ -435,7 +440,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() { }, nil }, errorExpected: require.Error, - rcExpected: require.Nil, + rcExpected: require.NotNil, }, } @@ -448,9 +453,78 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() { mg := mockGetter{ GetFunc: test.GetFunc, } - rc, err := downloadItem(ctx, mg, custom.ToCustomDriveItem(test.itemFunc())) + rc, err := downloadItem( + ctx, + mg, + "driveID", + custom.ToCustomDriveItem(test.itemFunc())) test.errorExpected(t, err, clues.ToCore(err)) - test.rcExpected(t, rc) + test.rcExpected(t, rc, "reader should only be nil if item is nil") + }) + } +} + +func (suite *ItemUnitTestSuite) TestDownloadItem_urlByFileSize() { + var ( + testRc = io.NopCloser(bytes.NewReader([]byte("test"))) + url = "https://example.com" + okResp = &http.Response{ + StatusCode: http.StatusOK, + Body: testRc, + } + ) + + table := []struct { + name string + itemFunc func() models.DriveItemable + GetFunc func(ctx context.Context, url string) (*http.Response, error) + errorExpected require.ErrorAssertionFunc + rcExpected require.ValueAssertionFunc + label string + }{ + { + name: "big file", + itemFunc: func() models.DriveItemable { + di := api.NewDriveItem("test", false) + di.SetAdditionalData(map[string]any{"@microsoft.graph.downloadUrl": url}) + di.SetSize(ptr.To[int64](20 * gigabyte)) + + return di + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + assert.Contains(suite.T(), url, "/content") + return okResp, nil + }, + }, + { + name: "small file", + itemFunc: func() models.DriveItemable { + di := api.NewDriveItem("test", false) + di.SetAdditionalData(map[string]any{"@microsoft.graph.downloadUrl": url}) + di.SetSize(ptr.To[int64](2 * gigabyte)) + + return di + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + assert.NotContains(suite.T(), url, "/content") + return okResp, nil + }, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + _, err := downloadItem( + ctx, + mockGetter{GetFunc: test.GetFunc}, + "driveID", + custom.ToCustomDriveItem(test.itemFunc())) + require.NoError(t, err, clues.ToCore(err)) }) } } @@ -507,7 +581,11 @@ func (suite *ItemUnitTestSuite) TestDownloadItem_ConnectionResetErrorOnFirstRead mg := mockGetter{ GetFunc: GetFunc, } - rc, err := downloadItem(ctx, mg, custom.ToCustomDriveItem(itemFunc())) + rc, err := downloadItem( + ctx, + mg, + "driveID", + custom.ToCustomDriveItem(itemFunc())) errorExpected(t, err, clues.ToCore(err)) rcExpected(t, rc) diff --git a/src/internal/m365/collection/drive/site_handler.go b/src/internal/m365/collection/drive/site_handler.go index 189131e04..e268be921 100644 --- a/src/internal/m365/collection/drive/site_handler.go +++ b/src/internal/m365/collection/drive/site_handler.go @@ -93,8 +93,9 @@ func (h siteBackupHandler) Get( ctx context.Context, url string, headers map[string]string, + requireAuth bool, ) (*http.Response, error) { - return h.ac.Get(ctx, url, headers) + return h.ac.Get(ctx, url, headers, requireAuth) } func (h siteBackupHandler) PathPrefix( diff --git a/src/internal/m365/collection/drive/url_cache_test.go b/src/internal/m365/collection/drive/url_cache_test.go index 2c47e67c3..ac8eabb92 100644 --- a/src/internal/m365/collection/drive/url_cache_test.go +++ b/src/internal/m365/collection/drive/url_cache_test.go @@ -154,7 +154,8 @@ func (suite *URLCacheIntegrationSuite) TestURLCacheBasic() { http.MethodGet, props.downloadURL, nil, - nil) + nil, + false) require.NoError(t, err, clues.ToCore(err)) require.NotNil(t, resp) diff --git a/src/internal/m365/collection/drive/user_drive_handler.go b/src/internal/m365/collection/drive/user_drive_handler.go index fcf3943d0..42e802f1b 100644 --- a/src/internal/m365/collection/drive/user_drive_handler.go +++ b/src/internal/m365/collection/drive/user_drive_handler.go @@ -93,8 +93,9 @@ func (h userDriveBackupHandler) Get( ctx context.Context, url string, headers map[string]string, + requireAuth bool, ) (*http.Response, error) { - return h.ac.Get(ctx, url, headers) + return h.ac.Get(ctx, url, headers, requireAuth) } func (h userDriveBackupHandler) PathPrefix( diff --git a/src/internal/m365/service/onedrive/mock/handlers.go b/src/internal/m365/service/onedrive/mock/handlers.go index e56c7bab5..2a54749f3 100644 --- a/src/internal/m365/service/onedrive/mock/handlers.go +++ b/src/internal/m365/service/onedrive/mock/handlers.go @@ -197,7 +197,12 @@ func (h BackupHandler[T]) AugmentItemInfo( return h.ItemInfo } -func (h *BackupHandler[T]) Get(context.Context, string, map[string]string) (*http.Response, error) { +func (h *BackupHandler[T]) Get( + context.Context, + string, + map[string]string, + bool, +) (*http.Response, error) { c := h.getCall h.getCall++ diff --git a/src/pkg/services/m365/api/access.go b/src/pkg/services/m365/api/access.go index 710430490..f6b3ee1d7 100644 --- a/src/pkg/services/m365/api/access.go +++ b/src/pkg/services/m365/api/access.go @@ -47,7 +47,7 @@ func (c Access) GetToken( c.Credentials.AzureClientSecret)) ) - resp, err := c.Post(ctx, rawURL, headers, body) + resp, err := c.Post(ctx, rawURL, headers, body, false) if err != nil { return clues.Stack(err) } diff --git a/src/pkg/services/m365/api/client.go b/src/pkg/services/m365/api/client.go index 972439b75..f454a5648 100644 --- a/src/pkg/services/m365/api/client.go +++ b/src/pkg/services/m365/api/client.go @@ -63,7 +63,14 @@ func NewClient( return Client{}, err } - rqr := graph.NewNoTimeoutHTTPWrapper(counter) + azureAuth, err := graph.NewAzureAuth(creds) + if err != nil { + return Client{}, clues.Wrap(err, "generating azure authorizer") + } + + rqr := graph.NewNoTimeoutHTTPWrapper( + counter, + graph.AuthorizeRequester(azureAuth)) if co.DeltaPageSize < 1 || co.DeltaPageSize > maxDeltaPageSize { co.DeltaPageSize = maxDeltaPageSize @@ -124,11 +131,7 @@ func newLargeItemService( counter *count.Bus, ) (*graph.Service, error) { a, err := NewService(creds, counter, graph.NoTimeout()) - if err != nil { - return nil, clues.Wrap(err, "generating no-timeout graph adapter") - } - - return a, nil + return a, clues.Wrap(err, "generating no-timeout graph adapter").OrNil() } type Getter interface { @@ -136,6 +139,7 @@ type Getter interface { ctx context.Context, url string, headers map[string]string, + requireAuth bool, ) (*http.Response, error) } @@ -144,8 +148,9 @@ func (c Client) Get( ctx context.Context, url string, headers map[string]string, + requireAuth bool, ) (*http.Response, error) { - return c.Requester.Request(ctx, http.MethodGet, url, nil, headers) + return c.Requester.Request(ctx, http.MethodGet, url, nil, headers, requireAuth) } // Get performs an ad-hoc get request using its graph.Requester @@ -154,8 +159,9 @@ func (c Client) Post( url string, headers map[string]string, body io.Reader, + requireAuth bool, ) (*http.Response, error) { - return c.Requester.Request(ctx, http.MethodGet, url, body, headers) + return c.Requester.Request(ctx, http.MethodGet, url, body, headers, requireAuth) } // --------------------------------------------------------------------------- diff --git a/src/pkg/services/m365/api/graph/auth.go b/src/pkg/services/m365/api/graph/auth.go new file mode 100644 index 000000000..da4cb43ee --- /dev/null +++ b/src/pkg/services/m365/api/graph/auth.go @@ -0,0 +1,94 @@ +package graph + +import ( + "context" + "net/http" + "net/url" + + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/alcionai/clues" + abstractions "github.com/microsoft/kiota-abstractions-go" + kauth "github.com/microsoft/kiota-authentication-azure-go" + + "github.com/alcionai/corso/src/pkg/account" +) + +func GetAuth(tenant, client, secret string) (*kauth.AzureIdentityAuthenticationProvider, error) { + // Client Provider: Uses Secret for access to tenant-level data + cred, err := azidentity.NewClientSecretCredential(tenant, client, secret, nil) + if err != nil { + return nil, clues.Wrap(err, "creating m365 client identity") + } + + auth, err := kauth.NewAzureIdentityAuthenticationProviderWithScopes( + cred, + []string{"https://graph.microsoft.com/.default"}) + if err != nil { + return nil, clues.Wrap(err, "creating azure authentication") + } + + return auth, nil +} + +// --------------------------------------------------------------------------- +// requester authorization +// --------------------------------------------------------------------------- + +type authorizer interface { + addAuthToHeaders( + ctx context.Context, + urlStr string, + headers http.Header, + ) error +} + +// consumed by kiota +type authenticateRequester interface { + AuthenticateRequest( + ctx context.Context, + request *abstractions.RequestInformation, + additionalAuthenticationContext map[string]any, + ) error +} + +// --------------------------------------------------------------------------- +// Azure Authorizer +// --------------------------------------------------------------------------- + +type azureAuth struct { + auth authenticateRequester +} + +func NewAzureAuth(creds account.M365Config) (*azureAuth, error) { + auth, err := GetAuth( + creds.AzureTenantID, + creds.AzureClientID, + creds.AzureClientSecret) + + return &azureAuth{auth}, clues.Stack(err).OrNil() +} + +func (aa azureAuth) addAuthToHeaders( + ctx context.Context, + urlStr string, + headers http.Header, +) error { + requestInfo := abstractions.NewRequestInformation() + + uri, err := url.Parse(urlStr) + if err != nil { + return clues.WrapWC(ctx, err, "parsing url").OrNil() + } + + requestInfo.SetUri(*uri) + + err = aa.auth.AuthenticateRequest(ctx, requestInfo, nil) + + for _, k := range requestInfo.Headers.ListKeys() { + for _, v := range requestInfo.Headers.Get(k) { + headers.Add(k, v) + } + } + + return clues.WrapWC(ctx, err, "authorizing request").OrNil() +} diff --git a/src/pkg/services/m365/api/graph/concurrency_middleware.go b/src/pkg/services/m365/api/graph/concurrency_middleware.go index ee9d62f73..3694a9f2e 100644 --- a/src/pkg/services/m365/api/graph/concurrency_middleware.go +++ b/src/pkg/services/m365/api/graph/concurrency_middleware.go @@ -240,7 +240,7 @@ func (mw *RateLimiterMiddleware) Intercept( middlewareIndex int, req *http.Request, ) (*http.Response, error) { - QueueRequest(req.Context()) + QueueRequest(getReqCtx(req)) return pipeline.Next(req, middlewareIndex) } @@ -339,7 +339,7 @@ func (mw *throttlingMiddleware) Intercept( middlewareIndex int, req *http.Request, ) (*http.Response, error) { - err := mw.tf.Block(req.Context()) + err := mw.tf.Block(getReqCtx(req)) if err != nil { return nil, err } diff --git a/src/pkg/services/m365/api/graph/http_wrapper.go b/src/pkg/services/m365/api/graph/http_wrapper.go index 17cf51d3a..77f5766bf 100644 --- a/src/pkg/services/m365/api/graph/http_wrapper.go +++ b/src/pkg/services/m365/api/graph/http_wrapper.go @@ -36,6 +36,7 @@ type Requester interface { method, url string, body io.Reader, headers map[string]string, + requireAuth bool, ) (*http.Response, error) } @@ -58,12 +59,8 @@ func NewHTTPWrapper( transport: defaultTransport(), }, } - redirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } hc = &http.Client{ - CheckRedirect: redirect, - Transport: rt, + Transport: rt, } ) @@ -100,6 +97,7 @@ func (hw httpWrapper) Request( method, url string, body io.Reader, headers map[string]string, + requireAuth bool, ) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { @@ -115,6 +113,17 @@ func (hw httpWrapper) Request( // See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version) + if requireAuth { + if hw.config.requesterAuth == nil { + return nil, clues.Wrap(err, "http wrapper misconfigured: missing required authorization") + } + + err := hw.config.requesterAuth.addAuthToHeaders(ctx, url, req.Header) + if err != nil { + return nil, clues.Wrap(err, "setting request auth headers") + } + } + retriedErrors := []string{} var e error diff --git a/src/pkg/services/m365/api/graph/http_wrapper_test.go b/src/pkg/services/m365/api/graph/http_wrapper_test.go index 555af8ffd..12fbaa7af 100644 --- a/src/pkg/services/m365/api/graph/http_wrapper_test.go +++ b/src/pkg/services/m365/api/graph/http_wrapper_test.go @@ -40,9 +40,10 @@ func (suite *HTTPWrapperIntgSuite) TestNewHTTPWrapper() { resp, err := hw.Request( ctx, http.MethodGet, - "https://www.corsobackup.io", + "https://www.google.com", nil, - nil) + nil, + false) require.NoError(t, err, clues.ToCore(err)) defer resp.Body.Close() @@ -76,6 +77,56 @@ func (mw *mwForceResp) Intercept( return mw.resp, mw.err } +func (suite *HTTPWrapperIntgSuite) TestHTTPWrapper_Request_withAuth() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + a := tconfig.NewM365Account(t) + m365, err := a.M365Config() + require.NoError(t, err, clues.ToCore(err)) + + azureAuth, err := NewAzureAuth(m365) + require.NoError(t, err, clues.ToCore(err)) + + hw := NewHTTPWrapper(count.New(), AuthorizeRequester(azureAuth)) + + // any request that requires authorization will do + resp, err := hw.Request( + ctx, + http.MethodGet, + "https://graph.microsoft.com/v1.0/users", + nil, + nil, + true) + require.NoError(t, err, clues.ToCore(err)) + + defer resp.Body.Close() + + require.NotNil(t, resp) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // also validate that non-auth'd endpoints succeed + resp, err = hw.Request( + ctx, + http.MethodGet, + "https://www.google.com", + nil, + nil, + true) + require.NoError(t, err, clues.ToCore(err)) + + defer resp.Body.Close() + + require.NotNil(t, resp) + require.Equal(t, http.StatusOK, resp.StatusCode) +} + +// --------------------------------------------------------------------------- +// unit +// --------------------------------------------------------------------------- + type HTTPWrapperUnitSuite struct { tester.Suite } @@ -84,26 +135,25 @@ func TestHTTPWrapperUnitSuite(t *testing.T) { suite.Run(t, &HTTPWrapperUnitSuite{Suite: tester.NewUnitSuite(t)}) } -func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_redirectMiddleware() { +func (suite *HTTPWrapperUnitSuite) TestHTTPWrapper_Request_redirect() { t := suite.T() ctx, flush := tester.NewContext(t) defer flush() - url := "https://graph.microsoft.com/fnords/beaux/regard" - - hdr := http.Header{} - hdr.Set("Location", "localhost:99999999/smarfs") + respHdr := http.Header{} + respHdr.Set("Location", "localhost:99999999/smarfs") toResp := &http.Response{ StatusCode: http.StatusFound, - Header: hdr, + Header: respHdr, } mwResp := mwForceResp{ resp: toResp, alternate: func(req *http.Request) (bool, *http.Response, error) { if strings.HasSuffix(req.URL.String(), "smarfs") { + assert.Equal(t, req.Header.Get("X-Test-Val"), "should-be-copied-to-redirect") return true, &http.Response{StatusCode: http.StatusOK}, nil } @@ -113,17 +163,22 @@ func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_redirectMiddleware() { hw := NewHTTPWrapper(count.New(), appendMiddleware(&mwResp)) - resp, err := hw.Request(ctx, http.MethodGet, url, nil, nil) + resp, err := hw.Request( + ctx, + http.MethodGet, + "https://graph.microsoft.com/fnords/beaux/regard", + nil, + map[string]string{"X-Test-Val": "should-be-copied-to-redirect"}, + false) require.NoError(t, err, clues.ToCore(err)) defer resp.Body.Close() require.NotNil(t, resp) - // require.Equal(t, 1, calledCorrectly, "test server was called with expected path") require.Equal(t, http.StatusOK, resp.StatusCode) } -func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_http2StreamErrorRetries() { +func (suite *HTTPWrapperUnitSuite) TestHTTPWrapper_Request_http2StreamErrorRetries() { var ( url = "https://graph.microsoft.com/fnords/beaux/regard" streamErr = http2.StreamError{ @@ -188,7 +243,7 @@ func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_http2StreamErrorRetries() // the test middleware. hw.retryDelay = 0 - _, err := hw.Request(ctx, http.MethodGet, url, nil, nil) + _, err := hw.Request(ctx, http.MethodGet, url, nil, nil, false) require.ErrorAs(t, err, &http2.StreamError{}, clues.ToCore(err)) require.Equal(t, test.expectRetries, tries, "count of retries") }) diff --git a/src/pkg/services/m365/api/graph/logging.go b/src/pkg/services/m365/api/graph/logging.go index 2b6e536c9..09283ec2b 100644 --- a/src/pkg/services/m365/api/graph/logging.go +++ b/src/pkg/services/m365/api/graph/logging.go @@ -6,6 +6,9 @@ import ( "net/http/httputil" "os" + "github.com/alcionai/clues" + + "github.com/alcionai/corso/src/internal/common/pii" "github.com/alcionai/corso/src/pkg/logger" ) @@ -69,3 +72,22 @@ func getRespDump(ctx context.Context, resp *http.Response, getBody bool) string return string(respDump) } + +func getReqCtx(req *http.Request) context.Context { + if req == nil { + return context.Background() + } + + var logURL pii.SafeURL + + if req.URL != nil { + logURL = LoggableURL(req.URL.String()) + } + + return clues.AddTraceName( + req.Context(), + "graph-http-middleware", + "method", req.Method, + "url", logURL, + "request_content_len", req.ContentLength) +} diff --git a/src/pkg/services/m365/api/graph/middleware.go b/src/pkg/services/m365/api/graph/middleware.go index 0c5156ed0..2a56b1e07 100644 --- a/src/pkg/services/m365/api/graph/middleware.go +++ b/src/pkg/services/m365/api/graph/middleware.go @@ -125,10 +125,7 @@ func (mw *LoggingMiddleware) Intercept( } ctx := clues.Add( - req.Context(), - "method", req.Method, - "url", LoggableURL(req.URL.String()), - "request_content_len", req.ContentLength, + getReqCtx(req), "resp_status", resp.Status, "resp_status_code", resp.StatusCode, "resp_content_len", resp.ContentLength) @@ -156,7 +153,7 @@ func (mw RetryMiddleware) Intercept( middlewareIndex int, req *http.Request, ) (*http.Response, error) { - ctx := req.Context() + ctx := getReqCtx(req) resp, err := pipeline.Next(req, middlewareIndex) retriable := IsErrTimeout(err) || @@ -249,7 +246,9 @@ func (mw RetryMiddleware) retryRequest( return resp, Wrap(ctx, err, "resetting request body reader") } } else { - logger.Ctx(ctx).Error("body is not an io.Seeker: unable to reset request body") + logger. + Ctx(getReqCtx(req)). + Error("body is not an io.Seeker: unable to reset request body") } } diff --git a/src/pkg/services/m365/api/graph/service.go b/src/pkg/services/m365/api/graph/service.go index 748e82f1a..ebb8a854d 100644 --- a/src/pkg/services/m365/api/graph/service.go +++ b/src/pkg/services/m365/api/graph/service.go @@ -6,11 +6,9 @@ import ( "net/http" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/alcionai/clues" abstractions "github.com/microsoft/kiota-abstractions-go" "github.com/microsoft/kiota-abstractions-go/serialization" - kauth "github.com/microsoft/kiota-authentication-azure-go" khttp "github.com/microsoft/kiota-http-go" msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go" msgraphgocore "github.com/microsoftgraph/msgraph-sdk-go-core" @@ -127,23 +125,6 @@ func CreateAdapter( return wrapAdapter(adpt, cc), nil } -func GetAuth(tenant string, client string, secret string) (*kauth.AzureIdentityAuthenticationProvider, error) { - // Client Provider: Uses Secret for access to tenant-level data - cred, err := azidentity.NewClientSecretCredential(tenant, client, secret, nil) - if err != nil { - return nil, clues.Wrap(err, "creating m365 client identity") - } - - auth, err := kauth.NewAzureIdentityAuthenticationProviderWithScopes( - cred, - []string{"https://graph.microsoft.com/.default"}) - if err != nil { - return nil, clues.Wrap(err, "creating azure authentication") - } - - return auth, nil -} - // KiotaHTTPClient creates a httpClient with middlewares and timeout configured // for use in the graph adapter. // @@ -200,6 +181,11 @@ type clientConfig struct { maxRetries int // The minimum delay in seconds between retries minDelay time.Duration + // requesterAuth sets the authorization step for requester-compliant clients. + // if non-nil, it will ensure calls are authorized before querying. + // does not get consumed by the standard graph client, which already comes + // packaged with an auth protocol. + requesterAuth authorizer appendMiddleware []khttp.Middleware } @@ -287,6 +273,12 @@ func MaxConnectionRetries(max int) Option { } } +func AuthorizeRequester(a authorizer) Option { + return func(c *clientConfig) { + c.requesterAuth = a + } +} + // --------------------------------------------------------------------------- // Middleware Control // --------------------------------------------------------------------------- diff --git a/src/pkg/services/m365/api/graph/uploadsession.go b/src/pkg/services/m365/api/graph/uploadsession.go index a511d81b2..8a439f3c3 100644 --- a/src/pkg/services/m365/api/graph/uploadsession.go +++ b/src/pkg/services/m365/api/graph/uploadsession.go @@ -77,7 +77,8 @@ func (iw *largeItemWriter) Write(p []byte) (int, error) { http.MethodPut, iw.url, bytes.NewReader(p), - headers) + headers, + false) if err != nil { return 0, clues.Wrap(err, "uploading item").With( "upload_id", iw.parentID,