diff --git a/src/internal/connector/data_collections.go b/src/internal/connector/data_collections.go index 9f0f738e5..e66846fef 100644 --- a/src/internal/connector/data_collections.go +++ b/src/internal/connector/data_collections.go @@ -49,6 +49,8 @@ func (gc *GraphConnector) ProduceBackupCollections( diagnostics.Index("service", sels.Service.String())) defer end() + ctx = graph.BindRateLimiterConfig(ctx, graph.LimiterCfg{Service: sels.PathService()}) + // Limit the max number of active requests to graph from this collection. ctrlOpts.Parallelism.ItemFetch = graph.Parallelism(sels.PathService()). ItemOverride(ctx, ctrlOpts.Parallelism.ItemFetch) @@ -194,7 +196,7 @@ func (gc *GraphConnector) ConsumeRestoreCollections( ctx context.Context, backupVersion int, acct account.Account, - selector selectors.Selector, + sels selectors.Selector, dest control.RestoreDestination, opts control.Options, dcs []data.RestoreCollection, @@ -203,6 +205,8 @@ func (gc *GraphConnector) ConsumeRestoreCollections( ctx, end := diagnostics.Span(ctx, "connector:restore") defer end() + ctx = graph.BindRateLimiterConfig(ctx, graph.LimiterCfg{Service: sels.PathService()}) + var ( status *support.ConnectorOperationStatus deets = &details.Builder{} @@ -213,7 +217,7 @@ func (gc *GraphConnector) ConsumeRestoreCollections( return nil, clues.Wrap(err, "malformed azure credentials") } - switch selector.Service { + switch sels.Service { case selectors.ServiceExchange: status, err = exchange.RestoreExchangeDataCollections(ctx, creds, gc.Service, dest, dcs, deets, errs) case selectors.ServiceOneDrive: @@ -221,7 +225,7 @@ func (gc *GraphConnector) ConsumeRestoreCollections( case selectors.ServiceSharePoint: status, err = sharepoint.RestoreCollections(ctx, backupVersion, creds, gc.Service, dest, dcs, deets, errs) default: - err = clues.Wrap(clues.New(selector.Service.String()), "service not supported") + err = clues.Wrap(clues.New(sels.Service.String()), "service not supported") } gc.incrementAwaitingMessages() diff --git a/src/internal/connector/graph/middleware.go b/src/internal/connector/graph/middleware.go index b1d4ad99f..004798cad 100644 --- a/src/internal/connector/graph/middleware.go +++ b/src/internal/connector/graph/middleware.go @@ -20,6 +20,7 @@ import ( "github.com/alcionai/corso/src/internal/common/pii" "github.com/alcionai/corso/src/internal/events" "github.com/alcionai/corso/src/pkg/logger" + "github.com/alcionai/corso/src/pkg/path" ) type nexter interface { @@ -369,18 +370,61 @@ func (mw RetryMiddleware) getRetryDelay( // the volume keeps up after that, we'll always stay between 9000 and 9900 out // of 10k. const ( - perSecond = 15 - maxCap = 900 + defaultPerSecond = 15 + defaultMaxCap = 900 + drivePerSecond = 15 + driveMaxCap = 1100 ) -// Single, global rate limiter at this time. Refinements for method (creates, -// versus reads) or service can come later. -var limiter = rate.NewLimiter(perSecond, maxCap) +var ( + driveLimiter = rate.NewLimiter(defaultPerSecond, defaultMaxCap) + // also used as the exchange service limiter + defaultLimiter = rate.NewLimiter(defaultPerSecond, defaultMaxCap) +) + +type LimiterCfg struct { + Service path.ServiceType +} + +type limiterCfgKey string + +const limiterCfgCtxKey limiterCfgKey = "corsoGraphRateLimiterCfg" + +func ctxLimiter(ctx context.Context) *rate.Limiter { + lc, ok := extractRateLimiterConfig(ctx) + if !ok { + return defaultLimiter + } + + switch lc.Service { + case path.OneDriveService, path.SharePointService: + return driveLimiter + default: + return defaultLimiter + } +} + +func BindRateLimiterConfig(ctx context.Context, lc LimiterCfg) context.Context { + return context.WithValue(ctx, limiterCfgCtxKey, lc) +} + +func extractRateLimiterConfig(ctx context.Context) (LimiterCfg, bool) { + l := ctx.Value(limiterCfgCtxKey) + if l == nil { + return LimiterCfg{}, false + } + + lc, ok := l.(LimiterCfg) + + return lc, ok +} // QueueRequest will allow the request to occur immediately if we're under the // 1k-calls-per-minute rate. Otherwise, the call will wait in a queue until // the next token set is available. func QueueRequest(ctx context.Context) { + limiter := ctxLimiter(ctx) + if err := limiter.Wait(ctx); err != nil { logger.CtxErr(ctx, err).Error("graph middleware waiting on the limiter") } diff --git a/src/internal/connector/graph/middleware_test.go b/src/internal/connector/graph/middleware_test.go index 0874a38f6..6ca660231 100644 --- a/src/internal/connector/graph/middleware_test.go +++ b/src/internal/connector/graph/middleware_test.go @@ -17,10 +17,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "golang.org/x/time/rate" "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/pkg/account" + "github.com/alcionai/corso/src/pkg/path" ) type mwReturns struct { @@ -132,9 +134,9 @@ func (suite *RetryMWIntgSuite) SetupSuite() { func (suite *RetryMWIntgSuite) TestRetryMiddleware_Intercept_byStatusCode() { var ( - uri = "https://graph.microsoft.com" - path = "/v1.0/users/user/messages/foo" - url = uri + path + uri = "https://graph.microsoft.com" + urlPath = "/v1.0/users/user/messages/foo" + url = uri + urlPath ) tests := []struct { @@ -230,3 +232,63 @@ func (suite *RetryMWIntgSuite) TestRetryMiddleware_RetryRequest_resetBodyAfter50 Post(ctx, body, nil) require.NoError(t, err, clues.ToCore(err)) } + +type MiddlewareUnitSuite struct { + tester.Suite +} + +func TestMiddlewareUnitSuite(t *testing.T) { + suite.Run(t, &MiddlewareUnitSuite{Suite: tester.NewUnitSuite(t)}) +} + +func (suite *MiddlewareUnitSuite) TestBindExtractLimiterConfig() { + ctx, flush := tester.NewContext() + defer flush() + + // an unpopulated ctx should produce the default limiter + assert.Equal(suite.T(), defaultLimiter, ctxLimiter(ctx)) + + table := []struct { + name string + service path.ServiceType + expectOK require.BoolAssertionFunc + expectLimiter *rate.Limiter + }{ + { + name: "exchange", + service: path.ExchangeService, + expectLimiter: defaultLimiter, + }, + { + name: "oneDrive", + service: path.OneDriveService, + expectLimiter: driveLimiter, + }, + { + name: "sharePoint", + service: path.SharePointService, + expectLimiter: driveLimiter, + }, + { + name: "unknownService", + service: path.UnknownService, + expectLimiter: defaultLimiter, + }, + { + name: "badService", + service: path.ServiceType(-1), + expectLimiter: defaultLimiter, + }, + } + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + tctx := BindRateLimiterConfig(ctx, LimiterCfg{Service: test.service}) + lc, ok := extractRateLimiterConfig(tctx) + require.True(t, ok, "found rate limiter in ctx") + assert.Equal(t, test.service, lc.Service) + assert.Equal(t, test.expectLimiter, ctxLimiter(tctx)) + }) + } +}