Search code examples
javaspring-cloud-gateway

How do I write a GatewayFilter that adds a principal name as a request parameter?


I need to write a org.springframework.cloud.gateway.filter.GatewayFilter that adds a principal name as a request parameter (let's say the param should be called clientId). Limitations:

  1. getPrincipal() returns not a Principal but a Mono<Principal> so I have to pass some callback right away which seems challenging in this case
  2. I can't block() without getting something like
java.lang.IllegalStateException: blockOptional() is blocking, which is not supported in thread reactor-http-nio-2
    at reactor.core.publisher.BlockingOptionalMonoSubscriber.blockingGet(BlockingOptionalMonoSubscriber.java:145) ~[reactor-core-3.5.10.jar:3.5.10]
    Suppressed: reactor.core.publisher.FluxOnAssembly$OnAssemblyException
  1. ServerWebExchange is immutable
  2. Java lambdas can only take final or effectively final variables

Here's GatewayFilter's abstract method:

Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain);

So basically I need to implement this:

GatewayFilter addPrincipalNameAsParamFilter() {
    return (serverWebExchange, gatewayFilterChain) -> // ...
}

A (non-working) attempt at doing it:

return (exchange, chain) -> {
    Mono<Principal> principalMono = exchange.getPrincipal();

    return principalMono.flatMap(principal -> {
        String principalName = principal.getName();

        ServerWebExchange updatedExchange = exchange
                    .mutate()
                    .request(requestBuilder -> 
                       requestBuilder.path(exchange.getRequest().getPath() + "?clientId=" + principalName).build())
                    .build();
        return chain.filter(updatedExchange);
    });
}

The code above ensures that first the endpoint is hit and then the lambda gets executed (which is the other way around). If only I could block to get the principal

