add authentication to requester (#5198)

the graph requester for large item downloads now includes the option to authenticate requests.  The option is configured at the time of creating the requester, therefore all requests using that servier are either authenticatd or not. In our case, we're opting to authenticate all requests, since we do not use this requester for non-graph api calls, and even if we did the addition of auth headers is likely benign.

---

#### Does this PR need a docs update or release note?

- [x]  No

#### Type of change

- [x] 🌻 Feature

#### Test Plan

- [x] 💚 E2E
This commit is contained in:
Keepers 2024-02-14 10:50:36 -07:00 committed by GitHub
parent 5e8407a970
commit bb2bd6df3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 481 additions and 100 deletions

View File

@ -59,6 +59,19 @@ func First(vs ...string) string {
return "" return ""
} }
// FirstIn returns the first entry in the map with a non-zero value
// when iterating the provided list of keys.
func FirstIn(m map[string]any, keys ...string) string {
for _, key := range keys {
v, err := AnyValueToString(key, m)
if err == nil && len(v) > 0 {
return v
}
}
return ""
}
// Preview reduces the string to the specified size. // Preview reduces the string to the specified size.
// If the string is longer than the size, the last three // If the string is longer than the size, the last three
// characters are replaced with an ellipsis. Size < 4 // characters are replaced with an ellipsis. Size < 4

View File

@ -118,3 +118,96 @@ func TestGenerateHash(t *testing.T) {
} }
} }
} }
func TestFirstIn(t *testing.T) {
table := []struct {
name string
m map[string]any
keys []string
expect string
}{
{
name: "nil map",
keys: []string{"foo", "bar"},
expect: "",
},
{
name: "empty map",
m: map[string]any{},
keys: []string{"foo", "bar"},
expect: "",
},
{
name: "no match",
m: map[string]any{
"baz": "baz",
},
keys: []string{"foo", "bar"},
expect: "",
},
{
name: "no keys",
m: map[string]any{
"baz": "baz",
},
keys: []string{},
expect: "",
},
{
name: "nil match",
m: map[string]any{
"foo": nil,
},
keys: []string{"foo", "bar"},
expect: "",
},
{
name: "empty match",
m: map[string]any{
"foo": "",
},
keys: []string{"foo", "bar"},
expect: "",
},
{
name: "matches first key",
m: map[string]any{
"foo": "fnords",
},
keys: []string{"foo", "bar"},
expect: "fnords",
},
{
name: "matches second key",
m: map[string]any{
"bar": "smarf",
},
keys: []string{"foo", "bar"},
expect: "smarf",
},
{
name: "matches second key with nil first match",
m: map[string]any{
"foo": nil,
"bar": "smarf",
},
keys: []string{"foo", "bar"},
expect: "smarf",
},
{
name: "matches second key with empty first match",
m: map[string]any{
"foo": "",
"bar": "smarf",
},
keys: []string{"foo", "bar"},
expect: "smarf",
},
}
for _, test := range table {
t.Run(test.name, func(t *testing.T) {
result := FirstIn(test.m, test.keys...)
assert.Equal(t, test.expect, result)
})
}
}

View File

