Compare commits

...

4 Commits

Author SHA1 Message Date
Abhishek Pandey
196f73b082 graph limiter movement 2023-11-09 19:29:07 -08:00
Abhishek Pandey
f10a7bb7b7 Add logs 2023-11-08 22:11:50 -08:00
Abhishek Pandey
4df81e7fd7 Make lowercase arg 2023-11-08 21:38:21 -08:00
Abhishek Pandey
59c0e7d6d3 Implement WaitN 2023-11-08 21:22:04 -08:00
9 changed files with 184 additions and 87 deletions

View File

@ -4,5 +4,6 @@ import "context"
type Limiter interface { type Limiter interface {
Wait(ctx context.Context) error Wait(ctx context.Context) error
WaitN(ctx context.Context, n int) error
Shutdown() Shutdown()
} }

View File

@ -84,16 +84,28 @@ func NewSlidingWindowLimiter(
} }
// Wait blocks a request until a token is available or the context is cancelled. // Wait blocks a request until a token is available or the context is cancelled.
// TODO(pandeyabs): Implement WaitN. // Equivalent to calling WaitN(ctx, 1).
func (s *slidingWindow) Wait(ctx context.Context) error { func (s *slidingWindow) Wait(ctx context.Context) error {
select { return s.WaitN(ctx, 1)
case <-ctx.Done(): }
return clues.Stack(ctx.Err())
case <-s.permits:
s.mu.Lock()
defer s.mu.Unlock()
s.curr.count[s.currentInterval]++ // WaitN blocks a request until n tokens are available or the context gets
// cancelled. WaitN should be called with n <= capacity otherwise it will block
// forever.
//
// TODO(pandeyabs): Enforce n <= capacity check. Not adding it right now because
// we are relying on capacity = 0 for ctx cancellation test, which would need
// some refactoring.
func (s *slidingWindow) WaitN(ctx context.Context, n int) error {
for i := 0; i < n; i++ {
select {
case <-ctx.Done():
return clues.Stack(ctx.Err())
case <-s.permits:
s.mu.Lock()
s.curr.count[s.currentInterval]++
s.mu.Unlock()
}
} }
return nil return nil

View File

@ -81,57 +81,81 @@ func (suite *SlidingWindowUnitTestSuite) TestWaitBasic() {
} }
// TestWaitSliding tests the sliding window functionality of the limiter with // TestWaitSliding tests the sliding window functionality of the limiter with
// time distributed Wait() calls. // time distributed WaitN() calls.
func (suite *SlidingWindowUnitTestSuite) TestWaitSliding() { func (suite *SlidingWindowUnitTestSuite) TestWaitSliding() {
var ( tests := []struct {
t = suite.T() Name string
windowSize = 1 * time.Second windowSize time.Duration
slideInterval = 10 * time.Millisecond slideInterval time.Duration
capacity = 100 capacity int
// Test will run for duration of 2 windowSize. numRequests int
numRequests = 2 * capacity n int
wg sync.WaitGroup }{
) {
Name: "Request 1 token each",
defer goleak.VerifyNone(t) windowSize: 1 * time.Second,
slideInterval: 10 * time.Millisecond,
ctx, flush := tester.NewContext(t) capacity: 100,
defer flush() numRequests: 200,
n: 1,
s, err := NewSlidingWindowLimiter(windowSize, slideInterval, capacity) },
require.NoError(t, err) {
Name: "Request 5 tokens each",
// Make concurrent requests to the limiter windowSize: 1 * time.Second,
for i := 0; i < numRequests; i++ { slideInterval: 10 * time.Millisecond,
wg.Add(1) capacity: 100,
numRequests: 100,
go func() { n: 5,
defer wg.Done() },
// Sleep for a random duration to spread out requests over multiple slide
// intervals & windows, so that we can test the sliding window logic better.
// Without this, the requests will be bunched up in the very first intervals
// of the 2 windows. Rest of the intervals will be empty.
time.Sleep(time.Duration(rand.Intn(1500)) * time.Millisecond)
err := s.Wait(ctx)
require.NoError(t, err)
}()
} }
wg.Wait()
// Shutdown the ticker before accessing the internal limiter state. for _, test := range tests {
s.Shutdown() suite.Run(test.Name, func() {
t := suite.T()
// Verify that number of requests allowed in each window is less than or equal defer goleak.VerifyNone(t)
// to window capacity
sw := s.(*slidingWindow)
data := append(sw.prev.count, sw.curr.count...)
sums := slidingSums(data, sw.numIntervals) ctx, flush := tester.NewContext(t)
defer flush()
for _, sum := range sums { s, err := NewSlidingWindowLimiter(test.windowSize, test.slideInterval, test.capacity)
require.True(t, sum <= capacity, "sum: %d, capacity: %d", sum, capacity) require.NoError(t, err)
var wg sync.WaitGroup
// Make concurrent requests to the limiter
for i := 0; i < test.numRequests; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// Sleep for a random duration to spread out requests over multiple slide
// intervals & windows, so that we can test the sliding window logic better.
// Without this, the requests will be bunched up in the very first intervals
// of the 2 windows. Rest of the intervals will be empty.
time.Sleep(time.Duration(rand.Intn(1500)) * time.Millisecond)
err := s.WaitN(ctx, test.n)
require.NoError(t, err)
}()
}
wg.Wait()
// Shutdown the ticker before accessing the internal limiter state.
s.Shutdown()
// Verify that number of requests allowed in each window is less than or equal
// to window capacity
sw := s.(*slidingWindow)
data := append(sw.prev.count, sw.curr.count...)
sums := slidingSums(data, sw.numIntervals)
for _, sum := range sums {
require.True(t, sum <= test.capacity, "sum: %d, capacity: %d", sum, test.capacity)
}
})
} }
} }

