diff --git a/CHANGELOG.md b/CHANGELOG.md index 83f3d3ee0..39171e240 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Change file extension of messages export to json to match the content +- SDK consumption of the /services/m365 package has shifted from independent functions to a client-based api. +- SDK consumers can now configure the /services/m365 graph api client configuration when constructing a new m365 client. ### Fixed - Handle OneDrive folders being deleted and recreated midway through a backup diff --git a/src/cli/backup/groups.go b/src/cli/backup/groups.go index 75b22c161..0d9a6e3d3 100644 --- a/src/cli/backup/groups.go +++ b/src/cli/backup/groups.go @@ -156,7 +156,12 @@ func createGroupsCmd(cmd *cobra.Command, args []string) error { // TODO: log/print recoverable errors errs := fault.New(false) - ins, err := m365.GroupsMap(ctx, *acct, errs) + svcCli, err := m365.NewM365Client(ctx, *acct) + if err != nil { + return Only(ctx, clues.Stack(err)) + } + + ins, err := svcCli.GroupsMap(ctx, errs) if err != nil { return Only(ctx, clues.Wrap(err, "Failed to retrieve M365 groups")) } diff --git a/src/cli/backup/sharepoint.go b/src/cli/backup/sharepoint.go index 012243856..0a5c06c04 100644 --- a/src/cli/backup/sharepoint.go +++ b/src/cli/backup/sharepoint.go @@ -157,7 +157,12 @@ func createSharePointCmd(cmd *cobra.Command, args []string) error { // TODO: log/print recoverable errors errs := fault.New(false) - ins, err := m365.SitesMap(ctx, *acct, errs) + svcCli, err := m365.NewM365Client(ctx, *acct) + if err != nil { + return Only(ctx, clues.Stack(err)) + } + + ins, err := svcCli.SitesMap(ctx, errs) if err != nil { return Only(ctx, clues.Wrap(err, "Failed to retrieve M365 sites")) } diff --git a/src/pkg/services/m365/api/client.go b/src/pkg/services/m365/api/client.go index 5b010d9b1..1140b3e85 100644 --- a/src/pkg/services/m365/api/client.go +++ b/src/pkg/services/m365/api/client.go @@ -51,8 +51,9 @@ func NewClient( creds account.M365Config, co control.Options, counter *count.Bus, + opts ...graph.Option, ) (Client, error) { - s, err := NewService(creds, counter) + s, err := NewService(creds, counter, opts...) if err != nil { return Client{}, err } diff --git a/src/pkg/services/m365/groups.go b/src/pkg/services/m365/groups.go index a9682c56c..5c608a661 100644 --- a/src/pkg/services/m365/groups.go +++ b/src/pkg/services/m365/groups.go @@ -8,9 +8,7 @@ import ( "github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/common/ptr" - "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/fault" - "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" ) @@ -29,19 +27,13 @@ type Group struct { } // GroupByID retrieves a specific group. -func GroupByID( +func (c client) GroupByID( ctx context.Context, - acct account.Account, id string, ) (*Group, error) { - ac, err := makeAC(ctx, acct, path.GroupsService) - if err != nil { - return nil, clues.Stack(err) - } - cc := api.CallConfig{} - g, err := ac.Groups().GetByID(ctx, id, cc) + g, err := c.ac.Groups().GetByID(ctx, id, cc) if err != nil { return nil, clues.Stack(err) } @@ -50,10 +42,10 @@ func GroupByID( } // GroupsCompat returns a list of groups in the specified M365 tenant. -func GroupsCompat(ctx context.Context, acct account.Account) ([]*Group, error) { +func (c client) GroupsCompat(ctx context.Context) ([]*Group, error) { errs := fault.New(true) - us, err := Groups(ctx, acct, errs) + us, err := c.Groups(ctx, errs) if err != nil { return nil, err } @@ -62,17 +54,11 @@ func GroupsCompat(ctx context.Context, acct account.Account) ([]*Group, error) { } // Groups returns a list of groups in the specified M365 tenant -func Groups( +func (c client) Groups( ctx context.Context, - acct account.Account, errs *fault.Bus, ) ([]*Group, error) { - ac, err := makeAC(ctx, acct, path.GroupsService) - if err != nil { - return nil, clues.Stack(err) - } - - return getAllGroups(ctx, ac.Groups()) + return getAllGroups(ctx, c.ac.Groups()) } func getAllGroups( @@ -98,18 +84,12 @@ func getAllGroups( return ret, nil } -func SitesInGroup( +func (c client) SitesInGroup( ctx context.Context, - acct account.Account, groupID string, errs *fault.Bus, ) ([]*Site, error) { - ac, err := makeAC(ctx, acct, path.GroupsService) - if err != nil { - return nil, clues.Stack(err) - } - - sites, err := ac.Groups().GetAllSites(ctx, groupID, errs) + sites, err := c.ac.Groups().GetAllSites(ctx, groupID, errs) if err != nil { return nil, clues.Stack(err) } @@ -144,12 +124,11 @@ func parseGroup(ctx context.Context, mg models.Groupable) (*Group, error) { } // GroupsMap retrieves an id-name cache of all groups in the tenant. -func GroupsMap( +func (c client) GroupsMap( ctx context.Context, - acct account.Account, errs *fault.Bus, ) (idname.Cacher, error) { - groups, err := Groups(ctx, acct, errs) + groups, err := c.Groups(ctx, errs) if err != nil { return idname.NewCache(nil), err } diff --git a/src/pkg/services/m365/groups_test.go b/src/pkg/services/m365/groups_test.go index 7b2a1d651..447e99391 100644 --- a/src/pkg/services/m365/groups_test.go +++ b/src/pkg/services/m365/groups_test.go @@ -1,4 +1,4 @@ -package m365_test +package m365 import ( "testing" @@ -11,17 +11,14 @@ import ( "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/internal/tester/tconfig" - "github.com/alcionai/corso/src/pkg/account" - "github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/errs" "github.com/alcionai/corso/src/pkg/fault" - "github.com/alcionai/corso/src/pkg/services/m365" "github.com/alcionai/corso/src/pkg/services/m365/api/graph" ) type GroupsIntgSuite struct { tester.Suite - acct account.Account + cli client } func TestGroupsIntgSuite(t *testing.T) { @@ -38,9 +35,13 @@ func (suite *GroupsIntgSuite) SetupSuite() { ctx, flush := tester.NewContext(t) defer flush() - graph.InitializeConcurrencyLimiter(ctx, true, 4) + acct := tconfig.NewM365Account(t) - suite.acct = tconfig.NewM365Account(t) + var err error + + // will init the concurrency limiter + suite.cli, err = NewM365Client(ctx, acct) + require.NoError(t, err, clues.ToCore(err)) } func (suite *GroupsIntgSuite) TestGroupByID() { @@ -49,11 +50,9 @@ func (suite *GroupsIntgSuite) TestGroupByID() { ctx, flush := tester.NewContext(t) defer flush() - graph.InitializeConcurrencyLimiter(ctx, true, 4) - gid := tconfig.M365TeamID(t) - group, err := m365.GroupByID(ctx, suite.acct, gid) + group, err := suite.cli.GroupByID(ctx, gid) require.NoError(t, err, clues.ToCore(err)) require.NotNil(t, group) @@ -67,11 +66,9 @@ func (suite *GroupsIntgSuite) TestGroupByID_ByEmail() { ctx, flush := tester.NewContext(t) defer flush() - graph.InitializeConcurrencyLimiter(ctx, true, 4) - gid := tconfig.M365TeamID(t) - group, err := m365.GroupByID(ctx, suite.acct, gid) + group, err := suite.cli.GroupByID(ctx, gid) require.NoError(t, err, clues.ToCore(err)) require.NotNil(t, group) @@ -80,7 +77,7 @@ func (suite *GroupsIntgSuite) TestGroupByID_ByEmail() { gemail := tconfig.M365TeamEmail(t) - groupByEmail, err := m365.GroupByID(ctx, suite.acct, gemail) + groupByEmail, err := suite.cli.GroupByID(ctx, gemail) require.NoError(t, err, clues.ToCore(err)) require.NotNil(t, group) @@ -93,9 +90,7 @@ func (suite *GroupsIntgSuite) TestGroupByID_notFound() { ctx, flush := tester.NewContext(t) defer flush() - graph.InitializeConcurrencyLimiter(ctx, true, 4) - - group, err := m365.GroupByID(ctx, suite.acct, uuid.NewString()) + group, err := suite.cli.GroupByID(ctx, uuid.NewString()) require.Nil(t, group) require.ErrorIs(t, err, graph.ErrResourceOwnerNotFound, clues.ToCore(err)) require.True(t, errs.Is(err, errs.ResourceOwnerNotFound)) @@ -107,12 +102,7 @@ func (suite *GroupsIntgSuite) TestGroups() { ctx, flush := tester.NewContext(t) defer flush() - graph.InitializeConcurrencyLimiter(ctx, true, 4) - - groups, err := m365.Groups( - ctx, - suite.acct, - fault.New(true)) + groups, err := suite.cli.Groups(ctx, fault.New(true)) assert.NoError(t, err, clues.ToCore(err)) assert.NotEmpty(t, groups) @@ -137,15 +127,9 @@ func (suite *GroupsIntgSuite) TestSitesInGroup() { ctx, flush := tester.NewContext(t) defer flush() - graph.InitializeConcurrencyLimiter(ctx, true, 4) - gid := tconfig.M365TeamID(t) - sites, err := m365.SitesInGroup( - ctx, - suite.acct, - gid, - fault.New(true)) + sites, err := suite.cli.SitesInGroup(ctx, gid, fault.New(true)) assert.NoError(t, err, clues.ToCore(err)) assert.NotEmpty(t, sites) } @@ -156,12 +140,7 @@ func (suite *GroupsIntgSuite) TestGroupsMap() { ctx, flush := tester.NewContext(t) defer flush() - graph.InitializeConcurrencyLimiter(ctx, true, 4) - - gm, err := m365.GroupsMap( - ctx, - suite.acct, - fault.New(true)) + gm, err := suite.cli.GroupsMap(ctx, fault.New(true)) assert.NoError(t, err, clues.ToCore(err)) assert.NotEmpty(t, gm) @@ -177,44 +156,3 @@ func (suite *GroupsIntgSuite) TestGroupsMap() { }) } } - -func (suite *GroupsIntgSuite) TestGroups_InvalidCredentials() { - table := []struct { - name string - acct func(t *testing.T) account.Account - }{ - { - name: "Invalid Credentials", - acct: func(t *testing.T) account.Account { - a, err := account.NewAccount( - account.ProviderM365, - account.M365Config{ - M365: credentials.M365{ - AzureClientID: "Test", - AzureClientSecret: "without", - }, - AzureTenantID: "data", - }) - require.NoError(t, err, clues.ToCore(err)) - - return a - }, - }, - } - - for _, test := range table { - suite.Run(test.name, func() { - t := suite.T() - - ctx, flush := tester.NewContext(t) - defer flush() - - groups, err := m365.Groups( - ctx, - test.acct(t), - fault.New(true)) - assert.Empty(t, groups, "returned no groups") - assert.NotNil(t, err) - }) - } -} diff --git a/src/pkg/services/m365/m365.go b/src/pkg/services/m365/m365.go index 58d5fad71..a7cfcf1c0 100644 --- a/src/pkg/services/m365/m365.go +++ b/src/pkg/services/m365/m365.go @@ -11,8 +11,22 @@ import ( "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" + "github.com/alcionai/corso/src/pkg/services/m365/api/graph" ) +type client struct { + ac api.Client +} + +func NewM365Client( + ctx context.Context, + acct account.Account, + opts ...graph.Option, +) (client, error) { + ac, err := makeAC(ctx, acct, opts...) + return client{ac}, clues.Stack(err).OrNil() +} + // --------------------------------------------------------------------------- // interfaces & structs // --------------------------------------------------------------------------- @@ -28,9 +42,10 @@ type getAller[T any] interface { func makeAC( ctx context.Context, acct account.Account, - pst path.ServiceType, + opts ...graph.Option, ) (api.Client, error) { - api.InitConcurrencyLimit(ctx, pst) + // exchange service inits a limit to concurrency. + api.InitConcurrencyLimit(ctx, path.ExchangeService) creds, err := acct.M365Config() if err != nil { @@ -45,5 +60,10 @@ func makeAC( return api.Client{}, clues.WrapWC(ctx, err, "constructing api client") } + // run a test to ensure credentials work for the client + if err := cli.Access().GetToken(ctx); err != nil { + return api.Client{}, clues.Wrap(err, "checking client connection") + } + return cli, nil } diff --git a/src/pkg/services/m365/m365_test.go b/src/pkg/services/m365/m365_test.go new file mode 100644 index 000000000..3f5bb2737 --- /dev/null +++ b/src/pkg/services/m365/m365_test.go @@ -0,0 +1,61 @@ +package m365 + +import ( + "testing" + + "github.com/alcionai/clues" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/internal/tester/tconfig" + "github.com/alcionai/corso/src/pkg/account" +) + +type M365IntgSuite struct { + tester.Suite +} + +func TestM365IntgSuite(t *testing.T) { + suite.Run(t, &M365IntgSuite{ + Suite: tester.NewIntegrationSuite( + t, + [][]string{}), + }) +} + +func (suite *userIntegrationSuite) TestNewM365Client() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + _, err := NewM365Client(ctx, tconfig.NewM365Account(t)) + assert.NoError(t, err, clues.ToCore(err)) +} + +func (suite *userIntegrationSuite) TestNewM365Client_invalidCredentials() { + table := []struct { + name string + acct func(t *testing.T) account.Account + }{ + { + name: "Invalid Credentials", + acct: func(t *testing.T) account.Account { + return tconfig.NewFakeM365Account(t) + }, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + _, err := NewM365Client(ctx, test.acct(t)) + assert.Error(t, err, clues.ToCore(err)) + }) + } +} diff --git a/src/pkg/services/m365/sites.go b/src/pkg/services/m365/sites.go index ba19ad685..d5b91ccab 100644 --- a/src/pkg/services/m365/sites.go +++ b/src/pkg/services/m365/sites.go @@ -10,10 +10,8 @@ import ( "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/common/str" "github.com/alcionai/corso/src/internal/common/tform" - "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/logger" - "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" "github.com/alcionai/corso/src/pkg/services/m365/api/graph" ) @@ -52,21 +50,15 @@ type Site struct { } // SiteByID retrieves a specific site. -func SiteByID( +func (c client) SiteByID( ctx context.Context, - acct account.Account, id string, ) (*Site, error) { - ac, err := makeAC(ctx, acct, path.SharePointService) - if err != nil { - return nil, clues.Stack(err) - } - cc := api.CallConfig{ Expand: []string{"drive"}, } - return getSiteByID(ctx, ac.Sites(), id, cc) + return getSiteByID(ctx, c.ac.Sites(), id, cc) } func getSiteByID( @@ -84,13 +76,8 @@ func getSiteByID( } // Sites returns a list of Sites in a specified M365 tenant -func Sites(ctx context.Context, acct account.Account, errs *fault.Bus) ([]*Site, error) { - ac, err := makeAC(ctx, acct, path.SharePointService) - if err != nil { - return nil, clues.Stack(err) - } - - return getAllSites(ctx, ac.Sites()) +func (c client) Sites(ctx context.Context, errs *fault.Bus) ([]*Site, error) { + return getAllSites(ctx, c.ac.Sites()) } func getAllSites( @@ -174,12 +161,11 @@ func ParseSite(ctx context.Context, item models.Siteable) *Site { // SitesMap retrieves all sites in the tenant, and returns two maps: one id-to-webURL, // and one webURL-to-id. -func SitesMap( +func (c client) SitesMap( ctx context.Context, - acct account.Account, errs *fault.Bus, ) (idname.Cacher, error) { - sites, err := Sites(ctx, acct, errs) + sites, err := c.Sites(ctx, errs) if err != nil { return idname.NewCache(nil), err } diff --git a/src/pkg/services/m365/sites_test.go b/src/pkg/services/m365/sites_test.go index a874a3292..3fa68d21c 100644 --- a/src/pkg/services/m365/sites_test.go +++ b/src/pkg/services/m365/sites_test.go @@ -14,8 +14,6 @@ import ( "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/internal/tester/tconfig" - "github.com/alcionai/corso/src/pkg/account" - "github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/services/m365/api" "github.com/alcionai/corso/src/pkg/services/m365/api/graph" @@ -24,6 +22,7 @@ import ( type siteIntegrationSuite struct { tester.Suite + cli client } func TestSiteIntegrationSuite(t *testing.T) { @@ -35,10 +34,18 @@ func TestSiteIntegrationSuite(t *testing.T) { } func (suite *siteIntegrationSuite) SetupSuite() { - ctx, flush := tester.NewContext(suite.T()) + t := suite.T() + + ctx, flush := tester.NewContext(t) defer flush() - graph.InitializeConcurrencyLimiter(ctx, true, 4) + acct := tconfig.NewM365Account(t) + + var err error + + // will init the concurrency limiter + suite.cli, err = NewM365Client(ctx, acct) + require.NoError(t, err, clues.ToCore(err)) } func (suite *siteIntegrationSuite) TestSites() { @@ -47,9 +54,7 @@ func (suite *siteIntegrationSuite) TestSites() { ctx, flush := tester.NewContext(t) defer flush() - acct := tconfig.NewM365Account(t) - - sites, err := Sites(ctx, acct, fault.New(true)) + sites, err := suite.cli.Sites(ctx, fault.New(true)) assert.NoError(t, err, clues.ToCore(err)) assert.NotEmpty(t, sites) @@ -68,16 +73,14 @@ func (suite *siteIntegrationSuite) TestSites_GetByID() { ctx, flush := tester.NewContext(t) defer flush() - acct := tconfig.NewM365Account(t) - - sites, err := Sites(ctx, acct, fault.New(true)) + sites, err := suite.cli.Sites(ctx, fault.New(true)) assert.NoError(t, err, clues.ToCore(err)) assert.NotEmpty(t, sites) for _, s := range sites { suite.Run("site_"+s.ID, func() { t := suite.T() - site, err := SiteByID(ctx, acct, s.ID) + site, err := suite.cli.SiteByID(ctx, s.ID) require.NoError(t, err, clues.ToCore(err)) assert.NotEmpty(t, site.WebURL) assert.NotEmpty(t, site.ID) @@ -86,52 +89,6 @@ func (suite *siteIntegrationSuite) TestSites_GetByID() { } } -func (suite *siteIntegrationSuite) TestSites_InvalidCredentials() { - table := []struct { - name string - acct func(t *testing.T) account.Account - }{ - { - name: "Invalid Credentials", - acct: func(t *testing.T) account.Account { - a, err := account.NewAccount( - account.ProviderM365, - account.M365Config{ - M365: credentials.M365{ - AzureClientID: "Test", - AzureClientSecret: "without", - }, - AzureTenantID: "data", - }) - require.NoError(t, err, clues.ToCore(err)) - - return a - }, - }, - { - name: "Empty Credentials", - acct: func(t *testing.T) account.Account { - // intentionally swallowing the error here - a, _ := account.NewAccount(account.ProviderM365) - return a - }, - }, - } - - for _, test := range table { - suite.Run(test.name, func() { - t := suite.T() - - ctx, flush := tester.NewContext(t) - defer flush() - - sites, err := Sites(ctx, test.acct(t), fault.New(true)) - assert.Empty(t, sites, "returned some sites") - assert.NotNil(t, err) - }) - } -} - // --------------------------------------------------------------------------- // Unit // --------------------------------------------------------------------------- diff --git a/src/pkg/services/m365/users.go b/src/pkg/services/m365/users.go index ac54a9c32..cd8df7743 100644 --- a/src/pkg/services/m365/users.go +++ b/src/pkg/services/m365/users.go @@ -10,9 +10,7 @@ import ( "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/m365/service/exchange" "github.com/alcionai/corso/src/internal/m365/service/onedrive" - "github.com/alcionai/corso/src/pkg/account" "github.com/alcionai/corso/src/pkg/fault" - "github.com/alcionai/corso/src/pkg/path" "github.com/alcionai/corso/src/pkg/services/m365/api" ) @@ -27,10 +25,10 @@ type UserNoInfo struct { // UsersCompatNoInfo returns a list of users in the specified M365 tenant. // TODO(pandeyabs): Rename this to Users now that `Info` support has been removed. Would // need corresponding changes in SDK consumers. -func UsersCompatNoInfo(ctx context.Context, acct account.Account) ([]*UserNoInfo, error) { +func (c client) UsersCompatNoInfo(ctx context.Context) ([]*UserNoInfo, error) { errs := fault.New(true) - us, err := usersNoInfo(ctx, acct, errs) + us, err := usersNoInfo(ctx, c.ac, errs) if err != nil { return nil, err } @@ -40,46 +38,29 @@ func UsersCompatNoInfo(ctx context.Context, acct account.Account) ([]*UserNoInfo // UserHasMailbox returns true if the user has an exchange mailbox enabled // false otherwise, and a nil pointer and an error in case of error -func UserHasMailbox(ctx context.Context, acct account.Account, userID string) (bool, error) { - ac, err := makeAC(ctx, acct, path.ExchangeService) - if err != nil { - return false, clues.Stack(err) - } - - return exchange.IsServiceEnabled(ctx, ac.Users(), userID) +func (c client) UserHasMailbox(ctx context.Context, userID string) (bool, error) { + return exchange.IsServiceEnabled(ctx, c.ac.Users(), userID) } -func UserGetMailboxInfo( +func (c client) UserGetMailboxInfo( ctx context.Context, - acct account.Account, userID string, ) (api.MailboxInfo, error) { - ac, err := makeAC(ctx, acct, path.ExchangeService) - if err != nil { - return api.MailboxInfo{}, clues.Stack(err) - } - - return exchange.GetMailboxInfo(ctx, ac.Users(), userID) + return exchange.GetMailboxInfo(ctx, c.ac.Users(), userID) } // UserHasDrives returns true if the user has any drives // false otherwise, and a nil pointer and an error in case of error -func UserHasDrives(ctx context.Context, acct account.Account, userID string) (bool, error) { - ac, err := makeAC(ctx, acct, path.OneDriveService) - if err != nil { - return false, clues.Stack(err) - } - - return onedrive.IsServiceEnabled(ctx, ac.Users(), userID) +func (c client) UserHasDrives(ctx context.Context, userID string) (bool, error) { + return onedrive.IsServiceEnabled(ctx, c.ac.Users(), userID) } // usersNoInfo returns a list of users in the specified M365 tenant - with no info -func usersNoInfo(ctx context.Context, acct account.Account, errs *fault.Bus) ([]*UserNoInfo, error) { - ac, err := makeAC(ctx, acct, path.UnknownService) - if err != nil { - return nil, clues.Stack(err) - } - +func usersNoInfo( + ctx context.Context, + ac api.Client, + errs *fault.Bus, +) ([]*UserNoInfo, error) { us, err := ac.Users().GetAll(ctx, errs) if err != nil { return nil, err @@ -105,13 +86,8 @@ func usersNoInfo(ctx context.Context, acct account.Account, errs *fault.Bus) ([] return ret, nil } -func UserAssignedLicenses(ctx context.Context, acct account.Account, userID string) (int, error) { - ac, err := makeAC(ctx, acct, path.UnknownService) - if err != nil { - return 0, clues.Stack(err) - } - - us, err := ac.Users().GetByID( +func (c client) UserAssignedLicenses(ctx context.Context, userID string) (int, error) { + us, err := c.ac.Users().GetByID( ctx, userID, api.CallConfig{Select: api.SelectProps("assignedLicenses")}) diff --git a/src/pkg/services/m365/users_test.go b/src/pkg/services/m365/users_test.go index d61fd458a..2d0d3ab53 100644 --- a/src/pkg/services/m365/users_test.go +++ b/src/pkg/services/m365/users_test.go @@ -11,15 +11,12 @@ import ( "github.com/alcionai/corso/src/internal/tester" "github.com/alcionai/corso/src/internal/tester/tconfig" - "github.com/alcionai/corso/src/pkg/account" - "github.com/alcionai/corso/src/pkg/credentials" "github.com/alcionai/corso/src/pkg/services/m365/api" - "github.com/alcionai/corso/src/pkg/services/m365/api/graph" ) type userIntegrationSuite struct { tester.Suite - acct account.Account + cli client } func TestUserIntegrationSuite(t *testing.T) { @@ -31,12 +28,18 @@ func TestUserIntegrationSuite(t *testing.T) { } func (suite *userIntegrationSuite) SetupSuite() { - ctx, flush := tester.NewContext(suite.T()) + t := suite.T() + + ctx, flush := tester.NewContext(t) defer flush() - graph.InitializeConcurrencyLimiter(ctx, true, 4) + acct := tconfig.NewM365Account(t) - suite.acct = tconfig.NewM365Account(suite.T()) + var err error + + // will init the concurrency limiter + suite.cli, err = NewM365Client(ctx, acct) + require.NoError(t, err, clues.ToCore(err)) } func (suite *userIntegrationSuite) TestUsersCompat_HasNoInfo() { @@ -45,11 +48,7 @@ func (suite *userIntegrationSuite) TestUsersCompat_HasNoInfo() { ctx, flush := tester.NewContext(t) defer flush() - graph.InitializeConcurrencyLimiter(ctx, true, 4) - - acct := tconfig.NewM365Account(suite.T()) - - users, err := UsersCompatNoInfo(ctx, acct) + users, err := suite.cli.UsersCompatNoInfo(ctx) assert.NoError(t, err, clues.ToCore(err)) assert.NotEmpty(t, users) @@ -66,7 +65,6 @@ func (suite *userIntegrationSuite) TestUsersCompat_HasNoInfo() { func (suite *userIntegrationSuite) TestUserHasMailbox() { t := suite.T() - acct := tconfig.NewM365Account(t) userID := tconfig.M365UserID(t) table := []struct { @@ -92,7 +90,7 @@ func (suite *userIntegrationSuite) TestUserHasMailbox() { ctx, flush := tester.NewContext(t) defer flush() - enabled, err := UserHasMailbox(ctx, acct, test.user) + enabled, err := suite.cli.UserHasMailbox(ctx, test.user) require.NoError(t, err, clues.ToCore(err)) assert.Equal(t, test.expect, enabled) }) @@ -101,7 +99,6 @@ func (suite *userIntegrationSuite) TestUserHasMailbox() { func (suite *userIntegrationSuite) TestUserHasDrive() { t := suite.T() - acct := tconfig.NewM365Account(t) userID := tconfig.M365UserID(t) table := []struct { @@ -130,7 +127,7 @@ func (suite *userIntegrationSuite) TestUserHasDrive() { ctx, flush := tester.NewContext(t) defer flush() - enabled, err := UserHasDrives(ctx, acct, test.user) + enabled, err := suite.cli.UserHasDrives(ctx, test.user) test.expectErr(t, err, clues.ToCore(err)) assert.Equal(t, test.expect, enabled) }) @@ -139,7 +136,6 @@ func (suite *userIntegrationSuite) TestUserHasDrive() { func (suite *userIntegrationSuite) TestUserGetMailboxInfo() { t := suite.T() - acct := tconfig.NewM365Account(t) userID := tconfig.M365UserID(t) table := []struct { @@ -195,55 +191,16 @@ func (suite *userIntegrationSuite) TestUserGetMailboxInfo() { ctx, flush := tester.NewContext(t) defer flush() - info, err := UserGetMailboxInfo(ctx, acct, test.user) + info, err := suite.cli.UserGetMailboxInfo(ctx, test.user) test.expectErr(t, err, clues.ToCore(err)) test.expect(t, info) }) } } -func (suite *userIntegrationSuite) TestUsers_InvalidCredentials() { - table := []struct { - name string - acct func(t *testing.T) account.Account - }{ - { - name: "Invalid Credentials", - acct: func(t *testing.T) account.Account { - a, err := account.NewAccount( - account.ProviderM365, - account.M365Config{ - M365: credentials.M365{ - AzureClientID: "Test", - AzureClientSecret: "without", - }, - AzureTenantID: "data", - }) - require.NoError(t, err, clues.ToCore(err)) - - return a - }, - }, - } - - for _, test := range table { - suite.Run(test.name, func() { - t := suite.T() - - ctx, flush := tester.NewContext(t) - defer flush() - - users, err := UsersCompatNoInfo(ctx, test.acct(t)) - assert.Empty(t, users, "returned some users") - assert.NotNil(t, err) - }) - } -} - func (suite *userIntegrationSuite) TestUserAssignedLicenses() { t := suite.T() ctx, flush := tester.NewContext(t) - graph.InitializeConcurrencyLimiter(ctx, true, 4) defer flush() @@ -275,10 +232,7 @@ func (suite *userIntegrationSuite) TestUserAssignedLicenses() { for _, run := range runs { t.Run(run.name, func(t *testing.T) { - user, err := UserAssignedLicenses( - ctx, - suite.acct, - run.userID) + user, err := suite.cli.UserAssignedLicenses(ctx, run.userID) run.expectErr(t, err, clues.ToCore(err)) assert.Equal(t, run.expect, user) })