Search code examples
pythonoauth-2.0authlib

Does authorize_access_token also verify an id token?


For retrieving an OAuth / OpenID Connect token, the function authorize_access_token is used in the authlib docs. OAuth providers like Google strongly advise to manually verify these tokens, for example by checking the expiry date.

Where is the documentation on authorize_access_token? I can't find anything on the website. Does the function verify the token automatically or do I have to do that myself?


Solution

  • You are correct that the documentation for authlib is lite on what the function authorize_access_token() does. There is some indication in the Web OAuth Client selection of the documentation, but that section only mentions the return value of authorize_access_token().

    # this only returns a token
    token = oauth.providername.authorize_access_token()
    

    In order to fully understand the functionality of authorize_access_token() we need to explore authlib’s GitHub repository.

    Here is a section of code from the a django client code making an OAuth 2 request.

    def authorize_access_token(self, request, **kwargs):
            """Fetch access token in one step.
            :param request: HTTP request instance from Django view.
            :return: A token dict.
            """
            if request.method == 'GET':
                error = request.GET.get('error')
                if error:
                    description = request.GET.get('error_description')
                    raise OAuthError(error=error, description=description)
                params = {
                    'code': request.GET.get('code'),
                    'state': request.GET.get('state'),
                }
            else:
                params = {
                    'code': request.POST.get('code'),
                    'state': request.POST.get('state'),
                }
    
            state_data = self.framework.get_state_data(request.session, params.get('state'))
            params = self._format_state_params(state_data, params)
            token = self.fetch_access_token(**params, **kwargs)
    
            if 'id_token' in token and 'nonce' in state_data:
                userinfo = self.parse_id_token(token, nonce=state_data['nonce'])
                token['userinfo'] = userinfo
            return token
    
    

    In the authorize_access_token() code above we have this call:

    error = request.GET.get('error')
    

    The call accessing a segment of the base_client code. Which has these verifications:

    • MismatchingStateError
    • MissingRequestTokenError
    • MissingTokenError

    One of the sub calls in the base_client code is to this:

     def _on_update_token(self, token, refresh_token=None, access_token=None):
     raise NotImplementedError()
    

    If something isn't right then authorize_access_token() will thrown an error:

    if error:
       description = request.GET.get('error_description')
       raise OAuthError(error=error, description=description)
    

    Each of these calls within the authorize_access_token() function are also doing various types of checks using the base_client code, which is extensive.

    state_data = self.framework.get_state_data(request.session, params.get('state'))
    params = self._format_state_params(state_data, params)
    token = self.fetch_access_token(**params, **kwargs)
    

    The first call state_data calls this function:

    def get_state_data(self, session, state):
            key = f'_state_{self.name}_{state}'
            if self.cache:
                value = self._get_cache_data(key)
            else:
                value = session.get(key)
            if value:
                return value.get('data')
            return None
    
    

    The second call params calls this function:

     def _format_state_params(state_data, params):
            if state_data is None:
                raise MismatchingStateError()
    
            code_verifier = state_data.get('code_verifier')
            if code_verifier:
                params['code_verifier'] = code_verifier
    
            redirect_uri = state_data.get('redirect_uri')
            if redirect_uri:
                params['redirect_uri'] = redirect_uri
            return params
    

    The third call token calls this function:

    def fetch_access_token(self, request_token=None, **kwargs):
            """Fetch access token in one step.
            :param request_token: A previous request token for OAuth 1.
            :param kwargs: Extra parameters to fetch access token.
            :return: A token dict.
            """
            with self._get_oauth_client() as client:
                if request_token is None:
                    raise MissingRequestTokenError()
                # merge request token with verifier
                token = {}
                token.update(request_token)
                token.update(kwargs)
                client.token = token
                params = self.access_token_params or {}
                token = client.fetch_access_token(self.access_token_url, **params)
            return token
    

    The final call in the authorize_access_token() function is to parse_id_token, which call this function in the base_client code:

    async def parse_id_token(self, token, nonce, claims_options=None):
            """Return an instance of UserInfo from token's ``id_token``."""
            claims_params = dict(
                nonce=nonce,
                client_id=self.client_id,
            )
            if 'access_token' in token:
                claims_params['access_token'] = token['access_token']
                claims_cls = CodeIDToken
            else:
                claims_cls = ImplicitIDToken
    
            metadata = await self.load_server_metadata()
            if claims_options is None and 'issuer' in metadata:
                claims_options = {'iss': {'values': [metadata['issuer']]}}
    
            alg_values = metadata.get('id_token_signing_alg_values_supported')
            if not alg_values:
                alg_values = ['RS256']
    
            jwt = JsonWebToken(alg_values)
    
            jwk_set = await self.fetch_jwk_set()
            try:
                claims = jwt.decode(
                    token['id_token'],
                    key=JsonWebKey.import_key_set(jwk_set),
                    claims_cls=claims_cls,
                    claims_options=claims_options,
                    claims_params=claims_params,
                )
            except ValueError:
                jwk_set = await self.fetch_jwk_set(force=True)
                claims = jwt.decode(
                    token['id_token'],
                    key=JsonWebKey.import_key_set(jwk_set),
                    claims_cls=claims_cls,
                    claims_options=claims_options,
                    claims_params=claims_params,
                )
    
            # https://github.com/lepture/authlib/issues/259
            if claims.get('nonce_supported') is False:
                claims.params['nonce'] = None
            claims.validate(leeway=120)
            return UserInfo(claims)
    

    Within all the functions above there are calls to other functions within the base_client code. For instance here is another part of the code that is being accessed from the authorize_access_token() function:

    from requests import Session
    from requests.auth import AuthBase
    from authlib.oauth2.client import OAuth2Client
    from authlib.oauth2.auth import ClientAuth, TokenAuth
    from ..base_client import (
        OAuthError,
        InvalidTokenError,
        MissingTokenError,
        UnsupportedTokenTypeError,
    )
    from .utils import update_session_configure
    
    __all__ = ['OAuth2Session', 'OAuth2Auth']
    
    
    class OAuth2Auth(AuthBase, TokenAuth):
        """Sign requests for OAuth 2.0, currently only bearer token is supported."""
    
        def ensure_active_token(self):
            if self.client and not self.client.ensure_active_token(self.token):
                raise InvalidTokenError()
    
        def __call__(self, req):
            self.ensure_active_token()
            try:
                req.url, req.headers, req.body = self.prepare(
                    req.url, req.headers, req.body)
            except KeyError as error:
                description = 'Unsupported token_type: {}'.format(str(error))
                raise UnsupportedTokenTypeError(description=description)
            return req
    
    
    class OAuth2ClientAuth(AuthBase, ClientAuth):
        """Attaches OAuth Client Authentication to the given Request object.
        """
        def __call__(self, req):
            req.url, req.headers, req.body = self.prepare(
                req.method, req.url, req.headers, req.body
            )
            return req
    
    
    class OAuth2Session(OAuth2Client, Session):
        """Construct a new OAuth 2 client requests session.
        :param client_id: Client ID, which you get from client registration.
        :param client_secret: Client Secret, which you get from registration.
        :param authorization_endpoint: URL of the authorization server's
            authorization endpoint.
        :param token_endpoint: URL of the authorization server's token endpoint.
        :param token_endpoint_auth_method: client authentication method for
            token endpoint.
        :param revocation_endpoint: URL of the authorization server's OAuth 2.0
            revocation endpoint.
        :param revocation_endpoint_auth_method: client authentication method for
            revocation endpoint.
        :param scope: Scope that you needed to access user resources.
        :param redirect_uri: Redirect URI you registered as callback.
        :param token: A dict of token attributes such as ``access_token``,
            ``token_type`` and ``expires_at``.
        :param token_placement: The place to put token in HTTP request. Available
            values: "header", "body", "uri".
        :param update_token: A function for you to update token. It accept a
            :class:`OAuth2Token` as parameter.
        """
        client_auth_class = OAuth2ClientAuth
        token_auth_class = OAuth2Auth
        SESSION_REQUEST_PARAMS = (
            'allow_redirects', 'timeout', 'cookies', 'files',
            'proxies', 'hooks', 'stream', 'verify', 'cert', 'json'
        )
    
        def __init__(self, client_id=None, client_secret=None,
                     token_endpoint_auth_method=None,
                     revocation_endpoint_auth_method=None,
                     scope=None, redirect_uri=None,
                     token=None, token_placement='header',
                     update_token=None, **kwargs):
    
            Session.__init__(self)
            update_session_configure(self, kwargs)
    
            OAuth2Client.__init__(
                self, session=self,
                client_id=client_id, client_secret=client_secret,
                token_endpoint_auth_method=token_endpoint_auth_method,
                revocation_endpoint_auth_method=revocation_endpoint_auth_method,
                scope=scope, redirect_uri=redirect_uri,
                token=token, token_placement=token_placement,
                update_token=update_token, **kwargs
            )
    
        def fetch_access_token(self, url=None, **kwargs):
            """Alias for fetch_token."""
            return self.fetch_token(url, **kwargs)
    
        def request(self, method, url, withhold_token=False, auth=None, **kwargs):
            """Send request with auto refresh token feature (if available)."""
            if not withhold_token and auth is None:
                if not self.token:
                    raise MissingTokenError()
                auth = self.token_auth
            return super(OAuth2Session, self).request(
                method, url, auth=auth, **kwargs)
    
        @staticmethod
        def handle_error(error_type, error_description):
            raise OAuthError(error_type, error_description)
    
    

    Here are the claims linked to validation:

    REGISTERED_CLAIMS = ['iss', 'sub', 'aud', 'exp', 'nbf', 'iat', 'jti']
    
    REGISTERED_CLAIMS = [
            'redirect_uris',
            'token_endpoint_auth_method',
            'grant_types',
            'response_types',
            'client_name',
            'client_uri',
            'logo_uri',
            'scope',
            'contacts',
            'tos_uri',
            'policy_uri',
            'jwks_uri',
            'jwks',
            'software_id',
            'software_version',
        ]
    

    At each stage of the OAuth/OpenID process authlib is doing multiple checks to validate a token.

    I hope that my answer helps you understand the function authorize_access_token() better.

    Happy Coding!!