Search code examples
pythonjwtjwejwcrypto

How to add expiry to JWE?


I am trying to add expiry time to JWE which I am generating using jwcrypto library in the following way

from jwcrypto import jwe, jwk, jwt
from datetime import datetime, timedelta
import time

# create JWK from existing key
jwk_str = '{"k":"29Js2yXM6P_4v9K1mHDlYVHw8Xvm_GEhvMTvKTRLRzY","kty":"oct"}'
jwk_key = jwk.JWK.from_json(jwk_str)

# calculate expiry time
d = datetime.now() + timedelta(seconds=5)
epoch = datetime.utcfromtimestamp(0)
total_seconds =  (d - epoch).total_seconds()
# Add exp to the claims
claims={"exp": total_seconds, "sub": "Some random payload"}
print(claims)
jwttoken = jwt.JWT(header={"alg": "A256KW", "enc": "A256CBC-HS512"}, claims=claims)
jwttoken.make_encrypted_token(jwk_key)
jwetokenstr = jwttoken.serialize()
print(jwetokenstr)

# wait for 10 seconds to cross the expiry time
time.sleep(10)

jwttoken = jwt.JWT()
jwttoken.deserialize(token, jwk_key) # Ideally this line should fail as expiry is reached but it doesn't
print(jwttoken.claims)

I am getting the payload, but expiry claim is not read and doesn't fail on expiry. What I am doing wrong ?


Solution

  • This ends up reducing to a datetime manipulation bug.

    The exp claim of a JSON web token should filled out with the seconds from epoch of the expiration time.

    datetime.now() returns a local time (not UTC time) datetime.datetime object. The code above then goes on to subtract this local time datetime.datetime object from the UTC time datetime.datetime object of 0-epoch time and evaluates the total seconds between these two to determine the expiry time. However, because this is comparing a local time datetime to a UTC time datetime, the number of seconds here is actually off of the epoch time by a constant factor of your local timezone difference from UTC.

    For example, if I live in a place where the time is 5 hours earlier than UTC, I will actually use an epoch time that is 5 * 60 * 60 seconds off of the true epoch time I want for the expiry with this code.

    Instead you could simply use round(time.time()) + x where x is the number of seconds forward in the future the JWT should expire. time.time() returns the seconds from epoch (but as a float so you need to round) from epoch.

    For example:

    from jwcrypto import jwe, jwk, jwt
    from datetime import datetime, timedelta
    import time
    
    jwk_str = '{"k":"29Js2yXM6P_4v9K1mHDlYVHw8Xvm_GEhvMTvKTRLRzY","kty":"oct"}'
    jwk_key = jwk.JWK.from_json(jwk_str)
    
    jwt_valid_seconds = 3
    expiry_time = round(time.time()) + jwt_valid_seconds
    claims={"exp": expiry_time, "sub": "Some random payload"}
    jwttoken = jwt.JWT(header={"alg": "A256KW", "enc": "A256CBC-HS512"}, claims=claims)
    jwttoken.make_encrypted_token(jwk_key)
    jwetokenstr = jwttoken.serialize()
    
    jwttoken2 = jwt.JWT()
    jwttoken2.deserialize(jwetokenstr, jwk_key)
    print('This should succeed because we are deserializing immediately before the JWT has expired:')
    print(jwttoken2.claims)
    
    # Wait for the JWT to expire, and then extra time for the leeway.
    leeway = 60
    time.sleep(leeway + jwt_valid_seconds + 1)
    
    jwttoken2 = jwt.JWT()
    print('\nThis should fail due to the JWT expiring:')
    jwttoken2.deserialize(jwetokenstr, jwk_key)
    

    gives the output

    (env) $ python jwe_expiry.py
    This should succeed because we are deserializing immediately before the JWT has expired:
    {"exp":1576737332,"sub":"Some random payload"}
    
    This should fail due to the JWT expiring:
    Traceback (most recent call last):
      File "jwe_expiry.py", line 26, in <module>
        jwttoken2.deserialize(jwetokenstr, jwk_key)
      File "... python3.7/site-packages/jwcrypto/jwt.py", line 493, in deserialize
        self._check_provided_claims()
      File "... python3.7/site-packages/jwcrypto/jwt.py", line 370, in _check_provided_claims
        self._check_default_claims(claims)
      File "... python3.7/site-packages/jwcrypto/jwt.py", line 351, in _check_default_claims
        self._check_exp(claims['exp'], time.time(), self._leeway)
      File "... python3.7/site-packages/jwcrypto/jwt.py", line 333, in _check_exp
        claim, limit, leeway))
    jwcrypto.jwt.JWTExpired: Expired at 1576737332, time: 1576737392(leeway: 60)