@ -366,7 +366,7 @@ func downloadContent(
itemID := ptr.Val(item.GetId()) itemID := ptr.Val(item.GetId())
ctx = clues.Add(ctx, "item_id", itemID) ctx = clues.Add(ctx, "item_id", itemID)
content, err := downloadItem(ctx, iaag, item) content, err := downloadItem(ctx, iaag, driveID, item)
if err == nil { if err == nil {
return content, nil return content, nil
} else if !graph.IsErrUnauthorizedOrBadToken(err) { } else if !graph.IsErrUnauthorizedOrBadToken(err) {
@ -395,7 +395,7 @@ func downloadContent(
cdi := custom.ToCustomDriveItem(di) cdi := custom.ToCustomDriveItem(di)
content, err = downloadItem(ctx, iaag, cdi) content, err = downloadItem(ctx, iaag, driveID, cdi)
if err != nil { if err != nil {
return nil, clues.Wrap(err, "content download retry") return nil, clues.Wrap(err, "content download retry")
} }
@ -426,7 +426,7 @@ func readItemContents(
return nil, core.ErrNotFound return nil, core.ErrNotFound
} }
rc, err := downloadFile(ctx, iaag, props.downloadURL) rc, err := downloadFile(ctx, iaag, props.downloadURL, false)
if graph.IsErrUnauthorizedOrBadToken(err) { if graph.IsErrUnauthorizedOrBadToken(err) {
logger.CtxErr(ctx, err).Debug("stale item in cache") logger.CtxErr(ctx, err).Debug("stale item in cache")
} }

View File

@ -795,7 +795,12 @@ func (h mockBackupHandler[T]) AugmentItemInfo(
return h.ItemInfo return h.ItemInfo
} }
func (h *mockBackupHandler[T]) Get(context.Context, string, map[string]string) (*http.Response, error) { func (h *mockBackupHandler[T]) Get(
context.Context,
string,
map[string]string,
bool,
) (*http.Response, error) {
c := h.getCall c := h.getCall
h.getCall++ h.getCall++

View File

@ -21,8 +21,10 @@ import (
) )
const ( const (
acceptHeaderKey = "Accept" acceptHeaderKey = "Accept"
acceptHeaderValue = "*/*" acceptHeaderValue = "*/*"
gigabyte = 1024 * 1024 * 1024
largeFileDownloadLimit = 15 * gigabyte
) )
// downloadUrlKeys is used to find the download URL in a DriveItem response. // downloadUrlKeys is used to find the download URL in a DriveItem response.
@ -33,7 +35,8 @@ var downloadURLKeys = []string{
func downloadItem( func downloadItem(
ctx context.Context, ctx context.Context,
ag api.Getter, getter api.Getter,
driveID string,
item *custom.DriveItem, item *custom.DriveItem,
) (io.ReadCloser, error) { ) (io.ReadCloser, error) {
if item == nil { if item == nil {
@ -41,36 +44,37 @@ func downloadItem(
} }
var ( var (
rc io.ReadCloser // very large file content needs to be downloaded through a different endpoint, or else
isFile = item.GetFile() != nil // the download could take longer than the lifespan of the download token in the cached
err error // url, which will cause us to timeout on every download request, even if we refresh the
// download url right before the query.
url = "https://graph.microsoft.com/v1.0/drives/" + driveID + "/items/" + ptr.Val(item.GetId()) + "/content"
reader io.ReadCloser
err error
isLargeFile = ptr.Val(item.GetSize()) > largeFileDownloadLimit
) )
if isFile { // if this isn't a file, no content is available for download
var ( if item.GetFile() == nil {
url string return reader, nil
ad = item.GetAdditionalData()
)
for _, key := range downloadURLKeys {
if v, err := str.AnyValueToString(key, ad); err == nil {
url = v
break
}
}
rc, err = downloadFile(ctx, ag, url)
if err != nil {
return nil, clues.Stack(err)
}
} }
return rc, nil // smaller files will maintain our current behavior (prefetching the download url with the
// url cache). That pattern works for us in general, and we only need to deviate for very
// large file sizes.
if !isLargeFile {
url = str.FirstIn(item.GetAdditionalData(), downloadURLKeys...)
}
reader, err = downloadFile(ctx, getter, url, isLargeFile)
return reader, clues.StackWC(ctx, err).OrNil()
} }
type downloadWithRetries struct { type downloadWithRetries struct {
getter api.Getter getter api.Getter
url string requireAuth bool
url string
} }
func (dg *downloadWithRetries) SupportsRange() bool { func (dg *downloadWithRetries) SupportsRange() bool {
@ -86,7 +90,7 @@ func (dg *downloadWithRetries) Get(
// wouldn't work without it (get 416 responses instead of 206). // wouldn't work without it (get 416 responses instead of 206).
headers[acceptHeaderKey] = acceptHeaderValue headers[acceptHeaderKey] = acceptHeaderValue
resp, err := dg.getter.Get(ctx, dg.url, headers) resp, err := dg.getter.Get(ctx, dg.url, headers, dg.requireAuth)
if err != nil { if err != nil {
return nil, clues.Wrap(err, "getting file") return nil, clues.Wrap(err, "getting file")
} }
@ -96,7 +100,7 @@ func (dg *downloadWithRetries) Get(
resp.Body.Close() resp.Body.Close()
} }
return nil, clues.New("malware detected").Label(graph.LabelsMalware) return nil, clues.NewWC(ctx, "malware detected").Label(graph.LabelsMalware)
} }
if resp != nil && (resp.StatusCode/100) != 2 { if resp != nil && (resp.StatusCode/100) != 2 {
@ -107,7 +111,7 @@ func (dg *downloadWithRetries) Get(
// upstream error checks can compare the status with // upstream error checks can compare the status with
// clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode)) // clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode))
return nil, clues. return nil, clues.
Wrap(clues.New(resp.Status), "non-2xx http response"). Wrap(clues.NewWC(ctx, resp.Status), "non-2xx http response").
Label(graph.LabelStatus(resp.StatusCode)) Label(graph.LabelStatus(resp.StatusCode))
} }
@ -118,6 +122,7 @@ func downloadFile(
ctx context.Context, ctx context.Context,
ag api.Getter, ag api.Getter,
url string, url string,
requireAuth bool,
) (io.ReadCloser, error) { ) (io.ReadCloser, error) {
if len(url) == 0 { if len(url) == 0 {
return nil, clues.NewWC(ctx, "empty file url") return nil, clues.NewWC(ctx, "empty file url")
@ -141,8 +146,9 @@ func downloadFile(
rc, err := readers.NewResetRetryHandler( rc, err := readers.NewResetRetryHandler(
ctx, ctx,
&downloadWithRetries{ &downloadWithRetries{
getter: ag, getter: ag,
url: url, requireAuth: requireAuth,
url: url,
}) })
return rc, clues.Stack(err).OrNil() return rc, clues.Stack(err).OrNil()

View File

@ -109,7 +109,11 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() {
} }
// Read data for the file // Read data for the file
itemData, err := downloadItem(ctx, bh, custom.ToCustomDriveItem(driveItem)) itemData, err := downloadItem(
ctx,
bh,
suite.m365.User.DriveID,
custom.ToCustomDriveItem(driveItem))
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
size, err := io.Copy(io.Discard, itemData) size, err := io.Copy(io.Discard, itemData)
@ -292,6 +296,7 @@ func (m mockGetter) Get(
ctx context.Context, ctx context.Context,
url string, url string,
headers map[string]string, headers map[string]string,
requireAuth bool,
) (*http.Response, error) { ) (*http.Response, error) {
return m.GetFunc(ctx, url) return m.GetFunc(ctx, url)
} }
@ -379,7 +384,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
return nil, clues.New("test error") return nil, clues.New("test error")
}, },
errorExpected: require.Error, errorExpected: require.Error,
rcExpected: require.Nil, rcExpected: require.NotNil,
}, },
{ {
name: "download url is empty", name: "download url is empty",
@ -416,7 +421,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
}, nil }, nil
}, },
errorExpected: require.Error, errorExpected: require.Error,
rcExpected: require.Nil, rcExpected: require.NotNil,
}, },
{ {
name: "non-2xx http response", name: "non-2xx http response",
@ -435,7 +440,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
}, nil }, nil
}, },
errorExpected: require.Error, errorExpected: require.Error,
rcExpected: require.Nil, rcExpected: require.NotNil,
}, },
} }
@ -448,9 +453,78 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
mg := mockGetter{ mg := mockGetter{
GetFunc: test.GetFunc, GetFunc: test.GetFunc,
} }
rc, err := downloadItem(ctx, mg, custom.ToCustomDriveItem(test.itemFunc())) rc, err := downloadItem(
ctx,
mg,
"driveID",
custom.ToCustomDriveItem(test.itemFunc()))
test.errorExpected(t, err, clues.ToCore(err)) test.errorExpected(t, err, clues.ToCore(err))
test.rcExpected(t, rc) test.rcExpected(t, rc, "reader should only be nil if item is nil")
})
}
}
func (suite *ItemUnitTestSuite) TestDownloadItem_urlByFileSize() {
var (
testRc = io.NopCloser(bytes.NewReader([]byte("test")))
url = "https://example.com"
okResp = &http.Response{
StatusCode: http.StatusOK,
Body: testRc,
}
)
table := []struct {
name string
itemFunc func() models.DriveItemable
GetFunc func(ctx context.Context, url string) (*http.Response, error)
errorExpected require.ErrorAssertionFunc
rcExpected require.ValueAssertionFunc
label string
}{
{
name: "big file",
itemFunc: func() models.DriveItemable {
di := api.NewDriveItem("test", false)
di.SetAdditionalData(map[string]any{"@microsoft.graph.downloadUrl": url})
di.SetSize(ptr.To[int64](20 * gigabyte))
return di
},
GetFunc: func(ctx context.Context, url string) (*http.Response, error) {
assert.Contains(suite.T(), url, "/content")
return okResp, nil
},
},
{
name: "small file",
itemFunc: func() models.DriveItemable {
di := api.NewDriveItem("test", false)
di.SetAdditionalData(map[string]any{"@microsoft.graph.downloadUrl": url})
di.SetSize(ptr.To[int64](2 * gigabyte))
return di
},
GetFunc: func(ctx context.Context, url string) (*http.Response, error) {
assert.NotContains(suite.T(), url, "/content")
return okResp, nil
},
},
}
for _, test := range table {
suite.Run(test.name, func() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
_, err := downloadItem(
ctx,
mockGetter{GetFunc: test.GetFunc},
"driveID",
custom.ToCustomDriveItem(test.itemFunc()))
require.NoError(t, err, clues.ToCore(err))
}) })
} }
} }
@ -507,7 +581,11 @@ func (suite *ItemUnitTestSuite) TestDownloadItem_ConnectionResetErrorOnFirstRead
mg := mockGetter{ mg := mockGetter{
GetFunc: GetFunc, GetFunc: GetFunc,
} }
rc, err := downloadItem(ctx, mg, custom.ToCustomDriveItem(itemFunc())) rc, err := downloadItem(
ctx,
mg,
"driveID",
custom.ToCustomDriveItem(itemFunc()))
errorExpected(t, err, clues.ToCore(err)) errorExpected(t, err, clues.ToCore(err))
rcExpected(t, rc) rcExpected(t, rc)

