diff --git a/src/internal/tester/tconfig/protected_resources.go b/src/internal/tester/tconfig/protected_resources.go index caac0c586..26c0187ac 100644 --- a/src/internal/tester/tconfig/protected_resources.go +++ b/src/internal/tester/tconfig/protected_resources.go @@ -223,11 +223,11 @@ func UnlicensedM365UserID(t *testing.T) string { // Teams -// M365TeamsID returns a teamID string representing the m365TeamsID described +// M365TeamID returns a teamID string representing the m365TeamsID described // by either the env var CORSO_M365_TEST_TEAM_ID, the corso_test.toml config // file or the default value (in that order of priority). The default is a // last-attempt fallback that will only work on alcion's testing org. -func M365TeamsID(t *testing.T) string { +func M365TeamID(t *testing.T) string { cfg, err := ReadTestConfig() require.NoError(t, err, "retrieving m365 team id from test configuration: %+v", clues.ToCore(err)) diff --git a/src/pkg/services/m365/api/groups.go b/src/pkg/services/m365/api/groups.go index c2a27dad3..3d036e610 100644 --- a/src/pkg/services/m365/api/groups.go +++ b/src/pkg/services/m365/api/groups.go @@ -49,24 +49,6 @@ func (c Groups) GetAll( return getGroups(ctx, errs, service) } -// GetTeams retrieves all Teams. -func (c Groups) GetTeams( - ctx context.Context, - errs *fault.Bus, -) ([]models.Groupable, error) { - service, err := c.Service() - if err != nil { - return nil, err - } - - groups, err := getGroups(ctx, errs, service) - if err != nil { - return nil, err - } - - return OnlyTeams(ctx, groups), nil -} - // GetAll retrieves all groups. func getGroups( ctx context.Context, @@ -113,31 +95,6 @@ func getGroups( return groups, el.Failure() } -func OnlyTeams(ctx context.Context, groups []models.Groupable) []models.Groupable { - log := logger.Ctx(ctx) - - var teams []models.Groupable - - for _, g := range groups { - if g.GetAdditionalData()[ResourceProvisioningOptions] != nil { - val, _ := tform.AnyValueToT[[]any](ResourceProvisioningOptions, g.GetAdditionalData()) - for _, v := range val { - s, err := str.AnyToString(v) - if err != nil { - log.Debug("could not be converted to string value: ", ResourceProvisioningOptions) - continue - } - - if s == teamsAdditionalDataLabel { - teams = append(teams, g) - } - } - } - } - - return teams -} - // GetID retrieves group by groupID. func (c Groups) GetByID( ctx context.Context, @@ -158,34 +115,6 @@ func (c Groups) GetByID( return resp, graph.Stack(ctx, err).OrNil() } -// GetTeamByID retrieves group by groupID. -func (c Groups) GetTeamByID( - ctx context.Context, - identifier string, -) (models.Groupable, error) { - service, err := c.Service() - if err != nil { - return nil, err - } - - resp, err := service.Client().Groups().ByGroupId(identifier).Get(ctx, nil) - if err != nil { - err := graph.Wrap(ctx, err, "getting group by id") - - return nil, err - } - - groups := []models.Groupable{resp} - - if len(OnlyTeams(ctx, groups)) == 0 { - err := clues.New("given teamID is not related to any team") - - return nil, err - } - - return resp, graph.Stack(ctx, err).OrNil() -} - // --------------------------------------------------------------------------- // helpers // --------------------------------------------------------------------------- @@ -203,3 +132,38 @@ func ValidateGroup(item models.Groupable) error { return nil } + +func OnlyTeams(ctx context.Context, groups []models.Groupable) []models.Groupable { + var teams []models.Groupable + + for _, g := range groups { + if IsTeam(ctx, g) { + teams = append(teams, g) + } + } + + return teams +} + +func IsTeam(ctx context.Context, mg models.Groupable) bool { + log := logger.Ctx(ctx) + + if mg.GetAdditionalData()[ResourceProvisioningOptions] == nil { + return false + } + + val, _ := tform.AnyValueToT[[]any](ResourceProvisioningOptions, mg.GetAdditionalData()) + for _, v := range val { + s, err := str.AnyToString(v) + if err != nil { + log.Debug("could not be converted to string value: ", ResourceProvisioningOptions) + continue + } + + if s == teamsAdditionalDataLabel { + return true + } + } + + return false +} diff --git a/src/pkg/services/m365/api/groups_test.go b/src/pkg/services/m365/api/groups_test.go index 8ce0f8f6b..ae435168a 100644 --- a/src/pkg/services/m365/api/groups_test.go +++ b/src/pkg/services/m365/api/groups_test.go @@ -97,7 +97,7 @@ func (suite *GroupsIntgSuite) SetupSuite() { suite.its = newIntegrationTesterSetup(suite.T()) } -func (suite *GroupsIntgSuite) TestGetAllGroups() { +func (suite *GroupsIntgSuite) TestGetAll() { t := suite.T() ctx, flush := tester.NewContext(t) @@ -107,100 +107,15 @@ func (suite *GroupsIntgSuite) TestGetAllGroups() { Groups(). GetAll(ctx, fault.New(true)) require.NoError(t, err) - require.NotZero(t, len(groups), "must have at least one group") -} - -func (suite *GroupsIntgSuite) TestGetAllTeams() { - t := suite.T() - - ctx, flush := tester.NewContext(t) - defer flush() - - teams, err := suite.its.ac. - Groups(). - GetTeams(ctx, fault.New(true)) - require.NoError(t, err) - require.NotZero(t, len(teams), "must have at least one teams") - - groups, err := suite.its.ac. - Groups(). - GetAll(ctx, fault.New(true)) - require.NoError(t, err) - require.NotZero(t, len(groups), "must have at least one group") - - var isTeam bool - - if len(groups) > len(teams) { - isTeam = true - } - - assert.True(t, isTeam, "must only return teams") -} - -func (suite *GroupsIntgSuite) TestTeams_GetByID() { - var ( - t = suite.T() - teamID = tconfig.M365TeamsID(t) - ) - - teamsAPI := suite.its.ac.Groups() - - table := []struct { - name string - id string - expectErr func(*testing.T, error) - }{ - { - name: "3 part id", - id: teamID, - expectErr: func(t *testing.T, err error) { - assert.NoError(t, err, clues.ToCore(err)) - }, - }, - { - name: "malformed id", - id: uuid.NewString(), - expectErr: func(t *testing.T, err error) { - assert.Error(t, err, clues.ToCore(err)) - }, - }, - { - name: "random id", - id: uuid.NewString() + "," + uuid.NewString(), - expectErr: func(t *testing.T, err error) { - assert.Error(t, err, clues.ToCore(err)) - }, - }, - - { - name: "malformed url", - id: "barunihlda", - expectErr: func(t *testing.T, err error) { - assert.Error(t, err, clues.ToCore(err)) - }, - }, - } - for _, test := range table { - suite.Run(test.name, func() { - t := suite.T() - - ctx, flush := tester.NewContext(t) - defer flush() - - _, err := teamsAPI.GetTeamByID(ctx, test.id) - test.expectErr(t, err) - }) - } + require.NotZero(t, len(groups), "must find at least one group") } func (suite *GroupsIntgSuite) TestGroups_GetByID() { var ( - t = suite.T() - groupID = tconfig.M365GroupID(t) + groupID = suite.its.groupID + groupsAPI = suite.its.ac.Groups() ) - groupsAPI := suite.its.ac.Groups() - table := []struct { name string id string diff --git a/src/pkg/services/m365/api/helper_test.go b/src/pkg/services/m365/api/helper_test.go index a9c12324f..8e8c760c0 100644 --- a/src/pkg/services/m365/api/helper_test.go +++ b/src/pkg/services/m365/api/helper_test.go @@ -83,7 +83,7 @@ type intgTesterSetup struct { siteID string siteDriveID string siteDriveRootFolderID string - teamID string + groupID string } func newIntegrationTesterSetup(t *testing.T) intgTesterSetup { @@ -132,13 +132,16 @@ func newIntegrationTesterSetup(t *testing.T) intgTesterSetup { its.siteDriveRootFolderID = ptr.Val(siteDriveRootFolder.GetId()) - // teams - its.teamID = tconfig.M365TeamsID(t) + // group - team, err := its.ac.Groups().GetTeamByID(ctx, its.teamID) + // use of the TeamID is intentional here, so that we are assured + // the group has full usage of the teams api. + its.groupID = tconfig.M365TeamID(t) + + team, err := its.ac.Groups().GetByID(ctx, its.groupID) require.NoError(t, err, clues.ToCore(err)) - its.teamID = ptr.Val(team.GetId()) + its.groupID = ptr.Val(team.GetId()) return its } diff --git a/src/pkg/services/m365/groups.go b/src/pkg/services/m365/groups.go new file mode 100644 index 000000000..f4924be22 --- /dev/null +++ b/src/pkg/services/m365/groups.go @@ -0,0 +1,97 @@ +package m365 + +import ( + "context" + + "github.com/alcionai/clues" + "github.com/microsoftgraph/msgraph-sdk-go/models" + + "github.com/alcionai/corso/src/internal/common/ptr" + "github.com/alcionai/corso/src/pkg/account" + "github.com/alcionai/corso/src/pkg/fault" + "github.com/alcionai/corso/src/pkg/path" + "github.com/alcionai/corso/src/pkg/services/m365/api" +) + +// Group is the minimal information required to identify and display a M365 Group. +type Group struct { + ID string + + // DisplayName is the human-readable name of the group. Normally the plaintext name that the + // user provided when they created the group, or the updated name if it was changed. + // Ex: displayName: "My Group" + DisplayName string + + // IsTeam is true if the group qualifies as a Teams resource, and is able to backup and restore + // teams data. + IsTeam bool +} + +// GroupsCompat returns a list of groups in the specified M365 tenant. +func GroupsCompat(ctx context.Context, acct account.Account) ([]*Group, error) { + errs := fault.New(true) + + us, err := Groups(ctx, acct, errs) + if err != nil { + return nil, err + } + + return us, errs.Failure() +} + +// Groups returns a list of groups in the specified M365 tenant +func Groups( + ctx context.Context, + acct account.Account, + errs *fault.Bus, +) ([]*Group, error) { + ac, err := makeAC(ctx, acct, path.GroupsService) + if err != nil { + return nil, clues.Stack(err).WithClues(ctx) + } + + return getAllGroups(ctx, ac.Groups()) +} + +func getAllGroups( + ctx context.Context, + ga getAller[models.Groupable], +) ([]*Group, error) { + groups, err := ga.GetAll(ctx, fault.New(true)) + if err != nil { + return nil, clues.Wrap(err, "retrieving groups") + } + + ret := make([]*Group, 0, len(groups)) + + for _, g := range groups { + t, err := parseGroup(ctx, g) + if err != nil { + return nil, clues.Wrap(err, "parsing groups") + } + + ret = append(ret, t) + } + + return ret, nil +} + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +// parseUser extracts information from `models.Groupable` we care about +func parseGroup(ctx context.Context, mg models.Groupable) (*Group, error) { + if mg.GetDisplayName() == nil { + return nil, clues.New("group missing display name"). + With("group_id", ptr.Val(mg.GetId())) + } + + u := &Group{ + ID: ptr.Val(mg.GetId()), + DisplayName: ptr.Val(mg.GetDisplayName()), + IsTeam: api.IsTeam(ctx, mg), + } + + return u, nil +} diff --git a/src/pkg/services/m365/groups_test.go b/src/pkg/services/m365/groups_test.go new file mode 100644 index 000000000..8fa650a98 --- /dev/null +++ b/src/pkg/services/m365/groups_test.go @@ -0,0 +1,108 @@ +package m365_test + +import ( + "testing" + + "github.com/alcionai/clues" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/m365/graph" + "github.com/alcionai/corso/src/internal/tester" + "github.com/alcionai/corso/src/internal/tester/tconfig" + "github.com/alcionai/corso/src/pkg/account" + "github.com/alcionai/corso/src/pkg/credentials" + "github.com/alcionai/corso/src/pkg/fault" + "github.com/alcionai/corso/src/pkg/services/m365" +) + +type GroupsIntgSuite struct { + tester.Suite + acct account.Account +} + +func TestGroupsIntgSuite(t *testing.T) { + suite.Run(t, &GroupsIntgSuite{ + Suite: tester.NewIntegrationSuite( + t, + [][]string{tconfig.M365AcctCredEnvs}), + }) +} + +func (suite *GroupsIntgSuite) SetupSuite() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + graph.InitializeConcurrencyLimiter(ctx, true, 4) + + suite.acct = tconfig.NewM365Account(t) +} + +func (suite *GroupsIntgSuite) TestGroups() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + graph.InitializeConcurrencyLimiter(ctx, true, 4) + + groups, err := m365.Groups(ctx, suite.acct, fault.New(true)) + assert.NoError(t, err, clues.ToCore(err)) + assert.NotEmpty(t, groups) + + for _, group := range groups { + suite.Run("group_"+group.ID, func() { + t := suite.T() + + assert.NotEmpty(t, group.ID) + assert.NotEmpty(t, group.DisplayName) + + // at least one known group should be a team + if group.ID == tconfig.M365TeamID(t) { + assert.True(t, group.IsTeam) + } + }) + } +} + +func (suite *GroupsIntgSuite) TestGroups_InvalidCredentials() { + table := []struct { + name string + acct func(t *testing.T) account.Account + }{ + { + name: "Invalid Credentials", + acct: func(t *testing.T) account.Account { + a, err := account.NewAccount( + account.ProviderM365, + account.M365Config{ + M365: credentials.M365{ + AzureClientID: "Test", + AzureClientSecret: "without", + }, + AzureTenantID: "data", + }, + ) + require.NoError(t, err, clues.ToCore(err)) + + return a + }, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + groups, err := m365.Groups(ctx, test.acct(t), fault.New(true)) + assert.Empty(t, groups, "returned no groups") + assert.NotNil(t, err) + }) + } +} diff --git a/src/pkg/services/m365/m365.go b/src/pkg/services/m365/m365.go index 5b61885e5..469f4d08f 100644 --- a/src/pkg/services/m365/m365.go +++ b/src/pkg/services/m365/m365.go @@ -24,6 +24,10 @@ type getDefaultDriver interface { GetDefaultDrive(ctx context.Context, userID string) (models.Driveable, error) } +type getAller[T any] interface { + GetAll(ctx context.Context, errs *fault.Bus) ([]T, error) +} + // --------------------------------------------------------------------------- // Users // --------------------------------------------------------------------------- @@ -253,12 +257,11 @@ func Sites(ctx context.Context, acct account.Account, errs *fault.Bus) ([]*Site, return getAllSites(ctx, ac.Sites()) } -type getAllSiteser interface { - GetAll(ctx context.Context, errs *fault.Bus) ([]models.Siteable, error) -} - -func getAllSites(ctx context.Context, gas getAllSiteser) ([]*Site, error) { - sites, err := gas.GetAll(ctx, fault.New(true)) +func getAllSites( + ctx context.Context, + ga getAller[models.Siteable], +) ([]*Site, error) { + sites, err := ga.GetAll(ctx, fault.New(true)) if err != nil { if clues.HasLabel(err, graph.LabelsNoSharePointLicense) { return nil, clues.Stack(graph.ErrServiceNotEnabled, err) diff --git a/src/pkg/services/m365/m365_test.go b/src/pkg/services/m365/m365_test.go index 1eafa67f2..0124f13f2 100644 --- a/src/pkg/services/m365/m365_test.go +++ b/src/pkg/services/m365/m365_test.go @@ -276,25 +276,25 @@ func (suite *m365UnitSuite) TestCheckUserHasDrives() { } } -type mockGAS struct { +type mockGASites struct { response []models.Siteable err error } -func (m mockGAS) GetAll(context.Context, *fault.Bus) ([]models.Siteable, error) { +func (m mockGASites) GetAll(context.Context, *fault.Bus) ([]models.Siteable, error) { return m.response, m.err } func (suite *m365UnitSuite) TestGetAllSites() { table := []struct { name string - mock func(context.Context) getAllSiteser + mock func(context.Context) getAller[models.Siteable] expectErr func(*testing.T, error) }{ { name: "ok", - mock: func(ctx context.Context) getAllSiteser { - return mockGAS{[]models.Siteable{}, nil} + mock: func(ctx context.Context) getAller[models.Siteable] { + return mockGASites{[]models.Siteable{}, nil} }, expectErr: func(t *testing.T, err error) { assert.NoError(t, err, clues.ToCore(err)) @@ -302,14 +302,14 @@ func (suite *m365UnitSuite) TestGetAllSites() { }, { name: "no sharepoint license", - mock: func(ctx context.Context) getAllSiteser { + mock: func(ctx context.Context) getAller[models.Siteable] { odErr := odataerrors.NewODataError() merr := odataerrors.NewMainError() merr.SetCode(ptr.To("code")) merr.SetMessage(ptr.To(string(graph.NoSPLicense))) odErr.SetErrorEscaped(merr) - return mockGAS{nil, graph.Stack(ctx, odErr)} + return mockGASites{nil, graph.Stack(ctx, odErr)} }, expectErr: func(t *testing.T, err error) { assert.ErrorIs(t, err, graph.ErrServiceNotEnabled, clues.ToCore(err)) @@ -317,14 +317,14 @@ func (suite *m365UnitSuite) TestGetAllSites() { }, { name: "arbitrary error", - mock: func(ctx context.Context) getAllSiteser { + mock: func(ctx context.Context) getAller[models.Siteable] { odErr := odataerrors.NewODataError() merr := odataerrors.NewMainError() merr.SetCode(ptr.To("code")) merr.SetMessage(ptr.To("message")) odErr.SetErrorEscaped(merr) - return mockGAS{nil, graph.Stack(ctx, odErr)} + return mockGASites{nil, graph.Stack(ctx, odErr)} }, expectErr: func(t *testing.T, err error) { assert.Error(t, err, clues.ToCore(err))