add authentication to requester (#5198)
the graph requester for large item downloads now includes the option to authenticate requests. The option is configured at the time of creating the requester, therefore all requests using that servier are either authenticatd or not. In our case, we're opting to authenticate all requests, since we do not use this requester for non-graph api calls, and even if we did the addition of auth headers is likely benign. --- #### Does this PR need a docs update or release note? - [x] ⛔ No #### Type of change - [x] 🌻 Feature #### Test Plan - [x] 💚 E2E
This commit is contained in:
parent
5e8407a970
commit
bb2bd6df3f
@ -59,6 +59,19 @@ func First(vs ...string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// FirstIn returns the first entry in the map with a non-zero value
|
||||
// when iterating the provided list of keys.
|
||||
func FirstIn(m map[string]any, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
v, err := AnyValueToString(key, m)
|
||||
if err == nil && len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Preview reduces the string to the specified size.
|
||||
// If the string is longer than the size, the last three
|
||||
// characters are replaced with an ellipsis. Size < 4
|
||||
|
||||
@ -118,3 +118,96 @@ func TestGenerateHash(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstIn(t *testing.T) {
|
||||
table := []struct {
|
||||
name string
|
||||
m map[string]any
|
||||
keys []string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "nil map",
|
||||
keys: []string{"foo", "bar"},
|
||||
expect: "",
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
m: map[string]any{},
|
||||
keys: []string{"foo", "bar"},
|
||||
expect: "",
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
m: map[string]any{
|
||||
"baz": "baz",
|
||||
},
|
||||
keys: []string{"foo", "bar"},
|
||||
expect: "",
|
||||
},
|
||||
{
|
||||
name: "no keys",
|
||||
m: map[string]any{
|
||||
"baz": "baz",
|
||||
},
|
||||
keys: []string{},
|
||||
expect: "",
|
||||
},
|
||||
{
|
||||
name: "nil match",
|
||||
m: map[string]any{
|
||||
"foo": nil,
|
||||
},
|
||||
keys: []string{"foo", "bar"},
|
||||
expect: "",
|
||||
},
|
||||
{
|
||||
name: "empty match",
|
||||
m: map[string]any{
|
||||
"foo": "",
|
||||
},
|
||||
keys: []string{"foo", "bar"},
|
||||
expect: "",
|
||||
},
|
||||
{
|
||||
name: "matches first key",
|
||||
m: map[string]any{
|
||||
"foo": "fnords",
|
||||
},
|
||||
keys: []string{"foo", "bar"},
|
||||
expect: "fnords",
|
||||
},
|
||||
{
|
||||
name: "matches second key",
|
||||
m: map[string]any{
|
||||
"bar": "smarf",
|
||||
},
|
||||
keys: []string{"foo", "bar"},
|
||||
expect: "smarf",
|
||||
},
|
||||
{
|
||||
name: "matches second key with nil first match",
|
||||
m: map[string]any{
|
||||
"foo": nil,
|
||||
"bar": "smarf",
|
||||
},
|
||||
keys: []string{"foo", "bar"},
|
||||
expect: "smarf",
|
||||
},
|
||||
{
|
||||
name: "matches second key with empty first match",
|
||||
m: map[string]any{
|
||||
"foo": "",
|
||||
"bar": "smarf",
|
||||
},
|
||||
keys: []string{"foo", "bar"},
|
||||
expect: "smarf",
|
||||
},
|
||||
}
|
||||
for _, test := range table {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := FirstIn(test.m, test.keys...)
|
||||
assert.Equal(t, test.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -366,7 +366,7 @@ func downloadContent(
|
||||
itemID := ptr.Val(item.GetId())
|
||||
ctx = clues.Add(ctx, "item_id", itemID)
|
||||
|
||||
content, err := downloadItem(ctx, iaag, item)
|
||||
content, err := downloadItem(ctx, iaag, driveID, item)
|
||||
if err == nil {
|
||||
return content, nil
|
||||
} else if !graph.IsErrUnauthorizedOrBadToken(err) {
|
||||
@ -395,7 +395,7 @@ func downloadContent(
|
||||
|
||||
cdi := custom.ToCustomDriveItem(di)
|
||||
|
||||
content, err = downloadItem(ctx, iaag, cdi)
|
||||
content, err = downloadItem(ctx, iaag, driveID, cdi)
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "content download retry")
|
||||
}
|
||||
@ -426,7 +426,7 @@ func readItemContents(
|
||||
return nil, core.ErrNotFound
|
||||
}
|
||||
|
||||
rc, err := downloadFile(ctx, iaag, props.downloadURL)
|
||||
rc, err := downloadFile(ctx, iaag, props.downloadURL, false)
|
||||
if graph.IsErrUnauthorizedOrBadToken(err) {
|
||||
logger.CtxErr(ctx, err).Debug("stale item in cache")
|
||||
}
|
||||
|
||||
@ -795,7 +795,12 @@ func (h mockBackupHandler[T]) AugmentItemInfo(
|
||||
return h.ItemInfo
|
||||
}
|
||||
|
||||
func (h *mockBackupHandler[T]) Get(context.Context, string, map[string]string) (*http.Response, error) {
|
||||
func (h *mockBackupHandler[T]) Get(
|
||||
context.Context,
|
||||
string,
|
||||
map[string]string,
|
||||
bool,
|
||||
) (*http.Response, error) {
|
||||
c := h.getCall
|
||||
h.getCall++
|
||||
|
||||
|
||||
@ -23,6 +23,8 @@ import (
|
||||
const (
|
||||
acceptHeaderKey = "Accept"
|
||||
acceptHeaderValue = "*/*"
|
||||
gigabyte = 1024 * 1024 * 1024
|
||||
largeFileDownloadLimit = 15 * gigabyte
|
||||
)
|
||||
|
||||
// downloadUrlKeys is used to find the download URL in a DriveItem response.
|
||||
@ -33,7 +35,8 @@ var downloadURLKeys = []string{
|
||||
|
||||
func downloadItem(
|
||||
ctx context.Context,
|
||||
ag api.Getter,
|
||||
getter api.Getter,
|
||||
driveID string,
|
||||
item *custom.DriveItem,
|
||||
) (io.ReadCloser, error) {
|
||||
if item == nil {
|
||||
@ -41,35 +44,36 @@ func downloadItem(
|
||||
}
|
||||
|
||||
var (
|
||||
rc io.ReadCloser
|
||||
isFile = item.GetFile() != nil
|
||||
// very large file content needs to be downloaded through a different endpoint, or else
|
||||
// the download could take longer than the lifespan of the download token in the cached
|
||||
// url, which will cause us to timeout on every download request, even if we refresh the
|
||||
// download url right before the query.
|
||||
url = "https://graph.microsoft.com/v1.0/drives/" + driveID + "/items/" + ptr.Val(item.GetId()) + "/content"
|
||||
reader io.ReadCloser
|
||||
err error
|
||||
isLargeFile = ptr.Val(item.GetSize()) > largeFileDownloadLimit
|
||||
)
|
||||
|
||||
if isFile {
|
||||
var (
|
||||
url string
|
||||
ad = item.GetAdditionalData()
|
||||
)
|
||||
|
||||
for _, key := range downloadURLKeys {
|
||||
if v, err := str.AnyValueToString(key, ad); err == nil {
|
||||
url = v
|
||||
break
|
||||
}
|
||||
// if this isn't a file, no content is available for download
|
||||
if item.GetFile() == nil {
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
rc, err = downloadFile(ctx, ag, url)
|
||||
if err != nil {
|
||||
return nil, clues.Stack(err)
|
||||
}
|
||||
// smaller files will maintain our current behavior (prefetching the download url with the
|
||||
// url cache). That pattern works for us in general, and we only need to deviate for very
|
||||
// large file sizes.
|
||||
if !isLargeFile {
|
||||
url = str.FirstIn(item.GetAdditionalData(), downloadURLKeys...)
|
||||
}
|
||||
|
||||
return rc, nil
|
||||
reader, err = downloadFile(ctx, getter, url, isLargeFile)
|
||||
|
||||
return reader, clues.StackWC(ctx, err).OrNil()
|
||||
}
|
||||
|
||||
type downloadWithRetries struct {
|
||||
getter api.Getter
|
||||
requireAuth bool
|
||||
url string
|
||||
}
|
||||
|
||||
@ -86,7 +90,7 @@ func (dg *downloadWithRetries) Get(
|
||||
// wouldn't work without it (get 416 responses instead of 206).
|
||||
headers[acceptHeaderKey] = acceptHeaderValue
|
||||
|
||||
resp, err := dg.getter.Get(ctx, dg.url, headers)
|
||||
resp, err := dg.getter.Get(ctx, dg.url, headers, dg.requireAuth)
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "getting file")
|
||||
}
|
||||
@ -96,7 +100,7 @@ func (dg *downloadWithRetries) Get(
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil, clues.New("malware detected").Label(graph.LabelsMalware)
|
||||
return nil, clues.NewWC(ctx, "malware detected").Label(graph.LabelsMalware)
|
||||
}
|
||||
|
||||
if resp != nil && (resp.StatusCode/100) != 2 {
|
||||
@ -107,7 +111,7 @@ func (dg *downloadWithRetries) Get(
|
||||
// upstream error checks can compare the status with
|
||||
// clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode))
|
||||
return nil, clues.
|
||||
Wrap(clues.New(resp.Status), "non-2xx http response").
|
||||
Wrap(clues.NewWC(ctx, resp.Status), "non-2xx http response").
|
||||
Label(graph.LabelStatus(resp.StatusCode))
|
||||
}
|
||||
|
||||
@ -118,6 +122,7 @@ func downloadFile(
|
||||
ctx context.Context,
|
||||
ag api.Getter,
|
||||
url string,
|
||||
requireAuth bool,
|
||||
) (io.ReadCloser, error) {
|
||||
if len(url) == 0 {
|
||||
return nil, clues.NewWC(ctx, "empty file url")
|
||||
@ -142,6 +147,7 @@ func downloadFile(
|
||||
ctx,
|
||||
&downloadWithRetries{
|
||||
getter: ag,
|
||||
requireAuth: requireAuth,
|
||||
url: url,
|
||||
})
|
||||
|
||||
|
||||
@ -109,7 +109,11 @@ func (suite *ItemIntegrationSuite) TestItemReader_oneDrive() {
|
||||
}
|
||||
|
||||
// Read data for the file
|
||||
itemData, err := downloadItem(ctx, bh, custom.ToCustomDriveItem(driveItem))
|
||||
itemData, err := downloadItem(
|
||||
ctx,
|
||||
bh,
|
||||
suite.m365.User.DriveID,
|
||||
custom.ToCustomDriveItem(driveItem))
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
size, err := io.Copy(io.Discard, itemData)
|
||||
@ -292,6 +296,7 @@ func (m mockGetter) Get(
|
||||
ctx context.Context,
|
||||
url string,
|
||||
headers map[string]string,
|
||||
requireAuth bool,
|
||||
) (*http.Response, error) {
|
||||
return m.GetFunc(ctx, url)
|
||||
}
|
||||
@ -379,7 +384,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
|
||||
return nil, clues.New("test error")
|
||||
},
|
||||
errorExpected: require.Error,
|
||||
rcExpected: require.Nil,
|
||||
rcExpected: require.NotNil,
|
||||
},
|
||||
{
|
||||
name: "download url is empty",
|
||||
@ -416,7 +421,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
|
||||
}, nil
|
||||
},
|
||||
errorExpected: require.Error,
|
||||
rcExpected: require.Nil,
|
||||
rcExpected: require.NotNil,
|
||||
},
|
||||
{
|
||||
name: "non-2xx http response",
|
||||
@ -435,7 +440,7 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
|
||||
}, nil
|
||||
},
|
||||
errorExpected: require.Error,
|
||||
rcExpected: require.Nil,
|
||||
rcExpected: require.NotNil,
|
||||
},
|
||||
}
|
||||
|
||||
@ -448,9 +453,78 @@ func (suite *ItemUnitTestSuite) TestDownloadItem() {
|
||||
mg := mockGetter{
|
||||
GetFunc: test.GetFunc,
|
||||
}
|
||||
rc, err := downloadItem(ctx, mg, custom.ToCustomDriveItem(test.itemFunc()))
|
||||
rc, err := downloadItem(
|
||||
ctx,
|
||||
mg,
|
||||
"driveID",
|
||||
custom.ToCustomDriveItem(test.itemFunc()))
|
||||
test.errorExpected(t, err, clues.ToCore(err))
|
||||
test.rcExpected(t, rc)
|
||||
test.rcExpected(t, rc, "reader should only be nil if item is nil")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *ItemUnitTestSuite) TestDownloadItem_urlByFileSize() {
|
||||
var (
|
||||
testRc = io.NopCloser(bytes.NewReader([]byte("test")))
|
||||
url = "https://example.com"
|
||||
okResp = &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: testRc,
|
||||
}
|
||||
)
|
||||
|
||||
table := []struct {
|
||||
name string
|
||||
itemFunc func() models.DriveItemable
|
||||
GetFunc func(ctx context.Context, url string) (*http.Response, error)
|
||||
errorExpected require.ErrorAssertionFunc
|
||||
rcExpected require.ValueAssertionFunc
|
||||
label string
|
||||
}{
|
||||
{
|
||||
name: "big file",
|
||||
itemFunc: func() models.DriveItemable {
|
||||
di := api.NewDriveItem("test", false)
|
||||
di.SetAdditionalData(map[string]any{"@microsoft.graph.downloadUrl": url})
|
||||
di.SetSize(ptr.To[int64](20 * gigabyte))
|
||||
|
||||
return di
|
||||
},
|
||||
GetFunc: func(ctx context.Context, url string) (*http.Response, error) {
|
||||
assert.Contains(suite.T(), url, "/content")
|
||||
return okResp, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "small file",
|
||||
itemFunc: func() models.DriveItemable {
|
||||
di := api.NewDriveItem("test", false)
|
||||
di.SetAdditionalData(map[string]any{"@microsoft.graph.downloadUrl": url})
|
||||
di.SetSize(ptr.To[int64](2 * gigabyte))
|
||||
|
||||
return di
|
||||
},
|
||||
GetFunc: func(ctx context.Context, url string) (*http.Response, error) {
|
||||
assert.NotContains(suite.T(), url, "/content")
|
||||
return okResp, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range table {
|
||||
suite.Run(test.name, func() {
|
||||
t := suite.T()
|
||||
|
||||
ctx, flush := tester.NewContext(t)
|
||||
defer flush()
|
||||
|
||||
_, err := downloadItem(
|
||||
ctx,
|
||||
mockGetter{GetFunc: test.GetFunc},
|
||||
"driveID",
|
||||
custom.ToCustomDriveItem(test.itemFunc()))
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -507,7 +581,11 @@ func (suite *ItemUnitTestSuite) TestDownloadItem_ConnectionResetErrorOnFirstRead
|
||||
mg := mockGetter{
|
||||
GetFunc: GetFunc,
|
||||
}
|
||||
rc, err := downloadItem(ctx, mg, custom.ToCustomDriveItem(itemFunc()))
|
||||
rc, err := downloadItem(
|
||||
ctx,
|
||||
mg,
|
||||
"driveID",
|
||||
custom.ToCustomDriveItem(itemFunc()))
|
||||
errorExpected(t, err, clues.ToCore(err))
|
||||
rcExpected(t, rc)
|
||||
|
||||
|
||||
@ -93,8 +93,9 @@ func (h siteBackupHandler) Get(
|
||||
ctx context.Context,
|
||||
url string,
|
||||
headers map[string]string,
|
||||
requireAuth bool,
|
||||
) (*http.Response, error) {
|
||||
return h.ac.Get(ctx, url, headers)
|
||||
return h.ac.Get(ctx, url, headers, requireAuth)
|
||||
}
|
||||
|
||||
func (h siteBackupHandler) PathPrefix(
|
||||
|
||||
@ -154,7 +154,8 @@ func (suite *URLCacheIntegrationSuite) TestURLCacheBasic() {
|
||||
http.MethodGet,
|
||||
props.downloadURL,
|
||||
nil,
|
||||
nil)
|
||||
nil,
|
||||
false)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
require.NotNil(t, resp)
|
||||
|
||||
@ -93,8 +93,9 @@ func (h userDriveBackupHandler) Get(
|
||||
ctx context.Context,
|
||||
url string,
|
||||
headers map[string]string,
|
||||
requireAuth bool,
|
||||
) (*http.Response, error) {
|
||||
return h.ac.Get(ctx, url, headers)
|
||||
return h.ac.Get(ctx, url, headers, requireAuth)
|
||||
}
|
||||
|
||||
func (h userDriveBackupHandler) PathPrefix(
|
||||
|
||||
@ -197,7 +197,12 @@ func (h BackupHandler[T]) AugmentItemInfo(
|
||||
return h.ItemInfo
|
||||
}
|
||||
|
||||
func (h *BackupHandler[T]) Get(context.Context, string, map[string]string) (*http.Response, error) {
|
||||
func (h *BackupHandler[T]) Get(
|
||||
context.Context,
|
||||
string,
|
||||
map[string]string,
|
||||
bool,
|
||||
) (*http.Response, error) {
|
||||
c := h.getCall
|
||||
h.getCall++
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ func (c Access) GetToken(
|
||||
c.Credentials.AzureClientSecret))
|
||||
)
|
||||
|
||||
resp, err := c.Post(ctx, rawURL, headers, body)
|
||||
resp, err := c.Post(ctx, rawURL, headers, body, false)
|
||||
if err != nil {
|
||||
return clues.Stack(err)
|
||||
}
|
||||
|
||||
@ -63,7 +63,14 @@ func NewClient(
|
||||
return Client{}, err
|
||||
}
|
||||
|
||||
rqr := graph.NewNoTimeoutHTTPWrapper(counter)
|
||||
azureAuth, err := graph.NewAzureAuth(creds)
|
||||
if err != nil {
|
||||
return Client{}, clues.Wrap(err, "generating azure authorizer")
|
||||
}
|
||||
|
||||
rqr := graph.NewNoTimeoutHTTPWrapper(
|
||||
counter,
|
||||
graph.AuthorizeRequester(azureAuth))
|
||||
|
||||
if co.DeltaPageSize < 1 || co.DeltaPageSize > maxDeltaPageSize {
|
||||
co.DeltaPageSize = maxDeltaPageSize
|
||||
@ -124,11 +131,7 @@ func newLargeItemService(
|
||||
counter *count.Bus,
|
||||
) (*graph.Service, error) {
|
||||
a, err := NewService(creds, counter, graph.NoTimeout())
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "generating no-timeout graph adapter")
|
||||
}
|
||||
|
||||
return a, nil
|
||||
return a, clues.Wrap(err, "generating no-timeout graph adapter").OrNil()
|
||||
}
|
||||
|
||||
type Getter interface {
|
||||
@ -136,6 +139,7 @@ type Getter interface {
|
||||
ctx context.Context,
|
||||
url string,
|
||||
headers map[string]string,
|
||||
requireAuth bool,
|
||||
) (*http.Response, error)
|
||||
}
|
||||
|
||||
@ -144,8 +148,9 @@ func (c Client) Get(
|
||||
ctx context.Context,
|
||||
url string,
|
||||
headers map[string]string,
|
||||
requireAuth bool,
|
||||
) (*http.Response, error) {
|
||||
return c.Requester.Request(ctx, http.MethodGet, url, nil, headers)
|
||||
return c.Requester.Request(ctx, http.MethodGet, url, nil, headers, requireAuth)
|
||||
}
|
||||
|
||||
// Get performs an ad-hoc get request using its graph.Requester
|
||||
@ -154,8 +159,9 @@ func (c Client) Post(
|
||||
url string,
|
||||
headers map[string]string,
|
||||
body io.Reader,
|
||||
requireAuth bool,
|
||||
) (*http.Response, error) {
|
||||
return c.Requester.Request(ctx, http.MethodGet, url, body, headers)
|
||||
return c.Requester.Request(ctx, http.MethodGet, url, body, headers, requireAuth)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
94
src/pkg/services/m365/api/graph/auth.go
Normal file
94
src/pkg/services/m365/api/graph/auth.go
Normal file
@ -0,0 +1,94 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/alcionai/clues"
|
||||
abstractions "github.com/microsoft/kiota-abstractions-go"
|
||||
kauth "github.com/microsoft/kiota-authentication-azure-go"
|
||||
|
||||
"github.com/alcionai/corso/src/pkg/account"
|
||||
)
|
||||
|
||||
func GetAuth(tenant, client, secret string) (*kauth.AzureIdentityAuthenticationProvider, error) {
|
||||
// Client Provider: Uses Secret for access to tenant-level data
|
||||
cred, err := azidentity.NewClientSecretCredential(tenant, client, secret, nil)
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "creating m365 client identity")
|
||||
}
|
||||
|
||||
auth, err := kauth.NewAzureIdentityAuthenticationProviderWithScopes(
|
||||
cred,
|
||||
[]string{"https://graph.microsoft.com/.default"})
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "creating azure authentication")
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// requester authorization
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type authorizer interface {
|
||||
addAuthToHeaders(
|
||||
ctx context.Context,
|
||||
urlStr string,
|
||||
headers http.Header,
|
||||
) error
|
||||
}
|
||||
|
||||
// consumed by kiota
|
||||
type authenticateRequester interface {
|
||||
AuthenticateRequest(
|
||||
ctx context.Context,
|
||||
request *abstractions.RequestInformation,
|
||||
additionalAuthenticationContext map[string]any,
|
||||
) error
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Azure Authorizer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type azureAuth struct {
|
||||
auth authenticateRequester
|
||||
}
|
||||
|
||||
func NewAzureAuth(creds account.M365Config) (*azureAuth, error) {
|
||||
auth, err := GetAuth(
|
||||
creds.AzureTenantID,
|
||||
creds.AzureClientID,
|
||||
creds.AzureClientSecret)
|
||||
|
||||
return &azureAuth{auth}, clues.Stack(err).OrNil()
|
||||
}
|
||||
|
||||
func (aa azureAuth) addAuthToHeaders(
|
||||
ctx context.Context,
|
||||
urlStr string,
|
||||
headers http.Header,
|
||||
) error {
|
||||
requestInfo := abstractions.NewRequestInformation()
|
||||
|
||||
uri, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return clues.WrapWC(ctx, err, "parsing url").OrNil()
|
||||
}
|
||||
|
||||
requestInfo.SetUri(*uri)
|
||||
|
||||
err = aa.auth.AuthenticateRequest(ctx, requestInfo, nil)
|
||||
|
||||
for _, k := range requestInfo.Headers.ListKeys() {
|
||||
for _, v := range requestInfo.Headers.Get(k) {
|
||||
headers.Add(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
return clues.WrapWC(ctx, err, "authorizing request").OrNil()
|
||||
}
|
||||
@ -240,7 +240,7 @@ func (mw *RateLimiterMiddleware) Intercept(
|
||||
middlewareIndex int,
|
||||
req *http.Request,
|
||||
) (*http.Response, error) {
|
||||
QueueRequest(req.Context())
|
||||
QueueRequest(getReqCtx(req))
|
||||
return pipeline.Next(req, middlewareIndex)
|
||||
}
|
||||
|
||||
@ -339,7 +339,7 @@ func (mw *throttlingMiddleware) Intercept(
|
||||
middlewareIndex int,
|
||||
req *http.Request,
|
||||
) (*http.Response, error) {
|
||||
err := mw.tf.Block(req.Context())
|
||||
err := mw.tf.Block(getReqCtx(req))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -36,6 +36,7 @@ type Requester interface {
|
||||
method, url string,
|
||||
body io.Reader,
|
||||
headers map[string]string,
|
||||
requireAuth bool,
|
||||
) (*http.Response, error)
|
||||
}
|
||||
|
||||
@ -58,11 +59,7 @@ func NewHTTPWrapper(
|
||||
transport: defaultTransport(),
|
||||
},
|
||||
}
|
||||
redirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
hc = &http.Client{
|
||||
CheckRedirect: redirect,
|
||||
Transport: rt,
|
||||
}
|
||||
)
|
||||
@ -100,6 +97,7 @@ func (hw httpWrapper) Request(
|
||||
method, url string,
|
||||
body io.Reader,
|
||||
headers map[string]string,
|
||||
requireAuth bool,
|
||||
) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
@ -115,6 +113,17 @@ func (hw httpWrapper) Request(
|
||||
// See https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#how-to-decorate-your-http-traffic
|
||||
req.Header.Set("User-Agent", "ISV|Alcion|Corso/"+version.Version)
|
||||
|
||||
if requireAuth {
|
||||
if hw.config.requesterAuth == nil {
|
||||
return nil, clues.Wrap(err, "http wrapper misconfigured: missing required authorization")
|
||||
}
|
||||
|
||||
err := hw.config.requesterAuth.addAuthToHeaders(ctx, url, req.Header)
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "setting request auth headers")
|
||||
}
|
||||
}
|
||||
|
||||
retriedErrors := []string{}
|
||||
|
||||
var e error
|
||||
|
||||
@ -40,9 +40,10 @@ func (suite *HTTPWrapperIntgSuite) TestNewHTTPWrapper() {
|
||||
resp, err := hw.Request(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
"https://www.corsobackup.io",
|
||||
"https://www.google.com",
|
||||
nil,
|
||||
nil)
|
||||
nil,
|
||||
false)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
defer resp.Body.Close()
|
||||
@ -76,6 +77,56 @@ func (mw *mwForceResp) Intercept(
|
||||
return mw.resp, mw.err
|
||||
}
|
||||
|
||||
func (suite *HTTPWrapperIntgSuite) TestHTTPWrapper_Request_withAuth() {
|
||||
t := suite.T()
|
||||
|
||||
ctx, flush := tester.NewContext(t)
|
||||
defer flush()
|
||||
|
||||
a := tconfig.NewM365Account(t)
|
||||
m365, err := a.M365Config()
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
azureAuth, err := NewAzureAuth(m365)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
hw := NewHTTPWrapper(count.New(), AuthorizeRequester(azureAuth))
|
||||
|
||||
// any request that requires authorization will do
|
||||
resp, err := hw.Request(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
"https://graph.microsoft.com/v1.0/users",
|
||||
nil,
|
||||
nil,
|
||||
true)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// also validate that non-auth'd endpoints succeed
|
||||
resp, err = hw.Request(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
"https://www.google.com",
|
||||
nil,
|
||||
nil,
|
||||
true)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// unit
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type HTTPWrapperUnitSuite struct {
|
||||
tester.Suite
|
||||
}
|
||||
@ -84,26 +135,25 @@ func TestHTTPWrapperUnitSuite(t *testing.T) {
|
||||
suite.Run(t, &HTTPWrapperUnitSuite{Suite: tester.NewUnitSuite(t)})
|
||||
}
|
||||
|
||||
func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_redirectMiddleware() {
|
||||
func (suite *HTTPWrapperUnitSuite) TestHTTPWrapper_Request_redirect() {
|
||||
t := suite.T()
|
||||
|
||||
ctx, flush := tester.NewContext(t)
|
||||
defer flush()
|
||||
|
||||
url := "https://graph.microsoft.com/fnords/beaux/regard"
|
||||
|
||||
hdr := http.Header{}
|
||||
hdr.Set("Location", "localhost:99999999/smarfs")
|
||||
respHdr := http.Header{}
|
||||
respHdr.Set("Location", "localhost:99999999/smarfs")
|
||||
|
||||
toResp := &http.Response{
|
||||
StatusCode: http.StatusFound,
|
||||
Header: hdr,
|
||||
Header: respHdr,
|
||||
}
|
||||
|
||||
mwResp := mwForceResp{
|
||||
resp: toResp,
|
||||
alternate: func(req *http.Request) (bool, *http.Response, error) {
|
||||
if strings.HasSuffix(req.URL.String(), "smarfs") {
|
||||
assert.Equal(t, req.Header.Get("X-Test-Val"), "should-be-copied-to-redirect")
|
||||
return true, &http.Response{StatusCode: http.StatusOK}, nil
|
||||
}
|
||||
|
||||
@ -113,17 +163,22 @@ func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_redirectMiddleware() {
|
||||
|
||||
hw := NewHTTPWrapper(count.New(), appendMiddleware(&mwResp))
|
||||
|
||||
resp, err := hw.Request(ctx, http.MethodGet, url, nil, nil)
|
||||
resp, err := hw.Request(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
"https://graph.microsoft.com/fnords/beaux/regard",
|
||||
nil,
|
||||
map[string]string{"X-Test-Val": "should-be-copied-to-redirect"},
|
||||
false)
|
||||
require.NoError(t, err, clues.ToCore(err))
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.NotNil(t, resp)
|
||||
// require.Equal(t, 1, calledCorrectly, "test server was called with expected path")
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_http2StreamErrorRetries() {
|
||||
func (suite *HTTPWrapperUnitSuite) TestHTTPWrapper_Request_http2StreamErrorRetries() {
|
||||
var (
|
||||
url = "https://graph.microsoft.com/fnords/beaux/regard"
|
||||
streamErr = http2.StreamError{
|
||||
@ -188,7 +243,7 @@ func (suite *HTTPWrapperUnitSuite) TestNewHTTPWrapper_http2StreamErrorRetries()
|
||||
// the test middleware.
|
||||
hw.retryDelay = 0
|
||||
|
||||
_, err := hw.Request(ctx, http.MethodGet, url, nil, nil)
|
||||
_, err := hw.Request(ctx, http.MethodGet, url, nil, nil, false)
|
||||
require.ErrorAs(t, err, &http2.StreamError{}, clues.ToCore(err))
|
||||
require.Equal(t, test.expectRetries, tries, "count of retries")
|
||||
})
|
||||
|
||||
@ -6,6 +6,9 @@ import (
|
||||
"net/http/httputil"
|
||||
"os"
|
||||
|
||||
"github.com/alcionai/clues"
|
||||
|
||||
"github.com/alcionai/corso/src/internal/common/pii"
|
||||
"github.com/alcionai/corso/src/pkg/logger"
|
||||
)
|
||||
|
||||
@ -69,3 +72,22 @@ func getRespDump(ctx context.Context, resp *http.Response, getBody bool) string
|
||||
|
||||
return string(respDump)
|
||||
}
|
||||
|
||||
func getReqCtx(req *http.Request) context.Context {
|
||||
if req == nil {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
var logURL pii.SafeURL
|
||||
|
||||
if req.URL != nil {
|
||||
logURL = LoggableURL(req.URL.String())
|
||||
}
|
||||
|
||||
return clues.AddTraceName(
|
||||
req.Context(),
|
||||
"graph-http-middleware",
|
||||
"method", req.Method,
|
||||
"url", logURL,
|
||||
"request_content_len", req.ContentLength)
|
||||
}
|
||||
|
||||
@ -125,10 +125,7 @@ func (mw *LoggingMiddleware) Intercept(
|
||||
}
|
||||
|
||||
ctx := clues.Add(
|
||||
req.Context(),
|
||||
"method", req.Method,
|
||||
"url", LoggableURL(req.URL.String()),
|
||||
"request_content_len", req.ContentLength,
|
||||
getReqCtx(req),
|
||||
"resp_status", resp.Status,
|
||||
"resp_status_code", resp.StatusCode,
|
||||
"resp_content_len", resp.ContentLength)
|
||||
@ -156,7 +153,7 @@ func (mw RetryMiddleware) Intercept(
|
||||
middlewareIndex int,
|
||||
req *http.Request,
|
||||
) (*http.Response, error) {
|
||||
ctx := req.Context()
|
||||
ctx := getReqCtx(req)
|
||||
resp, err := pipeline.Next(req, middlewareIndex)
|
||||
|
||||
retriable := IsErrTimeout(err) ||
|
||||
@ -249,7 +246,9 @@ func (mw RetryMiddleware) retryRequest(
|
||||
return resp, Wrap(ctx, err, "resetting request body reader")
|
||||
}
|
||||
} else {
|
||||
logger.Ctx(ctx).Error("body is not an io.Seeker: unable to reset request body")
|
||||
logger.
|
||||
Ctx(getReqCtx(req)).
|
||||
Error("body is not an io.Seeker: unable to reset request body")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -6,11 +6,9 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/alcionai/clues"
|
||||
abstractions "github.com/microsoft/kiota-abstractions-go"
|
||||
"github.com/microsoft/kiota-abstractions-go/serialization"
|
||||
kauth "github.com/microsoft/kiota-authentication-azure-go"
|
||||
khttp "github.com/microsoft/kiota-http-go"
|
||||
msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go"
|
||||
msgraphgocore "github.com/microsoftgraph/msgraph-sdk-go-core"
|
||||
@ -127,23 +125,6 @@ func CreateAdapter(
|
||||
return wrapAdapter(adpt, cc), nil
|
||||
}
|
||||
|
||||
func GetAuth(tenant string, client string, secret string) (*kauth.AzureIdentityAuthenticationProvider, error) {
|
||||
// Client Provider: Uses Secret for access to tenant-level data
|
||||
cred, err := azidentity.NewClientSecretCredential(tenant, client, secret, nil)
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "creating m365 client identity")
|
||||
}
|
||||
|
||||
auth, err := kauth.NewAzureIdentityAuthenticationProviderWithScopes(
|
||||
cred,
|
||||
[]string{"https://graph.microsoft.com/.default"})
|
||||
if err != nil {
|
||||
return nil, clues.Wrap(err, "creating azure authentication")
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// KiotaHTTPClient creates a httpClient with middlewares and timeout configured
|
||||
// for use in the graph adapter.
|
||||
//
|
||||
@ -200,6 +181,11 @@ type clientConfig struct {
|
||||
maxRetries int
|
||||
// The minimum delay in seconds between retries
|
||||
minDelay time.Duration
|
||||
// requesterAuth sets the authorization step for requester-compliant clients.
|
||||
// if non-nil, it will ensure calls are authorized before querying.
|
||||
// does not get consumed by the standard graph client, which already comes
|
||||
// packaged with an auth protocol.
|
||||
requesterAuth authorizer
|
||||
|
||||
appendMiddleware []khttp.Middleware
|
||||
}
|
||||
@ -287,6 +273,12 @@ func MaxConnectionRetries(max int) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func AuthorizeRequester(a authorizer) Option {
|
||||
return func(c *clientConfig) {
|
||||
c.requesterAuth = a
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Middleware Control
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@ -77,7 +77,8 @@ func (iw *largeItemWriter) Write(p []byte) (int, error) {
|
||||
http.MethodPut,
|
||||
iw.url,
|
||||
bytes.NewReader(p),
|
||||
headers)
|
||||
headers,
|
||||
false)
|
||||
if err != nil {
|
||||
return 0, clues.Wrap(err, "uploading item").With(
|
||||
"upload_id", iw.parentID,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user