Search code examples
springspring-bootspring-securityjwt

How do I catch exceptions in JwtAuthFilter and handle with Global Error Handler in Spring Boot?


As the title states, I am exploring validating recieved jwts as part of a request but I am struggling to figure out how to catch related exceptions thrown in my JWT auth filter and handle them with a global error handler as I appreciate the formatting there.

So far I've tried catching the exceptions in the jwt auth filter and jwt service without the error handler being able to pick them up and have come to realize it seems the handler wont because the filter is a seperate layer from the controllers and services which is where the error handler finds exceptions. I've tried incorporating what I have been able to find on stack overflow and video tutorials so far but without success.

Any help would be greatly appreciated, especially if you can point me to a visible example so I can see how it is done properly and specifically. Please let me know if I have not provided enough information. The following are what I believe are the relevant classes in their latest state(SecurityConfig, JwtAuthFilter, JwtService, GlobalErrorHandler ). Please forgive the mess as I have been all over the place the last two weeks trying to find something that works. Any help is greatly appreciated.

SecurityConfig

@Configuration
@EnableWebSecurity
@RequiredArgsConstructor
public class SecurityConfig {

    private final JwtAuthFilter authFilter;

    private final UserDetailsService userDetailsService;

    private final UserAuthenticationEntryPoint userAuthenticationEntryPoint;

    @Bean
    public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
        System.out.println("filter chain");
            return
                    http
                            .exceptionHandling().authenticationEntryPoint(userAuthenticationEntryPoint)
                            .and()
                            .addFilterBefore(authFilter, UsernamePasswordAuthenticationFilter.class)
                            .csrf().disable()

//                whitelisted
                            .authorizeHttpRequests()
                            .requestMatchers("/api/menu/**", "/api/order/**", "/api/owners-tools/login").permitAll()
                            .and()
//               restricted
                            .authorizeHttpRequests().requestMatchers("/api/owners-tools/**")
                            .authenticated().and()
                            .sessionManagement()
                            .sessionCreationPolicy(SessionCreationPolicy.STATELESS)
                            .and()
                            .authenticationProvider(authenticationProvider())
//                            .addFilterBefore(authFilter, UsernamePasswordAuthenticationFilter.class)
                            .build();
    }

    @Bean
    public  PasswordEncoder passwordEncoder(){
        return new BCryptPasswordEncoder();
    }

    @Bean
    public AuthenticationProvider authenticationProvider(){
        DaoAuthenticationProvider daoAuthenticationProvider = new DaoAuthenticationProvider();
        daoAuthenticationProvider.setUserDetailsService(userDetailsService);
        daoAuthenticationProvider.setPasswordEncoder(passwordEncoder());
        return daoAuthenticationProvider;
    }

    @Bean
    public AuthenticationManager authenticationManager(AuthenticationConfiguration config)throws Exception{
        return config.getAuthenticationManager();
    }
}

JwtAuthFilter

@Component
@RequiredArgsConstructor
public class JwtAuthFilter extends OncePerRequestFilter {

    private final JwtService jwtService;
    private final OwnerRepository ownerRepository;



    @Override
    protected void doFilterInternal(HttpServletRequest request, @NotNull HttpServletResponse response,
                                    @NotNull FilterChain filterChain)
            throws ServletException, IOException {


try {
            String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
            String token = null;
            String username = null;
            Date expiration = null;
            Date issuedAt = null;
//        String passwrd = null;
//        System.out.println(authHeader);
            if (authHeader != null && authHeader.startsWith("Bearer ")) {
                token = authHeader.substring(7);
//            System.out.println("token = " + token);
                username = jwtService.extractUsername(token);
//            System.out.println("username = " + username);
                expiration = jwtService.extractExpiration(token);
                System.out.println("expiration = " + expiration);
                issuedAt = jwtService.extractIssuedAt(token);
                System.out.println("issued at = " + issuedAt);
            }

        assert expiration != null;
        if (!issuedAt.before(expiration)){
            throw new JwtException("invalid date") ;
        }
            if (username != null && SecurityContextHolder.getContext().getAuthentication() == null) {
                System.out.println("encypted username: " + username);
                System.out.println("decrypted : " + jwtService.decrypt(username));
                UserDetails userDetails = userDetailsService().loadUserByUsername(jwtService.decrypt(username));
                if (userDetails == null) {
                    throw new UsernameNotFoundException("Invalid user");
                }
                System.out.println("token valid: " + jwtService.isTokenValid(token, userDetails));
                //           UserDetails userDetails = userDetailsService().loadUserByUsername(username);
                jwtService.validateToken(token, userDetails);
                UsernamePasswordAuthenticationToken authToken = new UsernamePasswordAuthenticationToken(userDetails, null
                        , userDetails.getAuthorities());
                authToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request));
                SecurityContextHolder.getContext().setAuthentication(authToken);
                System.out.println("jwt filter");
//                filterChain.doFilter(request, response);
            }
//        filterChain.doFilter(request, response);
        }catch (ExpiredJwtException e){
            System.out.println(e.getLocalizedMessage());
            response.setContentType("application/json");
            response.setStatus(HttpServletResponse.SC_FORBIDDEN);
            response.getOutputStream().println("{ \"error\": \"token issued after expiration\" }");
//    response.getOutputStream().println("{error:" +  e.getMessage()  + "}");
    response.getOutputStream().println(e.getLocalizedMessage());
//            throw new ExpiredJwtException(e.getHeader(), e.getClaims(), "expired" );
//    ResponseWrapper responseWrapper = new ResponseWrapper().fail().msg(e.getMessage());
        }
        filterChain.doFilter(request, response);
    }