View File

@ -93,8 +93,9 @@ func (h siteBackupHandler) Get(
ctx context.Context, ctx context.Context,
url string, url string,
headers map[string]string, headers map[string]string,
requireAuth bool,
) (*http.Response, error) { ) (*http.Response, error) {
return h.ac.Get(ctx, url, headers) return h.ac.Get(ctx, url, headers, requireAuth)
} }
func (h siteBackupHandler) PathPrefix( func (h siteBackupHandler) PathPrefix(

View File

@ -154,7 +154,8 @@ func (suite *URLCacheIntegrationSuite) TestURLCacheBasic() {
http.MethodGet, http.MethodGet,
props.downloadURL, props.downloadURL,
nil, nil,
nil) nil,
false)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
require.NotNil(t, resp) require.NotNil(t, resp)

View File

@ -93,8 +93,9 @@ func (h userDriveBackupHandler) Get(
ctx context.Context, ctx context.Context,
url string, url string,
headers map[string]string, headers map[string]string,
requireAuth bool,
) (*http.Response, error) { ) (*http.Response, error) {
return h.ac.Get(ctx, url, headers) return h.ac.Get(ctx, url, headers, requireAuth)
} }
func (h userDriveBackupHandler) PathPrefix( func (h userDriveBackupHandler) PathPrefix(

View File

@ -197,7 +197,12 @@ func (h BackupHandler[T]) AugmentItemInfo(
return h.ItemInfo return h.ItemInfo
} }
func (h *BackupHandler[T]) Get(context.Context, string, map[string]string) (*http.Response, error) { func (h *BackupHandler[T]) Get(
context.Context,
string,
map[string]string,
bool,
) (*http.Response, error) {
c := h.getCall c := h.getCall
h.getCall++ h.getCall++

View File

@ -47,7 +47,7 @@ func (c Access) GetToken(
c.Credentials.AzureClientSecret)) c.Credentials.AzureClientSecret))
) )
resp, err := c.Post(ctx, rawURL, headers, body) resp, err := c.Post(ctx, rawURL, headers, body, false)
if err != nil { if err != nil {
return clues.Stack(err) return clues.Stack(err)
} }