View File

@ -0,0 +1,29 @@
package limiters
import (
"context"
"golang.org/x/time/rate"
)
var _ Limiter = &TokenBucket{}
// Wrapper around the golang.org/x/time/rate token bucket rate limiter.
type TokenBucket struct {
lim *rate.Limiter
}
func NewTokenBucketLimiter(r int, burst int) Limiter {
lim := rate.NewLimiter(rate.Limit(r), burst)
return &TokenBucket{lim: lim}
}
func (tb *TokenBucket) Wait(ctx context.Context) error {
return tb.lim.Wait(ctx)
}
func (tb *TokenBucket) WaitN(ctx context.Context, n int) error {
return tb.lim.WaitN(ctx, n)
}
func (tb *TokenBucket) Shutdown() {}

View File

@ -8,6 +8,7 @@ import (
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/common/idname"
"github.com/alcionai/corso/src/internal/common/limiters"
"github.com/alcionai/corso/src/internal/data" "github.com/alcionai/corso/src/internal/data"
"github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/internal/m365/graph"
"github.com/alcionai/corso/src/internal/m365/resource" "github.com/alcionai/corso/src/internal/m365/resource"
@ -78,7 +79,24 @@ func NewController(
return nil, clues.Wrap(err, "retrieving m365 account configuration").WithClues(ctx) return nil, clues.Wrap(err, "retrieving m365 account configuration").WithClues(ctx)
} }
ac, err := api.NewClient(creds, co, counter) // Pick a rate limiter based on the service type
var lim limiters.Limiter
switch pst {
case path.OneDriveService, path.SharePointService, path.GroupsService:
lim = limiters.NewTokenBucketLimiter(graph.DrivePerSecond, graph.DriveMaxCap)
default:
// TODO(pandeyabs): Change default to token bucket exch limits like it exists today.
lim, err = limiters.NewSlidingWindowLimiter(
graph.ExchangeTimeLimit,
graph.ExchangeSlideInterval,
graph.ExchangeTokenQuota)
if err != nil {
return nil, clues.Wrap(err, "creating sliding window limiter").WithClues(ctx)
}
}
ac, err := api.NewClient(creds, co, counter, lim)
if err != nil { if err != nil {
return nil, clues.Wrap(err, "creating api client").WithClues(ctx) return nil, clues.Wrap(err, "creating api client").WithClues(ctx)
} }

View File

