Search code examples
springspring-securityspring-webflux

Spring reactive OAuth2 resource-server: include both access and ID tokens claims in the `Authentication`


I have a AWS Cognito user pool issuing tokens to my frontend application. The frontend application then uses the tokens to talk to my backend service.

This flow is working as intended. I am validating the tokens that hit my backend service using org.springframework.security:spring-security-oauth2-resource-server:6.0.1 which is configured to point back to Cognito

spring:
  security:
    oauth2:
      resourceserver:
        jwt:
          issuer-uri: https://cognito-idp.us-east-1.amazonaws.com/my_pool_endpoint

I have a simple SecurityConfig

@Configuration
@EnableWebFluxSecurity
@EnableReactiveMethodSecurity(useAuthorizationManager = true)
public class SecurityConfig {

    @Bean
    SecurityWebFilterChain securityWebFilterChain(final ServerHttpSecurity http) {

        return http.authorizeExchange()
                .pathMatchers("/v3/api-docs/**")
                .permitAll()
                .anyExchange()
                .authenticated()
                .and()
                .oauth2ResourceServer(ServerHttpSecurity.OAuth2ResourceServerSpec::jwt)
                .build();
    }

So far everything is looking good.

But how do I gather additional information from incoming tokens, for example things such as email and username are not included in the token response from Cognito. An example of a decoded token looks like:

{
  "sub": "00000000000000000000",
  "cognito:groups": [
    "00000000000000000000"
  ],
  "iss": "https://cognito-idp.us-east-1.amazonaws.com/00000000000000000000",
  "version": 2,
  "client_id": "00000000000000000000",
  "origin_jti": "00000000000000000000",
  "token_use": "access",
  "scope": "openid profile email",
  "auth_time": 1676066347,
  "exp": 1676186814,
  "iat": 1676143614,
  "jti": "00000000000000000000",
  "username": "google_00000000000000000000"
}

When I need extra information from the token, I'm calling https://my-congito-pool.auth.us-east-1.amazoncognito.com/oauth2/userInfo and passing the JWT as the Bearer token, which works and returns the information I'm looking for such as email, picture, username etc.

My question is I don't think doing this manually every time I want additional information is the 'correct' way of handling it.

Should I be using something like a UserDetailsService to perform this once and transforming the incoming JWT into my own User which holds this information?

If so, how do I do this using ReactiveSpringSecurity?


Solution

  • It looks like Cognito allows to enrich ID tokens, but not access tokens. That's sad, most competitors allow it and it makes spring resource-servers configurations much easier.

    I can think of two solutions:

    • configure your resource-server with access-token introspection (with http.oauth2ResourceServer().opaqueToken()), using your /oauth2/userInfo as introspection endpoint and the JWT access-token as "opaque" token
    • require clients to add the ID token in a dedicated header (let's say X-ID-Token) in addition to the access token (provided in the Authorization header as usual). Then in the authentication converter, retrieve and decode this additional header and build an Authentication of your own with both access and ID tokens strings and claims

    I will only develop the second solution for two reasons:

    • the first has the usual performance cost of token introspection (a call is made from the resource-server to the authorization-server before each request is processed)
    • the second permits to add any data from as many headers as we need to the Authentication instance for authentication and authorization (not only ID token as we demo here) with very little performance impact

    Spoiler: here is what I got:

    • with valid access and ID tokens Postman screenshot with valid access token and X-ID-Token
    • with just the access-token Postman screenshot without a valid ID token in X-ID-Token

    Isn't it exactly what you are looking for: an Authentication instance with the roles from the access token and the email from the ID token (or Unauthorized if authorization data is missing / invalid / incomplete)?

    Detailed Security Configuration

    Here is the security configuration for a reactive app. For Servlets, main lines are the same, only the tooling to statically access the request context is quite different. You can refer to this tutorial I just added to my collection for details.

    @Configuration
    @EnableReactiveMethodSecurity
    @EnableWebFluxSecurity
    public class SecurityConfig {
        static final String ID_TOKEN_HEADER_NAME = "X-ID-Token";
    
        public static Mono<ServerHttpRequest> getServerHttpRequest() {
            return Mono.deferContextual(Mono::just)
                    .map(contextView -> contextView.get(ServerWebExchange.class).getRequest());
        }
    
        public static Mono<String> getIdTokenHeader() {
            return getServerHttpRequest().map(req -> {
                final var headers = req.getHeaders().getOrEmpty(ID_TOKEN_HEADER_NAME).stream()
                        .filter(StringUtils::hasLength).toList();
                if (headers.size() == 0) {
                    throw new MissingIdTokenException();
                }
                if (headers.size() > 1) {
                    throw new MultiValuedIdTokenException();
                }
                return headers.get(0);
            });
        }
    
        @Bean
        SecurityWebFilterChain securityFilterChain(ServerHttpSecurity http, ReactiveJwtDecoder jwtDecoder) {
            http.oauth2ResourceServer().jwt().jwtAuthenticationConverter(accessToken -> getIdTokenHeader()
                    .flatMap(idTokenString -> jwtDecoder.decode(idTokenString).doOnError(JwtException.class, e -> {
                        throw new InvalidIdTokenException();
                    }).map(idToken -> {
                        final var idClaims = idToken.getClaims();
    
                        @SuppressWarnings("unchecked")
                        final var authorities = ((List<String>) accessToken.getClaims().getOrDefault("cognito:groups",
                                List.of())).stream().map(SimpleGrantedAuthority::new).toList();
    
                        return new MyAuth(authorities, accessToken.getTokenValue(), idTokenString, accessToken.getClaims(),
                                idClaims);
                    })));
    
            http.securityContextRepository(NoOpServerSecurityContextRepository.getInstance()).csrf().disable();
    
            http.authorizeExchange().anyExchange().authenticated();
    
            return http.build();
        }
    
        public static class MyAuth extends AbstractAuthenticationToken {
            private static final long serialVersionUID = 9115947200114995708L;
    
            // Save access and ID tokens strings just in case we need to call another
            // micro-service on behalf of the user who initiated the request and as so,
            // position "Authorization" and "X-ID-Token" headers
            private final String accessTokenString;
            private final String idTokenString;
    
            private final Map<String, Object> accessClaims;
            private final Map<String, Object> idClaims;
    
            public MyAuth(Collection<? extends GrantedAuthority> authorities, String accessTokenString,
                    String idTokenString, Map<String, Object> accessClaims, Map<String, Object> idClaims) {
                super(authorities);
                this.accessTokenString = accessTokenString;
                this.accessClaims = Collections.unmodifiableMap(accessClaims);
                this.idTokenString = idTokenString;
                this.idClaims = Collections.unmodifiableMap(idClaims);
    
                // Minimal security checks: assert that issuer and subject claims are the same
                // in access and ID tokens.
                if (!Objects.equals(accessClaims.get(IdTokenClaimNames.ISS), idClaims.get(IdTokenClaimNames.ISS))
                        || !Objects.equals(accessClaims.get(StandardClaimNames.SUB), idClaims.get(IdTokenClaimNames.SUB))) {
                    throw new InvalidIdTokenException();
                }
                // You could also make assertions on ID token audience, but this will require
                // adding a custom property for expected ID tokens audience.
                // You can't just check for audience equality with already validated access
                // token one.
    
                this.setAuthenticated(true);
            }
    
            @Override
            public String getCredentials() {
                return accessTokenString;
            }
    
            @Override
            public String getPrincipal() {
                return (String) accessClaims.get(StandardClaimNames.SUB);
            }
    
            public String getAccessTokenString() {
                return accessTokenString;
            }
    
            public String getIdTokenString() {
                return idTokenString;
            }
    
            public Map<String, Object> getAccessClaims() {
                return accessClaims;
            }
    
            public Map<String, Object> getIdClaims() {
                return idClaims;
            }
    
        }
    
        @ResponseStatus(code = HttpStatus.UNAUTHORIZED, reason = ID_TOKEN_HEADER_NAME + " is missing")
        static class MissingIdTokenException extends RuntimeException {
            private static final long serialVersionUID = -4894061353773464761L;
        }
    
        @ResponseStatus(code = HttpStatus.UNAUTHORIZED, reason = ID_TOKEN_HEADER_NAME + " is not unique")
        static class MultiValuedIdTokenException extends RuntimeException {
            private static final long serialVersionUID = 1654993007508549674L;
        }
    
        @ResponseStatus(code = HttpStatus.UNAUTHORIZED, reason = ID_TOKEN_HEADER_NAME + " is not valid")
        static class InvalidIdTokenException extends RuntimeException {
            private static final long serialVersionUID = -6233252290377524340L;
        }
    }
    

    Now, each time an authorization succeeds (isAuthenticated() is true), you'll have a MyAuth instance in the security context and it contains both the access and ID tokens claims!

    Sample Controller

    @RestController
    public class GreetingController {
        
        @GetMapping("/greet")
        @PreAuthorize("isAuthenticated()")
        Mono<String> greet(MyAuth auth) {
            return Mono.just("Hello %s! You are granted with %s".formatted(
                auth.getIdClaims().get("email"),
                auth.getAuthorities()));
        }
    
    }
    

    You may also build your @PreAuthorize expressions based on it. Something like:

    @RequiredArgsConstructor
    @RestController
    @RequestMapping("/something/protected")
    @PreAuthorize("isAuthenticated()")
    public class ProtectedResourceController {
        private final SomeResourceRepository resourceRepo;
    
        @GetMapping("/{resourceId}")
        @PreAuthorize("#auth.idClaims['email'] == #resource.email")
        ResourceDto getProtectedResource(MyAuth auth, @RequestParam("resourceId") SomeResource resource) {
            ...
        }
    
    }
    

    EDIT: Code reusability & spring-addons starters

    I maintain wrappers around spring-boot-starter-oauth2-resource-server. It is very thin and opensource. If you don't want to use it, you should have a look at how it is done to get inspiration from it:

    • inspect resources to find out what it takes to build your own spring-boot starters
    • inspect beans to pick ideas for creating your own configurable ones
    • browse to dependencies like OpenidClaimSet and OAuthentication which could be of inspiration

    Here is what the sample above becomes with "my" starter for reactive resource-servers with JWT decoders:

    @Configuration
    @EnableReactiveMethodSecurity
    @EnableWebFluxSecurity
    public class SecurityConfig {
        static final String ID_TOKEN_HEADER_NAME = "X-ID-Token";
    
        @Bean
        OAuth2AuthenticationFactory authenticationFactory(
                Converter<Map<String, Object>, Collection<? extends GrantedAuthority>> authoritiesConverter,
                ReactiveJwtDecoder jwtDecoder) {
            return (accessBearerString, accessClaims) -> ServerHttpRequestSupport.getUniqueHeader(ID_TOKEN_HEADER_NAME)
                    .flatMap(idTokenString -> jwtDecoder.decode(idTokenString).doOnError(JwtException.class, e -> {
                        throw new InvalidHeaderException(ID_TOKEN_HEADER_NAME);
                    }).map(idToken -> new MyAuth(
                            authoritiesConverter.convert(accessClaims),
                            accessBearerString,
                            new OpenidClaimSet(accessClaims),
                            idTokenString,
                            new OpenidClaimSet(idToken.getClaims()))));
        }
    
        @Data
        @EqualsAndHashCode(callSuper = true)
        public static class MyAuth extends OAuthentication<OpenidClaimSet> {
            private static final long serialVersionUID = 1734079415899000362L;
            private final String idTokenString;
            private final OpenidClaimSet idClaims;
    
            public MyAuth(Collection<? extends GrantedAuthority> authorities, String accessTokenString,
                    OpenidClaimSet accessClaims, String idTokenString, OpenidClaimSet idClaims) {
                super(accessClaims, authorities, accessTokenString);
                this.idTokenString = idTokenString;
                this.idClaims = idClaims;
            }
    
        }
    }
    

    Update @Controller (pay attention to the direct accessor to email claim):

    @RestController
    public class GreetingController {
        
        @GetMapping("/greet")
        @PreAuthorize("isAuthenticated()")
        Mono<String> greet(MyAuth auth) {
            return Mono.just("Hello %s! You are granted with %s".formatted(
                    auth.getIdClaims().getEmail(),
                    auth.getAuthorities()));
        }
    }
    

    This are the configuration properties (with different claims used as authorities source depending on the authorization-server configured in the profile):

    server:
      error.include-message: always
    
    spring:
      lifecycle.timeout-per-shutdown-phase: 30s
      security.oauth2.resourceserver.jwt.issuer-uri: https://localhost:8443/realms/master
    
    com:
      c4-soft:
        springaddons:
          security:
            issuers:
              - location: ${spring.security.oauth2.resourceserver.jwt.issuer-uri}
                authorities:
                  claims:
                    - realm_access.roles
                    - resource_access.spring-addons-public.roles
                    - resource_access.spring-addons-confidential.roles
                  caze: upper
                  prefix: ROLE_
            cors:
              - path: /greet
              
    
    ---
    spring.config.activate.on-profile: cognito
    spring.security.oauth2.resourceserver.jwt.issuer-uri: https://cognito-idp.us-west-2.amazonaws.com/us-west-2_RzhmgLwjl
    com.c4-soft.springaddons.security.issuers:
      - location: ${spring.security.oauth2.resourceserver.jwt.issuer-uri}
        authorities:
          claims: 
            - cognito:groups
          caze: upper
          prefix: ROLE_
    
    ---
    spring.config.activate.on-profile: auth0
    com.c4-soft.springaddons.security.issuers:
      - location: https://dev-ch4mpy.eu.auth0.com/
        authorities:
          claims:
            - roles
            - permissions
          caze: upper
          prefix: ROLE_
    

    Unit-tests with mocked identity for the @Controller above can be as simple as:

    @WebFluxTest(controllers = GreetingController.class)
    @AutoConfigureAddonsWebSecurity
    @Import(SecurityConfig.class)
    class GreetingControllerTest {
    
        @Autowired
        WebTestClientSupport api;
    
        @Test
        @WithMyAuth(authorities = { "AUTHOR" }, idClaims = @OpenIdClaims(email = "[email protected]"))
        void givenUserIsAuthenticated_whenGreet_thenOk() throws Exception {
            api.get("/greet").expectStatus().isOk()
                    .expectBody(String.class).isEqualTo("Hello [email protected]! You are granted with [AUTHOR]");
        }
    
        @Test
        void givenRequestIsAnonymous_whenGreet_thenUnauthorized() throws Exception {
            api.get("/greet").expectStatus().isUnauthorized();
        }
    
    }
    

    With annotation definition (to build the custom Authentication implementation and set it in the security context):

    @Target({ ElementType.METHOD, ElementType.TYPE })
    @Retention(RetentionPolicy.RUNTIME)
    @Inherited
    @Documented
    @WithSecurityContext(factory = WithMyAuth.MyAuthFactory.class)
    public @interface WithMyAuth {
    
        @AliasFor("authorities")
        String[] value() default {};
    
        @AliasFor("value")
        String[] authorities() default {};
    
        OpenIdClaims accessClaims() default @OpenIdClaims();
    
        OpenIdClaims idClaims() default @OpenIdClaims();
    
        String accessTokenString() default "machin.truc.chose";
    
        String idTokenString() default "machin.bidule.chose";
    
        @AliasFor(annotation = WithSecurityContext.class)
        TestExecutionEvent setupBefore()
    
        default TestExecutionEvent.TEST_METHOD;
    
        @Target({ ElementType.METHOD, ElementType.TYPE })
        @Retention(RetentionPolicy.RUNTIME)
        public static @interface Proxy {
            String onBehalfOf();
    
            String[] can() default {};
        }
    
        public static final class MyAuthFactory extends AbstractAnnotatedAuthenticationBuilder<WithMyAuth, MyAuth> {
            @Override
            public MyAuth authentication(WithMyAuth annotation) {
                final var accessClaims = new OpenidClaimSet(super.claims(annotation.accessClaims()));
                final var idClaims = new OpenidClaimSet(super.claims(annotation.idClaims()));
    
                return new MyAuth(super.authorities(annotation.authorities()), annotation.accessTokenString(), accessClaims, annotation.idTokenString(), idClaims);
            }
        }
    }
    

    And this is the pom body:

        <properties>
            <java.version>17</java.version>
            <spring-addons.version>6.0.13</spring-addons.version>
        </properties>
        <dependencies>
            <dependency>
                <groupId>com.c4-soft.springaddons</groupId>
                <artifactId>spring-addons-webflux-jwt-resource-server</artifactId>
                <version>${spring-addons.version}</version>
            </dependency>
            <dependency>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-starter-webflux</artifactId>
            </dependency>
    
            <dependency>
                <groupId>org.projectlombok</groupId>
                <artifactId>lombok</artifactId>
                <optional>true</optional>
            </dependency>
            <dependency>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-starter-test</artifactId>
                <scope>test</scope>
            </dependency>
            <dependency>
                <groupId>com.c4-soft.springaddons</groupId>
                <artifactId>spring-addons-webflux-jwt-test</artifactId>
                <version>${spring-addons.version}</version>
                <scope>test</scope>
            </dependency>
            <dependency>
                <groupId>io.projectreactor</groupId>
                <artifactId>reactor-test</artifactId>
                <scope>test</scope>
            </dependency>
        </dependencies>
    
        <build>
            <plugins>
                <plugin>
                    <groupId>org.springframework.boot</groupId>
                    <artifactId>spring-boot-maven-plugin</artifactId>
                    <configuration>
                        <excludes>
                            <exclude>
                                <groupId>org.projectlombok</groupId>
                                <artifactId>lombok</artifactId>
                            </exclude>
                        </excludes>
                    </configuration>
                </plugin>
            </plugins>
        </build>