Gateway has some out-of-the-box filter factories, but it appears it only fits if you pass static values (yes, you have to wrap it in OrderedGatewayFilter for some reason, otherwise it won't work)

return routeInConstruction.filter(
           new OrderedGatewayFilter(
               new AddRequestParameterGatewayFilterFactory()
                       .apply(nameValueConfig -> nameValueConfig.setName("clientId").setValue("hard-coded_principal_name")), 0
           )
       );

How can (should) I achieve my goal?

UPD

I tried to implement Tavark's suggestion. Here's an MRE to show that it doesn't work (if I implemented it correctly):

Gateway

package com.example.gatewaydemo;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class GatewaydemoApplication {

    public static void main(String[] args) {
        SpringApplication.run(GatewaydemoApplication.class, args);
    }

}
package com.example.gatewaydemo.constant;

public enum JWT {

    KEY("jxgEQeXHuPq8VdbyYFNkANdudQ53YUn4"),
    HEADER("Authorization"),
    ACCESS_TOKEN_EXPIRATION("3600000"),
    REFRESH_TOKEN_EXPIRATION("3600000"),
    UUID("uuid");

    private final String value;

    JWT(String value) {
        this.value = value;
    }

    public String getValue() {
        return value;
    }
}
package com.example.gatewaydemo.security;

import com.example.gatewaydemo.constant.JWT;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import io.jsonwebtoken.security.Keys;
import io.jsonwebtoken.security.SignatureException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;

import javax.crypto.SecretKey;
import java.nio.charset.StandardCharsets;

@Slf4j
public class JwtAuthorizationFilter implements WebFilter {

    private static final SecretKey KEY = Keys.hmacShaKeyFor(JWT.KEY.getValue().getBytes(StandardCharsets.UTF_8));
    
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {

        String jwt = exchange.getRequest().getHeaders().getFirst(JWT.HEADER.getValue());

        if (jwt != null) {
            try {
                Claims claims = Jwts.parserBuilder()
                        .setSigningKey(KEY)
                        .build()
                        .parseClaimsJws(jwt)
                        .getBody();
                String clientId = String.valueOf(claims.get(JWT.UUID.getValue()));
                Authentication auth = new UsernamePasswordAuthenticationToken(clientId, null, null);
                return chain.filter(exchange).contextWrite(ReactiveSecurityContextHolder.withAuthentication(auth));
            } catch (ExpiredJwtException e) {
                log.error("Token has expired. " + e.getMessage());
                return createError(exchange, e);
            } catch (SignatureException e) {
                log.error("Token signature cannot be verified. " + e.getMessage());
                return createError(exchange, e);
            } catch (MalformedJwtException e) {
                log.error("Token is not properly formatted. " + e.getMessage());
                return createError(exchange, e);
            } catch (Exception e) {
                log.error("Error parsing token. " + e.getMessage());
                return createError(exchange, e);
            }
        }
        return chain.filter(exchange);
    }

    private Mono<Void> createError(ServerWebExchange exchange, Exception e) {
        ServerHttpResponse response = exchange.getResponse();
        response.setStatusCode(HttpStatus.UNAUTHORIZED);
        return response.writeWith(Mono.just(createErrorBody(e.getMessage())));
    }

    private DataBuffer createErrorBody(String errorBody) {
        byte[] bytes = errorBody.getBytes();
        return new DefaultDataBufferFactory().wrap(bytes);
    }
}
package com.example.gatewaydemo.security;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpStatus;
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
import org.springframework.security.config.web.server.SecurityWebFiltersOrder;
import org.springframework.security.config.web.server.ServerHttpSecurity;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.reactive.CorsConfigurationSource;
import org.springframework.web.cors.reactive.UrlBasedCorsConfigurationSource;
import reactor.core.publisher.Mono;

import java.time.Duration;
import java.util.Collections;
import java.util.List;

@Configuration
@EnableWebFluxSecurity
public class SecurityConfig {

    private final JwtAuthorizationFilter jwtAuthorizationFilter = new JwtAuthorizationFilter();

    @Bean
    public SecurityWebFilterChain filterChain(ServerHttpSecurity http) {
        final var permitAll = new String[]{
                "/api/v1/registration/**",
                "/api/v1/security/**",
                "/api/v1/login/**",
                "/api/v1/exchange-rates/**",
                "/api/v1/bank-branch/**",
                "/api/v1/error-message",

                "/actuator/**",
                "/webjars/**",
                "/swagger-ui.html",
                "/swagger-ui-config",
                "/*/doc"
        };
        return http
                .csrf(ServerHttpSecurity.CsrfSpec::disable)
                .securityContextRepository(NoOpServerSecurityContextRepository.getInstance())
                .httpBasic(ServerHttpSecurity.HttpBasicSpec::disable)
                .formLogin(ServerHttpSecurity.FormLoginSpec::disable)
                .authorizeExchange(exchanges -> exchanges
                        .pathMatchers(permitAll).permitAll()
                        .anyExchange().authenticated()
                )
                .addFilterBefore(jwtAuthorizationFilter, SecurityWebFiltersOrder.AUTHORIZATION)
                .exceptionHandling(exceptionHandlingSpec -> exceptionHandlingSpec
                        .authenticationEntryPoint((exchange, ex) -> Mono.fromRunnable(
                                () -> exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED)
                        )))
                .cors(corsSpec -> corsSpec.configurationSource(corsConfiguration()))
                .build();
    }

    private CorsConfigurationSource corsConfiguration() {
        UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
        CorsConfiguration config = new CorsConfiguration();
        config.setAllowedMethods(Collections.singletonList("*"));
        config.setAllowCredentials(true);
        config.setAllowedHeaders(Collections.singletonList("*"));
        config.setExposedHeaders(List.of("Authorization"));
        config.setMaxAge(Duration.ofHours(1));
        config.setAllowedOriginPatterns(List.of("*"));
        source.registerCorsConfiguration("/**", config);
        return source;
    }
}
package com.example.gatewaydemo.routing;

import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.web.util.UriComponentsBuilder;

import java.net.URI;

@Configuration
public class MyRoutingConfig {
    @Bean
    public RouteLocator routeLocator(RouteLocatorBuilder routeLocatorBuilder) {
        return routeLocatorBuilder.routes()
                .route(predicateSpec -> predicateSpec
                        .path("/api/v1/hello-world")
                        .and()
                        .method(HttpMethod.GET)
                        .filters(gatewayFilterSpec -> gatewayFilterSpec.rewritePath("/api/v1", "/auth")
                                .filter(addClientIdAsParamFilter()))
                        .uri("http://localhost:8090")
                ).build();
    }

