refactors post list to accept deseriallized stored list (#5008)

refactors post list to accept deseriallized stored list

#### Does this PR need a docs update or release note?
- [x]  No

#### Type of change

<!--- Please check the type of change your PR introduces: --->
- [x] 🧹 Tech Debt/Cleanup

#### Issue(s)
#4754 

#### Test Plan

<!-- How will this be tested prior to merging.-->
- [x] 💪 Manual
- [x]  Unit test
- [x] 💚 E2E
This commit is contained in:
Hitesh Pattanayak 2024-01-11 19:22:48 +05:30 committed by GitHub
parent 26e851ed01
commit ad783172b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 101 additions and 76 deletions

View File

@ -43,7 +43,7 @@ type PostLister interface {
PostList( PostList(
ctx context.Context, ctx context.Context,
listName string, listName string,
storedListData []byte, storedList models.Listable,
errs *fault.Bus, errs *fault.Bus,
) (models.Listable, error) ) (models.Listable, error)
} }

View File

@ -67,10 +67,10 @@ func NewListsRestoreHandler(protectedResource string, ac api.Lists) listsRestore
func (rh listsRestoreHandler) PostList( func (rh listsRestoreHandler) PostList(
ctx context.Context, ctx context.Context,
listName string, listName string,
storedListData []byte, storedList models.Listable,
errs *fault.Bus, errs *fault.Bus,
) (models.Listable, error) { ) (models.Listable, error) {
return rh.ac.PostList(ctx, rh.protectedResource, listName, storedListData, errs) return rh.ac.PostList(ctx, rh.protectedResource, listName, storedList, errs)
} }
func (rh listsRestoreHandler) DeleteList( func (rh listsRestoreHandler) DeleteList(

View File

@ -8,8 +8,10 @@ import (
"runtime/trace" "runtime/trace"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/microsoftgraph/msgraph-sdk-go/models"
"github.com/alcionai/corso/src/internal/common/idname" "github.com/alcionai/corso/src/internal/common/idname"
"github.com/alcionai/corso/src/internal/common/ptr"
"github.com/alcionai/corso/src/internal/data" "github.com/alcionai/corso/src/internal/data"
"github.com/alcionai/corso/src/internal/diagnostics" "github.com/alcionai/corso/src/internal/diagnostics"
"github.com/alcionai/corso/src/internal/m365/collection/drive" "github.com/alcionai/corso/src/internal/m365/collection/drive"
@ -139,25 +141,30 @@ func restoreListItem(
siteID, destName string, siteID, destName string,
errs *fault.Bus, errs *fault.Bus,
) (details.ItemInfo, error) { ) (details.ItemInfo, error) {
var (
dii = details.ItemInfo{}
itemID = itemData.ID()
)
ctx, end := diagnostics.Span(ctx, "m365:sharepoint:restoreList", diagnostics.Label("item_uuid", itemData.ID())) ctx, end := diagnostics.Span(ctx, "m365:sharepoint:restoreList", diagnostics.Label("item_uuid", itemData.ID()))
defer end() defer end()
ctx = clues.Add(ctx, "list_item_id", itemData.ID()) ctx = clues.Add(ctx, "list_item_id", itemID)
var (
dii = details.ItemInfo{}
listName = itemData.ID()
)
bytes, err := io.ReadAll(itemData.ToReader()) bytes, err := io.ReadAll(itemData.ToReader())
if err != nil { if err != nil {
return dii, clues.WrapWC(ctx, err, "reading backup data") return dii, clues.WrapWC(ctx, err, "reading backup data")
} }
newName := fmt.Sprintf("%s_%s", destName, listName) storedList, err := api.BytesToListable(bytes)
if err != nil {
return dii, clues.WrapWC(ctx, err, "generating list from stored bytes")
}
newName := formatListsRestoreDestination(destName, itemID, storedList)
// Restore to List base to M365 back store // Restore to List base to M365 back store
restoredList, err := rh.PostList(ctx, newName, bytes, errs) restoredList, err := rh.PostList(ctx, newName, storedList, errs)
if err != nil { if err != nil {
return dii, graph.Wrap(ctx, err, "restoring list") return dii, graph.Wrap(ctx, err, "restoring list")
} }
@ -328,3 +335,16 @@ func RestorePageCollection(
return metrics, el.Failure() return metrics, el.Failure()
} }
// newName is of format: destinationName_listID
// here we replace listID with displayName of list generated from stored list
func formatListsRestoreDestination(destName, itemID string, storedList models.Listable) string {
part1 := destName
part2 := itemID
if dispName, ok := ptr.ValOK(storedList.GetDisplayName()); ok {
part2 = dispName
}
return fmt.Sprintf("%s_%s", part1, part2)
}

View File

@ -25,11 +25,63 @@ import (
"github.com/alcionai/corso/src/pkg/control" "github.com/alcionai/corso/src/pkg/control"
"github.com/alcionai/corso/src/pkg/control/testdata" "github.com/alcionai/corso/src/pkg/control/testdata"
"github.com/alcionai/corso/src/pkg/count" "github.com/alcionai/corso/src/pkg/count"
"github.com/alcionai/corso/src/pkg/dttm"
"github.com/alcionai/corso/src/pkg/fault" "github.com/alcionai/corso/src/pkg/fault"
"github.com/alcionai/corso/src/pkg/services/m365/api" "github.com/alcionai/corso/src/pkg/services/m365/api"
"github.com/alcionai/corso/src/pkg/services/m365/api/graph" "github.com/alcionai/corso/src/pkg/services/m365/api/graph"
) )
type SharePointRestoreUnitSuite struct {
tester.Suite
}
func TestSharePointRestoreUnitSuite(t *testing.T) {
suite.Run(t, &SharePointRestoreUnitSuite{Suite: tester.NewUnitSuite(t)})
}
func (suite *SharePointCollectionUnitSuite) TestFormatListsRestoreDestination() {
t := suite.T()
dt := dttm.FormatNow(dttm.SafeForTesting)
tests := []struct {
name string
destName string
itemID string
getStoredList func() models.Listable
expectedName string
}{
{
name: "stored list has a display name",
destName: "Corso_Restore_" + dt,
itemID: "someid",
getStoredList: func() models.Listable {
list := models.NewList()
list.SetDisplayName(ptr.To("list1"))
return list
},
expectedName: "Corso_Restore_" + dt + "_list1",
},
{
name: "stored list does not have a display name",
destName: "Corso_Restore_" + dt,
itemID: "someid",
getStoredList: func() models.Listable {
return models.NewList()
},
expectedName: "Corso_Restore_" + dt + "_someid",
},
}
for _, test := range tests {
suite.Run(test.name, func() {
newName := formatListsRestoreDestination(test.destName, test.itemID, test.getStoredList())
assert.Equal(t, test.expectedName, newName, "new name for list")
})
}
}
type SharePointRestoreSuite struct { type SharePointRestoreSuite struct {
tester.Suite tester.Suite
siteID string siteID string

View File

@ -2,7 +2,6 @@ package api
import ( import (
"context" "context"
"strings"
"github.com/alcionai/clues" "github.com/alcionai/clues"
"github.com/microsoftgraph/msgraph-sdk-go/models" "github.com/microsoftgraph/msgraph-sdk-go/models"
@ -217,31 +216,13 @@ func (c Lists) PostList(
ctx context.Context, ctx context.Context,
siteID string, siteID string,
listName string, listName string,
oldListByteArray []byte, storedList models.Listable,
errs *fault.Bus, errs *fault.Bus,
) (models.Listable, error) { ) (models.Listable, error) {
var ( el := errs.Local()
newListName = listName
el = errs.Local()
)
oldList, err := BytesToListable(oldListByteArray)
if err != nil {
return nil, clues.WrapWC(ctx, err, "generating list from stored bytes")
}
// the input listName is of format: destinationName_listID
// here we replace listID with displayName of list generated from stored bytes
if name, ok := ptr.ValOK(oldList.GetDisplayName()); ok {
nameParts := strings.Split(listName, "_")
if len(nameParts) > 0 {
nameParts[len(nameParts)-1] = name
newListName = strings.Join(nameParts, "_")
}
}
// this ensure all columns, contentTypes are set to the newList // this ensure all columns, contentTypes are set to the newList
newList, columnNames := ToListable(oldList, newListName) newList, columnNames := ToListable(storedList, listName)
if newList.GetList() != nil && if newList.GetList() != nil &&
SkipListTemplates.HasKey(ptr.Val(newList.GetList().GetTemplate())) { SkipListTemplates.HasKey(ptr.Val(newList.GetList().GetTemplate())) {
@ -261,7 +242,7 @@ func (c Lists) PostList(
listItems := make([]models.ListItemable, 0) listItems := make([]models.ListItemable, 0)
for _, itm := range oldList.GetItems() { for _, itm := range storedList.GetItems() {
temp := CloneListItem(itm, columnNames) temp := CloneListItem(itm, columnNames)
listItems = append(listItems, temp) listItems = append(listItems, temp)
} }
@ -335,7 +316,7 @@ func BytesToListable(bytes []byte) (models.Listable, error) {
// not attached in this method. // not attached in this method.
// ListItems are not included in creation of new list, and have to be restored // ListItems are not included in creation of new list, and have to be restored
// in separate call. // in separate call.
func ToListable(orig models.Listable, displayName string) (models.Listable, map[string]any) { func ToListable(orig models.Listable, listName string) (models.Listable, map[string]any) {
newList := models.NewList() newList := models.NewList()
newList.SetContentTypes(orig.GetContentTypes()) newList.SetContentTypes(orig.GetContentTypes())
@ -343,7 +324,7 @@ func ToListable(orig models.Listable, displayName string) (models.Listable, map[
newList.SetCreatedByUser(orig.GetCreatedByUser()) newList.SetCreatedByUser(orig.GetCreatedByUser())
newList.SetCreatedDateTime(orig.GetCreatedDateTime()) newList.SetCreatedDateTime(orig.GetCreatedDateTime())
newList.SetDescription(orig.GetDescription()) newList.SetDescription(orig.GetDescription())
newList.SetDisplayName(&displayName) newList.SetDisplayName(ptr.To(listName))
newList.SetLastModifiedBy(orig.GetLastModifiedBy()) newList.SetLastModifiedBy(orig.GetLastModifiedBy())
newList.SetLastModifiedByUser(orig.GetLastModifiedByUser()) newList.SetLastModifiedByUser(orig.GetLastModifiedByUser())
newList.SetLastModifiedDateTime(orig.GetLastModifiedDateTime()) newList.SetLastModifiedDateTime(orig.GetLastModifiedDateTime())

View File

@ -746,17 +746,11 @@ func (suite *ListsAPIIntgSuite) TestLists_PostList() {
fieldsData, list := getFieldsDataAndList() fieldsData, list := getFieldsDataAndList()
err := writer.WriteObjectValue("", list) newList, err := acl.PostList(ctx, siteID, listName, list, fault.New(true))
require.NoError(t, err)
oldListByteArray, err := writer.GetSerializedContent()
require.NoError(t, err)
newList, err := acl.PostList(ctx, siteID, listName, oldListByteArray, fault.New(true))
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, listName, ptr.Val(newList.GetDisplayName())) assert.Equal(t, listName, ptr.Val(newList.GetDisplayName()))
_, err = acl.PostList(ctx, siteID, listName, oldListByteArray, fault.New(true)) _, err = acl.PostList(ctx, siteID, listName, list, fault.New(true))
require.Error(t, err) require.Error(t, err)
newListItems := newList.GetItems() newListItems := newList.GetItems()
@ -767,10 +761,7 @@ func (suite *ListsAPIIntgSuite) TestLists_PostList() {
newListItemsData := newListItemFields.GetAdditionalData() newListItemsData := newListItemFields.GetAdditionalData()
require.NotEmpty(t, newListItemsData) require.NotEmpty(t, newListItemsData)
assert.Equal(t, fieldsData, newListItemsData)
for k, v := range newListItemsData {
assert.Equal(t, fieldsData[k], ptr.Val(v.(*string)))
}
err = acl.DeleteList(ctx, siteID, ptr.Val(newList.GetId())) err = acl.DeleteList(ctx, siteID, ptr.Val(newList.GetId()))
require.NoError(t, err) require.NoError(t, err)
@ -819,11 +810,17 @@ func (suite *ListsAPIIntgSuite) TestLists_PostList_invalidTemplate() {
suite.Run(test.name, func() { suite.Run(test.name, func() {
t := suite.T() t := suite.T()
overrideListInfo := models.NewListInfo()
overrideListInfo.SetTemplate(ptr.To(test.template))
_, list := getFieldsDataAndList()
list.SetList(overrideListInfo)
_, err := acl.PostList( _, err := acl.PostList(
ctx, ctx,
siteID, siteID,
listName, listName,
getStoredListBytes(t, test.template), list,
fault.New(false)) fault.New(false))
require.Error(t, err) require.Error(t, err)
assert.Equal(t, ErrSkippableListTemplate.Error(), err.Error()) assert.Equal(t, ErrSkippableListTemplate.Error(), err.Error())
@ -848,13 +845,7 @@ func (suite *ListsAPIIntgSuite) TestLists_DeleteList() {
_, list := getFieldsDataAndList() _, list := getFieldsDataAndList()
err := writer.WriteObjectValue("", list) newList, err := acl.PostList(ctx, siteID, listName, list, fault.New(true))
require.NoError(t, err)
oldListByteArray, err := writer.GetSerializedContent()
require.NoError(t, err)
newList, err := acl.PostList(ctx, siteID, listName, oldListByteArray, fault.New(true))
require.NoError(t, err, clues.ToCore(err)) require.NoError(t, err, clues.ToCore(err))
assert.Equal(t, listName, ptr.Val(newList.GetDisplayName())) assert.Equal(t, listName, ptr.Val(newList.GetDisplayName()))
@ -881,7 +872,7 @@ func getFieldsDataAndList() (map[string]any, *models.List) {
fields := models.NewFieldValueSet() fields := models.NewFieldValueSet()
fieldsData := map[string]any{ fieldsData := map[string]any{
textColumnDefName: "item1", textColumnDefName: ptr.To("item1"),
} }
fields.SetAdditionalData(fieldsData) fields.SetAdditionalData(fieldsData)
@ -897,22 +888,3 @@ func getFieldsDataAndList() (map[string]any, *models.List) {
return fieldsData, list return fieldsData, list
} }
func getStoredListBytes(t *testing.T, template string) []byte {
writer := kjson.NewJsonSerializationWriter()
defer writer.Close()
overrideListInfo := models.NewListInfo()
overrideListInfo.SetTemplate(ptr.To(template))
_, list := getFieldsDataAndList()
list.SetList(overrideListInfo)
err := writer.WriteObjectValue("", list)
require.NoError(t, err)
storedListBytes, err := writer.GetSerializedContent()
require.NoError(t, err)
return storedListBytes
}