diff --git a/src/pkg/services/m365/api/teams.go b/src/pkg/services/m365/api/teams.go index 883ddee01..8eef3fba3 100644 --- a/src/pkg/services/m365/api/teams.go +++ b/src/pkg/services/m365/api/teams.go @@ -7,8 +7,16 @@ import ( msgraphgocore "github.com/microsoftgraph/msgraph-sdk-go-core" "github.com/microsoftgraph/msgraph-sdk-go/models" + "github.com/alcionai/corso/src/internal/common/str" + "github.com/alcionai/corso/src/internal/common/tform" "github.com/alcionai/corso/src/internal/m365/graph" "github.com/alcionai/corso/src/pkg/fault" + "github.com/alcionai/corso/src/pkg/logger" +) + +const ( + teamsAdditionalDataLabel = "Team" + ResourceProvisioningOptions = "resourceProvisioningOptions" ) // --------------------------------------------------------------------------- @@ -19,74 +27,118 @@ func (c Client) Teams() Teams { return Teams{c} } +// On creation of each Teams team a corrsponding group gets created. +// The group acts as the protected resource, and all teams data like events, +// drive and mail messages are owned by that group. + // Teams is an interface-compliant provider of the client. type Teams struct { Client } -// GetAllTeams retrieves all teams. +// GetAllTeams retrieves all groups. func (c Teams) GetAll( ctx context.Context, errs *fault.Bus, -) ([]models.Teamable, error) { +) ([]models.Groupable, error) { service, err := c.Service() if err != nil { return nil, err } - resp, err := service.Client().Teams().Get(ctx, nil) + return getGroups(ctx, true, errs, service) +} + +// GetAll retrieves all groups. +func getGroups( + ctx context.Context, + getOnlyTeams bool, + errs *fault.Bus, + service graph.Servicer, +) ([]models.Groupable, error) { + resp, err := service.Client().Groups().Get(ctx, nil) if err != nil { - return nil, graph.Wrap(ctx, err, "getting all teams") + return nil, graph.Wrap(ctx, err, "getting all groups") } - iter, err := msgraphgocore.NewPageIterator[models.Teamable]( + iter, err := msgraphgocore.NewPageIterator[models.Groupable]( resp, service.Adapter(), models.CreateTeamCollectionResponseFromDiscriminatorValue) if err != nil { - return nil, graph.Wrap(ctx, err, "creating teams iterator") + return nil, graph.Wrap(ctx, err, "creating groups iterator") } var ( - teams = make([]models.Teamable, 0) - el = errs.Local() + groups = make([]models.Groupable, 0) + el = errs.Local() ) - iterator := func(item models.Teamable) bool { + iterator := func(item models.Groupable) bool { if el.Failure() != nil { return false } - err := ValidateTeams(item) + err := ValidateGroup(item) if err != nil { - el.AddRecoverable(ctx, graph.Wrap(ctx, err, "validating teams")) + el.AddRecoverable(ctx, graph.Wrap(ctx, err, "validating groups")) } else { - teams = append(teams, item) + isTeam := IsTeam(ctx, item) + if !getOnlyTeams || isTeam { + groups = append(groups, item) + } } return true } if err := iter.Iterate(ctx, iterator); err != nil { - return nil, graph.Wrap(ctx, err, "iterating all teams") + return nil, graph.Wrap(ctx, err, "iterating all groups") } - return teams, el.Failure() + return groups, el.Failure() } -// GetID retrieves team by teamID. +func IsTeam(ctx context.Context, g models.Groupable) bool { + log := logger.Ctx(ctx) + + if g.GetAdditionalData()[ResourceProvisioningOptions] != nil { + val, _ := tform.AnyValueToT[[]any](ResourceProvisioningOptions, g.GetAdditionalData()) + for _, v := range val { + s, err := str.AnyToString(v) + if err != nil { + log.Debug("could not be converted to string value: ", ResourceProvisioningOptions) + return false + } + + if s == teamsAdditionalDataLabel { + return true + } + } + } + + return false +} + +// GetID retrieves team by groupID/teamID. func (c Teams) GetByID( ctx context.Context, identifier string, -) (models.Teamable, error) { +) (models.Groupable, error) { service, err := c.Service() if err != nil { return nil, err } - resp, err := service.Client().Teams().ByTeamId(identifier).Get(ctx, nil) + resp, err := service.Client().Groups().ByGroupId(identifier).Get(ctx, nil) if err != nil { - err := graph.Wrap(ctx, err, "getting team by id") + err := graph.Wrap(ctx, err, "getting group by id") + + return nil, err + } + + if !IsTeam(ctx, resp) { + err := clues.New("given teamID is not related to any team") return nil, err } @@ -98,9 +150,9 @@ func (c Teams) GetByID( // helpers // --------------------------------------------------------------------------- -// ValidateTeams ensures the item is a Teamable, and contains the necessary -// identifiers that we handle with all teams. -func ValidateTeams(item models.Teamable) error { +// ValidateGroup ensures the item is a Groupable, and contains the necessary +// identifiers that we handle with all groups. +func ValidateGroup(item models.Groupable) error { if item.GetId() == nil { return clues.New("missing ID") } diff --git a/src/pkg/services/m365/api/teams_test.go b/src/pkg/services/m365/api/teams_test.go index 89aab93b5..dcb039dc5 100644 --- a/src/pkg/services/m365/api/teams_test.go +++ b/src/pkg/services/m365/api/teams_test.go @@ -25,21 +25,21 @@ func TestTeamsUnitSuite(t *testing.T) { suite.Run(t, &TeamsUnitSuite{Suite: tester.NewUnitSuite(t)}) } -func (suite *TeamsUnitSuite) TestValidateTeams() { +func (suite *TeamsUnitSuite) TestValidateGroup() { team := models.NewTeam() - team.SetDisplayName(ptr.To("testteam")) + team.SetDisplayName(ptr.To("testgroup")) team.SetId(ptr.To("testID")) tests := []struct { name string - args models.Teamable + args models.Groupable errCheck assert.ErrorAssertionFunc errIsSkippable bool }{ { - name: "Valid Team", - args: func() *models.Team { - s := models.NewTeam() + name: "Valid group ", + args: func() *models.Group { + s := models.NewGroup() s.SetId(ptr.To("id")) s.SetDisplayName(ptr.To("testTeam")) return s @@ -48,8 +48,8 @@ func (suite *TeamsUnitSuite) TestValidateTeams() { }, { name: "No name", - args: func() *models.Team { - s := models.NewTeam() + args: func() *models.Group { + s := models.NewGroup() s.SetId(ptr.To("id")) return s }(), @@ -57,8 +57,8 @@ func (suite *TeamsUnitSuite) TestValidateTeams() { }, { name: "No ID", - args: func() *models.Team { - s := models.NewTeam() + args: func() *models.Group { + s := models.NewGroup() s.SetDisplayName(ptr.To("testTeam")) return s }(), @@ -70,7 +70,7 @@ func (suite *TeamsUnitSuite) TestValidateTeams() { suite.Run(test.name, func() { t := suite.T() - err := api.ValidateTeams(test.args) + err := api.ValidateGroup(test.args) test.errCheck(t, err, clues.ToCore(err)) if test.errIsSkippable { @@ -108,6 +108,10 @@ func (suite *TeamsIntgSuite) TestGetAllTeams() { GetAll(ctx, fault.New(true)) require.NoError(t, err) require.NotZero(t, len(teams), "must have at least one team") + + for _, team := range teams { + assert.True(t, api.IsTeam(ctx, team), "must not return non teams groups") + } } func (suite *TeamsIntgSuite) TestTeams_GetByID() {