diff --git a/CHANGELOG.md b/CHANGELOG.md index 8608857df..4860687ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix Exchange cli args for filtering items - Skip OneNote items bigger than 2GB (Graph API prevents us from downloading them) - ParentPath of json output for Exchange calendar now shows names instead of IDs. +- Fixed failure when downloading huge amount of attachments ## [v0.6.1] (beta) - 2023-03-21 diff --git a/src/go.mod b/src/go.mod index 16791cb53..02353d1c5 100644 --- a/src/go.mod +++ b/src/go.mod @@ -43,6 +43,8 @@ require ( github.com/andybalholm/brotli v1.0.4 // indirect github.com/dnaeon/go-vcr v1.2.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/h2non/gock v1.2.0 // indirect + github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/magiconair/properties v1.8.7 // indirect diff --git a/src/go.sum b/src/go.sum index 92f343ee9..fbf659be2 100644 --- a/src/go.sum +++ b/src/go.sum @@ -198,6 +198,10 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 h1:+9834+KizmvFV7pXQGSXQTsaWhq2GjuNUt0aUU0YBYw= +github.com/h2non/gock v1.2.0 h1:K6ol8rfrRkUOefooBC8elXoaNGYkpp7y2qcxGG6BzUE= +github.com/h2non/gock v1.2.0/go.mod h1:tNhoxHYW2W42cYkYb1WqzdbYIieALC99kpYr7rH/BQk= +github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= +github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= github.com/hanwen/go-fuse/v2 v2.2.0 h1:jo5QZYmBLNcl9ovypWaQ5yXMSSV+Ch68xoC3rtZvvBM= github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= @@ -303,6 +307,7 @@ github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3P github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= +github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= diff --git a/src/internal/connector/exchange/api/api.go b/src/internal/connector/exchange/api/api.go index 05e90728b..444251d6b 100644 --- a/src/internal/connector/exchange/api/api.go +++ b/src/internal/connector/exchange/api/api.go @@ -54,20 +54,20 @@ type GraphRetrievalFunc func( type Client struct { Credentials account.M365Config - // The stable service is re-usable for any non-paged request. + // The Stable service is re-usable for any non-paged request. // This allows us to maintain performance across async requests. - stable graph.Servicer + Stable graph.Servicer - // The largeItem graph servicer is configured specifically for + // The LargeItem graph servicer is configured specifically for // downloading large items. Specifically for use when handling // attachments, and for no other use. - largeItem graph.Servicer + LargeItem graph.Servicer } // NewClient produces a new exchange api client. Must be used in // place of creating an ad-hoc client struct. func NewClient(creds account.M365Config) (Client, error) { - s, err := newService(creds) + s, err := NewService(creds) if err != nil { return Client{}, err } @@ -84,33 +84,30 @@ func NewClient(creds account.M365Config) (Client, error) { // requests instead of the client's stable service, so that in-flight state // within the adapter doesn't get clobbered func (c Client) service() (*graph.Service, error) { - s, err := newService(c.Credentials) + s, err := NewService(c.Credentials) return s, err } -func newService(creds account.M365Config) (*graph.Service, error) { +func NewService(creds account.M365Config, opts ...graph.Option) (*graph.Service, error) { a, err := graph.CreateAdapter( creds.AzureTenantID, creds.AzureClientID, - creds.AzureClientSecret) + creds.AzureClientSecret, + opts...) if err != nil { - return nil, clues.Wrap(err, "generating no-timeout graph adapter") + return nil, clues.Wrap(err, "generating graph adapter") } return graph.NewService(a), nil } func newLargeItemService(creds account.M365Config) (*graph.Service, error) { - a, err := graph.CreateAdapter( - creds.AzureTenantID, - creds.AzureClientID, - creds.AzureClientSecret, - graph.NoTimeout()) + a, err := NewService(creds, graph.NoTimeout()) if err != nil { return nil, clues.Wrap(err, "generating no-timeout graph adapter") } - return graph.NewService(a), nil + return a, nil } // --------------------------------------------------------------------------- diff --git a/src/internal/connector/exchange/api/contacts.go b/src/internal/connector/exchange/api/contacts.go index f4f768519..de5c6ca60 100644 --- a/src/internal/connector/exchange/api/contacts.go +++ b/src/internal/connector/exchange/api/contacts.go @@ -47,7 +47,7 @@ func (c Contacts) CreateContactFolder( temp := folderName requestBody.SetDisplayName(&temp) - mdl, err := c.stable.Client().UsersById(user).ContactFolders().Post(ctx, requestBody, nil) + mdl, err := c.Stable.Client().UsersById(user).ContactFolders().Post(ctx, requestBody, nil) if err != nil { return nil, graph.Wrap(ctx, err, "creating contact folder") } @@ -62,7 +62,7 @@ func (c Contacts) DeleteContainer( ) error { // deletes require unique http clients // https://github.com/alcionai/corso/issues/2707 - srv, err := newService(c.Credentials) + srv, err := NewService(c.Credentials) if err != nil { return graph.Stack(ctx, err) } @@ -81,7 +81,7 @@ func (c Contacts) GetItem( user, itemID string, _ *fault.Bus, // no attachments to iterate over, so this goes unused ) (serialization.Parsable, *details.ExchangeInfo, error) { - cont, err := c.stable.Client().UsersById(user).ContactsById(itemID).Get(ctx, nil) + cont, err := c.Stable.Client().UsersById(user).ContactsById(itemID).Get(ctx, nil) if err != nil { return nil, nil, graph.Stack(ctx, err) } @@ -98,7 +98,7 @@ func (c Contacts) GetContainerByID( return nil, graph.Wrap(ctx, err, "setting contact folder options") } - resp, err := c.stable.Client().UsersById(userID).ContactFoldersById(dirID).Get(ctx, ofcf) + resp, err := c.Stable.Client().UsersById(userID).ContactFoldersById(dirID).Get(ctx, ofcf) if err != nil { return nil, graph.Stack(ctx, err) } diff --git a/src/internal/connector/exchange/api/events.go b/src/internal/connector/exchange/api/events.go index dfa4d8541..c63e6d458 100644 --- a/src/internal/connector/exchange/api/events.go +++ b/src/internal/connector/exchange/api/events.go @@ -48,7 +48,7 @@ func (c Events) CreateCalendar( requestbody := models.NewCalendar() requestbody.SetName(&calendarName) - mdl, err := c.stable.Client().UsersById(user).Calendars().Post(ctx, requestbody, nil) + mdl, err := c.Stable.Client().UsersById(user).Calendars().Post(ctx, requestbody, nil) if err != nil { return nil, graph.Wrap(ctx, err, "creating calendar") } @@ -64,7 +64,7 @@ func (c Events) DeleteContainer( ) error { // deletes require unique http clients // https://github.com/alcionai/corso/issues/2707 - srv, err := newService(c.Credentials) + srv, err := NewService(c.Credentials) if err != nil { return graph.Stack(ctx, err) } @@ -110,7 +110,7 @@ func (c Events) GetItem( event models.Eventable ) - event, err = c.stable.Client().UsersById(user).EventsById(itemID).Get(ctx, nil) + event, err = c.Stable.Client().UsersById(user).EventsById(itemID).Get(ctx, nil) if err != nil { return nil, nil, graph.Stack(ctx, err) } @@ -122,7 +122,7 @@ func (c Events) GetItem( }, } - attached, err := c.largeItem. + attached, err := c.LargeItem. Client(). UsersById(user). EventsById(itemID). diff --git a/src/internal/connector/exchange/api/mail.go b/src/internal/connector/exchange/api/mail.go index 606c1987c..860d89efb 100644 --- a/src/internal/connector/exchange/api/mail.go +++ b/src/internal/connector/exchange/api/mail.go @@ -48,7 +48,7 @@ func (c Mail) CreateMailFolder( requestBody.SetDisplayName(&folder) requestBody.SetIsHidden(&isHidden) - mdl, err := c.stable.Client().UsersById(user).MailFolders().Post(ctx, requestBody, nil) + mdl, err := c.Stable.Client().UsersById(user).MailFolders().Post(ctx, requestBody, nil) if err != nil { return nil, graph.Wrap(ctx, err, "creating mail folder") } @@ -91,7 +91,7 @@ func (c Mail) DeleteContainer( ) error { // deletes require unique http clients // https://github.com/alcionai/corso/issues/2707 - srv, err := newService(c.Credentials) + srv, err := NewService(c.Credentials) if err != nil { return graph.Stack(ctx, err) } @@ -133,31 +133,82 @@ func (c Mail) GetItem( user, itemID string, errs *fault.Bus, ) (serialization.Parsable, *details.ExchangeInfo, error) { - mail, err := c.stable.Client().UsersById(user).MessagesById(itemID).Get(ctx, nil) + mail, err := c.Stable.Client().UsersById(user).MessagesById(itemID).Get(ctx, nil) if err != nil { return nil, nil, graph.Stack(ctx, err) } - if ptr.Val(mail.GetHasAttachments()) || HasAttachments(mail.GetBody()) { - options := &users.ItemMessagesItemAttachmentsRequestBuilderGetRequestConfiguration{ - QueryParameters: &users.ItemMessagesItemAttachmentsRequestBuilderGetQueryParameters{ + if !ptr.Val(mail.GetHasAttachments()) && !HasAttachments(mail.GetBody()) { + return mail, MailInfo(mail), nil + } + + options := &users.ItemMessagesItemAttachmentsRequestBuilderGetRequestConfiguration{ + QueryParameters: &users.ItemMessagesItemAttachmentsRequestBuilderGetQueryParameters{ + Expand: []string{"microsoft.graph.itemattachment/item"}, + }, + } + + attached, err := c.LargeItem. + Client(). + UsersById(user). + MessagesById(itemID). + Attachments(). + Get(ctx, options) + if err == nil { + mail.SetAttachments(attached.GetValue()) + return mail, MailInfo(mail), nil + } + + // A failure can be caused by having a lot of attachments as + // we are trying to fetch the data within the attachments as + // well in the request. We instead fetch all the attachment + // ids and fetch each item individually. + // NOTE: Maybe filter for specific error: + // graph.IsErrTimeout(err) || graph.IsServiceUnavailable(err) + // TODO: Once MS Graph fixes pagination for this, we can + // probably paginate and fetch items. + // https://learn.microsoft.com/en-us/answers/questions/1227026/pagination-not-working-when-fetching-message-attac + logger.CtxErr(ctx, err).Info("fetching all attachments by id") + + // Getting size just to log in case of error + options.QueryParameters.Select = []string{"id", "size"} + + attachments, err := c.LargeItem. + Client(). + UsersById(user). + MessagesById(itemID). + Attachments(). + Get(ctx, options) + if err != nil { + return nil, nil, graph.Wrap(ctx, err, "getting mail attachment ids") + } + + atts := []models.Attachmentable{} + + for _, a := range attachments.GetValue() { + options := &users.ItemMessagesItemAttachmentsAttachmentItemRequestBuilderGetRequestConfiguration{ + QueryParameters: &users.ItemMessagesItemAttachmentsAttachmentItemRequestBuilderGetQueryParameters{ Expand: []string{"microsoft.graph.itemattachment/item"}, }, } - attached, err := c.largeItem. + att, err := c.Stable. Client(). UsersById(user). MessagesById(itemID). - Attachments(). + AttachmentsById(ptr.Val(a.GetId())). Get(ctx, options) if err != nil { - return nil, nil, graph.Wrap(ctx, err, "mail attachment download") + return nil, nil, + graph.Wrap(ctx, err, "getting mail attachment"). + With("attachment_id", ptr.Val(a.GetId()), "attachment_size", ptr.Val(a.GetSize())) } - mail.SetAttachments(attached.GetValue()) + atts = append(atts, att) } + mail.SetAttachments(atts) + return mail, MailInfo(mail), nil } diff --git a/src/internal/connector/exchange/api/mail_test.go b/src/internal/connector/exchange/api/mail_test.go index 505db67d3..2e904ac76 100644 --- a/src/internal/connector/exchange/api/mail_test.go +++ b/src/internal/connector/exchange/api/mail_test.go @@ -1,15 +1,26 @@ -package api +package api_test import ( + "encoding/json" "testing" "time" + "github.com/alcionai/clues" + "github.com/h2non/gock" + "github.com/microsoft/kiota-abstractions-go/serialization" + kjson "github.com/microsoft/kiota-serialization-json-go" "github.com/microsoftgraph/msgraph-sdk-go/models" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/internal/connector/exchange/api" + "github.com/alcionai/corso/src/internal/connector/exchange/api/mock" "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/backup/details" + "github.com/alcionai/corso/src/pkg/fault" ) type MailAPIUnitSuite struct { @@ -141,7 +152,205 @@ func (suite *MailAPIUnitSuite) TestMailInfo() { for _, tt := range tests { suite.Run(tt.name, func() { msg, expected := tt.msgAndRP() - assert.Equal(suite.T(), expected, MailInfo(msg)) + assert.Equal(suite.T(), expected, api.MailInfo(msg)) + }) + } +} + +type MailAPIE2ESuite struct { + tester.Suite + credentials account.M365Config + ac api.Client + user string +} + +// We do end up mocking the actual request, but creating the rest +// similar to E2E suite +func TestMailAPIE2ESuite(t *testing.T) { + suite.Run(t, &MailAPIE2ESuite{ + Suite: tester.NewIntegrationSuite( + t, + [][]string{tester.M365AcctCredEnvs}, + ), + }) +} + +func (suite *MailAPIE2ESuite) SetupSuite() { + t := suite.T() + + a := tester.NewM365Account(t) + m365, err := a.M365Config() + require.NoError(t, err, clues.ToCore(err)) + + suite.credentials = m365 + suite.ac, err = mock.NewClient(m365) + require.NoError(t, err, clues.ToCore(err)) + + suite.user = tester.M365UserID(suite.T()) +} + +func getJSONObject(t *testing.T, thing serialization.Parsable) map[string]interface{} { + sw := kjson.NewJsonSerializationWriter() + + err := sw.WriteObjectValue("", thing) + require.NoError(t, err, "serialize") + + content, err := sw.GetSerializedContent() + require.NoError(t, err, "serialize") + + var out map[string]interface{} + err = json.Unmarshal([]byte(content), &out) + require.NoError(t, err, "unmarshall") + + return out +} + +func (suite *MailAPIE2ESuite) TestHugeAttachmentListDownload() { + mid := "fake-message-id" + aid := "fake-attachment-id" + + tests := []struct { + name string + setupf func() + attachmentCount int + expect assert.ErrorAssertionFunc + }{ + { + name: "no attachments", + setupf: func() { + mitem := models.NewMessage() + mitem.SetId(&mid) + + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid). + Reply(200). + JSON(getJSONObject(suite.T(), mitem)) + }, + expect: assert.NoError, + }, + { + name: "fetch with attachment", + setupf: func() { + mitem := models.NewMessage() + mitem.SetId(&mid) + mitem.SetHasAttachments(ptr.To(true)) + + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid). + Reply(200). + JSON(getJSONObject(suite.T(), mitem)) + + atts := models.NewAttachmentCollectionResponse() + aitem := models.NewAttachment() + atts.SetValue([]models.Attachmentable{aitem}) + + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid + "/attachments"). + Reply(200). + JSON(getJSONObject(suite.T(), atts)) + }, + attachmentCount: 1, + expect: assert.NoError, + }, + { + name: "fetch individual attachment", + setupf: func() { + truthy := true + mitem := models.NewMessage() + mitem.SetId(&mid) + mitem.SetHasAttachments(&truthy) + + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid). + Reply(200). + JSON(getJSONObject(suite.T(), mitem)) + + atts := models.NewAttachmentCollectionResponse() + aitem := models.NewAttachment() + aitem.SetId(&aid) + + asize := int32(200) + aitem.SetSize(&asize) + + atts.SetValue([]models.Attachmentable{aitem}) + + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid + "/attachments"). + Reply(503) + + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid + "/attachments"). + Reply(200). + JSON(getJSONObject(suite.T(), atts)) + + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid + "/attachments/" + aid). + Reply(200). + JSON(getJSONObject(suite.T(), aitem)) + }, + attachmentCount: 1, + expect: assert.NoError, + }, + { + name: "fetch multiple individual attachments", + setupf: func() { + truthy := true + mitem := models.NewMessage() + mitem.SetId(&mid) + mitem.SetHasAttachments(&truthy) + + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid). + Reply(200). + JSON(getJSONObject(suite.T(), mitem)) + + atts := models.NewAttachmentCollectionResponse() + aitem := models.NewAttachment() + aitem.SetId(&aid) + + asize := int32(200) + aitem.SetSize(&asize) + + atts.SetValue([]models.Attachmentable{aitem, aitem, aitem, aitem, aitem}) + + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid + "/attachments"). + Reply(503) + + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid + "/attachments"). + Reply(200). + JSON(getJSONObject(suite.T(), atts)) + + for i := 0; i < 5; i++ { + gock.New("https://graph.microsoft.com"). + Get("/v1.0/users/user/messages/" + mid + "/attachments/" + aid). + Reply(200). + JSON(getJSONObject(suite.T(), aitem)) + } + }, + attachmentCount: 5, + expect: assert.NoError, + }, + } + + for _, tt := range tests { + suite.Run(tt.name, func() { + ctx, flush := tester.NewContext() + defer flush() + + defer gock.Off() + tt.setupf() + + item, _, err := suite.ac.Mail().GetItem(ctx, "user", mid, fault.New(true)) + tt.expect(suite.T(), err) + + it, ok := item.(models.Messageable) + require.True(suite.T(), ok, "convert to messageable") + + assert.Equal(suite.T(), *it.GetId(), mid) + assert.Equal(suite.T(), tt.attachmentCount, len(it.GetAttachments()), "attachment count") + assert.True(suite.T(), gock.IsDone(), "made all requests") }) } } diff --git a/src/internal/connector/exchange/api/mock/mail.go b/src/internal/connector/exchange/api/mock/mail.go new file mode 100644 index 000000000..43f6f8d5c --- /dev/null +++ b/src/internal/connector/exchange/api/mock/mail.go @@ -0,0 +1,43 @@ +package mock + +import ( + "github.com/alcionai/clues" + + "github.com/alcionai/corso/src/internal/connector/exchange/api" + "github.com/alcionai/corso/src/internal/connector/graph" + "github.com/alcionai/corso/src/internal/connector/graph/mock" + "github.com/alcionai/corso/src/pkg/account" +) + +func NewService(creds account.M365Config, opts ...graph.Option) (*graph.Service, error) { + a, err := mock.CreateAdapter( + creds.AzureTenantID, + creds.AzureClientID, + creds.AzureClientSecret, + opts...) + if err != nil { + return nil, clues.Wrap(err, "generating graph adapter") + } + + return graph.NewService(a), nil +} + +// NewClient produces a new exchange api client that can be +// mocked using gock. +func NewClient(creds account.M365Config) (api.Client, error) { + s, err := NewService(creds) + if err != nil { + return api.Client{}, err + } + + li, err := NewService(creds, graph.NoTimeout()) + if err != nil { + return api.Client{}, err + } + + return api.Client{ + Credentials: creds, + Stable: s, + LargeItem: li, + }, nil +} diff --git a/src/internal/connector/graph/mock/service.go b/src/internal/connector/graph/mock/service.go new file mode 100644 index 000000000..9a2a9b292 --- /dev/null +++ b/src/internal/connector/graph/mock/service.go @@ -0,0 +1,31 @@ +package mock + +import ( + "github.com/h2non/gock" + msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go" + + "github.com/alcionai/corso/src/internal/connector/graph" +) + +// CreateAdapter is similar to graph.CreateAdapter, but with option to +// enable interceptions via gock to make it mockable. +func CreateAdapter( + tenant, client, secret string, + opts ...graph.Option, +) (*msgraphsdkgo.GraphRequestAdapter, error) { + auth, err := graph.GetAuth(tenant, client, secret) + if err != nil { + return nil, err + } + + httpClient := graph.HTTPClient(opts...) + + // This makes sure that we are able to intercept any requests via + // gock. Only necessary for testing. + gock.InterceptClient(httpClient) + + return msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient( + auth, + nil, nil, + httpClient) +} diff --git a/src/internal/connector/graph/service.go b/src/internal/connector/graph/service.go index 9c5ed7d21..7167f7186 100644 --- a/src/internal/connector/graph/service.go +++ b/src/internal/connector/graph/service.go @@ -106,8 +106,22 @@ func (s Service) Serialize(object serialization.Parsable) ([]byte, error) { // to create *msgraphsdk.GraphServiceClient func CreateAdapter( tenant, client, secret string, - opts ...option, + opts ...Option, ) (*msgraphsdkgo.GraphRequestAdapter, error) { + auth, err := GetAuth(tenant, client, secret) + if err != nil { + return nil, err + } + + httpClient := HTTPClient(opts...) + + return msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient( + auth, + nil, nil, + httpClient) +} + +func GetAuth(tenant string, client string, secret string) (*kauth.AzureIdentityAuthenticationProvider, error) { // Client Provider: Uses Secret for access to tenant-level data cred, err := azidentity.NewClientSecretCredential(tenant, client, secret, nil) if err != nil { @@ -122,12 +136,7 @@ func CreateAdapter( return nil, clues.Wrap(err, "creating azure authentication") } - httpClient := HTTPClient(opts...) - - return msgraphsdkgo.NewGraphRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient( - auth, - nil, nil, - httpClient) + return auth, nil } // HTTPClient creates the httpClient with middlewares and timeout configured @@ -136,7 +145,7 @@ func CreateAdapter( // and consume relatively unbound socket connections. It is important // to centralize this client to be passed downstream where api calls // can utilize it on a per-download basis. -func HTTPClient(opts ...option) *http.Client { +func HTTPClient(opts ...Option) *http.Client { clientOptions := msgraphsdkgo.GetDefaultClientOptions() clientconfig := (&clientConfig{}).populate(opts...) noOfRetries, minRetryDelay := clientconfig.applyMiddlewareConfig() @@ -162,10 +171,10 @@ type clientConfig struct { overrideRetryCount bool } -type option func(*clientConfig) +type Option func(*clientConfig) // populate constructs a clientConfig according to the provided options. -func (c *clientConfig) populate(opts ...option) *clientConfig { +func (c *clientConfig) populate(opts ...Option) *clientConfig { for _, opt := range opts { opt(c) } @@ -203,20 +212,20 @@ func (c *clientConfig) apply(hc *http.Client) { // The resulting client isn't suitable for most queries, due to the // capacity for a call to persist forever. This configuration should // only be used when downloading very large files. -func NoTimeout() option { +func NoTimeout() Option { return func(c *clientConfig) { c.noTimeout = true } } -func MaxRetries(max int) option { +func MaxRetries(max int) Option { return func(c *clientConfig) { c.overrideRetryCount = true c.maxRetries = max } } -func MinimumBackoff(dur time.Duration) option { +func MinimumBackoff(dur time.Duration) Option { return func(c *clientConfig) { c.minDelay = dur } diff --git a/src/internal/connector/graph/service_test.go b/src/internal/connector/graph/service_test.go index d4dbdec59..4565efca1 100644 --- a/src/internal/connector/graph/service_test.go +++ b/src/internal/connector/graph/service_test.go @@ -47,19 +47,19 @@ func (suite *GraphUnitSuite) TestCreateAdapter() { func (suite *GraphUnitSuite) TestHTTPClient() { table := []struct { name string - opts []option + opts []Option check func(*testing.T, *http.Client) }{ { name: "no options", - opts: []option{}, + opts: []Option{}, check: func(t *testing.T, c *http.Client) { assert.Equal(t, defaultHTTPClientTimeout, c.Timeout, "default timeout") }, }, { name: "no timeout", - opts: []option{NoTimeout()}, + opts: []Option{NoTimeout()}, check: func(t *testing.T, c *http.Client) { // FIXME: Change to 0 one upstream issue is fixed assert.Equal(t, time.Duration(48*time.Hour), c.Timeout, "unlimited timeout")