diff --git a/src/internal/m365/collection/site/handlers.go b/src/internal/m365/collection/site/handlers.go index 253883a0e..bb47c6d07 100644 --- a/src/internal/m365/collection/site/handlers.go +++ b/src/internal/m365/collection/site/handlers.go @@ -43,7 +43,7 @@ type PostLister interface { PostList( ctx context.Context, listName string, - storedListData []byte, + storedList models.Listable, errs *fault.Bus, ) (models.Listable, error) } diff --git a/src/internal/m365/collection/site/lists_handler.go b/src/internal/m365/collection/site/lists_handler.go index 5ba1b9f78..4e4d08cd3 100644 --- a/src/internal/m365/collection/site/lists_handler.go +++ b/src/internal/m365/collection/site/lists_handler.go @@ -67,10 +67,10 @@ func NewListsRestoreHandler(protectedResource string, ac api.Lists) listsRestore func (rh listsRestoreHandler) PostList( ctx context.Context, listName string, - storedListData []byte, + storedList models.Listable, errs *fault.Bus, ) (models.Listable, error) { - return rh.ac.PostList(ctx, rh.protectedResource, listName, storedListData, errs) + return rh.ac.PostList(ctx, rh.protectedResource, listName, storedList, errs) } func (rh listsRestoreHandler) DeleteList( diff --git a/src/internal/m365/collection/site/restore.go b/src/internal/m365/collection/site/restore.go index eafd47d65..5721dcbf5 100644 --- a/src/internal/m365/collection/site/restore.go +++ b/src/internal/m365/collection/site/restore.go @@ -8,8 +8,10 @@ import ( "runtime/trace" "github.com/alcionai/clues" + "github.com/microsoftgraph/msgraph-sdk-go/models" "github.com/alcionai/corso/src/internal/common/idname" + "github.com/alcionai/corso/src/internal/common/ptr" "github.com/alcionai/corso/src/internal/data" "github.com/alcionai/corso/src/internal/diagnostics" "github.com/alcionai/corso/src/internal/m365/collection/drive" @@ -139,25 +141,30 @@ func restoreListItem( siteID, destName string, errs *fault.Bus, ) (details.ItemInfo, error) { + var ( + dii = details.ItemInfo{} + itemID = itemData.ID() + ) + ctx, end := diagnostics.Span(ctx, "m365:sharepoint:restoreList", diagnostics.Label("item_uuid", itemData.ID())) defer end() - ctx = clues.Add(ctx, "list_item_id", itemData.ID()) - - var ( - dii = details.ItemInfo{} - listName = itemData.ID() - ) + ctx = clues.Add(ctx, "list_item_id", itemID) bytes, err := io.ReadAll(itemData.ToReader()) if err != nil { return dii, clues.WrapWC(ctx, err, "reading backup data") } - newName := fmt.Sprintf("%s_%s", destName, listName) + storedList, err := api.BytesToListable(bytes) + if err != nil { + return dii, clues.WrapWC(ctx, err, "generating list from stored bytes") + } + + newName := formatListsRestoreDestination(destName, itemID, storedList) // Restore to List base to M365 back store - restoredList, err := rh.PostList(ctx, newName, bytes, errs) + restoredList, err := rh.PostList(ctx, newName, storedList, errs) if err != nil { return dii, graph.Wrap(ctx, err, "restoring list") } @@ -328,3 +335,16 @@ func RestorePageCollection( return metrics, el.Failure() } + +// newName is of format: destinationName_listID +// here we replace listID with displayName of list generated from stored list +func formatListsRestoreDestination(destName, itemID string, storedList models.Listable) string { + part1 := destName + part2 := itemID + + if dispName, ok := ptr.ValOK(storedList.GetDisplayName()); ok { + part2 = dispName + } + + return fmt.Sprintf("%s_%s", part1, part2) +} diff --git a/src/internal/m365/collection/site/restore_test.go b/src/internal/m365/collection/site/restore_test.go index 1c9335ab2..2d6dae751 100644 --- a/src/internal/m365/collection/site/restore_test.go +++ b/src/internal/m365/collection/site/restore_test.go @@ -25,11 +25,63 @@ import ( "github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/control/testdata" "github.com/alcionai/corso/src/pkg/count" + "github.com/alcionai/corso/src/pkg/dttm" "github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/services/m365/api" "github.com/alcionai/corso/src/pkg/services/m365/api/graph" ) +type SharePointRestoreUnitSuite struct { + tester.Suite +} + +func TestSharePointRestoreUnitSuite(t *testing.T) { + suite.Run(t, &SharePointRestoreUnitSuite{Suite: tester.NewUnitSuite(t)}) +} + +func (suite *SharePointCollectionUnitSuite) TestFormatListsRestoreDestination() { + t := suite.T() + + dt := dttm.FormatNow(dttm.SafeForTesting) + + tests := []struct { + name string + destName string + itemID string + getStoredList func() models.Listable + expectedName string + }{ + { + name: "stored list has a display name", + destName: "Corso_Restore_" + dt, + itemID: "someid", + getStoredList: func() models.Listable { + list := models.NewList() + list.SetDisplayName(ptr.To("list1")) + + return list + }, + expectedName: "Corso_Restore_" + dt + "_list1", + }, + { + name: "stored list does not have a display name", + destName: "Corso_Restore_" + dt, + itemID: "someid", + getStoredList: func() models.Listable { + return models.NewList() + }, + expectedName: "Corso_Restore_" + dt + "_someid", + }, + } + + for _, test := range tests { + suite.Run(test.name, func() { + newName := formatListsRestoreDestination(test.destName, test.itemID, test.getStoredList()) + assert.Equal(t, test.expectedName, newName, "new name for list") + }) + } +} + type SharePointRestoreSuite struct { tester.Suite siteID string diff --git a/src/pkg/services/m365/api/lists.go b/src/pkg/services/m365/api/lists.go index d308cd78b..c9a1d18a0 100644 --- a/src/pkg/services/m365/api/lists.go +++ b/src/pkg/services/m365/api/lists.go @@ -2,7 +2,6 @@ package api import ( "context" - "strings" "github.com/alcionai/clues" "github.com/microsoftgraph/msgraph-sdk-go/models" @@ -217,31 +216,13 @@ func (c Lists) PostList( ctx context.Context, siteID string, listName string, - oldListByteArray []byte, + storedList models.Listable, errs *fault.Bus, ) (models.Listable, error) { - var ( - newListName = listName - el = errs.Local() - ) - - oldList, err := BytesToListable(oldListByteArray) - if err != nil { - return nil, clues.WrapWC(ctx, err, "generating list from stored bytes") - } - - // the input listName is of format: destinationName_listID - // here we replace listID with displayName of list generated from stored bytes - if name, ok := ptr.ValOK(oldList.GetDisplayName()); ok { - nameParts := strings.Split(listName, "_") - if len(nameParts) > 0 { - nameParts[len(nameParts)-1] = name - newListName = strings.Join(nameParts, "_") - } - } + el := errs.Local() // this ensure all columns, contentTypes are set to the newList - newList, columnNames := ToListable(oldList, newListName) + newList, columnNames := ToListable(storedList, listName) if newList.GetList() != nil && SkipListTemplates.HasKey(ptr.Val(newList.GetList().GetTemplate())) { @@ -261,7 +242,7 @@ func (c Lists) PostList( listItems := make([]models.ListItemable, 0) - for _, itm := range oldList.GetItems() { + for _, itm := range storedList.GetItems() { temp := CloneListItem(itm, columnNames) listItems = append(listItems, temp) } @@ -335,7 +316,7 @@ func BytesToListable(bytes []byte) (models.Listable, error) { // not attached in this method. // ListItems are not included in creation of new list, and have to be restored // in separate call. -func ToListable(orig models.Listable, displayName string) (models.Listable, map[string]any) { +func ToListable(orig models.Listable, listName string) (models.Listable, map[string]any) { newList := models.NewList() newList.SetContentTypes(orig.GetContentTypes()) @@ -343,7 +324,7 @@ func ToListable(orig models.Listable, displayName string) (models.Listable, map[ newList.SetCreatedByUser(orig.GetCreatedByUser()) newList.SetCreatedDateTime(orig.GetCreatedDateTime()) newList.SetDescription(orig.GetDescription()) - newList.SetDisplayName(&displayName) + newList.SetDisplayName(ptr.To(listName)) newList.SetLastModifiedBy(orig.GetLastModifiedBy()) newList.SetLastModifiedByUser(orig.GetLastModifiedByUser()) newList.SetLastModifiedDateTime(orig.GetLastModifiedDateTime()) diff --git a/src/pkg/services/m365/api/lists_test.go b/src/pkg/services/m365/api/lists_test.go index 39e6d11b9..040c085c1 100644 --- a/src/pkg/services/m365/api/lists_test.go +++ b/src/pkg/services/m365/api/lists_test.go @@ -746,17 +746,11 @@ func (suite *ListsAPIIntgSuite) TestLists_PostList() { fieldsData, list := getFieldsDataAndList() - err := writer.WriteObjectValue("", list) - require.NoError(t, err) - - oldListByteArray, err := writer.GetSerializedContent() - require.NoError(t, err) - - newList, err := acl.PostList(ctx, siteID, listName, oldListByteArray, fault.New(true)) + newList, err := acl.PostList(ctx, siteID, listName, list, fault.New(true)) require.NoError(t, err, clues.ToCore(err)) assert.Equal(t, listName, ptr.Val(newList.GetDisplayName())) - _, err = acl.PostList(ctx, siteID, listName, oldListByteArray, fault.New(true)) + _, err = acl.PostList(ctx, siteID, listName, list, fault.New(true)) require.Error(t, err) newListItems := newList.GetItems() @@ -767,10 +761,7 @@ func (suite *ListsAPIIntgSuite) TestLists_PostList() { newListItemsData := newListItemFields.GetAdditionalData() require.NotEmpty(t, newListItemsData) - - for k, v := range newListItemsData { - assert.Equal(t, fieldsData[k], ptr.Val(v.(*string))) - } + assert.Equal(t, fieldsData, newListItemsData) err = acl.DeleteList(ctx, siteID, ptr.Val(newList.GetId())) require.NoError(t, err) @@ -819,11 +810,17 @@ func (suite *ListsAPIIntgSuite) TestLists_PostList_invalidTemplate() { suite.Run(test.name, func() { t := suite.T() + overrideListInfo := models.NewListInfo() + overrideListInfo.SetTemplate(ptr.To(test.template)) + + _, list := getFieldsDataAndList() + list.SetList(overrideListInfo) + _, err := acl.PostList( ctx, siteID, listName, - getStoredListBytes(t, test.template), + list, fault.New(false)) require.Error(t, err) assert.Equal(t, ErrSkippableListTemplate.Error(), err.Error()) @@ -848,13 +845,7 @@ func (suite *ListsAPIIntgSuite) TestLists_DeleteList() { _, list := getFieldsDataAndList() - err := writer.WriteObjectValue("", list) - require.NoError(t, err) - - oldListByteArray, err := writer.GetSerializedContent() - require.NoError(t, err) - - newList, err := acl.PostList(ctx, siteID, listName, oldListByteArray, fault.New(true)) + newList, err := acl.PostList(ctx, siteID, listName, list, fault.New(true)) require.NoError(t, err, clues.ToCore(err)) assert.Equal(t, listName, ptr.Val(newList.GetDisplayName())) @@ -881,7 +872,7 @@ func getFieldsDataAndList() (map[string]any, *models.List) { fields := models.NewFieldValueSet() fieldsData := map[string]any{ - textColumnDefName: "item1", + textColumnDefName: ptr.To("item1"), } fields.SetAdditionalData(fieldsData) @@ -897,22 +888,3 @@ func getFieldsDataAndList() (map[string]any, *models.List) { return fieldsData, list } - -func getStoredListBytes(t *testing.T, template string) []byte { - writer := kjson.NewJsonSerializationWriter() - defer writer.Close() - - overrideListInfo := models.NewListInfo() - overrideListInfo.SetTemplate(ptr.To(template)) - - _, list := getFieldsDataAndList() - list.SetList(overrideListInfo) - - err := writer.WriteObjectValue("", list) - require.NoError(t, err) - - storedListBytes, err := writer.GetSerializedContent() - require.NoError(t, err) - - return storedListBytes -}