View File

@ -63,7 +63,14 @@ func NewClient(
return Client{}, err return Client{}, err
} }
rqr := graph.NewNoTimeoutHTTPWrapper(counter) azureAuth, err := graph.NewAzureAuth(creds)
if err != nil {
return Client{}, clues.Wrap(err, "generating azure authorizer")
}
rqr := graph.NewNoTimeoutHTTPWrapper(
counter,
graph.AuthorizeRequester(azureAuth))
if co.DeltaPageSize < 1 || co.DeltaPageSize > maxDeltaPageSize { if co.DeltaPageSize < 1 || co.DeltaPageSize > maxDeltaPageSize {
co.DeltaPageSize = maxDeltaPageSize co.DeltaPageSize = maxDeltaPageSize
@ -124,11 +131,7 @@ func newLargeItemService(
counter *count.Bus, counter *count.Bus,
) (*graph.Service, error) { ) (*graph.Service, error) {
a, err := NewService(creds, counter, graph.NoTimeout()) a, err := NewService(creds, counter, graph.NoTimeout())
if err != nil { return a, clues.Wrap(err, "generating no-timeout graph adapter").OrNil()
return nil, clues.Wrap(err, "generating no-timeout graph adapter")
}
return a, nil
} }
type Getter interface { type Getter interface {
@ -136,6 +139,7 @@ type Getter interface {
ctx context.Context, ctx context.Context,
url string, url string,
headers map[string]string, headers map[string]string,
requireAuth bool,
) (*http.Response, error) ) (*http.Response, error)
} }
@ -144,8 +148,9 @@ func (c Client) Get(
ctx context.Context, ctx context.Context,
url string, url string,
headers map[string]string, headers map[string]string,
requireAuth bool,
) (*http.Response, error) { ) (*http.Response, error) {
return c.Requester.Request(ctx, http.MethodGet, url, nil, headers) return c.Requester.Request(ctx, http.MethodGet, url, nil, headers, requireAuth)
} }
// Get performs an ad-hoc get request using its graph.Requester // Get performs an ad-hoc get request using its graph.Requester
@ -154,8 +159,9 @@ func (c Client) Post(
url string, url string,
headers map[string]string, headers map[string]string,
body io.Reader, body io.Reader,
requireAuth bool,
) (*http.Response, error) { ) (*http.Response, error) {
return c.Requester.Request(ctx, http.MethodGet, url, body, headers) return c.Requester.Request(ctx, http.MethodGet, url, body, headers, requireAuth)
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------

View File

@ -0,0 +1,94 @@
package graph
import (
"context"
"net/http"
"net/url"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/alcionai/clues"
abstractions "github.com/microsoft/kiota-abstractions-go"
kauth "github.com/microsoft/kiota-authentication-azure-go"
"github.com/alcionai/corso/src/pkg/account"
)
func GetAuth(tenant, client, secret string) (*kauth.AzureIdentityAuthenticationProvider, error) {
// Client Provider: Uses Secret for access to tenant-level data
cred, err := azidentity.NewClientSecretCredential(tenant, client, secret, nil)
if err != nil {
return nil, clues.Wrap(err, "creating m365 client identity")
}
auth, err := kauth.NewAzureIdentityAuthenticationProviderWithScopes(
cred,
[]string{"https://graph.microsoft.com/.default"})
if err != nil {
return nil, clues.Wrap(err, "creating azure authentication")
}
return auth, nil
}
// ---------------------------------------------------------------------------
// requester authorization
// ---------------------------------------------------------------------------
type authorizer interface {
addAuthToHeaders(
ctx context.Context,
urlStr string,
headers http.Header,
) error
}
// consumed by kiota
type authenticateRequester interface {
AuthenticateRequest(
ctx context.Context,
request *abstractions.RequestInformation,
additionalAuthenticationContext map[string]any,
) error
}
// ---------------------------------------------------------------------------
// Azure Authorizer
// ---------------------------------------------------------------------------
type azureAuth struct {
auth authenticateRequester
}
func NewAzureAuth(creds account.M365Config) (*azureAuth, error) {
auth, err := GetAuth(
creds.AzureTenantID,
creds.AzureClientID,
creds.AzureClientSecret)
return &azureAuth{auth}, clues.Stack(err).OrNil()
}
func (aa azureAuth) addAuthToHeaders(
ctx context.Context,
urlStr string,
headers http.Header,
) error {
requestInfo := abstractions.NewRequestInformation()
uri, err := url.Parse(urlStr)
if err != nil {
return clues.WrapWC(ctx, err, "parsing url").OrNil()
}
requestInfo.SetUri(*uri)
err = aa.auth.AuthenticateRequest(ctx, requestInfo, nil)
for _, k := range requestInfo.Headers.ListKeys() {
for _, v := range requestInfo.Headers.Get(k) {
headers.Add(k, v)
}
}
return clues.WrapWC(ctx, err, "authorizing request").OrNil()
}

View File

@ -240,7 +240,7 @@ func (mw *RateLimiterMiddleware) Intercept(
middlewareIndex int, middlewareIndex int,
req *http.Request, req *http.Request,
) (*http.Response, error) { ) (*http.Response, error) {
QueueRequest(req.Context()) QueueRequest(getReqCtx(req))
return pipeline.Next(req, middlewareIndex) return pipeline.Next(req, middlewareIndex)
} }
@ -339,7 +339,7 @@ func (mw *throttlingMiddleware) Intercept(
middlewareIndex int, middlewareIndex int,
req *http.Request, req *http.Request,
) (*http.Response, error) { ) (*http.Response, error) {
err := mw.tf.Block(req.Context()) err := mw.tf.Block(getReqCtx(req))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -36,6 +36,7 @@ type Requester interface {
method, url string, method, url string,
body io.Reader, body io.Reader,
headers map[string]string, headers map[string]string,
requireAuth bool,
) (*http.Response, error) ) (*http.Response, error)
} }
@ -58,12 +59,8 @@ func NewHTTPWrapper(
transport: defaultTransport(), transport: defaultTransport(),
}, },
} }
redirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
hc = &http.Client{ hc = &http.Client{
CheckRedirect: redirect, Transport: rt,
Transport: rt,
} }
) )
@ -100,6 +97,7 @@ func (hw httpWrapper) Request(
method, url string, method, url string,
body io.Reader, body io.Reader,
headers map[string]string, headers map[string]string,
requireAuth bool,
) (*http.Response, error) { ) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body) req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil { if err != nil {
@ -115,6 +113,17 @@ func (hw httpWrapper) Request(
// See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic // See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic
req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version) req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version)
if requireAuth {
if hw.config.requesterAuth == nil {
return nil, clues.Wrap(err, "http wrapper misconfigured: missing required authorization")
}
err := hw.config.requesterAuth.addAuthToHeaders(ctx, url, req.Header)
if err != nil {
return nil, clues.Wrap(err, "setting request auth headers")
}
}
retriedErrors := []string{} retriedErrors := []string{}
var e error var e error

View File

@ -40,9 +40,10 @@ func (suite *HTTPWrapperIntgSuite) TestNewHTTPWrapper() {
resp, err := hw.Request( resp, err := hw.Request(
ctx, ctx,
http.MethodGet, http.MethodGet,
"https://www.corsobackup.io", "https://www.google.com",
nil, nil,
nil) nil,
false)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
defer resp.Body.Close() defer resp.Body.Close()
@ -76,6 +77,56 @@ func (mw *mwForceResp) Intercept(
return mw.resp, mw.err return mw.resp, mw.err
} }
func (suite *HTTPWrapperIntgSuite) TestHTTPWrapper_Request_withAuth() {
t := suite.T()
ctx, flush := tester.NewContext(t)
defer flush()
a := tconfig.NewM365Account(t)
m365, err := a.M365Config()
require.NoError(t, err, clues.ToCore(err))
azureAuth, err := NewAzureAuth(m365)
require.NoError(t, err, clues.ToCore(err))
hw := NewHTTPWrapper(count.New(), AuthorizeRequester(azureAuth))
// any request that requires authorization will do
resp, err := hw.Request(
ctx,
http.MethodGet,
"https://graph.microsoft.com/v1.0/users",
nil,
nil,
true)
require.NoError(t, err, clues.ToCore(err))
defer resp.Body.Close()
require.NotNil(t, resp)
require.Equal(t, http.StatusOK, resp.StatusCode)
// also validate that non-auth'd endpoints succeed
resp, err = hw.Request(
ctx,
http.MethodGet,
"https://www.google.com",
nil,
nil,
true)
require.NoError(t, err, clues.ToCore(err))
defer resp.Body.Close()
require.NotNil(t, resp)
require.Equal(t, http.StatusOK, resp.StatusCode)
}
// ---------------------------------------------------------------------------
// unit
// ---------------------------------------------------------------------------
type HTTPWrapperUnitSuite struct { type HTTPWrapperUnitSuite struct {
tester.Suite tester.Suite
} }
@ -84,26 +135,25 @@ func TestHTTPWrapperUnitSuite(t *testing.T) {
suite.Run(t, &HTTPWrapperUnitSuite{Suite: tester.NewUnitSuite(t)}) suite.Run(t, &HTTPWrapperUnitSuite{Suite: tester.NewUnitSuite(t)})
} }
func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_redirectMiddleware() { func (suite *HTTPWrapperUnitSuite) TestHTTPWrapper_Request_redirect() {
t := suite.T() t := suite.T()
ctx, flush := tester.NewContext(t) ctx, flush := tester.NewContext(t)
defer flush() defer flush()
url := "https://graph.microsoft.com/fnords/beaux/regard" respHdr := http.Header{}
respHdr.Set("Location", "localhost:99999999/smarfs")
hdr := http.Header{}
hdr.Set("Location", "localhost:99999999/smarfs")
toResp := &http.Response{ toResp := &http.Response{
StatusCode: http.StatusFound, StatusCode: http.StatusFound,
Header: hdr, Header: respHdr,
} }
mwResp := mwForceResp{ mwResp := mwForceResp{
resp: toResp, resp: toResp,
alternate: func(req *http.Request) (bool, *http.Response, error) { alternate: func(req *http.Request) (bool, *http.Response, error) {
if strings.HasSuffix(req.URL.String(), "smarfs") { if strings.HasSuffix(req.URL.String(), "smarfs") {
assert.Equal(t, req.Header.Get("X-Test-Val"), "should-be-copied-to-redirect")
return true, &http.Response{StatusCode: http.StatusOK}, nil return true, &http.Response{StatusCode: http.StatusOK}, nil
} }
@ -113,17 +163,22 @@ func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_redirectMiddleware() {
hw := NewHTTPWrapper(count.New(), appendMiddleware(&mwResp)) hw := NewHTTPWrapper(count.New(), appendMiddleware(&mwResp))
resp, err := hw.Request(ctx, http.MethodGet, url, nil, nil) resp, err := hw.Request(
ctx,
http.MethodGet,
"https://graph.microsoft.com/fnords/beaux/regard",
nil,
map[string]string{"X-Test-Val": "should-be-copied-to-redirect"},
false)
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
defer resp.Body.Close() defer resp.Body.Close()
require.NotNil(t, resp) require.NotNil(t, resp)
// require.Equal(t, 1, calledCorrectly, "test server was called with expected path")
require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, http.StatusOK, resp.StatusCode)
} }
func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_http2StreamErrorRetries() { func (suite *HTTPWrapperUnitSuite) TestHTTPWrapper_Request_http2StreamErrorRetries() {
var ( var (
url = "https://graph.microsoft.com/fnords/beaux/regard" url = "https://graph.microsoft.com/fnords/beaux/regard"
streamErr = http2.StreamError{ streamErr = http2.StreamError{
@ -188,7 +243,7 @@ func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_http2StreamErrorRetries()
// the test middleware. // the test middleware.
hw.retryDelay = 0 hw.retryDelay = 0
_, err := hw.Request(ctx, http.MethodGet, url, nil, nil) _, err := hw.Request(ctx, http.MethodGet, url, nil, nil, false)
require.ErrorAs(t, err, &http2.StreamError{}, clues.ToCore(err)) require.ErrorAs(t, err, &http2.StreamError{}, clues.ToCore(err))
require.Equal(t, test.expectRetries, tries, "count of retries") require.Equal(t, test.expectRetries, tries, "count of retries")
}) })

View File

@ -6,6 +6,9 @@ import (
"net/http/httputil" "net/http/httputil"
"os" "os"
"github.com/alcionai/clues"
"github.com/alcionai/corso/src/internal/common/pii"
"github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/logger"
) )
@ -69,3 +72,22 @@ func getRespDump(ctx context.Context, resp *http.Response, getBody bool) string
return string(respDump) return string(respDump)
} }
func getReqCtx(req *http.Request) context.Context {
if req == nil {
return context.Background()
}
var logURL pii.SafeURL
if req.URL != nil {
logURL = LoggableURL(req.URL.String())
}
return clues.AddTraceName(
req.Context(),
"graph-http-middleware",
"method", req.Method,
"url", logURL,
"request_content_len", req.ContentLength)
}

View File

@ -125,10 +125,7 @@ func (mw *LoggingMiddleware) Intercept(
} }
ctx := clues.Add( ctx := clues.Add(
req.Context(), getReqCtx(req),
"method", req.Method,
"url", LoggableURL(req.URL.String()),
"request_content_len", req.ContentLength,
"resp_status", resp.Status, "resp_status", resp.Status,
"resp_status_code", resp.StatusCode, "resp_status_code", resp.StatusCode,
"resp_content_len", resp.ContentLength) "resp_content_len", resp.ContentLength)
@ -156,7 +153,7 @@ func (mw RetryMiddleware) Intercept(
middlewareIndex int, middlewareIndex int,
req *http.Request, req *http.Request,
) (*http.Response, error) { ) (*http.Response, error) {
ctx := req.Context() ctx := getReqCtx(req)
resp, err := pipeline.Next(req, middlewareIndex) resp, err := pipeline.Next(req, middlewareIndex)
retriable := IsErrTimeout(err) || retriable := IsErrTimeout(err) ||
@ -249,7 +246,9 @@ func (mw RetryMiddleware) retryRequest(
return resp, Wrap(ctx, err, "resetting request body reader") return resp, Wrap(ctx, err, "resetting request body reader")
} }
} else { } else {
logger.Ctx(ctx).Error("body is not an io.Seeker: unable to reset request body") logger.
Ctx(getReqCtx(req)).
Error("body is not an io.Seeker: unable to reset request body")
} }
} }

