diff --git a/src/internal/common/limiters/limiter.go b/src/internal/common/limiters/limiter.go index e653a7aa1..f866fc41d 100644 --- a/src/internal/common/limiters/limiter.go +++ b/src/internal/common/limiters/limiter.go @@ -4,6 +4,7 @@ import "context" type Limiter interface { Wait(ctx context.Context) error + WaitN(ctx context.Context, n int) error Shutdown() Reset() } diff --git a/src/internal/common/limiters/sliding_window.go b/src/internal/common/limiters/sliding_window.go index d657fa01e..27e21c957 100644 --- a/src/internal/common/limiters/sliding_window.go +++ b/src/internal/common/limiters/sliding_window.go @@ -89,21 +89,35 @@ func NewSlidingWindowLimiter( } // Wait blocks a request until a token is available or the context is cancelled. -// TODO(pandeyabs): Implement WaitN. func (s *slidingWindow) Wait(ctx context.Context) error { + return s.WaitN(ctx, 1) +} + +// WaitN blocks a request until n tokens are available or the context gets +// cancelled. WaitN should be called with n <= capacity otherwise it will block +// forever. +// +// TODO(pandeyabs): Enforce n <= capacity check. Not adding it right now because +// we are relying on capacity = 0 for ctx cancellation test, which would need +// some refactoring. +func (s *slidingWindow) WaitN(ctx context.Context, n int) error { + // Acquire request mutex and slide mutex in order. s.requestMu.Lock() defer s.requestMu.Unlock() - select { - case <-ctx.Done(): - return clues.Stack(ctx.Err()) - case <-s.permits: - s.mu.Lock() - defer s.mu.Unlock() - - s.curr.count[s.currentInterval]++ + for i := 0; i < n; i++ { + select { + case <-ctx.Done(): + return clues.Stack(ctx.Err()) + case <-s.permits: + } } + // Mark n tokens as granted in the current interval. + s.mu.Lock() + defer s.mu.Unlock() + s.curr.count[s.currentInterval] += n + return nil } diff --git a/src/internal/common/limiters/sliding_window_test.go b/src/internal/common/limiters/sliding_window_test.go index def355ea5..f94440d60 100644 --- a/src/internal/common/limiters/sliding_window_test.go +++ b/src/internal/common/limiters/sliding_window_test.go @@ -81,57 +81,83 @@ func (suite *SlidingWindowUnitTestSuite) TestWaitBasic() { } // TestWaitSliding tests the sliding window functionality of the limiter with -// time distributed Wait() calls. -func (suite *SlidingWindowUnitTestSuite) TestWaitSliding() { - var ( - t = suite.T() - windowSize = 1 * time.Second - slideInterval = 10 * time.Millisecond - capacity = 100 - // Test will run for duration of 2 windowSize. - numRequests = 2 * capacity - wg sync.WaitGroup - ) - - defer goleak.VerifyNone(t) - - ctx, flush := tester.NewContext(t) - defer flush() - - s, err := NewSlidingWindowLimiter(windowSize, slideInterval, capacity) - require.NoError(t, err) - - // Make concurrent requests to the limiter - for i := 0; i < numRequests; i++ { - wg.Add(1) - - go func() { - defer wg.Done() - - // Sleep for a random duration to spread out requests over multiple slide - // intervals & windows, so that we can test the sliding window logic better. - // Without this, the requests will be bunched up in the very first intervals - // of the 2 windows. Rest of the intervals will be empty. - time.Sleep(time.Duration(rand.Intn(1500)) * time.Millisecond) - - err := s.Wait(ctx) - require.NoError(t, err) - }() +// time distributed WaitN() calls. +func (suite *SlidingWindowUnitTestSuite) TestWaitNSliding() { + tests := []struct { + Name string + windowSize time.Duration + slideInterval time.Duration + capacity int + numRequests int + n int + }{ + { + Name: "Request 1 token each", + windowSize: 100 * time.Millisecond, + slideInterval: 10 * time.Millisecond, + capacity: 100, + numRequests: 200, + n: 1, + }, + { + Name: "Request N tokens each", + windowSize: 100 * time.Millisecond, + slideInterval: 10 * time.Millisecond, + capacity: 1000, + numRequests: 200, + n: 10, + }, } - wg.Wait() - // Shutdown the ticker before accessing the internal limiter state. - s.Shutdown() + for _, test := range tests { + suite.Run(test.Name, func() { + t := suite.T() - // Verify that number of requests allowed in each window is less than or equal - // to window capacity - sw := s.(*slidingWindow) - data := append(sw.prev.count, sw.curr.count...) + defer goleak.VerifyNone(t) - sums := slidingSums(data, sw.numIntervals) + ctx, flush := tester.NewContext(t) + defer flush() - for _, sum := range sums { - require.True(t, sum <= capacity, "sum: %d, capacity: %d", sum, capacity) + s, err := NewSlidingWindowLimiter(test.windowSize, test.slideInterval, test.capacity) + require.NoError(t, err) + + var wg sync.WaitGroup + + // Make concurrent requests to the limiter + for i := 0; i < test.numRequests; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + // Sleep for a random duration to spread out requests over + // multiple slide intervals & windows, so that we can test + // the sliding window logic better. + // Without this, the requests will be bunched up in the very + // first interval of the 2 windows. Rest of the intervals + // will be empty. + time.Sleep(time.Duration(rand.Intn(1500)) * time.Millisecond) + + err := s.WaitN(ctx, test.n) + require.NoError(t, err) + }() + } + wg.Wait() + + // Shutdown the ticker before accessing the internal limiter state. + s.Shutdown() + + // Verify that number of requests allowed in each window is less than or equal + // to window capacity + sw := s.(*slidingWindow) + data := append(sw.prev.count, sw.curr.count...) + + sums := slidingSums(data, sw.numIntervals) + + for _, sum := range sums { + require.True(t, sum <= test.capacity, "sum: %d, capacity: %d", sum, test.capacity) + } + }) } }