diff --git a/src/internal/common/jwt/jwt.go b/src/internal/common/jwt/jwt.go index 5d2aa6d2a..197ffa687 100644 --- a/src/internal/common/jwt/jwt.go +++ b/src/internal/common/jwt/jwt.go @@ -1,10 +1,13 @@ package jwt import ( + "context" "time" "github.com/alcionai/clues" jwt "github.com/golang-jwt/jwt/v5" + + "github.com/alcionai/corso/src/pkg/logger" ) // IsJWTExpired checks if the JWT token is past expiry by analyzing the @@ -37,3 +40,51 @@ func IsJWTExpired( return expired, nil } + +// GetJWTLifetime returns the issued at(iat) and expiration time(exp) claims +// present in the JWT token. These are optional claims and may not be present +// in the token. Absence is not reported as an error. +// +// An error is returned if the supplied token is malformed. Times are returned +// in UTC to have parity with graph responses. +func GetJWTLifetime( + ctx context.Context, + rawToken string, +) (time.Time, time.Time, error) { + var ( + issuedAt time.Time + expiresAt time.Time + ) + + p := jwt.NewParser() + + token, _, err := p.ParseUnverified(rawToken, &jwt.RegisteredClaims{}) + if err != nil { + logger.CtxErr(ctx, err).Debug("parsing jwt token") + return time.Time{}, time.Time{}, clues.Wrap(err, "invalid jwt") + } + + exp, err := token.Claims.GetExpirationTime() + if err != nil { + logger.CtxErr(ctx, err).Debug("extracting exp claim") + return time.Time{}, time.Time{}, clues.Wrap(err, "getting token expiry time") + } + + iat, err := token.Claims.GetIssuedAt() + if err != nil { + logger.CtxErr(ctx, err).Debug("extracting iat claim") + return time.Time{}, time.Time{}, clues.Wrap(err, "getting token issued at time") + } + + // Absence of iat or exp claims is not reported as an error by jwt library as these + // are optional as per spec. + if iat != nil { + issuedAt = iat.UTC() + } + + if exp != nil { + expiresAt = exp.UTC() + } + + return issuedAt, expiresAt, nil +} diff --git a/src/internal/common/jwt/jwt_test.go b/src/internal/common/jwt/jwt_test.go index 1b7f334f0..f9a6f2672 100644 --- a/src/internal/common/jwt/jwt_test.go +++ b/src/internal/common/jwt/jwt_test.go @@ -113,3 +113,134 @@ func (suite *JWTUnitSuite) TestIsJWTExpired() { }) } } + +func (suite *JWTUnitSuite) TestGetJWTLifetime() { + // Set of time values to be used in the tests. + // Truncate to seconds for comparisons since jwt tokens have second + // level precision. + idToTime := map[string]time.Time{ + "T0": time.Now().UTC().Add(-time.Hour).Truncate(time.Second), + "T1": time.Now().UTC().Truncate(time.Second), + "T2": time.Now().UTC().Add(time.Hour).Truncate(time.Second), + } + + table := []struct { + name string + getToken func() (string, error) + expectFunc func(t *testing.T, iat time.Time, exp time.Time) + expectErr assert.ErrorAssertionFunc + }{ + { + name: "alive token", + getToken: func() (string, error) { + return createJWTToken( + jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(idToTime["T0"]), + ExpiresAt: jwt.NewNumericDate(idToTime["T1"]), + }) + }, + expectFunc: func(t *testing.T, iat time.Time, exp time.Time) { + assert.Equal(t, idToTime["T0"], iat) + assert.Equal(t, idToTime["T1"], exp) + }, + expectErr: assert.NoError, + }, + // Test with a token which is not generated using the go-jwt lib. + // This is a long lived token which is valid for 100 years. + { + name: "alive raw token with iat and exp claims", + getToken: func() (string, error) { + return rawToken, nil + }, + expectFunc: func(t *testing.T, iat time.Time, exp time.Time) { + assert.Less(t, iat, time.Now(), "iat should be in the past") + assert.Greater(t, exp, time.Now(), "exp should be in the future") + }, + expectErr: assert.NoError, + }, + // Regardless of whether the token is expired or not, we should be able to + // extract the iat and exp claims from it without error. + { + name: "expired token", + getToken: func() (string, error) { + return createJWTToken( + jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(idToTime["T1"]), + ExpiresAt: jwt.NewNumericDate(idToTime["T0"]), + }) + }, + expectFunc: func(t *testing.T, iat time.Time, exp time.Time) { + assert.Equal(t, idToTime["T1"], iat) + assert.Equal(t, idToTime["T0"], exp) + }, + expectErr: assert.NoError, + }, + { + name: "missing iat claim", + getToken: func() (string, error) { + return createJWTToken( + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(idToTime["T2"]), + }) + }, + expectFunc: func(t *testing.T, iat time.Time, exp time.Time) { + assert.Equal(t, time.Time{}, iat) + assert.Equal(t, idToTime["T2"], exp) + }, + expectErr: assert.NoError, + }, + { + name: "missing exp claim", + getToken: func() (string, error) { + return createJWTToken( + jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(idToTime["T0"]), + }) + }, + expectFunc: func(t *testing.T, iat time.Time, exp time.Time) { + assert.Equal(t, idToTime["T0"], iat) + assert.Equal(t, time.Time{}, exp) + }, + expectErr: assert.NoError, + }, + { + name: "both claims missing", + getToken: func() (string, error) { + return createJWTToken(jwt.RegisteredClaims{}) + }, + expectFunc: func(t *testing.T, iat time.Time, exp time.Time) { + assert.Equal(t, time.Time{}, iat) + assert.Equal(t, time.Time{}, exp) + }, + expectErr: assert.NoError, + }, + { + name: "malformed token", + getToken: func() (string, error) { + return "header.claims.signature", nil + }, + expectFunc: func(t *testing.T, iat time.Time, exp time.Time) { + assert.Equal(t, time.Time{}, iat) + assert.Equal(t, time.Time{}, exp) + }, + expectErr: assert.Error, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + token, err := test.getToken() + require.NoError(t, err) + + iat, exp, err := GetJWTLifetime(ctx, token) + test.expectErr(t, err) + + test.expectFunc(t, iat, exp) + }) + } +} diff --git a/src/pkg/services/m365/api/graph/http_wrapper.go b/src/pkg/services/m365/api/graph/http_wrapper.go index 77f5766bf..a83206830 100644 --- a/src/pkg/services/m365/api/graph/http_wrapper.go +++ b/src/pkg/services/m365/api/graph/http_wrapper.go @@ -146,7 +146,7 @@ func (hw httpWrapper) Request( resp, err := hw.client.Do(req) if err == nil { - logResp(ictx, resp) + logResp(ictx, resp, req) return resp, nil } diff --git a/src/pkg/services/m365/api/graph/logging.go b/src/pkg/services/m365/api/graph/logging.go index 09283ec2b..7cc529f83 100644 --- a/src/pkg/services/m365/api/graph/logging.go +++ b/src/pkg/services/m365/api/graph/logging.go @@ -5,9 +5,12 @@ import ( "net/http" "net/http/httputil" "os" + "strings" + "time" "github.com/alcionai/clues" + "github.com/alcionai/corso/src/internal/common/jwt" "github.com/alcionai/corso/src/internal/common/pii" "github.com/alcionai/corso/src/pkg/logger" ) @@ -28,7 +31,7 @@ func shouldLogRespBody(resp *http.Response) bool { resp.StatusCode > 399 } -func logResp(ctx context.Context, resp *http.Response) { +func logResp(ctx context.Context, resp *http.Response, req *http.Request) { var ( log = logger.Ctx(ctx) respClass = resp.StatusCode / 100 @@ -45,6 +48,25 @@ func logResp(ctx context.Context, resp *http.Response) { return } + // Log bearer token iat and exp claims if we hit 401s. This is purely for + // debugging purposes and will be removed in the future. + if resp.StatusCode == http.StatusUnauthorized { + errs := []any{"graph api error: " + resp.Status} + + // As per MSFT docs, the token may have a special format and may not always + // validate as a JWT. Hence log token lifetime in a best effort manner only. + iat, exp, err := getTokenLifetime(ctx, req) + if err != nil { + errs = append(errs, " getting token lifetime: ", err) + } + + log.With("response", getRespDump(ctx, resp, logBody)). + With("token issued at", iat, "token expires at", exp). + Error(errs...) + + return + } + // Log api calls according to api debugging configurations. switch respClass { case 2: @@ -91,3 +113,32 @@ func getReqCtx(req *http.Request) context.Context { "url", logURL, "request_content_len", req.ContentLength) } + +// GetTokenLifetime extracts the JWT token embedded in the request and returns +// the token's issue and expiration times. The token is expected to be in the +// "Authorization" header, with a "Bearer " prefix. If the token is not present +// or is malformed, an error is returned. +func getTokenLifetime( + ctx context.Context, + req *http.Request, +) (time.Time, time.Time, error) { + if req == nil { + return time.Time{}, time.Time{}, clues.New("nil request") + } + + // Don't throw an error if auth header is absent. This is to prevent + // unnecessary noise in the logs for requests served by the http requestor + // client. These requests may be preauthenticated and may not carry auth headers. + rawToken := req.Header.Get("Authorization") + if len(rawToken) == 0 { + return time.Time{}, time.Time{}, nil + } + + // Strip the "Bearer " prefix from the token. This prefix is guaranteed to be + // present as per msft docs. But even if it's not, the jwt lib will handle + // malformed tokens gracefully and return an error. + rawToken = strings.TrimPrefix(rawToken, "Bearer ") + iat, exp, err := jwt.GetJWTLifetime(ctx, rawToken) + + return iat, exp, clues.Stack(err).OrNil() +} diff --git a/src/pkg/services/m365/api/graph/middleware.go b/src/pkg/services/m365/api/graph/middleware.go index 152665e9c..bba0f388b 100644 --- a/src/pkg/services/m365/api/graph/middleware.go +++ b/src/pkg/services/m365/api/graph/middleware.go @@ -130,7 +130,7 @@ func (mw *LoggingMiddleware) Intercept( "resp_status_code", resp.StatusCode, "resp_content_len", resp.ContentLength) - logResp(ctx, resp) + logResp(ctx, resp, req) return resp, err } diff --git a/src/pkg/services/m365/api/graph/middleware_test.go b/src/pkg/services/m365/api/graph/middleware_test.go index 1a1f93501..6f0210d99 100644 --- a/src/pkg/services/m365/api/graph/middleware_test.go +++ b/src/pkg/services/m365/api/graph/middleware_test.go @@ -505,3 +505,95 @@ func (suite *MiddlewareUnitSuite) TestLimiterConsumption() { }) } } + +const ( + // Raw test token valid for 100 years. + rawToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9." + + "eyJuYmYiOiIxNjkxODE5NTc5IiwiZXhwIjoiMzk0NTUyOTE3OSIsImVuZHBvaW50dXJsTGVuZ3RoIjoiMTYw" + + "IiwiaXNsb29wYmFjayI6IlRydWUiLCJ2ZXIiOiJoYXNoZWRwcm9vZnRva2VuIiwicm9sZXMiOiJhbGxmaWxl" + + "cy53cml0ZSBhbGxzaXRlcy5mdWxsY29udHJvbCBhbGxwcm9maWxlcy5yZWFkIiwidHQiOiIxIiwiYWxnIjoi" + + "SFMyNTYifQ" + + ".signature" +) + +// Tests getTokenLifetime +func (suite *MiddlewareUnitSuite) TestGetTokenLifetime() { + table := []struct { + name string + request *http.Request + expectErr assert.ErrorAssertionFunc + }{ + { + name: "nil request", + request: nil, + expectErr: assert.Error, + }, + // Test that we don't throw an error if auth header is absent. + // This is to prevent unnecessary noise in logs for requestor http client. + { + name: "no authorization header", + request: &http.Request{ + Header: http.Header{}, + }, + expectErr: assert.NoError, + }, + { + name: "well formed auth header with token", + request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + rawToken}, + }, + }, + expectErr: assert.NoError, + }, + { + name: "Missing Bearer prefix but valid token", + request: &http.Request{ + Header: http.Header{ + "Authorization": []string{rawToken}, + }, + }, + expectErr: assert.NoError, + }, + { + name: "invalid token", + request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + "invalid"}, + }, + }, + expectErr: assert.Error, + }, + { + name: "valid prefix but empty token", + request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer "}, + }, + }, + expectErr: assert.Error, + }, + { + name: "Invalid prefix but valid token", + request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer" + rawToken}, + }, + }, + expectErr: assert.Error, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + + ctx, flush := tester.NewContext(t) + defer flush() + + // iat, exp specific tests are in jwt package. + _, _, err := getTokenLifetime(ctx, test.request) + test.expectErr(t, err, clues.ToCore(err)) + }) + } +}