View File

@ -6,11 +6,9 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/alcionai/clues" "github.com/alcionai/clues"
abstractions "github.com/microsoft/kiota-abstractions-go" abstractions "github.com/microsoft/kiota-abstractions-go"
"github.com/microsoft/kiota-abstractions-go/serialization" "github.com/microsoft/kiota-abstractions-go/serialization"
kauth "github.com/microsoft/kiota-authentication-azure-go"
khttp "github.com/microsoft/kiota-http-go" khttp "github.com/microsoft/kiota-http-go"
msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go" msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go"
msgraphgocore "github.com/microsoftgraph/msgraph-sdk-go-core" msgraphgocore "github.com/microsoftgraph/msgraph-sdk-go-core"
@ -127,23 +125,6 @@ func CreateAdapter(
return wrapAdapter(adpt, cc), nil return wrapAdapter(adpt, cc), nil
} }
func GetAuth(tenant string, client string, secret string) (*kauth.AzureIdentityAuthenticationProvider, error) {
// Client Provider: Uses Secret for access to tenant-level data
cred, err := azidentity.NewClientSecretCredential(tenant, client, secret, nil)
if err != nil {
return nil, clues.Wrap(err, "creating m365 client identity")
}
auth, err := kauth.NewAzureIdentityAuthenticationProviderWithScopes(
cred,
[]string{"https://graph.microsoft.com/.default"})
if err != nil {
return nil, clues.Wrap(err, "creating azure authentication")
}
return auth, nil
}
// KiotaHTTPClient creates a httpClient with middlewares and timeout configured // KiotaHTTPClient creates a httpClient with middlewares and timeout configured
// for use in the graph adapter. // for use in the graph adapter.
// //
@ -200,6 +181,11 @@ type clientConfig struct {
maxRetries int maxRetries int
// The minimum delay in seconds between retries // The minimum delay in seconds between retries
minDelay time.Duration minDelay time.Duration
// requesterAuth sets the authorization step for requester-compliant clients.
// if non-nil, it will ensure calls are authorized before querying.
// does not get consumed by the standard graph client, which already comes
// packaged with an auth protocol.
requesterAuth authorizer
appendMiddleware []khttp.Middleware appendMiddleware []khttp.Middleware
} }
@ -287,6 +273,12 @@ func MaxConnectionRetries(max int) Option {
} }
} }
func AuthorizeRequester(a authorizer) Option {
return func(c *clientConfig) {
c.requesterAuth = a
}
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Middleware Control // Middleware Control
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------

View File

@ -77,7 +77,8 @@ func (iw *largeItemWriter) Write(p []byte) (int, error) {
http.MethodPut, http.MethodPut,
iw.url, iw.url,
bytes.NewReader(p), bytes.NewReader(p),
headers) headers,
false)
if err != nil { if err != nil {
return 0, clues.Wrap(err, "uploading item").With( return 0, clues.Wrap(err, "uploading item").With(
"upload_id", iw.parentID, "upload_id", iw.parentID,