diff --git a/src/internal/connector/exchange/exchange_data_collection.go b/src/internal/connector/exchange/exchange_data_collection.go index 11ade10ab..0c2140caf 100644 --- a/src/internal/connector/exchange/exchange_data_collection.go +++ b/src/internal/connector/exchange/exchange_data_collection.go @@ -28,9 +28,10 @@ import ( ) var ( - _ data.Collection = &Collection{} - _ data.Stream = &Stream{} - _ data.StreamInfo = &Stream{} + _ data.Collection = &Collection{} + _ data.Stream = &Stream{} + _ data.StreamInfo = &Stream{} + _ data.StreamModTime = &Stream{} ) const ( @@ -222,6 +223,20 @@ func (col *Collection) finishPopulation(ctx context.Context, success int, totalB col.statusUpdater(status) } +type modTimer interface { + GetLastModifiedDateTime() *time.Time +} + +func getModTime(mt modTimer) time.Time { + res := time.Now() + + if t := mt.GetLastModifiedDateTime(); t != nil { + res = *t + } + + return res +} + // GraphSerializeFunc are class of functions that are used by Collections to transform GraphRetrievalFunc // responses into data.Stream items contained within the Collection type GraphSerializeFunc func( @@ -290,7 +305,12 @@ func eventToDataCollection( } if len(byteArray) > 0 { - dataChannel <- &Stream{id: *event.GetId(), message: byteArray, info: EventInfo(event, int64(len(byteArray)))} + dataChannel <- &Stream{ + id: *event.GetId(), + message: byteArray, + info: EventInfo(event, int64(len(byteArray))), + modTime: getModTime(event), + } } return len(byteArray), nil @@ -323,7 +343,12 @@ func contactToDataCollection( } if len(byteArray) > 0 { - dataChannel <- &Stream{id: *contact.GetId(), message: byteArray, info: ContactInfo(contact, int64(len(byteArray)))} + dataChannel <- &Stream{ + id: *contact.GetId(), + message: byteArray, + info: ContactInfo(contact, int64(len(byteArray))), + modTime: getModTime(contact), + } } return len(byteArray), nil @@ -382,7 +407,12 @@ func messageToDataCollection( return 0, support.SetNonRecoverableError(err) } - dataChannel <- &Stream{id: *aMessage.GetId(), message: byteArray, info: MessageInfo(aMessage, int64(len(byteArray)))} + dataChannel <- &Stream{ + id: *aMessage.GetId(), + message: byteArray, + info: MessageInfo(aMessage, int64(len(byteArray))), + modTime: getModTime(aMessage), + } return len(byteArray), nil } @@ -395,6 +425,9 @@ type Stream struct { // some structured type in here (serialization to []byte can be done in `Read`) message []byte info *details.ExchangeInfo // temporary change to bring populate function into directory + // TODO(ashmrtn): Can probably eventually be sourced from info as there's a + // request to provide modtime in ItemInfo structs. + modTime time.Time } func (od *Stream) UUID() string { @@ -409,11 +442,16 @@ func (od *Stream) Info() details.ItemInfo { return details.ItemInfo{Exchange: od.info} } +func (od *Stream) ModTime() time.Time { + return od.modTime +} + // NewStream constructor for exchange.Stream object -func NewStream(identifier string, dataBytes []byte, detail details.ExchangeInfo) Stream { +func NewStream(identifier string, dataBytes []byte, detail details.ExchangeInfo, modTime time.Time) Stream { return Stream{ id: identifier, message: dataBytes, info: &detail, + modTime: modTime, } } diff --git a/src/internal/connector/graph_connector_helper_test.go b/src/internal/connector/graph_connector_helper_test.go index e1bc6fd56..9fcf43d2b 100644 --- a/src/internal/connector/graph_connector_helper_test.go +++ b/src/internal/connector/graph_connector_helper_test.go @@ -633,6 +633,10 @@ func compareItem( category path.CategoryType, item data.Stream, ) { + if mt, ok := item.(data.StreamModTime); ok { + assert.NotZero(t, mt.ModTime()) + } + switch service { case path.ExchangeService: switch category { diff --git a/src/internal/connector/onedrive/collection.go b/src/internal/connector/onedrive/collection.go index 4f65b0ac0..095c1c1f7 100644 --- a/src/internal/connector/onedrive/collection.go +++ b/src/internal/connector/onedrive/collection.go @@ -31,9 +31,10 @@ const ( ) var ( - _ data.Collection = &Collection{} - _ data.Stream = &Item{} - _ data.StreamInfo = &Item{} + _ data.Collection = &Collection{} + _ data.Stream = &Item{} + _ data.StreamInfo = &Item{} + _ data.StreamModTime = &Item{} ) // Collection represents a set of OneDrive objects retreived from M365 @@ -115,6 +116,10 @@ func (od *Item) Info() details.ItemInfo { return details.ItemInfo{OneDrive: od.info} } +func (od *Item) ModTime() time.Time { + return od.info.Modified +} + // populateItems iterates through items added to the collection // and uses the collection `itemReader` to read the item func (oc *Collection) populateItems(ctx context.Context) { diff --git a/src/internal/connector/onedrive/collection_test.go b/src/internal/connector/onedrive/collection_test.go index 2f7533eb3..2f1b85ff3 100644 --- a/src/internal/connector/onedrive/collection_test.go +++ b/src/internal/connector/onedrive/collection_test.go @@ -7,6 +7,7 @@ import ( "io" "sync" "testing" + "time" msgraphsdk "github.com/microsoftgraph/msgraph-sdk-go" "github.com/stretchr/testify/assert" @@ -59,6 +60,7 @@ func (suite *OneDriveCollectionSuite) TestOneDriveCollection() { t := suite.T() wg := sync.WaitGroup{} collStatus := support.ConnectorOperationStatus{} + now := time.Now() folderPath, err := GetCanonicalPath("drive/driveID1/root:/dir1/dir2/dir3", "a-tenant", "a-user", OneDriveSource) require.NoError(t, err) @@ -77,7 +79,10 @@ func (suite *OneDriveCollectionSuite) TestOneDriveCollection() { coll.Add(testItemID) coll.itemReader = func(context.Context, graph.Service, string, string) (*details.OneDriveInfo, io.ReadCloser, error) { - return &details.OneDriveInfo{ItemName: testItemName}, io.NopCloser(bytes.NewReader(testItemData)), nil + return &details.OneDriveInfo{ + ItemName: testItemName, + Modified: now, + }, io.NopCloser(bytes.NewReader(testItemData)), nil } // Read items from the collection @@ -101,6 +106,11 @@ func (suite *OneDriveCollectionSuite) TestOneDriveCollection() { readItemInfo := readItem.(data.StreamInfo) assert.Equal(t, testItemName, readItem.UUID()) + + require.Implements(t, (*data.StreamModTime)(nil), readItem) + mt := readItem.(data.StreamModTime) + assert.Equal(t, now, mt.ModTime()) + readData, err := io.ReadAll(readItem.ToReader()) require.NoError(t, err) diff --git a/src/internal/connector/sharepoint/collection.go b/src/internal/connector/sharepoint/collection.go index afb78b8f2..fc9678b28 100644 --- a/src/internal/connector/sharepoint/collection.go +++ b/src/internal/connector/sharepoint/collection.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "time" kw "github.com/microsoft/kiota-serialization-json-go" @@ -27,8 +28,10 @@ const ( ) var ( - _ data.Collection = &Collection{} - _ data.Stream = &Item{} + _ data.Collection = &Collection{} + _ data.Stream = &Item{} + _ data.StreamInfo = &Item{} + _ data.StreamModTime = &Item{} ) type Collection struct { @@ -72,9 +75,10 @@ func (sc *Collection) Items() <-chan data.Stream { } type Item struct { - id string - data io.ReadCloser - info *details.SharePointInfo + id string + data io.ReadCloser + info *details.SharePointInfo + modTime time.Time } func (sd *Item) UUID() string { @@ -89,6 +93,10 @@ func (sd *Item) Info() details.ItemInfo { return details.ItemInfo{SharePoint: sd.info} } +func (sd *Item) ModTime() time.Time { + return sd.modTime +} + func (sc *Collection) finishPopulation(ctx context.Context, success int, totalBytes int64, errs error) { close(sc.data) attempted := len(sc.jobs) @@ -150,13 +158,19 @@ func (sc *Collection) populate(ctx context.Context) { arrayLength = int64(len(byteArray)) if arrayLength > 0 { + t := time.Now() + if t1 := lst.GetLastModifiedDateTime(); t1 != nil { + t = *t1 + } + totalBytes += arrayLength success++ sc.data <- &Item{ - id: *lst.GetId(), - data: io.NopCloser(bytes.NewReader(byteArray)), - info: sharePointListInfo(lst, arrayLength), + id: *lst.GetId(), + data: io.NopCloser(bytes.NewReader(byteArray)), + info: sharePointListInfo(lst, arrayLength), + modTime: t, } colProgress <- struct{}{}