//            throws ServletException, IOException {
//        String authHeader = request.getHeader("Authorization");
//        String token = null;
//        String username = null;
////        String passwrd = null;
////        System.out.println(authHeader);
//        if (authHeader!=null && authHeader.startsWith("Bearer ")){
//            token = authHeader.substring(7);
////            System.out.println("token = " + token);
//            username = jwtService.extractUsername(token);
////            System.out.println("username = " + username);
//        }
//        if (username != null && SecurityContextHolder.getContext().getAuthentication()==null){
//            UserDetails userDetails = userDetailsService().loadUserByUsername(jwtService.decrypt(username));
//            if (userDetails == null){
//                throw new UsernameNotFoundException("Invalid user");
//            }
//            //           UserDetails userDetails = userDetailsService().loadUserByUsername(username);
//            jwtService.validateToken(token, userDetails);
//            UsernamePasswordAuthenticationToken authToken = new UsernamePasswordAuthenticationToken(userDetails, null
//                    , userDetails.getAuthorities());
//            authToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request));
//            SecurityContextHolder.getContext().setAuthentication(authToken);
//            System.out.println("jwt filter");
//        }
//        filterChain.doFilter(request, response);
//    }

    @Bean
    UserDetailsService userDetailsService(){
        return username -> ownerRepository.findByUsername(username)
                .orElseThrow(() -> new UsernameNotFoundException("Username invalid."));
    }
}


JwtService

@Service
public class JwtService {

    //create token
    @Value("${key}")
    private String SECRET;

    @Value("${BEGIN_KEY}")
    private int BEGIN_KEY;
    @Value("${END_KEY}")
    private int END_KEY;

    @Value("${charmin}")
    private int charmin;

    @Value("${charmax}")
    private int charmax;

    @Value("${submin}")
    private int submin;

    @Value("${submax}")
    private int submax;

    @Value("${ex1}")
    private int ex1;
    @Value("${ex2}")
    private int ex2;
    @Value("${ex3}")
    private int ex3;

    @Value("${CHARSET}")
    private String CHARSET;

    private Key getSignKey(){
        byte[] keyBytes = Decoders.BASE64.decode(SECRET);
        return Keys.hmacShaKeyFor(keyBytes);
    }
    private String buildToken(String username){
//        set time variable instead of creating new
        String token = Jwts.builder()
                .setSubject(username)
                .setIssuedAt(new Date(System.currentTimeMillis()))
//                16 hours, reflective of our owners work day - to be altered to facilitate mitigation of token capture
                .setExpiration(new Date(System.currentTimeMillis() + (1000 * 60 * 60) * 16))
                .signWith(getSignKey(), SignatureAlgorithm.HS256).compact();
        System.out.println(token);
        System.out.println("token issued: " + new Date(System.currentTimeMillis()));
        System.out.println("token expires: " + new Date(System.currentTimeMillis() + (1000 * 60 * 60) * 16));
//        try {
//            return token;
//        }catch (io.jsonwebtoken.security.SignatureException exception){
//            System.out.println(exception.getLocalizedMessage());
//        }
        return token;
    }


    public String generateToken(String username){
        return buildToken(username);
    }

//    validate token

    private Claims extractAllClaims(String token){
//        try {
//            return
//                    Jwts
//                            .parserBuilder()
//                            .setSigningKey(getSignKey())
//                            .build()
//                            .parseClaimsJws(token)
//                            .getBody();
//        } catch (ExpiredJwtException e) {
//            System.out.println(e.getLocalizedMessage());
//            throw new JwtException("bad jwt");
//        }
        return
                Jwts
                        .parserBuilder()
                        .setSigningKey(getSignKey())
                        .build()
                        .parseClaimsJws(token)
                        .getBody();
    }

    public <T> T extractClaim(String token, Function<Claims, T> claimsResolver){
        final Claims claims = extractAllClaims(token);
        return claimsResolver.apply(claims);
    }
    public String extractUsername(String token){
        return extractClaim(token, Claims::getSubject);
    }

