diff --git a/src/go.mod b/src/go.mod index 146e144c6..24c5bd2dc 100644 --- a/src/go.mod +++ b/src/go.mod @@ -10,6 +10,7 @@ require ( github.com/armon/go-metrics v0.4.1 github.com/aws/aws-xray-sdk-go v1.8.2 github.com/cenkalti/backoff/v4 v4.2.1 + github.com/golang-jwt/jwt/v5 v5.0.0 github.com/google/uuid v1.3.1 github.com/h2non/gock v1.2.0 github.com/kopia/kopia v0.13.0 @@ -46,7 +47,6 @@ require ( github.com/aws/aws-sdk-go v1.45.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/gofrs/flock v0.8.1 // indirect - github.com/golang-jwt/jwt/v5 v5.0.0 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect github.com/hashicorp/cronexpr v1.1.2 // indirect diff --git a/src/internal/common/jwt/jwt.go b/src/internal/common/jwt/jwt.go new file mode 100644 index 000000000..5d2aa6d2a --- /dev/null +++ b/src/internal/common/jwt/jwt.go @@ -0,0 +1,39 @@ +package jwt + +import ( + "time" + + "github.com/alcionai/clues" + jwt "github.com/golang-jwt/jwt/v5" +) + +// IsJWTExpired checks if the JWT token is past expiry by analyzing the +// "exp" claim present in the token. Token is considered expired if "exp" +// claim < current time. Missing "exp" claim is considered as non-expired. +// An error is returned if the supplied token is malformed. +func IsJWTExpired( + rawToken string, +) (bool, error) { + p := jwt.NewParser() + + // Note: Call to ParseUnverified is intentional since token verification is + // not our objective. We only care about the embed claims in the token. + // We assume the token signature is valid & verified by caller stack. + token, _, err := p.ParseUnverified(rawToken, &jwt.RegisteredClaims{}) + if err != nil { + return false, clues.Wrap(err, "invalid jwt") + } + + t, err := token.Claims.GetExpirationTime() + if err != nil { + return false, clues.Wrap(err, "getting token expiry time") + } + + if t == nil { + return false, nil + } + + expired := t.Before(time.Now()) + + return expired, nil +} diff --git a/src/internal/common/jwt/jwt_test.go b/src/internal/common/jwt/jwt_test.go new file mode 100644 index 000000000..1b7f334f0 --- /dev/null +++ b/src/internal/common/jwt/jwt_test.go @@ -0,0 +1,115 @@ +package jwt + +import ( + "testing" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/tester" +) + +type JWTUnitSuite struct { + tester.Suite +} + +func TestJWTUnitSuite(t *testing.T) { + suite.Run(t, &JWTUnitSuite{Suite: tester.NewUnitSuite(t)}) +} + +// createJWTToken creates a JWT token with the specified expiration time. +func createJWTToken( + claims jwt.RegisteredClaims, +) (string, error) { + // build claims from map + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + return token.SignedString([]byte("")) +} + +const ( + // Raw test token valid for 100 years. + rawToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9." + + "eyJuYmYiOiIxNjkxODE5NTc5IiwiZXhwIjoiMzk0NTUyOTE3OSIsImVuZHBvaW50dXJsTGVuZ3RoIjoiMTYw" + + "IiwiaXNsb29wYmFjayI6IlRydWUiLCJ2ZXIiOiJoYXNoZWRwcm9vZnRva2VuIiwicm9sZXMiOiJhbGxmaWxl" + + "cy53cml0ZSBhbGxzaXRlcy5mdWxsY29udHJvbCBhbGxwcm9maWxlcy5yZWFkIiwidHQiOiIxIiwiYWxnIjoi" + + "SFMyNTYifQ" + + ".signature" +) + +func (suite *JWTUnitSuite) TestIsJWTExpired() { + table := []struct { + name string + expect bool + getToken func() (string, error) + expectErr assert.ErrorAssertionFunc + }{ + { + name: "alive token", + getToken: func() (string, error) { + return createJWTToken( + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }) + }, + expect: false, + expectErr: assert.NoError, + }, + { + name: "expired token", + getToken: func() (string, error) { + return createJWTToken( + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + }) + }, + expect: true, + expectErr: assert.NoError, + }, + // Test with a raw token which is not generated with go-jwt lib. + { + name: "alive raw token", + getToken: func() (string, error) { + return rawToken, nil + }, + expect: false, + expectErr: assert.NoError, + }, + { + name: "alive token, missing exp claim", + getToken: func() (string, error) { + return createJWTToken(jwt.RegisteredClaims{}) + }, + expect: false, + expectErr: assert.NoError, + }, + { + name: "malformed token", + getToken: func() (string, error) { + return "header.claims.signature", nil + }, + expect: false, + expectErr: assert.Error, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + _, flush := tester.NewContext(t) + defer flush() + + token, err := test.getToken() + require.NoError(t, err) + + expired, err := IsJWTExpired(token) + test.expectErr(t, err) + + assert.Equal(t, test.expect, expired) + }) + } +} diff --git a/src/internal/common/url.go b/src/internal/common/url.go new file mode 100644 index 000000000..7efaf14ac --- /dev/null +++ b/src/internal/common/url.go @@ -0,0 +1,27 @@ +package common + +import ( + "net/url" + + "github.com/alcionai/clues" +) + +// GetQueryParamFromURL parses an URL and returns value of the specified +// query parameter. +func GetQueryParamFromURL( + rawURL, queryParam string, +) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", clues.Wrap(err, "parsing url") + } + + qp := u.Query() + + val := qp.Get(queryParam) + if len(val) == 0 { + return "", clues.New("query param not found").With("query_param", queryParam) + } + + return val, nil +} diff --git a/src/internal/common/url_test.go b/src/internal/common/url_test.go new file mode 100644 index 000000000..fa1d1cc20 --- /dev/null +++ b/src/internal/common/url_test.go @@ -0,0 +1,72 @@ +package common_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/common" + "github.com/alcionai/corso/src/internal/tester" +) + +type URLUnitSuite struct { + tester.Suite +} + +func TestURLUnitSuite(t *testing.T) { + suite.Run(t, &URLUnitSuite{Suite: tester.NewUnitSuite(t)}) +} + +func (suite *URLUnitSuite) TestGetQueryParamFromURL() { + qp := "tempauth" + table := []struct { + name string + rawURL string + queryParam string + expectedResult string + expect assert.ErrorAssertionFunc + }{ + { + name: "valid", + rawURL: "http://localhost:8080?" + qp + "=h.c.s&other=val", + queryParam: qp, + expectedResult: "h.c.s", + expect: assert.NoError, + }, + { + name: "query param not found", + rawURL: "http://localhost:8080?other=val", + queryParam: qp, + expect: assert.Error, + }, + { + name: "empty query param", + rawURL: "http://localhost:8080?" + qp + "=h.c.s&other=val", + queryParam: "", + expect: assert.Error, + }, + // In case of multiple occurrences, the first occurrence of param is returned. + { + name: "multiple occurrences", + rawURL: "http://localhost:8080?" + qp + "=h.c.s&other=val&" + qp + "=h1.c1.s1", + queryParam: qp, + expectedResult: "h.c.s", + expect: assert.NoError, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + _, flush := tester.NewContext(t) + defer flush() + + token, err := common.GetQueryParamFromURL(test.rawURL, test.queryParam) + test.expect(t, err) + + assert.Equal(t, test.expectedResult, token) + }) + } +}