Search code examples
pythonfirebasefirebase-authenticationjwtpyjwt

Decode Firebase JWT in Python using PyJWT


I have written the following code :

def check_token(token):
    response = requests.get("https://www.googleapis.com/robot/v1/metadata/x509/[email protected]")
    key_list = response.json()
    decoded_token = jwt.decode(token, key=key_list, algorithms=["RS256"])
    print(f"Decoded token : {decoded_token}")

I am trying to decode the token provided by firebase client side to verify it server-side.
The above code is throwing the following exception :

TypeError: Expecting a PEM-formatted key.

I have tried to not pass a list to the jwt.decode method, only the key content and i have a bigger error that the library could not deserialize the Key.
I was following this answer but i am getting this error.

Is it a requests conversion problem ? What am i doing wrong ?


Solution

  • The 2nd parameter key in decode() seems to take a string value instead of list. The Google API request returns a dict/map containing multiple keys. The flow goes like:

    1. Fetch public keys from the Google API endpoint
    2. Then read headers without validation to get the kid claim then use it to get appropriate key from that dict
    3. That is a X.509 Certificate and not the public key as in this answer so you need to get public key from that.

    The following function worked for me:

    import jwt
    import requests
    from cryptography.hazmat.backends import default_backend
    from cryptography import x509
    
    def check_token(token):
        n_decoded = jwt.get_unverified_header(token)
        kid_claim = n_decoded["kid"]
    
        response = requests.get("https://www.googleapis.com/robot/v1/metadata/x509/[email protected]")
        x509_key = response.json()[kid_claim]
        key = x509.load_pem_x509_certificate(x509_key.encode('utf-8'),  backend=default_backend())
        public_key = key.public_key()
    
        decoded_token = jwt.decode(token, public_key, ["RS256"], options=None, audience="<FIREBASE_PROJECT_ID>")
        print(f"Decoded token : {decoded_token}")
    
    check_token("FIREBASE_ID_TOKEN")