    private GatewayFilter addClientIdAsParamFilter() {
        return (exchange, chain) -> exchange.getPrincipal()
                .flatMap(principal -> {
                    URI newUri = UriComponentsBuilder
                            .fromUri(exchange.getRequest().getURI())
                            .replaceQuery("?clientId=" + principal.getName())
                            .build(true)
                            .toUri();

                    ServerHttpRequest request = exchange
                            .getRequest()
                            .mutate()
                            .uri(newUri)
                            .build();

                    return chain.filter(exchange.mutate().request(request).build());
                });
    }
}
server:
  port: 8070
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.1.5</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>com.example</groupId>
    <artifactId>gatewaydemo</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>gatewaydemo</name>
    <description>gatewaydemo</description>
    <properties>
        <java.version>17</java.version>
        <spring-cloud.version>2022.0.4</spring-cloud.version>
        <jsonwebtoken.version>0.11.5</jsonwebtoken.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.springframework.cloud</groupId>
            <artifactId>spring-cloud-starter-gateway</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-security</artifactId>
        </dependency>

        <dependency>
            <groupId>io.jsonwebtoken</groupId>
            <artifactId>jjwt-api</artifactId>
            <version>${jsonwebtoken.version}</version>
        </dependency>

        <dependency>
            <groupId>io.jsonwebtoken</groupId>
            <artifactId>jjwt-impl</artifactId>
            <version>${jsonwebtoken.version}</version>
            <scope>runtime</scope>
        </dependency>

        <dependency>
            <groupId>io.jsonwebtoken</groupId>
            <artifactId>jjwt-jackson</artifactId>
            <version>${jsonwebtoken.version}</version>
            <scope>runtime</scope>
        </dependency>

        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

    </dependencies>
    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.springframework.cloud</groupId>
                <artifactId>spring-cloud-dependencies</artifactId>
                <version>${spring-cloud.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <configuration>
                    <excludes>
                        <exclude>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                        </exclude>
                    </excludes>
                </configuration>
            </plugin>
        </plugins>
    </build>

</project>

Microservice

package com.example.helloworldmre;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class HelloworldMreApplication {

    public static void main(String[] args) {
        SpringApplication.run(HelloworldMreApplication.class, args);
    }

}
package com.example.helloworldmre.controller;

import com.example.helloworldmre.data.SuccessMessage;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.text.MessageFormat;

@RestController
@RequestMapping("/auth")
@Tag(name = "Message Controller")
public class MessageController {
    @GetMapping("/hello-world")
    public SuccessMessage getHelloWorld(@RequestParam(required = false) String clientId) {
        return new SuccessMessage(MessageFormat.format("hello world to client {0}!", clientId));
    }
}
package com.example.helloworldmre.data;

import lombok.Getter;
import lombok.Setter;

@Getter
@Setter
public class SuccessMessage extends Message {
    public SuccessMessage(String message) {
        super(message);
    }
}
package com.example.helloworldmre.data;

import lombok.Getter;
import lombok.Setter;

@Getter
@Setter
public abstract class Message {
    private String message;

    public Message(String message) {
        this.message = message;
    }
}
server.port=8090
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
    xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.6.4</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>com.example</groupId>
    <artifactId>helloworldservice</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>helloworldservice</name>
    <description>helloworldservice</description>
    <properties>
        <java.version>17</java.version>
        <spring-cloud.version>2021.0.1</spring-cloud.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.springframework.cloud</groupId>
                <artifactId>spring-cloud-dependencies</artifactId>
                <version>${spring-cloud.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <configuration>
                    <image>
                        <builder>paketobuildpacks/builder-jammy-base:latest</builder>
                    </image>
                    <excludes>
                        <exclude>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                        </exclude>
                    </excludes>
                </configuration>
            </plugin>
        </plugins>
    </build>

</project>

Curl

curl --location 'http://localhost:8070/api/v1/hello-world' \
--header 'Authorization: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwidXVpZCI6ImNhOTk0ZTk3LThjNDctNDM3YS1iN2Y3LWFjYzYxYTY5NmEyOCIsImlhdCI6MTUxNjIzOTAyMn0.z9hnbc7RgrePm-N8S7eYdWoN75jKDgPL9tlHXlgvR-w'

enter image description here

Response

{
    "message": "hello world to client null!"
}

Solution

  • If I change this (what Tavark suggested)

    (exchange, chain) -> exchange.getPrincipal()
                    .flatMap(principal -> {
                        URI newUri = UriComponentsBuilder
                                .fromUri(exchange.getRequest().getURI())
                                .replaceQuery("?clientId=" + principal.getName())
                                .build(true)
                                .toUri();
    
                        ServerHttpRequest request = exchange
                                .getRequest()
                                .mutate()
                                .uri(newUri)
                                .build();
    
                        return chain.filter(exchange.mutate().request(request).build());
                    });
    

    to this

                    new OrderedGatewayFilter(
                            (exchange, chain) -> exchange.getPrincipal()
                            .flatMap(principal -> {
                                URI newUri = UriComponentsBuilder
                                        .fromUri(exchange.getRequest().getURI())
                                        .queryParam("clientId", principal.getName())
                                        .build(true)
                                        .toUri();
    
                                ServerHttpRequest request = exchange
                                        .getRequest()
                                        .mutate()
                                        .uri(newUri)
                                        .build();
    
                                return chain.filter(exchange.mutate().request(request).build());
                            }), 0);
    

    it works at least sometimes. Go figure