Search code examples
spring-bootwebsocketjwtspring-websocket

How to disconnect users' websocket connection from server when the JWT access token expires using Spring Boot + Websockets?


Current Situation

I have a web application (REST API with JWT token authentication) developed using Spring Boot. I used Spring Websockets to implement a STOMP webscoket server with RabbitMQ. And I have a separate React frontend which consumes the REST endpoints and websocket.

When I connect to websocket from frontend, I pass the JWT access token as a query parameter and if authentication is successful websocket connection is established. I use this websocket connection to pass some messages only from server to client (using queues).

The Problem

The problem is, even after the access token is expired, the websocket connection stays active, which is a serious security issue. I want a way way to close connection of the users from server side when their token expires. Unfortunately I could not find any example or a mechanism to handle this situation.

What I have in mind

  1. Try to maintain a some sort of expiry time for every web session. If the user gets a new access token before expiring the current one, extend the expiry time. When the expiry time is met, close the connection from the server. Is this even possible?

Could someone please give me a solution to this problem?

I didn't add any code because I'm not sure which codes to add here.


Solution

  • I found a solution to the problem (even though it is not that pretty) which works quite fine.

    The trick is to store the WebSocket connection object (which is generated when a user connects) in a store (I used a ConcurrentHashMap), alongside its expiry time. Then keep checking if there are any expired sessions and disconnect them.

    import org.springframework.context.annotation.Configuration;
    import org.springframework.messaging.simp.config.MessageBrokerRegistry;
    import org.springframework.web.socket.CloseStatus;
    import org.springframework.web.socket.WebSocketHandler;
    import org.springframework.web.socket.WebSocketSession;
    import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
    import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
    import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
    import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration;
    import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
    import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
    //import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession;
    
    @Configuration
    @EnableWebSocketMessageBroker
    public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
    
    
        @Override
        public void registerStompEndpoints(StompEndpointRegistry registry) {
            
        }
    
        @Override
        public void configureMessageBroker(MessageBrokerRegistry config){
            
        }
    
    
        @Override
        public void configureWebSocketTransport(WebSocketTransportRegistration registry) {
            WebSocketMessageBrokerConfigurer.super.configureWebSocketTransport(registry);
    
            registry.addDecoratorFactory(new WebSocketHandlerDecoratorFactory() {
                @Override
                public WebSocketHandler decorate(WebSocketHandler webSocketHandler) {
                    return new WebSocketHandlerDecorator(webSocketHandler) {
                        @Override
                        public void afterConnectionEstablished(final WebSocketSession session) throws Exception {
                            try{
                                //Access to the WebSocketSession object
    
                                //You can access Principal object, JWT token details like expiry time as well.
            
                                //Store the WebSocketSession session in a store alongside expiry time (In a HashMap or any relavent)
    
                                super.afterConnectionEstablished(session);
    
                            } catch (Exception e){
                                //Use close method to close the connection at anytime
                                session.close(CloseStatus.BAD_DATA);
                            }
                        }
    
                        @Override
                        public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
                            try{
                                //Remove the WebSocketSession object from the store
                            }catch(Exception e){
                                e.printStackTrace();
                            }finally {
                                super.afterConnectionClosed(session, closeStatus);
                            }
    
    
                        }
                    };
                }
            });
        }
    }