diff --git a/src/internal/connector/exchange/data_collections.go b/src/internal/connector/exchange/data_collections.go index fd9a2b883..8af42aee4 100644 --- a/src/internal/connector/exchange/data_collections.go +++ b/src/internal/connector/exchange/data_collections.go @@ -182,6 +182,9 @@ func DataCollections( categories = map[path.CategoryType]struct{}{} ) + // TODO: Add hidden cli flag to disable this feature + graph.InitializeConcurrencyLimiter(ctrlOpts.Parallelism.ItemFetch) + cdps, err := parseMetadataCollections(ctx, metadata, errs) if err != nil { return nil, nil, err diff --git a/src/internal/connector/graph/concurrency_limiter.go b/src/internal/connector/graph/concurrency_limiter.go new file mode 100644 index 000000000..6fe1ea0cd --- /dev/null +++ b/src/internal/connector/graph/concurrency_limiter.go @@ -0,0 +1,53 @@ +package graph + +import ( + "net/http" + "sync" + + "github.com/alcionai/clues" + khttp "github.com/microsoft/kiota-http-go" +) + +// concurrencyLimiter middleware limits the number of concurrent requests to graph API +type concurrencyLimiter struct { + semaphore chan struct{} +} + +var ( + once sync.Once + concurrencyLim *concurrencyLimiter + maxConcurrentRequests = 4 +) + +func generateConcurrencyLimiter(capacity int) *concurrencyLimiter { + if capacity < 1 || capacity > maxConcurrentRequests { + capacity = maxConcurrentRequests + } + + return &concurrencyLimiter{ + semaphore: make(chan struct{}, capacity), + } +} + +func InitializeConcurrencyLimiter(capacity int) { + once.Do(func() { + concurrencyLim = generateConcurrencyLimiter(capacity) + }) +} + +func (cl *concurrencyLimiter) Intercept( + pipeline khttp.Pipeline, + middlewareIndex int, + req *http.Request, +) (*http.Response, error) { + if cl == nil || cl.semaphore == nil { + return nil, clues.New("nil concurrency limiter") + } + + cl.semaphore <- struct{}{} + defer func() { + <-cl.semaphore + }() + + return pipeline.Next(req, middlewareIndex) +} diff --git a/src/internal/connector/graph/concurrency_limiter_test.go b/src/internal/connector/graph/concurrency_limiter_test.go new file mode 100644 index 000000000..4e7e57606 --- /dev/null +++ b/src/internal/connector/graph/concurrency_limiter_test.go @@ -0,0 +1,117 @@ +package graph + +import ( + "math/rand" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + khttp "github.com/microsoft/kiota-http-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/tester" +) + +type ConcurrencyLimiterUnitTestSuite struct { + tester.Suite +} + +func TestConcurrencyLimiterSuite(t *testing.T) { + suite.Run(t, &ConcurrencyLimiterUnitTestSuite{Suite: tester.NewUnitSuite(t)}) +} + +func (suite *ConcurrencyLimiterUnitTestSuite) TestConcurrencyLimiter() { + t := suite.T() + + maxConcurrentRequests := 4 + cl := generateConcurrencyLimiter(maxConcurrentRequests) + client := khttp.GetDefaultClient(cl) + + // Server side handler to simulate 429s + sem := make(chan struct{}, maxConcurrentRequests) + reqHandler := func(w http.ResponseWriter, r *http.Request) { + select { + case sem <- struct{}{}: + defer func() { + <-sem + }() + + time.Sleep(time.Duration(rand.Intn(50)+50) * time.Millisecond) + w.WriteHeader(http.StatusOK) + + return + default: + w.WriteHeader(http.StatusTooManyRequests) + return + } + } + + ts := httptest.NewServer(http.HandlerFunc(reqHandler)) + defer ts.Close() + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + resp, err := client.Get(ts.URL) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + }() + } + wg.Wait() +} + +func (suite *ConcurrencyLimiterUnitTestSuite) TestInitializeConcurrencyLimiter() { + t := suite.T() + + InitializeConcurrencyLimiter(2) + InitializeConcurrencyLimiter(4) + + assert.Equal(t, cap(concurrencyLim.semaphore), 2, "singleton semaphore capacity changed") +} + +func (suite *ConcurrencyLimiterUnitTestSuite) TestGenerateConcurrencyLimiter() { + tests := []struct { + name string + cap int + expectedCap int + }{ + { + name: "valid capacity", + cap: 2, + expectedCap: 2, + }, + { + name: "zero capacity", + cap: 0, + expectedCap: maxConcurrentRequests, + }, + { + name: "negative capacity", + cap: -1, + expectedCap: maxConcurrentRequests, + }, + { + name: "out of bounds capacity", + cap: 10, + expectedCap: maxConcurrentRequests, + }, + } + + for _, test := range tests { + suite.Run(test.name, func() { + t := suite.T() + + actual := generateConcurrencyLimiter(test.cap) + assert.Equal(t, cap(actual.semaphore), test.expectedCap, + "retrieved semaphore capacity vs expected capacity") + }) + } +} diff --git a/src/internal/connector/graph/service.go b/src/internal/connector/graph/service.go index 044af3ac6..42ef4440c 100644 --- a/src/internal/connector/graph/service.go +++ b/src/internal/connector/graph/service.go @@ -234,7 +234,14 @@ func kiotaMiddlewares( options *msgraphgocore.GraphClientOptions, cc *clientConfig, ) []khttp.Middleware { - return []khttp.Middleware{ + mw := []khttp.Middleware{} + + // Optionally add concurrency limiter middleware if it has been initialized + if concurrencyLim != nil { + mw = append(mw, concurrencyLim) + } + + mw = append(mw, []khttp.Middleware{ msgraphgocore.NewGraphTelemetryHandler(options), &RetryHandler{ MaxRetries: cc.maxRetries, @@ -248,5 +255,7 @@ func kiotaMiddlewares( &LoggingMiddleware{}, &ThrottleControlMiddleware{}, &MetricsMiddleware{}, - } + }...) + + return mw }