Search code examples
javaspringspring-bootspring-securityoauth-2.0

Spring Security 6 OAuth2 Custom Validator


I have a java application on spring-boot 3.0.5, which is configured as a resource server like this:

Maven dependencies:

    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-oauth2-resource-server</artifactId>
    </dependency>

Security configuration:

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.jwt.*;
import org.springframework.security.web.SecurityFilterChain;

@Configuration
@EnableWebSecurity
public class SecurityConfiguration {
    
    @Value("${spring.security.oauth2.resourceserver.jwt.issuer-uri}")
    private String issuerUri;
    
    
    @Bean
    public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
        http.cors().disable()
                .csrf().disable()
                .authorizeHttpRequests()
                .requestMatchers("/actuator/**").permitAll()
                .anyRequest().authenticated()
                .and()
                .sessionManagement()
                .sessionCreationPolicy(SessionCreationPolicy.STATELESS)
                .and()
                .oauth2ResourceServer()
                .jwt();

        return http.build();
    }
    
    @Bean
    public JwtDecoder jwtDecoder() {
        NimbusJwtDecoder jwtDecoder = JwtDecoders.fromIssuerLocation(issuerUri);

        OAuth2TokenValidator<Jwt> withIssuer = JwtValidators.createDefaultWithIssuer(issuerUri);
        OAuth2TokenValidator<Jwt> withOperation = new OperationClaimValidator();

        jwtDecoder.setJwtValidator(
                new DelegatingOAuth2TokenValidator<>(withIssuer, withOperation)
        );

        return jwtDecoder;
    }
}

As you can see, I have defined JwtDecoder bean to add a custom validator OperationClaimValidator for JWT, because I want to validate my custom jwt claim operation:

import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jwt.Jwt;

public class OperationClaimValidator implements OAuth2TokenValidator<Jwt> {

    private static final String CLAIM_OPERATION = "operation";
    
    
    @Override
    public OAuth2TokenValidatorResult validate(Jwt jwt) {
        if (jwt.getClaimAsString(CLAIM_OPERATION).equals("Value-from-HttpServletRequest")) { // problem here
            return OAuth2TokenValidatorResult.success();
        } else {
            return OAuth2TokenValidatorResult.failure(
                    new OAuth2Error("invalid_token", "The required operation is GET", null)
            );
        }
    }
}

The problem is that I want to compare the value from the operation claim with the value from the HttpServletRequest#getMethod, but I can't figure out how I can access the HttpServletRequest from the OAuth2TokenValidator implementation.

I can validate this claim another way using OncePerRequestFilter:

@Component
public class JwtAuthenticationFilter extends OncePerRequestFilter {

    private static final String CLAIM_OPERATION = "operation";

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
            throws ServletException, IOException {

        String requestMethod = request.getMethod();
        String operationClaimValue = getOperationClaimValue(request);

        if (!requestMethod.equals(operationClaimValue)) {
            response.sendError(HttpStatus.UNAUTHORIZED.value(),
                    "Operation claim in token does not match with http method used in request");
        } else {
            filterChain.doFilter(request, response);
        }
    }

    private String getOperationClaimValue(HttpServletRequest request) {
        Principal principal = request.getUserPrincipal();
        Jwt token = ((JwtAuthenticationToken) principal).getToken();
        return token.getClaimAsString(CLAIM_OPERATION);
    }
}

, but I would like to achieve it implementing OAuth2TokenValidator. Any ideas? Thank you in advance.


Solution

  • You can statically access the ServletRequest as follow:

    public static Optional<HttpServletRequest> getRequest() {
        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        if (requestAttributes instanceof ServletRequestAttributes attr) {
            return Optional.ofNullable(attr.getRequest());
        }
        return Optional.empty();
    }