corso/src/internal/connector/graph/middleware.go
Keepers d5fac8a480
custom http wrapper for item downloads (#3132)
the http client customized by the graph client for use in downloading files doesn't include our middleware like expected. This introduces a new http client wrapper that populates the roundtripper with a middleware wrapper that can utilize our kiota middleware in the same way as the graph client does.

---

#### Does this PR need a docs update or release note?

- [x]  No

#### Type of change

- [x] 🌻 Feature
- [x] 🐛 Bugfix
- [x] 🧹 Tech Debt/Cleanup

#### Issue(s)

* #3129

#### Test Plan

- [x] 💪 Manual
- [x]  Unit test
- [x] 💚 E2E
2023-04-20 23:20:23 +00:00

378 lines
9.8 KiB
Go

package graph
import (
"context"
"fmt"
"net/http"
"net/http/httputil"
"os"
"strconv"
"strings"
"time"
"github.com/alcionai/clues"
backoff "github.com/cenkalti/backoff/v4"
khttp "github.com/microsoft/kiota-http-go"
"golang.org/x/time/rate"
"github.com/alcionai/corso/src/internal/common/pii"
"github.com/alcionai/corso/src/internal/events"
"github.com/alcionai/corso/src/pkg/logger"
)
type nexter interface {
Next(req *http.Request, middlewareIndex int) (*http.Response, error)
}
// ---------------------------------------------------------------------------
// Logging
// ---------------------------------------------------------------------------
// LoggingMiddleware can be used to log the http request sent by the graph client
type LoggingMiddleware struct{}
// well-known path names used by graph api calls
// used to un-hide path elements in a pii.SafeURL
var SafeURLPathParams = pii.MapWithPlurals(
//nolint:misspell
"alltime",
"analytics",
"archive",
"beta",
"calendargroup",
"calendar",
"calendarview",
"channel",
"childfolder",
"children",
"clone",
"column",
"contactfolder",
"contact",
"contenttype",
"delta",
"drive",
"event",
"group",
"inbox",
"instance",
"invitation",
"item",
"joinedteam",
"label",
"list",
"mailfolder",
"member",
"message",
"notification",
"page",
"primarychannel",
"root",
"security",
"site",
"subscription",
"team",
"unarchive",
"user",
"v1.0")
// well-known safe query parameters used by graph api calls
//
// used to un-hide query params in a pii.SafeURL
var SafeURLQueryParams = map[string]struct{}{
"deltatoken": {},
"startdatetime": {},
"enddatetime": {},
"$count": {},
"$expand": {},
"$filter": {},
"$select": {},
"$top": {},
}
func LoggableURL(url string) pii.SafeURL {
return pii.SafeURL{
URL: url,
SafePathElems: SafeURLPathParams,
SafeQueryKeys: SafeURLQueryParams,
}
}
func (handler *LoggingMiddleware) Intercept(
pipeline khttp.Pipeline,
middlewareIndex int,
req *http.Request,
) (*http.Response, error) {
ctx := clues.Add(
req.Context(),
"method", req.Method,
"url", LoggableURL(req.URL.String()),
"request_len", req.ContentLength)
// call the next middleware
resp, err := pipeline.Next(req, middlewareIndex)
if strings.Contains(req.URL.String(), "users//") {
logger.Ctx(ctx).Error("malformed request url: missing resource")
}
if resp == nil {
return resp, err
}
ctx = clues.Add(ctx, "status", resp.Status, "statusCode", resp.StatusCode)
log := logger.Ctx(ctx)
// Return immediately if the response is good (2xx).
// If api logging is toggled, log a body-less dump of the request/resp.
if (resp.StatusCode / 100) == 2 {
if logger.DebugAPIFV || os.Getenv(log2xxGraphRequestsEnvKey) != "" {
log.Debugw("2xx graph api resp", "response", getRespDump(ctx, resp, os.Getenv(log2xxGraphResponseEnvKey) != ""))
}
return resp, err
}
// Log errors according to api debugging configurations.
// When debugging is toggled, every non-2xx is recorded with a response dump.
// Otherwise, throttling cases and other non-2xx responses are logged
// with a slimmer reference for telemetry/supportability purposes.
if logger.DebugAPIFV || os.Getenv(logGraphRequestsEnvKey) != "" {
log.Errorw("non-2xx graph api response", "response", getRespDump(ctx, resp, true))
return resp, err
}
msg := fmt.Sprintf("graph api error: %s", resp.Status)
// special case for supportability: log all throttling cases.
if resp.StatusCode == http.StatusTooManyRequests {
log = log.With(
"limit", resp.Header.Get(rateLimitHeader),
"remaining", resp.Header.Get(rateRemainingHeader),
"reset", resp.Header.Get(rateResetHeader),
"retry-after", resp.Header.Get(retryAfterHeader))
} else if resp.StatusCode/100 == 4 || resp.StatusCode == http.StatusServiceUnavailable {
log = log.With("response", getRespDump(ctx, resp, true))
}
log.Info(msg)
return resp, err
}
func getRespDump(ctx context.Context, resp *http.Response, getBody bool) string {
respDump, err := httputil.DumpResponse(resp, getBody)
if err != nil {
logger.CtxErr(ctx, err).Error("dumping http response")
}
return string(respDump)
}
// ---------------------------------------------------------------------------
// Retry & Backoff
// ---------------------------------------------------------------------------
// RetryHandler handles transient HTTP responses and retries the request given the retry options
type RetryHandler struct {
// The maximum number of times a request can be retried
MaxRetries int
// The delay in seconds between retries
Delay time.Duration
}
func (middleware RetryHandler) retryRequest(
ctx context.Context,
pipeline khttp.Pipeline,
middlewareIndex int,
req *http.Request,
resp *http.Response,
executionCount int,
cumulativeDelay time.Duration,
exponentialBackoff *backoff.ExponentialBackOff,
respErr error,
) (*http.Response, error) {
if (respErr != nil || middleware.isRetriableErrorCode(req, resp.StatusCode)) &&
middleware.isRetriableRequest(req) &&
executionCount < middleware.MaxRetries {
executionCount++
delay := middleware.getRetryDelay(req, resp, exponentialBackoff)
cumulativeDelay += delay
req.Header.Set(retryAttemptHeader, strconv.Itoa(executionCount))
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
// Don't retry if the context is marked as done, it will just error out
// when we attempt to send the retry anyway.
return resp, ctx.Err()
// Will exit switch-block so the remainder of the code doesn't need to be
// indented.
case <-timer.C:
}
response, err := pipeline.Next(req, middlewareIndex)
if err != nil && !IsErrTimeout(err) && !IsErrConnectionReset(err) {
return response, Stack(ctx, err).With("retry_count", executionCount)
}
return middleware.retryRequest(ctx,
pipeline,
middlewareIndex,
req,
response,
executionCount,
cumulativeDelay,
exponentialBackoff,
err)
}
if respErr != nil {
return nil, Stack(ctx, respErr).With("retry_count", executionCount)
}
return resp, nil
}
func (middleware RetryHandler) isRetriableErrorCode(req *http.Request, code int) bool {
return code == http.StatusInternalServerError || code == http.StatusServiceUnavailable
}
func (middleware RetryHandler) isRetriableRequest(req *http.Request) bool {
isBodiedMethod := req.Method == "POST" || req.Method == "PUT" || req.Method == "PATCH"
if isBodiedMethod && req.Body != nil {
return req.ContentLength != -1
}
return true
}
func (middleware RetryHandler) getRetryDelay(
req *http.Request,
resp *http.Response,
exponentialBackoff *backoff.ExponentialBackOff,
) time.Duration {
var retryAfter string
if resp != nil {
retryAfter = resp.Header.Get(retryAfterHeader)
}
if retryAfter != "" {
retryAfterDelay, err := strconv.ParseFloat(retryAfter, 64)
if err == nil {
return time.Duration(retryAfterDelay) * time.Second
}
} // TODO parse the header if it's a date
return exponentialBackoff.NextBackOff()
}
// Intercept implements the interface and evaluates whether to retry a failed request.
func (middleware RetryHandler) Intercept(
pipeline khttp.Pipeline,
middlewareIndex int,
req *http.Request,
) (*http.Response, error) {
ctx := req.Context()
response, err := pipeline.Next(req, middlewareIndex)
if err != nil && !IsErrTimeout(err) {
return response, Stack(ctx, err)
}
exponentialBackOff := backoff.NewExponentialBackOff()
exponentialBackOff.InitialInterval = middleware.Delay
exponentialBackOff.Reset()
response, err = middleware.retryRequest(
ctx,
pipeline,
middlewareIndex,
req,
response,
0,
0,
exponentialBackOff,
err)
if err != nil {
return nil, Stack(ctx, err)
}
return response, nil
}
// We're trying to keep calls below the 10k-per-10-minute threshold.
// 15 tokens every second nets 900 per minute. That's 9000 every 10 minutes,
// which is a bit below the mark.
// But suppose we have a minute-long dry spell followed by a 10 minute tsunami.
// We'll have built up 900 tokens in reserve, so the first 900 calls go through
// immediately. Over the next 10 minutes, we'll partition out the other calls
// at a rate of 900-per-minute, ending at a total of 9900. Theoretically, if
// the volume keeps up after that, we'll always stay between 9000 and 9900 out
// of 10k.
const (
perSecond = 15
maxCap = 900
)
// Single, global rate limiter at this time. Refinements for method (creates,
// versus reads) or service can come later.
var limiter = rate.NewLimiter(perSecond, maxCap)
// QueueRequest will allow the request to occur immediately if we're under the
// 1k-calls-per-minute rate. Otherwise, the call will wait in a queue until
// the next token set is available.
func QueueRequest(ctx context.Context) {
if err := limiter.Wait(ctx); err != nil {
logger.CtxErr(ctx, err).Error("graph middleware waiting on the limiter")
}
}
// ---------------------------------------------------------------------------
// Rate Limiting
// ---------------------------------------------------------------------------
// ThrottleControlMiddleware is used to ensure we don't overstep 10k-per-10-min
// request limits.
type ThrottleControlMiddleware struct{}
func (handler *ThrottleControlMiddleware) Intercept(
pipeline khttp.Pipeline,
middlewareIndex int,
req *http.Request,
) (*http.Response, error) {
QueueRequest(req.Context())
return pipeline.Next(req, middlewareIndex)
}
// MetricsMiddleware aggregates per-request metrics on the events bus
type MetricsMiddleware struct{}
func (handler *MetricsMiddleware) Intercept(
pipeline khttp.Pipeline,
middlewareIndex int,
req *http.Request,
) (*http.Response, error) {
var (
start = time.Now()
resp, err = pipeline.Next(req, middlewareIndex)
status = "nil-resp"
)
if resp != nil {
status = resp.Status
}
events.Inc(events.APICall)
events.Inc(events.APICall, status)
events.Since(start, events.APICall)
events.Since(start, events.APICall, status)
return resp, err
}