diff --git a/src/cli/repo/s3.go b/src/cli/repo/s3.go index 947d8788b..06be910e3 100644 --- a/src/cli/repo/s3.go +++ b/src/cli/repo/s3.go @@ -1,6 +1,7 @@ package repo import ( + "context" "fmt" "os" @@ -75,10 +76,12 @@ func initS3Cmd(cmd *cobra.Command, args []string) { os.Exit(1) } - if _, err := repository.Initialize(cmd.Context(), a, s); err != nil { + r, err := repository.Connect(cmd.Context(), a, s) + if err != nil { fmt.Printf("Failed to initialize a new S3 repository: %v", err) os.Exit(1) } + defer closeRepo(cmd.Context(), r) fmt.Printf("Initialized a S3 repository within bucket %s.\n", s3Cfg.Bucket) } @@ -121,10 +124,12 @@ func connectS3Cmd(cmd *cobra.Command, args []string) { os.Exit(1) } - if _, err := repository.Connect(cmd.Context(), a, s); err != nil { + r, err := repository.Connect(cmd.Context(), a, s) + if err != nil { fmt.Printf("Failed to connect to the S3 repository: %v", err) os.Exit(1) } + defer closeRepo(cmd.Context(), r) fmt.Printf("Connected to S3 bucket %s.\n", s3Cfg.Bucket) } @@ -158,3 +163,9 @@ func makeS3Config() (storage.S3Config, storage.CommonConfig, error) { storage.CORSO_PASSWORD: corsoPasswd, }) } + +func closeRepo(ctx context.Context, r repository.Repository) { + if err := r.Close(ctx); err != nil { + fmt.Printf("Error closing repository: %v\n", err) + } +} diff --git a/src/internal/kopia/kopia.go b/src/internal/kopia/kopia.go index a24e7470d..3be5222c0 100644 --- a/src/internal/kopia/kopia.go +++ b/src/internal/kopia/kopia.go @@ -19,15 +19,16 @@ var ( errConnect = errors.New("connecting repo") ) -type kopiaWrapper struct { +type KopiaWrapper struct { storage storage.Storage + rep repo.Repository } -func New(s storage.Storage) kopiaWrapper { - return kopiaWrapper{s} +func New(s storage.Storage) KopiaWrapper { + return KopiaWrapper{storage: s} } -func (kw kopiaWrapper) Initialize(ctx context.Context) error { +func (kw KopiaWrapper) Initialize(ctx context.Context) error { bst, err := blobStoreByProvider(ctx, kw.storage) if err != nil { return errors.Wrap(err, errInit.Error()) @@ -55,10 +56,14 @@ func (kw kopiaWrapper) Initialize(ctx context.Context) error { return errors.Wrap(err, errConnect.Error()) } + if err := kw.open(ctx, cfg.CorsoPassword); err != nil { + return err + } + return nil } -func (kw kopiaWrapper) Connect(ctx context.Context) error { +func (kw KopiaWrapper) Connect(ctx context.Context) error { bst, err := blobStoreByProvider(ctx, kw.storage) if err != nil { return errors.Wrap(err, errInit.Error()) @@ -80,6 +85,11 @@ func (kw kopiaWrapper) Connect(ctx context.Context) error { ); err != nil { return errors.Wrap(err, errConnect.Error()) } + + if err := kw.open(ctx, cfg.CorsoPassword); err != nil { + return err + } + return nil } @@ -91,3 +101,29 @@ func blobStoreByProvider(ctx context.Context, s storage.Storage) (blob.Storage, return nil, errors.New("storage provider details are required") } } + +func (kw KopiaWrapper) Close(ctx context.Context) error { + if kw.rep == nil { + return nil + } + + err := kw.rep.Close(ctx) + kw.rep = nil + + if err != nil { + return errors.Wrap(err, "closing repository connection") + } + + return nil +} + +func (kw KopiaWrapper) open(ctx context.Context, password string) error { + // TODO(ashmrtnz): issue #75: nil here should be storage.ConnectionOptions(). + rep, err := repo.Open(ctx, defaultKopiaConfigFilePath, password, nil) + if err != nil { + return errors.Wrap(err, "opening repository connection") + } + + kw.rep = rep + return nil +} diff --git a/src/pkg/repository/repository.go b/src/pkg/repository/repository.go index 8e63d426d..6acd31e70 100644 --- a/src/pkg/repository/repository.go +++ b/src/pkg/repository/repository.go @@ -4,9 +4,11 @@ import ( "context" "time" + "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/alcionai/corso/internal/kopia" "github.com/alcionai/corso/pkg/storage" - "github.com/google/uuid" ) type repoProvider int @@ -23,8 +25,9 @@ type Repository struct { CreatedAt time.Time Version string // in case of future breaking changes - Account Account // the user's m365 account connection details - Storage storage.Storage // the storage provider details and configuration + Account Account // the user's m365 account connection details + Storage storage.Storage // the storage provider details and configuration + dataLayer *kopia.KopiaWrapper } // Account holds the user's m365 account details. @@ -52,10 +55,11 @@ func Initialize( return Repository{}, err } r := Repository{ - ID: uuid.New(), - Version: "v1", - Account: acct, - Storage: storage, + ID: uuid.New(), + Version: "v1", + Account: acct, + Storage: storage, + dataLayer: &k, } return r, nil } @@ -76,9 +80,25 @@ func Connect( } // todo: ID and CreatedAt should get retrieved from a stored kopia config. r := Repository{ - Version: "v1", - Account: acct, - Storage: storage, + Version: "v1", + Account: acct, + Storage: storage, + dataLayer: &k, } return r, nil } + +func (r Repository) Close(ctx context.Context) error { + if r.dataLayer == nil { + return nil + } + + err := r.dataLayer.Close(ctx) + r.dataLayer = nil + + if err != nil { + return errors.Wrap(err, "closing corso Repository") + } + + return nil +} diff --git a/src/pkg/repository/repository_test.go b/src/pkg/repository/repository_test.go index ebfdce3ea..b69200d80 100644 --- a/src/pkg/repository/repository_test.go +++ b/src/pkg/repository/repository_test.go @@ -125,7 +125,13 @@ func (suite *RepositoryIntegrationSuite) TestInitialize() { suite.T().Run(test.prefix, func(t *testing.T) { st, err := test.storage() assert.NoError(t, err) - _, err = repository.Initialize(ctx, test.account, st) + r, err := repository.Initialize(ctx, test.account, st) + if err == nil { + defer func() { + assert.NoError(t, r.Close(ctx)) + }() + } + test.errCheck(t, err) }) }