@ -11,6 +11,7 @@ import (
khttp "github.com/microsoft/kiota-http-go" khttp "github.com/microsoft/kiota-http-go"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"github.com/alcionai/corso/src/internal/common/limiters"
"github.com/alcionai/corso/src/pkg/count" "github.com/alcionai/corso/src/pkg/count"
"github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/logger"
"github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/path"
@ -88,18 +89,23 @@ const (
// but doing so risks timeouts. It's better to give the limits breathing room. // but doing so risks timeouts. It's better to give the limits breathing room.
defaultPerSecond = 16 // 16 * 60 * 10 = 9600 defaultPerSecond = 16 // 16 * 60 * 10 = 9600
defaultMaxCap = 200 // real cap is 10k-per-10-minutes defaultMaxCap = 200 // real cap is 10k-per-10-minutes
ExchangeTimeLimit = 10 * time.Minute
ExchangeTokenQuota = 9600
ExchangeSlideInterval = 1 * time.Second
// since drive runs on a per-minute, rather than per-10-minute bucket, we have // since drive runs on a per-minute, rather than per-10-minute bucket, we have
// to keep the max cap equal to the per-second cap. A large maxCap pool (say, // to keep the max cap equal to the per-second cap. A large maxCap pool (say,
// 1200, similar to the per-minute cap) would allow us to make a flood of 2400 // 1200, similar to the per-minute cap) would allow us to make a flood of 2400
// calls in the first minute, putting us over the per-minute limit. Keeping // calls in the first minute, putting us over the per-minute limit. Keeping
// the cap at the per-second burst means we only dole out a max of 1240 in one // the cap at the per-second burst means we only dole out a max of 1240 in one
// minute (20 cap + 1200 per minute + one burst of padding). // minute (20 cap + 1200 per minute + one burst of padding).
drivePerSecond = 20 // 20 * 60 = 1200 DrivePerSecond = 20 // 20 * 60 = 1200
driveMaxCap = 20 // real cap is 1250-per-minute DriveMaxCap = 20 // real cap is 1250-per-minute
) )
var ( var (
driveLimiter = rate.NewLimiter(drivePerSecond, driveMaxCap) driveLimiter = rate.NewLimiter(DrivePerSecond, DriveMaxCap)
// also used as the exchange service limiter // also used as the exchange service limiter
defaultLimiter = rate.NewLimiter(defaultPerSecond, defaultMaxCap) defaultLimiter = rate.NewLimiter(defaultPerSecond, defaultMaxCap)
) )
@ -116,21 +122,6 @@ func BindRateLimiterConfig(ctx context.Context, lc LimiterCfg) context.Context {
return context.WithValue(ctx, limiterCfgCtxKey, lc) return context.WithValue(ctx, limiterCfgCtxKey, lc)
} }
func ctxLimiter(ctx context.Context) *rate.Limiter {
lc, ok := extractRateLimiterConfig(ctx)
if !ok {
return defaultLimiter
}
switch lc.Service {
// FIXME: Handle based on category once we add chat backup
case path.OneDriveService, path.SharePointService, path.GroupsService:
return driveLimiter
default:
return defaultLimiter
}
}
func extractRateLimiterConfig(ctx context.Context) (LimiterCfg, bool) { func extractRateLimiterConfig(ctx context.Context) (LimiterCfg, bool) {
l := ctx.Value(limiterCfgCtxKey) l := ctx.Value(limiterCfgCtxKey)
if l == nil { if l == nil {
@ -184,24 +175,25 @@ func ctxLimiterConsumption(ctx context.Context, defaultConsumption int) int {
// QueueRequest will allow the request to occur immediately if we're under the // QueueRequest will allow the request to occur immediately if we're under the
// calls-per-minute rate. Otherwise, the call will wait in a queue until // calls-per-minute rate. Otherwise, the call will wait in a queue until
// the next token set is available. // the next token set is available.
func QueueRequest(ctx context.Context) { func QueueRequest(ctx context.Context, lim limiters.Limiter) {
limiter := ctxLimiter(ctx)
consume := ctxLimiterConsumption(ctx, defaultLC) consume := ctxLimiterConsumption(ctx, defaultLC)
if err := limiter.WaitN(ctx, consume); err != nil { if err := lim.WaitN(ctx, consume); err != nil {
logger.CtxErr(ctx, err).Error("graph middleware waiting on the limiter") logger.CtxErr(ctx, err).Error("graph middleware waiting on the limiter")
} }
} }
// RateLimiterMiddleware is used to ensure we don't overstep per-min request limits. // RateLimiterMiddleware is used to ensure we don't overstep per-min request limits.
type RateLimiterMiddleware struct{} type RateLimiterMiddleware struct {
lim limiters.Limiter
}
func (mw *RateLimiterMiddleware) Intercept( func (mw *RateLimiterMiddleware) Intercept(
pipeline khttp.Pipeline, pipeline khttp.Pipeline,
middlewareIndex int, middlewareIndex int,
req *http.Request, req *http.Request,
) (*http.Response, error) { ) (*http.Response, error) {
QueueRequest(req.Context()) QueueRequest(req.Context(), mw.lim)
return pipeline.Next(req, middlewareIndex) return pipeline.Next(req, middlewareIndex)
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"github.com/alcionai/corso/src/internal/common/limiters"
"github.com/alcionai/corso/src/internal/events" "github.com/alcionai/corso/src/internal/events"
"github.com/alcionai/corso/src/internal/version" "github.com/alcionai/corso/src/internal/version"
"github.com/alcionai/corso/src/pkg/count" "github.com/alcionai/corso/src/pkg/count"
@ -39,13 +40,14 @@ type Requester interface {
// can utilize it on a per-download basis. // can utilize it on a per-download basis.
func NewHTTPWrapper( func NewHTTPWrapper(
counter *count.Bus, counter *count.Bus,
lim limiters.Limiter,
opts ...Option, opts ...Option,
) *httpWrapper { ) *httpWrapper {
var ( var (
cc = populateConfig(opts...) cc = populateConfig(opts...)
rt = customTransport{ rt = customTransport{
n: pipeline{ n: pipeline{
middlewares: internalMiddleware(cc, counter), middlewares: internalMiddleware(cc, counter, lim),
transport: defaultTransport(), transport: defaultTransport(),
}, },
} }
@ -72,10 +74,11 @@ func NewHTTPWrapper(
// can utilize it on a per-download basis. // can utilize it on a per-download basis.
func NewNoTimeoutHTTPWrapper( func NewNoTimeoutHTTPWrapper(
counter *count.Bus, counter *count.Bus,
lim limiters.Limiter,
opts ...Option, opts ...Option,
) *httpWrapper { ) *httpWrapper {
opts = append(opts, NoTimeout()) opts = append(opts, NoTimeout())
return NewHTTPWrapper(counter, opts...) return NewHTTPWrapper(counter, lim, opts...)
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -188,6 +191,7 @@ func defaultTransport() http.RoundTripper {
func internalMiddleware( func internalMiddleware(
cc *clientConfig, cc *clientConfig,
counter *count.Bus, counter *count.Bus,
lim limiters.Limiter,
) []khttp.Middleware { ) []khttp.Middleware {
throttler := &throttlingMiddleware{ throttler := &throttlingMiddleware{
tf: newTimedFence(), tf: newTimedFence(),
@ -203,7 +207,9 @@ func internalMiddleware(
khttp.NewRedirectHandler(), khttp.NewRedirectHandler(),
&LoggingMiddleware{}, &LoggingMiddleware{},
throttler, throttler,
&RateLimiterMiddleware{}, &RateLimiterMiddleware{
lim: lim,
},
&MetricsMiddleware{ &MetricsMiddleware{
counter: counter, counter: counter,
}, },

View File

@ -16,6 +16,7 @@ import (
"github.com/alcionai/corso/src/internal/common/crash" "github.com/alcionai/corso/src/internal/common/crash"
"github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/common/idname"
"github.com/alcionai/corso/src/internal/common/limiters"
"github.com/alcionai/corso/src/internal/events" "github.com/alcionai/corso/src/internal/events"
"github.com/alcionai/corso/src/pkg/count" "github.com/alcionai/corso/src/pkg/count"
"github.com/alcionai/corso/src/pkg/filters" "github.com/alcionai/corso/src/pkg/filters"
@ -106,6 +107,7 @@ func (s Service) Serialize(object serialization.Parsable) ([]byte, error) {
func CreateAdapter( func CreateAdapter(
tenant, client, secret string, tenant, client, secret string,
counter *count.Bus, counter *count.Bus,
lim limiters.Limiter,
opts ...Option, opts ...Option,
) (abstractions.RequestAdapter, error) { ) (abstractions.RequestAdapter, error) {
auth, err := GetAuth(tenant, client, secret) auth, err := GetAuth(tenant, client, secret)
@ -113,7 +115,7 @@ func CreateAdapter(
return nil, err return nil, err
} }
httpClient, cc := KiotaHTTPClient(counter, opts...) httpClient, cc := KiotaHTTPClient(counter, lim, opts...)
adpt, err := msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient( adpt, err := msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(
auth, auth,
@ -152,12 +154,13 @@ func GetAuth(tenant string, client string, secret string) (*kauth.AzureIdentityA
// can utilize it on a per-download basis. // can utilize it on a per-download basis.
func KiotaHTTPClient( func KiotaHTTPClient(
counter *count.Bus, counter *count.Bus,
lim limiters.Limiter,
opts ...Option, opts ...Option,
) (*http.Client, *clientConfig) { ) (*http.Client, *clientConfig) {
var ( var (
clientOptions = msgraphsdkgo.GetDefaultClientOptions() clientOptions = msgraphsdkgo.GetDefaultClientOptions()
cc = populateConfig(opts...) cc = populateConfig(opts...)
middlewares = kiotaMiddlewares(&clientOptions, cc, counter) middlewares = kiotaMiddlewares(&clientOptions, cc, counter, lim)
httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...) httpClient = msgraphgocore.GetDefaultClient(&clientOptions, middlewares...)
) )
@ -277,6 +280,7 @@ func kiotaMiddlewares(
options *msgraphgocore.GraphClientOptions, options *msgraphgocore.GraphClientOptions,
cc *clientConfig, cc *clientConfig,
counter *count.Bus, counter *count.Bus,
lim limiters.Limiter,
) []khttp.Middleware { ) []khttp.Middleware {
mw := []khttp.Middleware{ mw := []khttp.Middleware{
msgraphgocore.NewGraphTelemetryHandler(options), msgraphgocore.NewGraphTelemetryHandler(options),
@ -305,7 +309,9 @@ func kiotaMiddlewares(
mw = append( mw = append(
mw, mw,
throttler, throttler,
&RateLimiterMiddleware{}, &RateLimiterMiddleware{
lim: lim,
},
&MetricsMiddleware{ &MetricsMiddleware{
counter: counter, counter: counter,
}) })

View File

@ -7,6 +7,7 @@ import (
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/alcionai/corso/src/internal/common/limiters"
"github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/internal/m365/graph"
"github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/account"
"github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/control"
@ -43,6 +44,9 @@ type Client struct {
counter *count.Bus counter *count.Bus
options control.Options options control.Options
// rate limiter
lim limiters.Limiter
} }
// NewClient produces a new exchange api client. Must be used in // NewClient produces a new exchange api client. Must be used in
@ -51,18 +55,19 @@ func NewClient(
creds account.M365Config, creds account.M365Config,
co control.Options, co control.Options,
counter *count.Bus, counter *count.Bus,
lim limiters.Limiter,
) (Client, error) { ) (Client, error) {
s, err := NewService(creds, counter) s, err := NewService(creds, counter, lim)
if err != nil { if err != nil {
return Client{}, err return Client{}, err
} }
li, err := newLargeItemService(creds, counter) li, err := newLargeItemService(creds, counter, lim)
if err != nil { if err != nil {
return Client{}, err return Client{}, err
} }
rqr := graph.NewNoTimeoutHTTPWrapper(counter) rqr := graph.NewNoTimeoutHTTPWrapper(counter, lim)
if co.DeltaPageSize < 1 || co.DeltaPageSize > maxDeltaPageSize { if co.DeltaPageSize < 1 || co.DeltaPageSize > maxDeltaPageSize {
co.DeltaPageSize = maxDeltaPageSize co.DeltaPageSize = maxDeltaPageSize
@ -75,6 +80,7 @@ func NewClient(
Requester: rqr, Requester: rqr,
counter: counter, counter: counter,
options: co, options: co,
lim: lim,
} }
return cli, nil return cli, nil
@ -100,6 +106,7 @@ func (c Client) Service(counter *count.Bus) (graph.Servicer, error) {
func NewService( func NewService(
creds account.M365Config, creds account.M365Config,
counter *count.Bus, counter *count.Bus,
lim limiters.Limiter,
opts ...graph.Option, opts ...graph.Option,
) (*graph.Service, error) { ) (*graph.Service, error) {
a, err := graph.CreateAdapter( a, err := graph.CreateAdapter(
@ -107,6 +114,7 @@ func NewService(
creds.AzureClientID, creds.AzureClientID,
creds.AzureClientSecret, creds.AzureClientSecret,
counter, counter,
lim,
opts...) opts...)
if err != nil { if err != nil {
return nil, clues.Wrap(err, "generating graph api adapter") return nil, clues.Wrap(err, "generating graph api adapter")
@ -118,8 +126,9 @@ func NewService(
func newLargeItemService( func newLargeItemService(
creds account.M365Config, creds account.M365Config,
counter *count.Bus, counter *count.Bus,
lim limiters.Limiter,
) (*graph.Service, error) { ) (*graph.Service, error) {
a, err := NewService(creds, counter, graph.NoTimeout()) a, err := NewService(creds, counter, lim, graph.NoTimeout())
if err != nil { if err != nil {
return nil, clues.Wrap(err, "generating no-timeout graph adapter") return nil, clues.Wrap(err, "generating no-timeout graph adapter")
} }