Search code examples
ipwhitelistflask-jwt-extended

Allow JWT Tokens if Expired, Provided User is from Trusted IP address


Using flask-jwt-extended, I have a situation where the API has to serve both users, and also a series of web-applications (for example, one of the latter is a chatbot).

For the users, the package functionality out of the box is just perfect, however, for the web applications, I would like the JWT tokens to behave more like API-keys, where they don't necessarily expire after a period of time.

So what I would like to do, is to suppress the checks for 'expiry' provided the request has come from a predefined and trusted IP address.

I have a sqlalchemy model that stores trusted 'ip addresses', and this has a foreign key relationship with the users model, meaning that a user can specify one (or more) whitelisted ip addresses.

Now the decode_token function:

https://flask-jwt-extended.readthedocs.io/en/stable/_modules/flask_jwt_extended/utils.html#decode_token

Has an argument allow_expired, which permits the expiry to be overridden, however, this is not utilised in any way within the _decode_jwt_from_request(...) function, which seems to be instrumental when validating JWT tokens.

Ultimately, I am after a decorator replacement to @jwt_required, which permits expired tokens to be used, provided the request is coming from the whitelisted IP address.

My questions are twofold:

  1. Is the above structure ok from a security point of view?, and,
  2. Without having to duplicate (and slightly modify) entire functions from the library, how might I go about the above?

Solution

  • Unless anyone tells me a better way to do it, I ended up monkey patching the decode_token function:

    I have highlighted the 'patched' region, which intercepts the 'ExpiredSignatureError', and checks if the ip address is in the user ip-whitelist, and if so, permits business as usual.

    def decode_token(encoded_token, csrf_value=None, allow_expired=False):
        """
        Returns the decoded token (python dict) from an encoded JWT. This does all
        the checks to insure that the decoded token is valid before returning it.
    
        :param encoded_token: The encoded JWT to decode into a python dict.
        :param csrf_value: Expected CSRF double submit value (optional)
        :param allow_expired: Options to ignore exp claim validation in token
        :return: Dictionary containing contents of the JWT
        """
        jwt_manager = _get_jwt_manager()
        unverified_claims = jwt.decode(
            encoded_token, verify=False, algorithms=config.decode_algorithms
        )
        unverified_headers = jwt.get_unverified_header(encoded_token)
        # Attempt to call callback with both claims and headers, but fallback to just claims
        # for backwards compatibility
        try:
            secret = jwt_manager._decode_key_callback(unverified_claims, unverified_headers)
        except TypeError:
            msg = (
                "The single-argument (unverified_claims) form of decode_key_callback ",
                "is deprecated. Update your code to use the two-argument form ",
                "(unverified_claims, unverified_headers)."
            )
            warn(msg, DeprecationWarning)
            secret = jwt_manager._decode_key_callback(unverified_claims)
    
        try:
            return decode_jwt(
                encoded_token=encoded_token,
                secret=secret,
                algorithms=config.decode_algorithms,
                identity_claim_key=config.identity_claim_key,
                user_claims_key=config.user_claims_key,
                csrf_value=csrf_value,
                audience=config.audience,
                issuer=config.issuer,
                leeway=config.leeway,
                allow_expired=allow_expired
            )
        except ExpiredSignatureError:
            expired_token = decode_jwt(
                encoded_token=encoded_token,
                secret=secret,
                algorithms=config.decode_algorithms,
                identity_claim_key=config.identity_claim_key,
                user_claims_key=config.user_claims_key,
                csrf_value=csrf_value,
                audience=config.audience,
                issuer=config.issuer,
                leeway=config.leeway,
                allow_expired=True
            )
    
    
            # ------------------------------------------------------------
            # Author:   Nicholas E. Hamilton
            # Date:     25th August 2019
            # Patch:    Check if ip address is in the whitelist,
            #           and if so, permit an expired token
            # ------------------------------------------------------------
            user = user_loader(expired_token[config.identity_claim_key])
            ip_address = request.remote_addr
            if user and ip_address:
                ip_whitelist = [x.ip_address for x in user.ip_whitelist]
                if ip_address in ip_whitelist:
                    return expired_token
            # >>>> END PATCH
    
            # Proceed as normal
            ctx_stack.top.expired_jwt = expired_token
            raise
    
    flask_jwt_extended.view_decorators.decode_token = flask_jwt_extended.utils.decode_token = decode_token