    public Date extractExpiration(String token){
        return extractClaim(token, Claims::getExpiration);
    }
    public Date extractIssuedAt(String token){
        return extractClaim(token, Claims::getIssuedAt);
    }
    private Boolean isTokenExpired(String token){
        return extractExpiration(token).before(new Date());
    }

//    possibly condense into one
    public boolean isTokenValid(String token, UserDetails userDetails) {
        final String username = decrypt(extractUsername(token));
//        if (!username.equals(userDetails.getUsername()) && isTokenExpired(token)){
//            throw new SignatureException("token no good");
//        }
        return (username.equals(userDetails.getUsername())) && !isTokenExpired(token);
    }

//    public void validateToken(String token, UserDetails userDetails){
//        final String username = extractUsername(token);
//        userDetails.getUsername();
//    }

    public void validateToken(String token, UserDetails userDetails){
        final String username = extractUsername(token);
        userDetails.getUsername();
    }

//  encrypt - used during development as a means to encrypt credentials
//  before storing them and facilitating decryption means

    public String encrypt(String string){
//        rework

        System.out.println("value to be encrypted: " + string);

        byte[] codeBytes = string.getBytes(StandardCharsets.UTF_8);
        List<Integer> rolledCodeBytes = new ArrayList<>();
        int codeByteValue;
//        System.out.println("begin key: " + BEGIN_KEY);
//        System.out.println("code bytes " + Arrays.toString(codeBytes));
        for (byte codeByte : codeBytes) {
            codeByteValue = codeByte;
//            System.out.println(codeByteValue += BEGIN_KEY);
//            System.out.println("key " + BEGIN_KEY);
            codeByteValue += BEGIN_KEY;
            rolledCodeBytes.add(codeByteValue);
        }
//        System.out.println("rolled code bytes: " + rolledCodeBytes);
//      new collection with altered char values
        List<Character> chars = new ArrayList<>();
        for (int integer : rolledCodeBytes) {
            chars.add((char) integer);
        }
//      convert chars to string
//        StringBuilder rolledCharBuilder = new StringBuilder(chars.size());
//        for (Character ch : chars) {
//            rolledCharBuilder.append(ch);
//        }
//        System.out.println("chars: " + chars);
//        System.out.println("rolled charbuilder: " + rolledCharBuilder);
//      for each element insert three new random chars
        for (int i = 0; i < chars.size(); i++) {
            chars.add(i, randomChar());
            i++;
            chars.add(i, randomChar());
            i++;
            chars.add(i, randomChar());
            i++;
        }
        chars.add(randomChar());
        chars.add(randomChar());
        chars.add(randomChar());
//        -----------------------------------
        StringBuilder encryptionBuilder = new StringBuilder(chars.size());
        for (Character ch : chars) {
            encryptionBuilder.append(ch);
        }
        System.out.println("value encrypted: " + encryptionBuilder);
        return encryptionBuilder.toString();
    }

//    decrypt
    public String decrypt(String  encodedString)  {
        String decodedStart = String.valueOf(encodedString.charAt(BEGIN_KEY));
        String decodedEnd = String.valueOf(encodedString.charAt(encodedString.length() - END_KEY));
        String wholeDecoded = "";
        StringBuilder decoded = new StringBuilder();
        for (int i = BEGIN_KEY; i < encodedString.length(); i = i + END_KEY) {
            decoded.append(encodedString.charAt(i));
        }
        decoded = new StringBuilder(decoded.substring(submin, decoded.toString().length() - submax));
        wholeDecoded = wholeDecoded.concat(decodedStart + decoded + decodedEnd);
        byte[] decodedBytes = wholeDecoded.getBytes(StandardCharsets.UTF_8);
        int decodeByteValue;
        List<Character> decodedChars = new ArrayList<>();
        StringBuilder decrypt = new StringBuilder(0);
        for (byte codeByte : decodedBytes) {
            decodeByteValue = codeByte;
            decodeByteValue -= BEGIN_KEY;
            decodedChars.add((char) decodeByteValue);
        }
        for (Character ch : decodedChars) {
            decrypt.append(ch);
        }
        return decrypt.toString();
    }

    private char randomChar() {
        int min = charmin, max = charmax;
        int random = (int) (Math.random() * ((max - min)) + min);
        int[] excluded = {ex1, ex2, ex3};
        char choice = 0;
        for (int ex : excluded) {
            choice = random == ex ? randomChar() : (char) random;
        }
        return choice;
    }
}

GlobalErrorHandler

@Data
@RestControllerAdvice
public class GlobalErrorHandler{

    private String message;

    private enum LogStatus{
        STACK_TRACE, MESSAGE_ONLY
    }

