diff --git a/src/go.mod b/src/go.mod index 17ad919ad..4c31c0e20 100644 --- a/src/go.mod +++ b/src/go.mod @@ -68,6 +68,7 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.48.0 // indirect go.opentelemetry.io/otel/metric v1.19.0 // indirect + go.uber.org/goleak v1.3.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 // indirect ) diff --git a/src/go.sum b/src/go.sum index 3e0a5574e..59f275545 100644 --- a/src/go.sum +++ b/src/go.sum @@ -470,6 +470,8 @@ go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1 go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= diff --git a/src/internal/common/limiters/limiter.go b/src/internal/common/limiters/limiter.go new file mode 100644 index 000000000..f842bb957 --- /dev/null +++ b/src/internal/common/limiters/limiter.go @@ -0,0 +1,8 @@ +package limiters + +import "context" + +type Limiter interface { + Wait(ctx context.Context) error + Shutdown() +} diff --git a/src/internal/common/limiters/sliding_window.go b/src/internal/common/limiters/sliding_window.go index 4a29f5fe0..b810adb26 100644 --- a/src/internal/common/limiters/sliding_window.go +++ b/src/internal/common/limiters/sliding_window.go @@ -4,92 +4,161 @@ import ( "context" "sync" "time" + + "github.com/alcionai/clues" ) -type ( - token struct{} - Limiter interface { - Wait(ctx context.Context) error - } -) +type token struct{} -// TODO: Expose interfaces for limiter and window -type window struct { - // TODO: See if we need to store start time. Without it there is no way - // to tell if the ticker is lagging behind ( due to contention from consumers or otherwise). - // Although with our use cases, at max we'd have 10k requests contending with the ticker which - // should be easily doable in fraction of 1 sec. Although we should benchmark this. - // start time.Time - count []int64 +type fixedWindow struct { + count []int } var _ Limiter = &slidingWindow{} type slidingWindow struct { - w time.Duration - slidingInterval time.Duration - capacity int64 - currentInterval int64 - numIntervals int64 - permits chan token - mu sync.Mutex - curr window - prev window + // capacity is the maximum number of requests allowed in a sliding window at + // any given time. + capacity int + // windowSize is the total duration of the sliding window. Limiter will allow + // at most capacity requests in this duration. + windowSize time.Duration + // slideInterval controls how frequently the window slides. Smaller interval + // provides better accuracy at the cost of more frequent sliding & more + // memory usage. + slideInterval time.Duration + + // numIntervals is the number of intervals in the window. Calculated as + // windowSize / slideInterval. + numIntervals int + // currentInterval tracks the current slide interval + currentInterval int + + // Each request acquires a token from the permits channel. If the channel + // is empty, the request is blocked until a permit is available or if the + // context is cancelled. + permits chan token + + // curr and prev are fixed windows of size windowSize. Each window contains + // a slice of intervals which hold a count of the number of tokens granted + // during that interval. + curr fixedWindow + prev fixedWindow + + // mu synchronizes access to the curr and prev windows + mu sync.Mutex + // stopTimer stops the recurring slide timer + stopTimer chan struct{} } -// slidingInterval controls degree of movement of the sliding window from left to right -// Smaller slidingInterval means more frequent movement of the sliding window. -// TODO: Introduce an option to control token refresh frequency. Otherwise, if the sliding interval is -// large, it may slow down the token refresh rate. Not implementing this for simplicity, since for our -// use cases we are going to have a sliding interval of 1 sec which is good enough. -func NewLimiter(w time.Duration, slidingInterval time.Duration, capacity int64) Limiter { - ni := int64(w / slidingInterval) +func NewSlidingWindowLimiter( + windowSize, slideInterval time.Duration, + capacity int, +) (Limiter, error) { + if err := validate(windowSize, slideInterval, capacity); err != nil { + return nil, err + } - sw := &slidingWindow{ - w: w, - slidingInterval: slidingInterval, - capacity: capacity, - permits: make(chan token, capacity), - numIntervals: ni, - prev: window{ - count: make([]int64, ni), + ni := int(windowSize / slideInterval) + + s := &slidingWindow{ + windowSize: windowSize, + slideInterval: slideInterval, + capacity: capacity, + permits: make(chan token, capacity), + numIntervals: ni, + prev: fixedWindow{ + count: make([]int, ni), }, - curr: window{ - count: make([]int64, ni), + curr: fixedWindow{ + count: make([]int, ni), }, currentInterval: -1, + stopTimer: make(chan struct{}), } - // Initialize - sw.nextInterval() + s.initialize() - // Move the sliding window forward every slidingInterval - // TODO: fix leaking goroutine - go sw.run() - - // Prefill permits - for i := int64(0); i < capacity; i++ { - sw.permits <- token{} - } - - return sw + return s, nil } -// TODO: Implement stopping the ticker -func (s *slidingWindow) run() { - ticker := time.NewTicker(s.slidingInterval) +// Wait blocks a request until a token is available or the context is cancelled. +// TODO(pandeyabs): Implement WaitN. +func (s *slidingWindow) Wait(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-s.permits: + s.mu.Lock() + defer s.mu.Unlock() - for range ticker.C { - s.slide() + s.curr.count[s.currentInterval]++ + } + + return nil +} + +// Shutdown cleans up the slide goroutine. If shutdown is not called, the slide +// goroutine will continue to run until the program exits. +func (s *slidingWindow) Shutdown() { + select { + case s.stopTimer <- struct{}{}: + default: } } -func (s *slidingWindow) slide() { - // Slide into the next interval +// initialize starts the slide goroutine and prefills tokens to full capacity. +func (s *slidingWindow) initialize() { + // Ok to not hold the mutex here since nothing else is running yet. s.nextInterval() - // Remove permits from the previous window - for i := int64(0); i < s.prev.count[s.currentInterval]; i++ { + // Start a goroutine which runs every slideInterval. This goroutine will + // continue to run until the program exits or until Shutdown is called. + go func() { + ticker := time.NewTicker(s.slideInterval) + + for { + select { + case <-ticker.C: + s.slide() + case <-s.stopTimer: + ticker.Stop() + return + } + } + }() + + // Prefill permits to allow tokens to be granted immediately + for i := int(0); i < s.capacity; i++ { + s.permits <- token{} + } +} + +// nextInterval increments the current interval and slides the fixed +// windows if needed. Should be called with the mutex held. +func (s *slidingWindow) nextInterval() { + // Increment current interval + s.currentInterval = (s.currentInterval + 1) % s.numIntervals + + // Slide the fixed windows if windowSize time has elapsed. + if s.currentInterval == 0 { + s.prev = s.curr + s.curr = fixedWindow{ + count: make([]int, s.numIntervals), + } + } +} + +// slide moves the window forward by one interval. It reclaims tokens from the +// interval that we slid past and adds them back to available permits. If the +// permits are already at capacity, excess tokens are discarded. +func (s *slidingWindow) slide() { + s.mu.Lock() + defer s.mu.Unlock() + + s.nextInterval() + + for i := int(0); i < s.prev.count[s.currentInterval]; i++ { select { case s.permits <- token{}: default: @@ -99,32 +168,30 @@ func (s *slidingWindow) slide() { } } -// next increments the current interval and resets the current window if needed -func (s *slidingWindow) nextInterval() { - s.mu.Lock() - // Increment current interval - s.currentInterval = (s.currentInterval + 1) % s.numIntervals - - // If it's the first interval, move curr window to prev window and reset curr window. - if s.currentInterval == 0 { - s.prev = s.curr - s.curr = window{ - count: make([]int64, s.numIntervals), - } +func validate( + windowSize, slideInterval time.Duration, + capacity int, +) error { + if windowSize <= 0 { + return clues.New("invalid window size") } - s.mu.Unlock() -} + if slideInterval <= 0 { + return clues.New("invalid slide interval") + } -// TODO: Implement WaitN -func (s *slidingWindow) Wait(ctx context.Context) error { - <-s.permits + // Allow capacity to be 0 for testing purposes + if capacity < 0 { + return clues.New("invalid window capacity") + } - // Acquire mutex and increment current interval's count - s.mu.Lock() - defer s.mu.Unlock() + if windowSize < slideInterval { + return clues.New("window too small to fit slide interval") + } - s.curr.count[s.currentInterval]++ + if windowSize%slideInterval != 0 { + return clues.New("window not divisible by slide interval") + } return nil } diff --git a/src/internal/common/limiters/sliding_window_test.go b/src/internal/common/limiters/sliding_window_test.go index 60903c5a0..9ee89994f 100644 --- a/src/internal/common/limiters/sliding_window_test.go +++ b/src/internal/common/limiters/sliding_window_test.go @@ -1,42 +1,249 @@ package limiters import ( + "context" "fmt" + "math/rand" "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "go.uber.org/goleak" + "github.com/alcionai/corso/src/internal/tester" ) -func BenchmarkSlidingWindowLimiter(b *testing.B) { - // 1 second window, 1 millisecond sliding interval, 1000 token capacity (1k per sec) - limiter := NewLimiter(1*time.Second, 1*time.Millisecond, 1000) - // If the allowed rate is 1k per sec, 4k goroutines should take 3.xx sec - numGoroutines := 4000 +type SlidingWindowUnitTestSuite struct { + tester.Suite +} - ctx, flush := tester.NewContext(b) +func TestSlidingWindowLimiterSuite(t *testing.T) { + suite.Run(t, &SlidingWindowUnitTestSuite{Suite: tester.NewUnitSuite(t)}) +} + +func (suite *SlidingWindowUnitTestSuite) TestWaitBasic() { + var ( + t = suite.T() + windowSize = 1 * time.Second + // Assume slide interval is equal to window size for simplicity. + slideInterval = 1 * time.Second + capacity = 100 + startTime = time.Now() + numRequests = 3 * capacity + wg sync.WaitGroup + mu sync.Mutex + intervalToCount = make(map[time.Duration]int) + ) + + defer goleak.VerifyNone(t) + + ctx, flush := tester.NewContext(t) defer flush() - var wg sync.WaitGroup + s, err := NewSlidingWindowLimiter(windowSize, slideInterval, capacity) + require.NoError(t, err) - b.ResetTimer() - b.StartTimer() + defer s.Shutdown() - for i := 0; i < numGoroutines; i++ { + // Check if all tokens are available for use post initialization. + require.Equal(t, capacity, len(s.(*slidingWindow).permits)) + + // Make concurrent requests to the limiter + for i := 0; i < numRequests; i++ { wg.Add(1) go func() { defer wg.Done() - _ = limiter.Wait(ctx) + err := s.Wait(ctx) + require.NoError(t, err) + + // Number of seconds since startTime + bucket := time.Since(startTime).Truncate(windowSize) + + mu.Lock() + intervalToCount[bucket]++ + mu.Unlock() }() } wg.Wait() - b.StopTimer() - totalDuration := b.Elapsed() - - fmt.Printf("Total time taken: %v\n", totalDuration) + // Verify that number of requests allowed in each window is less than or equal + // to window capacity + for _, c := range intervalToCount { + require.True(t, c <= capacity, "count: %d, capacity: %d", c, capacity) + } +} + +func (suite *SlidingWindowUnitTestSuite) TestWaitSliding() { + var ( + t = suite.T() + windowSize = 1 * time.Second + slideInterval = 10 * time.Millisecond + capacity = 100 + // Test will run for duration of 2 windowSize. + numRequests = 2 * capacity + wg sync.WaitGroup + ) + + defer goleak.VerifyNone(t) + + ctx, flush := tester.NewContext(t) + defer flush() + + s, err := NewSlidingWindowLimiter(windowSize, slideInterval, capacity) + require.NoError(t, err) + + defer s.Shutdown() + + // Make concurrent requests to the limiter + for i := 0; i < numRequests; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + // Sleep for a random duration to spread out requests over multiple slide + // intervals & windows, so that we can test the sliding window logic better. + // Without this, the requests will be bunched up in the very first intervals + // of the 2 windows. Rest of the intervals will be empty. + time.Sleep(time.Duration(rand.Intn(1500)) * time.Millisecond) + + err := s.Wait(ctx) + require.NoError(t, err) + }() + } + wg.Wait() + + // Verify that number of requests allowed in each window is less than or equal + // to window capacity + sw := s.(*slidingWindow) + data := append(sw.prev.count, sw.curr.count...) + + sums := slidingSum(data, sw.numIntervals) + + for _, sum := range sums { + fmt.Printf("sum: %d\n", sum) + require.True(t, sum <= capacity, "sum: %d, capacity: %d", sum, capacity) + } +} + +func (suite *SlidingWindowUnitTestSuite) TestContextCancellation() { + var ( + t = suite.T() + windowSize = 100 * time.Millisecond + slideInterval = 10 * time.Millisecond + wg sync.WaitGroup + ) + + defer goleak.VerifyNone(t) + + ctx, flush := tester.NewContext(t) + defer flush() + + // Initialize limiter with capacity = 0 to test context cancellations. + s, err := NewSlidingWindowLimiter(windowSize, slideInterval, 0) + require.NoError(t, err) + + defer s.Shutdown() + + ctx, cancel := context.WithTimeout(ctx, 2*windowSize) + defer cancel() + + wg.Add(1) + + go func() { + defer wg.Done() + + err := s.Wait(ctx) + require.Equal(t, context.DeadlineExceeded, err) + }() + + wg.Wait() +} + +func (suite *SlidingWindowUnitTestSuite) TestNewSlidingWindowLimiter() { + tests := []struct { + name string + windowSize time.Duration + slideInterval time.Duration + capacity int + expectErr assert.ErrorAssertionFunc + }{ + { + name: "Invalid window size", + windowSize: 0, + slideInterval: 10 * time.Millisecond, + capacity: 100, + expectErr: assert.Error, + }, + { + name: "Invalid slide interval", + windowSize: 100 * time.Millisecond, + slideInterval: 0, + capacity: 100, + expectErr: assert.Error, + }, + { + name: "Slide interval > window size", + windowSize: 10 * time.Millisecond, + slideInterval: 100 * time.Millisecond, + capacity: 100, + expectErr: assert.Error, + }, + { + name: "Invalid capacity", + windowSize: 100 * time.Millisecond, + slideInterval: 10 * time.Millisecond, + capacity: -1, + expectErr: assert.Error, + }, + { + name: "Valid parameters", + windowSize: 100 * time.Millisecond, + slideInterval: 10 * time.Millisecond, + capacity: 100, + expectErr: assert.NoError, + }, + } + + for _, test := range tests { + suite.Run(test.name, func() { + t := suite.T() + + s, err := NewSlidingWindowLimiter( + test.windowSize, + test.slideInterval, + test.capacity) + test.expectErr(t, err) + + if s != nil { + s.Shutdown() + } + }) + } +} + +func slidingSum(data []int, w int) []int { + var ( + sum = 0 + res = make([]int, len(data)-w+1) + ) + + for i := 0; i < w; i++ { + sum += data[i] + } + + res[0] = sum + + for i := 1; i < len(data)-w+1; i++ { + sum = sum - data[i-1] + data[i+w-1] + res[i] = sum + } + + return res } diff --git a/src/internal/m365/graph/concurrency_middleware.go b/src/internal/m365/graph/concurrency_middleware.go index bec525558..60762ca32 100644 --- a/src/internal/m365/graph/concurrency_middleware.go +++ b/src/internal/m365/graph/concurrency_middleware.go @@ -11,7 +11,6 @@ import ( khttp "github.com/microsoft/kiota-http-go" "golang.org/x/time/rate" - "github.com/alcionai/corso/src/internal/common/limiters" "github.com/alcionai/corso/src/pkg/count" "github.com/alcionai/corso/src/pkg/logger" "github.com/alcionai/corso/src/pkg/path" @@ -103,9 +102,6 @@ var ( driveLimiter = rate.NewLimiter(drivePerSecond, driveMaxCap) // also used as the exchange service limiter defaultLimiter = rate.NewLimiter(defaultPerSecond, defaultMaxCap) - - // 10 min window, 1 second sliding interval, 10k capacity - exchangeLimiter = limiters.NewLimiter(10*time.Minute, 1*time.Second, 10000) ) type LimiterCfg struct { @@ -189,12 +185,10 @@ func ctxLimiterConsumption(ctx context.Context, defaultConsumption int) int { // calls-per-minute rate. Otherwise, the call will wait in a queue until // the next token set is available. func QueueRequest(ctx context.Context) { - // limiter := ctxLimiter(ctx) - // consume := ctxLimiterConsumption(ctx, defaultLC) - // if err := limiter.WaitN(ctx, consume); err != nil { - // logger.CtxErr(ctx, err).Error("graph middleware waiting on the limiter") - // } - if err := exchangeLimiter.Wait(ctx); err != nil { + limiter := ctxLimiter(ctx) + consume := ctxLimiterConsumption(ctx, defaultLC) + + if err := limiter.WaitN(ctx, consume); err != nil { logger.CtxErr(ctx, err).Error("graph middleware waiting on the limiter") } }