Search code examples
amazon-web-servicesspring-bootspring-securityamazon-cognitographql-subscriptions

Authenticate graphql subscription using aws cognito


I have a GraphQL server written in Java using Spring for GraphQL. We plan to use it for sending data to React clients using GraphQL's subscription over web sockets. Now, we need to deploy it into AWS EKS and authenticate traffic via AWS Cognito using JWT. How can we achieve that? Is there a way to configure Spring Security in order to validate JWT tokens using Cognito for GraphQL subscriptions? I really appreciate any suggestions.

I was trying to use org.springframework.graphql.server.WebSocketGraphQlInterceptor but without luck. After adding Spring Security to the project whole traffic is blocked, and I'm unable to reach the interceptor while debugging.

EDIT

I have configured DefaultSecurityFilterChain and WebSocketGraphQlInterceptor, but the interceptor catches only mutations and not subscriptions. Each time when I fire mutation I am able to catch it within intercept method inside WebSocketGraphQlInterceptor, but Ithis is not true with subscriptions - it always returns 401 Unauthorized instead of going into handleConnectionInitialization first.

This is my SecurityFilterChain:

 @Bean
 public DefaultSecurityFilterChain graphqlSecurityChain(HttpSecurity http) throws Exception {

    return http.cors(AbstractHttpConfigurer::disable)
        .csrf(AbstractHttpConfigurer::disable)
        .sessionManagement(
            session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
        .authorizeHttpRequests(
            auth ->
                auth.requestMatchers(HttpMethod.GET, "/subscription")
                    .permitAll()
                    .anyRequest()
                    .authenticated())
        .oauth2ResourceServer(
            oauthConfigurer ->
                oauthConfigurer.jwt(
                    jwtConfigurer -> jwtConfigurer.decoder(JwtDecoders.fromIssuerLocation(issuer))))
        .build();
  }

What could be the reason that WebSocketGraphQlInterceptor does not handle subscriptions properly?


Solution

  • As it turns out we made a typo in our GraphQL Spring Boot configuration file. After correction, all works perfectly fine. I'm attaching Spring Boot GraphQL config + source code if any would face a similar case.

    application.yaml

    server:
      port: ${GRAPHQL_PORT:4000}
    spring:
      graphql:
        path: ${GRAPHQL_PATH:/graphql}
        websocket:
          path: ${GRAPHQL_PATH:/subscription}
        graphiql:
          enabled: ${GRAPHQL_DASHBOARD:true}
      security:
        oauth2:
          resourceserver:
            jwt:
              issuer-uri: ${COGNITO_URI:https://cognito-idp.eu-west-1.amazonaws.com/eu-west-1_xxxxxxxxxx}
              audiences: ${COGNITO_CLIENT:xxxxxxxxxx}
    

    GraphqlHttpSecurityAdapter.java

    @Slf4j
    @Configuration
    @EnableWebSecurity
    @EnableMethodSecurity()
    public class GraphqlHttpSecurityAdapter {
    
        @Value("${spring.security.oauth2.resourceserver.jwt.issuer-uri}")
        private String issuer;
    
        @Bean
        public DefaultSecurityFilterChain graphqlSecurityChain(HttpSecurity http) throws Exception {
    
            return http.cors(AbstractHttpConfigurer::disable)
                .csrf(AbstractHttpConfigurer::disable)
                .sessionManagement(
                    session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
                .authorizeHttpRequests(auth -> auth.anyRequest().permitAll())
                .oauth2ResourceServer(
                    oauthConfigurer ->
                        oauthConfigurer.jwt(
                            jwtConfigurer -> jwtConfigurer.decoder(JwtDecoders.fromIssuerLocation(issuer))))
                .exceptionHandling(
                    exHandlingConfigurer ->
                        exHandlingConfigurer.accessDeniedHandler(
                            (request, response, accessDeniedException) ->
                                log.error(accessDeniedException.getMessage(), accessDeniedException)))
                .build();
        }
    }
    

    GraphqlSubscriptionInterceptor.java

    @Slf4j
    @Configuration
    @RequiredArgsConstructor
    public class GraphqlSubscriptionInterceptor implements WebSocketGraphQlInterceptor {
    
        public static final String BEARER_TOKEN_PREFIX = "Bearer";
        private static final String AUTHENTICATION_KEY = "Authorization";
        private final @NonNull ConfigurableJWTProcessor<SecurityContext> jwtProcessor;
    
        @Value("${spring.security.oauth2.resourceserver.jwt.issuer-uri}")
        private String issuer;
    
        @Override
        public Mono<WebGraphQlResponse> intercept(WebGraphQlRequest request, Chain chain) {
    
            log.trace("Intercepted request -> {}", request);
            return WebSocketGraphQlInterceptor.super.intercept(request, chain);
        }
    
        @Override
        public Mono<Object> handleConnectionInitialization(
            WebSocketSessionInfo sessionInfo, Map<String, Object> connectionInitPayload) {
    
            log.debug("Initialized subscription - session-id: {}", sessionInfo.getId());
            var token =
                StringUtils.removeStart(
                    (String)
                        Optional.ofNullable(connectionInitPayload.get(AUTHENTICATION_KEY))
                            .orElseThrow(() -> new MissingTokenException("Missing JWT token !")),
                    BEARER_TOKEN_PREFIX);
    
            try {
                jwtProcessor.process(token, null);
            } catch (ParseException | BadJOSEException | JOSEException e) {
                throw new InvalidTokenException("Invalid JWT token !", e);
            }
    
            var jwtToken = JwtDecoders.fromIssuerLocation(issuer).decode(token);
            var jwtAuthenticationToken = new JwtAuthenticationToken(jwtToken);
            SecurityContextHolder.getContext().setAuthentication(jwtAuthenticationToken);
    
            return WebSocketGraphQlInterceptor.super.handleConnectionInitialization(
                sessionInfo, connectionInitPayload);
        }
    
        @Override
        public @NonNull Mono<Void> handleCancelledSubscription(
            WebSocketSessionInfo sessionInfo, String subscriptionId) {
    
            log.debug("Cancelled subscription - session-id: {}", sessionInfo.getId());
            return WebSocketGraphQlInterceptor.super.handleCancelledSubscription(
                sessionInfo, subscriptionId);
        }
    
        @Override
        public void handleConnectionClosed(
            WebSocketSessionInfo sessionInfo, int statusCode, Map<String, Object> connectionInitPayload) {
    
            log.debug(
                "Closed subscription - session-id: {} - status-code: {}", sessionInfo.getId(), statusCode);
            WebSocketGraphQlInterceptor.super.handleConnectionClosed(
                sessionInfo, statusCode, connectionInitPayload);
        }
    }
    

    JwtConfig.java

    @Configuration
    public class JwtConfig {
    
        public static final String WELL_KNOWN_JWKS_JSON = "/.well-known/jwks.json";
    
        @Value("${spring.security.oauth2.resourceserver.jwt.issuer-uri}")
        private String issuer;
    
        @Value("${spring.security.oauth2.resourceserver.jwt.audiences}")
        private String audiences;
    
        @Bean
        public ConfigurableJWTProcessor<SecurityContext> jwtProcessor() throws MalformedURLException {
    
            ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
            JWKSource<SecurityContext> jwkSource =
                new RemoteJWKSet<>(new URL(issuer + WELL_KNOWN_JWKS_JSON));
            JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256;
            JWSKeySelector<SecurityContext> jwsKeySelector =
                new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource);
            jwtProcessor.setJWSKeySelector(jwsKeySelector);
    
            jwtProcessor.setJWTClaimsSetVerifier(
                new DefaultJWTClaimsVerifier<>(
                    new JWTClaimsSet.Builder().issuer(issuer).claim("client_id", audiences).build(),
                    Set.of(
                        JwtClaimNames.SUB,
                        JwtClaimNames.ISS,
                        JwtClaimNames.EXP,
                        JwtClaimNames.IAT,
                        "client_id",
                        "token_use",
                        "scope",
                        "auth_time")));
    
            return jwtProcessor;
        }
    }
    

    Additionally, if any would be interested in general how to implement GraphQL subscriptions using Spring for GraphQL I've made a post on Medium -> https://medium.com/@mielczarek.lukasz.karol/implementing-graphql-subscriptions-using-spring-for-graphql-and-redis-80f89a95c94c

    Regards