    @ExceptionHandler(NumberFormatException.class)
    @ResponseStatus(code = HttpStatus.BAD_REQUEST)
    public Map <String, Object> handleNumberFormatException(
            NumberFormatException e, WebRequest webRequest){
        return createExceptionMessage(e.getLocalizedMessage(), HttpStatus.BAD_REQUEST, webRequest);
    }
    @ExceptionHandler(EntityNotFoundException.class)
    @ResponseStatus(code = HttpStatus.NOT_FOUND)
    public Map <String, Object> handleEntityNotFoundException(
            EntityNotFoundException e, WebRequest webRequest) {
        return createExceptionMessage(e.getLocalizedMessage(), HttpStatus.NOT_FOUND, webRequest);
    }

    @ExceptionHandler(IllegalArgumentException.class)
    @ResponseStatus(code = HttpStatus.BAD_REQUEST)
    public Map <String, Object> handleIllegalArgumentException(
            IllegalArgumentException e, WebRequest webRequest){
        return createExceptionMessage(e.getLocalizedMessage(), HttpStatus.BAD_REQUEST, webRequest);
    }

    @ExceptionHandler(UsernameNotFoundException.class)
    @ResponseStatus(code = HttpStatus.FORBIDDEN)
    public Map<String, Object> handleUsernameNotFoundException(
            UsernameNotFoundException e, WebRequest webRequest){
        return  createExceptionMessage(e.getLocalizedMessage(), HttpStatus.FORBIDDEN, webRequest);
    }

    @ExceptionHandler(SignatureException.class)
    @ResponseStatus(code = HttpStatus.FORBIDDEN)
    public Map<String, Object> handleSignatureException(
            SignatureException e, WebRequest webRequest){
        return  createExceptionMessage(e.getLocalizedMessage(), HttpStatus.FORBIDDEN, webRequest);
    }

    @ExceptionHandler(JwtException.class)
    @ResponseStatus(code = HttpStatus.FORBIDDEN)
    public Map<String, Object> handleJwtException(
            JwtException e, WebRequest webRequest){
        return  createExceptionMessage(e.getLocalizedMessage(), HttpStatus.FORBIDDEN, webRequest);
    }

    @ExceptionHandler(MalformedJwtException.class)
    @ResponseStatus(code = HttpStatus.FORBIDDEN)
    public Map<String, Object> handleMalformedJwtException(
            MalformedJwtException e, WebRequest webRequest){
        return  createExceptionMessage(e.getLocalizedMessage(), HttpStatus.FORBIDDEN, webRequest);
    }

    @ExceptionHandler(ExpiredJwtException.class)
    @ResponseStatus(code = HttpStatus.FORBIDDEN)
    public Map<String, Object> handleExpiredJwtException(
            ExpiredJwtException e, WebRequest webRequest){
        return  createExceptionMessage(e.getMessage(), HttpStatus.FORBIDDEN, webRequest);
    }

    @ExceptionHandler(BadCredentialsException.class)
    @ResponseStatus(code = HttpStatus.UNAUTHORIZED)
    public Map<String, Object> handleBadCredentialsException(
            BadCredentialsException e, WebRequest webRequest
    ){
        return createExceptionMessage(e.getLocalizedMessage(), HttpStatus.UNAUTHORIZED, webRequest);
    }





//    ---------------- Being reworked


// alter to not just create the message but also log the error
// create method to log the error to an internal file
    private Map<String,Object> createExceptionMessage(String e, HttpStatus status, WebRequest webRequest) {

    Map <String, Object> error = new HashMap<>();
    String timestamp = ZonedDateTime.now().format(DateTimeFormatter.RFC_1123_DATE_TIME);

    if(webRequest instanceof ServletWebRequest){
        error.put("uri",
                ((ServletWebRequest)webRequest).getRequest().getRequestURI());
    }
    error.put("message", e);
    error.put("status code", status.value());
    error.put("timestamp", timestamp);
    error.put("reason", status.getReasonPhrase());
    return error;
    }
}


Solution

  • Your global error handler wont catch exceptions thrown at filter level. You can inject a HandlerExceptionResolver into your filter, and then in case of an exception use the resolveException method. There's a similar question here How to manage exceptions thrown in filters in Spring?.

    The second answer is the one I think youre looking for (reusing the GlobalExceptionHandler annotated with @RestControllerAdvice).

    And here's an example:

    @Component
    @RequiredArgsConstructor
    public class JwtAuthFilter extends OncePerRequestFilter {
    
        private final JwtService jwtService;
        private final OwnerRepository ownerRepository;
        private final HandlerExceptionResolver resolver;
    
     @Override
     protected void doFilterInternal(HttpServletRequest request, @NotNullHttpServletResponse response,
                                    @NotNull FilterChain filterChain)
                                    throws ServletException, IOException {
        try{
           //logic
        } catch (Exception e){
           resolver.resolveException(request, response, null, e);
        }
     }
    

    Hope that works