From f03eeefd903d287c7189adab4feed43331aa3498 Mon Sep 17 00:00:00 2001 From: Abhishek Pandey Date: Mon, 2 Oct 2023 12:22:07 +0530 Subject: [PATCH] Check if JWT token has expired --- src/go.mod | 1 + src/go.sum | 2 + src/internal/common/jwt/jwt.go | 29 +++++++++ src/internal/common/jwt/jwt_test.go | 99 +++++++++++++++++++++++++++++ src/internal/common/url.go | 27 ++++++++ src/internal/common/url_test.go | 72 +++++++++++++++++++++ 6 files changed, 230 insertions(+) create mode 100644 src/internal/common/jwt/jwt.go create mode 100644 src/internal/common/jwt/jwt_test.go create mode 100644 src/internal/common/url.go create mode 100644 src/internal/common/url_test.go diff --git a/src/go.mod b/src/go.mod index d20785bf6..9f2968a84 100644 --- a/src/go.mod +++ b/src/go.mod @@ -79,6 +79,7 @@ require ( github.com/edsrzf/mmap-go v1.1.0 // indirect github.com/go-logr/logr v1.2.4 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/protobuf v1.5.3 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/src/go.sum b/src/go.sum index cda31894a..8b13fe422 100644 --- a/src/go.sum +++ b/src/go.sum @@ -138,6 +138,8 @@ github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= diff --git a/src/internal/common/jwt/jwt.go b/src/internal/common/jwt/jwt.go new file mode 100644 index 000000000..2e74ab677 --- /dev/null +++ b/src/internal/common/jwt/jwt.go @@ -0,0 +1,29 @@ +package jwt + +import ( + "github.com/alcionai/clues" + jwt "github.com/golang-jwt/jwt" +) + +// IsJWTExpired checks if the JWT token is past expiry by analyzing the +// "exp" claim present in the token. Token is considered alive if : +// 1. time.now <= "exp" claim. +// 2. "exp" claim is missing. +// An error is returned if the supplied token is malformed. +func IsJWTExpired( + rawToken string, +) (bool, error) { + // Note: Call to ParseUnverified is intentional since token verification is + // not our objective. We assume the token signature is valid & verified + // by caller stack. We only care about the embed claims in the token. + token, _, err := new(jwt.Parser).ParseUnverified(rawToken, jwt.MapClaims{}) + if err != nil { + return false, clues.Wrap(err, "invalid jwt") + } + + claims, _ := token.Claims.(jwt.MapClaims) + // If "exp" claim is missing, token is considered alive. + expired := !claims.VerifyExpiresAt(jwt.TimeFunc().Unix(), false) + + return expired, nil +} diff --git a/src/internal/common/jwt/jwt_test.go b/src/internal/common/jwt/jwt_test.go new file mode 100644 index 000000000..e5a4602b8 --- /dev/null +++ b/src/internal/common/jwt/jwt_test.go @@ -0,0 +1,99 @@ +package jwt + +import ( + "testing" + "time" + + jwt "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/tester" +) + +type JWTUnitSuite struct { + tester.Suite +} + +func TestJWTUnitSuite(t *testing.T) { + suite.Run(t, &JWTUnitSuite{Suite: tester.NewUnitSuite(t)}) +} + +// createJWTToken creates a JWT token with the specified expiration time. +func createJWTToken( + expiration time.Time, + claims jwt.MapClaims, +) (string, error) { + // build claims from map + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + return token.SignedString([]byte("")) +} + +func (suite *JWTUnitSuite) TestIsJWTExpired() { + table := []struct { + name string + expect bool + getToken func() (string, error) + expectErr assert.ErrorAssertionFunc + }{ + { + name: "alive token", + getToken: func() (string, error) { + return createJWTToken( + time.Now().Add(time.Hour), + jwt.MapClaims{ + "exp": time.Now().Add(time.Hour).Unix(), + }) + }, + expect: false, + expectErr: assert.NoError, + }, + { + name: "expired token", + getToken: func() (string, error) { + return createJWTToken( + time.Now().Add(time.Hour), + jwt.MapClaims{ + "exp": time.Now().Add(-time.Hour).Unix(), + }) + }, + expect: true, + expectErr: assert.NoError, + }, + { + name: "alive token, missing exp claim", + getToken: func() (string, error) { + return createJWTToken(time.Now().Add(time.Hour), jwt.MapClaims{}) + }, + expect: false, + expectErr: assert.NoError, + }, + { + name: "malformed token", + getToken: func() (string, error) { + return "header.claims.signature", nil + }, + expect: false, + expectErr: assert.Error, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + _, flush := tester.NewContext(t) + defer flush() + + token, err := test.getToken() + require.NoError(t, err) + + expired, err := IsJWTExpired(token) + test.expectErr(t, err) + + assert.Equal(t, test.expect, expired) + }) + } +} diff --git a/src/internal/common/url.go b/src/internal/common/url.go new file mode 100644 index 000000000..7efaf14ac --- /dev/null +++ b/src/internal/common/url.go @@ -0,0 +1,27 @@ +package common + +import ( + "net/url" + + "github.com/alcionai/clues" +) + +// GetQueryParamFromURL parses an URL and returns value of the specified +// query parameter. +func GetQueryParamFromURL( + rawURL, queryParam string, +) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", clues.Wrap(err, "parsing url") + } + + qp := u.Query() + + val := qp.Get(queryParam) + if len(val) == 0 { + return "", clues.New("query param not found").With("query_param", queryParam) + } + + return val, nil +} diff --git a/src/internal/common/url_test.go b/src/internal/common/url_test.go new file mode 100644 index 000000000..fa1d1cc20 --- /dev/null +++ b/src/internal/common/url_test.go @@ -0,0 +1,72 @@ +package common_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + + "github.com/alcionai/corso/src/internal/common" + "github.com/alcionai/corso/src/internal/tester" +) + +type URLUnitSuite struct { + tester.Suite +} + +func TestURLUnitSuite(t *testing.T) { + suite.Run(t, &URLUnitSuite{Suite: tester.NewUnitSuite(t)}) +} + +func (suite *URLUnitSuite) TestGetQueryParamFromURL() { + qp := "tempauth" + table := []struct { + name string + rawURL string + queryParam string + expectedResult string + expect assert.ErrorAssertionFunc + }{ + { + name: "valid", + rawURL: "http://localhost:8080?" + qp + "=h.c.s&other=val", + queryParam: qp, + expectedResult: "h.c.s", + expect: assert.NoError, + }, + { + name: "query param not found", + rawURL: "http://localhost:8080?other=val", + queryParam: qp, + expect: assert.Error, + }, + { + name: "empty query param", + rawURL: "http://localhost:8080?" + qp + "=h.c.s&other=val", + queryParam: "", + expect: assert.Error, + }, + // In case of multiple occurrences, the first occurrence of param is returned. + { + name: "multiple occurrences", + rawURL: "http://localhost:8080?" + qp + "=h.c.s&other=val&" + qp + "=h1.c1.s1", + queryParam: qp, + expectedResult: "h.c.s", + expect: assert.NoError, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + _, flush := tester.NewContext(t) + defer flush() + + token, err := common.GetQueryParamFromURL(test.rawURL, test.queryParam) + test.expect(t, err) + + assert.Equal(t, test.expectedResult, token) + }) + } +}