corso/src/internal/m365/graph/concurrency_middleware.go
Abhishek Pandey b41c344d03 Fix ctx bug
2023-10-07 00:41:09 +05:30

378 lines
10 KiB
Go

package graph
import (
"context"
"net/http"
"strconv"
"sync"
"time"
"github.com/alcionai/clues"
khttp "github.com/microsoft/kiota-http-go"
"go.opentelemetry.io/otel/metric"
"golang.org/x/time/rate"
"github.com/alcionai/corso/src/internal/events"
"github.com/alcionai/corso/src/pkg/count"
"github.com/alcionai/corso/src/pkg/logger"
"github.com/alcionai/corso/src/pkg/path"
)
// ---------------------------------------------------------------------------
// Concurrency Limiter
// "how many calls at one time"
// ---------------------------------------------------------------------------
// concurrencyLimiterMiddleware middleware limits the number of concurrent requests to graph API
type concurrencyLimiterMiddleware struct {
semaphore chan struct{}
}
var (
once sync.Once
concurrencyLimitMiddlewareSingleton = &concurrencyLimiterMiddleware{}
maxConcurrentRequests = 4
)
func generateConcurrencyLimiter(capacity int) *concurrencyLimiterMiddleware {
if capacity < 1 || capacity > maxConcurrentRequests {
capacity = maxConcurrentRequests
}
return &concurrencyLimiterMiddleware{
semaphore: make(chan struct{}, capacity),
}
}
func InitializeConcurrencyLimiter(ctx context.Context, enable bool, capacity int) {
once.Do(func() {
switch enable {
case true:
logger.Ctx(ctx).Infow("turning on the concurrency limiter", "concurrency_limit", capacity)
concurrencyLimitMiddlewareSingleton.semaphore = generateConcurrencyLimiter(capacity).semaphore
case false:
logger.Ctx(ctx).Info("turning off the concurrency limiter")
concurrencyLimitMiddlewareSingleton = nil
}
})
}
func (cl *concurrencyLimiterMiddleware) Intercept(
pipeline khttp.Pipeline,
middlewareIndex int,
req *http.Request,
) (*http.Response, error) {
if cl == nil || cl.semaphore == nil {
return nil, clues.New("nil concurrency limiter")
}
cl.semaphore <- struct{}{}
defer func() {
<-cl.semaphore
}()
return pipeline.Next(req, middlewareIndex)
}
//nolint:lll
// ---------------------------------------------------------------------------
// Rate Limiter
// "how many calls in a minute"
// https://learn.microsoft.com/en-us/sharepoint/dev/general-development/how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online
// ---------------------------------------------------------------------------
const (
// Default goal is to keep calls below the 10k-per-10-minute threshold.
// 16 tokens every second nets 960 per minute. That's 9600 every 10 minutes,
// which is a bit below the mark.
// If the bucket is full, we can push out 200 calls immediately, which brings
// the total in the first 10 minutes to 9800. We can toe that line if we want,
// but doing so risks timeouts. It's better to give the limits breathing room.
defaultPerSecond = 16 // 16 * 60 * 10 = 9600
defaultMaxCap = 200 // real cap is 10k-per-10-minutes
// since drive runs on a per-minute, rather than per-10-minute bucket, we have
// to keep the max cap equal to the per-second cap. A large maxCap pool (say,
// 1200, similar to the per-minute cap) would allow us to make a flood of 2400
// calls in the first minute, putting us over the per-minute limit. Keeping
// the cap at the per-second burst means we only dole out a max of 1240 in one
// minute (20 cap + 1200 per minute + one burst of padding).
drivePerSecond = 20 // 20 * 60 = 1200
driveMaxCap = 20 // real cap is 1250-per-minute
)
var (
driveLimiter = rate.NewLimiter(drivePerSecond, driveMaxCap)
// also used as the exchange service limiter
defaultLimiter = rate.NewLimiter(defaultPerSecond, defaultMaxCap)
)
type LimiterCfg struct {
Service path.ServiceType
}
type limiterCfgKey string
const limiterCfgCtxKey limiterCfgKey = "corsoGaphRateLimiterCfg"
func BindRateLimiterConfig(ctx context.Context, lc LimiterCfg) context.Context {
return context.WithValue(ctx, limiterCfgCtxKey, lc)
}
func ctxLimiter(ctx context.Context) *rate.Limiter {
lc, ok := extractRateLimiterConfig(ctx)
if !ok {
return defaultLimiter
}
switch lc.Service {
// FIXME: Handle based on category once we add chat backup
case path.OneDriveService, path.SharePointService, path.GroupsService:
return driveLimiter
default:
return defaultLimiter
}
}
func extractRateLimiterConfig(ctx context.Context) (LimiterCfg, bool) {
l := ctx.Value(limiterCfgCtxKey)
if l == nil {
return LimiterCfg{}, false
}
lc, ok := l.(LimiterCfg)
return lc, ok
}
type limiterConsumptionKey string
const limiterConsumptionCtxKey limiterConsumptionKey = "corsoGraphRateLimiterConsumption"
const (
// https://learn.microsoft.com/en-us/sharepoint/dev/general-development
// /how-to-avoid-getting-throttled-or-blocked-in-sharepoint-online#application-throttling
defaultLC = 1
driveDefaultLC = 2
// limit consumption rate for single-item GETs requests,
// or delta-based multi-item GETs.
SingleGetOrDeltaLC = 1
// limit consumption rate for anything permissions related
PermissionsLC = 5
)
var twonce sync.Once
var token int64 = 0
func RegisterRLMetrics(ctx context.Context) {
twonce.Do(func() {
cb := func(ctx context.Context, o metric.Observer) error {
token += int64(ctxLimiter(ctx).Tokens())
o.ObserveInt64(events.RLGauge, token)
return nil
}
events.RegisterGauge(ctx, events.RLTokens, cb)
})
}
// ConsumeNTokens ensures any calls using this context will consume
// n rate-limiter tokens. Default is 1, and this value does not need
// to be established in the context to consume the default tokens.
// This should only get used on a per-call basis, to avoid cross-pollination.
func ConsumeNTokens(ctx context.Context, n int) context.Context {
return context.WithValue(ctx, limiterConsumptionCtxKey, n)
}
func ctxLimiterConsumption(ctx context.Context, defaultConsumption int) int {
l := ctx.Value(limiterConsumptionCtxKey)
if l == nil {
return defaultConsumption
}
lc, ok := l.(int)
if !ok || lc < 1 {
return defaultConsumption
}
return lc
}
// QueueRequest will allow the request to occur immediately if we're under the
// calls-per-minute rate. Otherwise, the call will wait in a queue until
// the next token set is available.
func QueueRequest(ctx context.Context) {
limiter := ctxLimiter(ctx)
defaultConsumed := defaultLC
if limiter == driveLimiter {
defaultConsumed = driveDefaultLC
}
// events.IncN(ctx, int(t), events.RLTokens)
consume := ctxLimiterConsumption(ctx, defaultConsumed)
if err := limiter.WaitN(ctx, consume); err != nil {
logger.CtxErr(ctx, err).Error("graph middleware waiting on the limiter")
}
}
// RateLimiterMiddleware is used to ensure we don't overstep per-min request limits.
type RateLimiterMiddleware struct{}
func (mw *RateLimiterMiddleware) Intercept(
pipeline khttp.Pipeline,
middlewareIndex int,
req *http.Request,
) (*http.Response, error) {
RegisterRLMetrics(req.Context())
QueueRequest(req.Context())
return pipeline.Next(req, middlewareIndex)
}
// ---------------------------------------------------------------------------
// global throttle fencing
// ---------------------------------------------------------------------------
// timedFence sets up a fence for a certain amount of time.
// the time can be extended arbitrarily. All processes blocked at
// the fence will be let through when all timer extensions conclude.
type timedFence struct {
mu sync.Mutex
c chan struct{}
timers map[int64]*time.Timer
}
func newTimedFence() *timedFence {
return &timedFence{
mu: sync.Mutex{},
c: nil,
timers: map[int64]*time.Timer{},
}
}
// Block until the fence is let down.
// if no fence is up, return immediately.
// returns if the ctx deadlines before the fence is let down.
func (tf *timedFence) Block(ctx context.Context) error {
// set to a local var to avoid race panics from tf.c
// getting set to nil between the conditional check and
// the read case. If c gets closed between those two
// points then the select case will exit immediately,
// as if we didn't block at all.
c := tf.c
if c != nil {
select {
case <-ctx.Done():
return clues.Wrap(ctx.Err(), "blocked on throttling fence")
case <-c:
}
}
return nil
}
// RaiseFence puts up a fence to block requests for the provided
// duration of time. Seconds are always added to the current time.
// Multiple calls to RaiseFence are not additive. ie: calling
// `RaiseFence(5); RaiseFence(1)` will keep the fence up until
// now+5 seconds, not now+6 seconds. When the last remaining fence
// is dropped, all currently blocked calls are allowed through.
func (tf *timedFence) RaiseFence(seconds time.Duration) {
tf.mu.Lock()
defer tf.mu.Unlock()
if seconds < 1 {
return
}
if tf.c == nil {
tf.c = make(chan struct{})
}
timer := time.NewTimer(seconds)
tid := time.Now().Add(seconds).UnixMilli()
tf.timers[tid] = timer
go func(c <-chan time.Time, id int64) {
// wait for the timeout
<-c
tf.mu.Lock()
defer tf.mu.Unlock()
// remove the timer
delete(tf.timers, id)
// if no timers remain, close the channel to drop the fence
// and set the fenc channel to nil
if len(tf.timers) == 0 && tf.c != nil {
close(tf.c)
tf.c = nil
}
}(timer.C, tid)
}
// throttlingMiddleware is used to ensure we don't overstep per-min request limits.
type throttlingMiddleware struct {
tf *timedFence
}
func (mw *throttlingMiddleware) Intercept(
pipeline khttp.Pipeline,
middlewareIndex int,
req *http.Request,
) (*http.Response, error) {
err := mw.tf.Block(req.Context())
if err != nil {
return nil, err
}
resp, err := pipeline.Next(req, middlewareIndex)
if resp == nil || err != nil {
return resp, err
}
seconds := getRetryAfterHeader(resp)
if seconds < 1 {
return resp, nil
}
countBus := count.Ctx(req.Context())
countBus.Inc(count.ThrottledAPICalls)
// if all prior conditions pass, we need to add a fence that blocks
// calls, globally, from progressing until the timeout retry-after
// passes.
mw.tf.RaiseFence(time.Duration(seconds) * time.Second)
return resp, nil
}
func getRetryAfterHeader(resp *http.Response) int {
if resp == nil || len(resp.Header) == 0 {
return -1
}
if resp.StatusCode != http.StatusTooManyRequests {
return -1
}
rah := resp.Header.Get(retryAfterHeader)
if len(rah) == 0 {
return -1
}
seconds, err := strconv.Atoi(rah)
if err != nil {
// the error itself is irrelevant, we only want
// to wait if we have a clear length of time to wait until.
return -1
}
return seconds
}