Search code examples
springspring-bootwebsocketspring-websocketspring-rabbit

How to handle spring boot stomp websocket subscribe endpoint to authorize users


We have successfully developed webosocket+stomp+rabbitmq for our project with security. It is working fine although we have some problems in solving the following case: The workflow of this websocket works as following:

  • First user subscribes to the websocket endpoint which is working as expected

  • Second after being authorized by user's token, the user tries to subscribe the following endpoint '/user/queue/' + chatRoomId + '.messages'. Here chatroomId defines to which chatroom user connects to, which is also working fine, however here user can connect any chatroomid which is not being validated in the backend which is also the big matter we are trying to solve.

    stompClient.subscribe('/user/queue/' + chatRoomId + '.messages', incomingMessages);

My question is how can i validate any users when they try to subscribe to this endpoint? I mean is there any way to handle each specific subscriptions

This is our front end code. if u need complete page i will upload it

 function connect() {
        socket = new SockJS('http://localhost:9600/wsss/messages');
        stompClient = Stomp.over(socket);
        // var stompClient = Stomp.client("ws://localhost:9600/ws/messages");
        // stompClient.connect({ 'chat_id' : chatRoomId,
        //     'X-Authorization' : 'Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiI0Iiwic2NvcGVzIjoiUk9MRV9BRE1JTiIsImVtYWlsIjoiYWRtaW5AZ21haWwuY29tIiwiaWF0IjoxNTc5MDgxMzg5LCJleHAiOjE1ODE2NzMzODl9.H3mnti0ZNtH6uLe-sOfrr5jzwssvGNcBiHGg-nUQ6xY' },
        //     stompSuccess, stompFailure);
        stompClient.connect({ 'chatRoomId' : chatRoomId,
            'login' : 'Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIyODg5Iiwic2NvcGVzIjoiUk9MRV9VU0VSLFJPTEVfTU9ERVJBVE9SIiwiaWF0IjoxNTgyMDMxMDA0LCJleHAiOjE1ODQ2MjMwMDR9.NGAAed4R46FgrtgyDmrLSrmd-o3tkqbF60vOg8vAWYg' },
            stompSuccess, stompFailure);
    }

    function stompSuccess(frame) {
        enableInputMessage();
        successMessage("Your WebSocket connection was successfuly established!");
        console.log(frame);

        stompClient.subscribe('/user/queue/' + chatRoomId + '.messages', incomingMessages);
        stompClient.subscribe('/topic/notification', incomingNotificationMessage);
        // stompClient.subscribe('/app/join/notification', incomingNotificationMessage);
    }

And here is the code I am using for my backend

@Configuration
@EnableWebSocketMessageBroker
@Order(Ordered.HIGHEST_PRECEDENCE + 99)
class WebSocketConfig @Autowired constructor(
        val jwtTokenUtil: TokenProvider
) : WebSocketMessageBrokerConfigurer {


    @Autowired
    @Resource(name = "userService")
    private val userDetailsService: UserDetailsService? = null

    @Autowired
    private lateinit var authenticationManager: AuthenticationManager

    @Value("\${spring.rabbitmq.username}")
    private val userName: String? = null
    @Value("\${spring.rabbitmq.password}")
    private val password: String? = null
    @Value("\${spring.rabbitmq.host}")
    private val host: String? = null
    @Value("\${spring.rabbitmq.port}")
    private val port: Int = 0
    @Value("\${endpoint}")
    private val endpoint: String? = null
    @Value("\${destination.prefix}")
    private val destinationPrefix: String? = null
    @Value("\${stomp.broker.relay}")
    private val stompBrokerRelay: String? = null

    override fun configureMessageBroker(config: MessageBrokerRegistry) {
        config.enableStompBrokerRelay("/queue/", "/topic/")
                .setRelayHost(host!!)
                .setRelayPort(port)
                .setSystemLogin(userName!!)
                .setSystemPasscode(password!!)
        config.setApplicationDestinationPrefixes(destinationPrefix!!)
    }

    override fun registerStompEndpoints(registry: StompEndpointRegistry) {
        registry.addEndpoint("/websocket").setAllowedOrigins("*").setAllowedOrigins("*")
        registry.addEndpoint("/websocket/messages").addInterceptors(customHttpSessionHandshakeInterceptor()).setAllowedOrigins("*")
        registry.addEndpoint("/wsss/messages").addInterceptors(customHttpSessionHandshakeInterceptor()).setAllowedOrigins("*").withSockJS()
    }

    @Bean
    fun customHttpSessionHandshakeInterceptor(): CustomHttpSessionHandshakeInterceptor {
        return CustomHttpSessionHandshakeInterceptor()
    }

    override fun configureClientInboundChannel(registration: ChannelRegistration) {
        registration.interceptors(object : ChannelInterceptor {
            override fun preSend(message: Message<*>, channel: MessageChannel): Message<*> {
                val accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor::class.java)
                if (StompCommand.CONNECT == accessor!!.command || StompCommand.STOMP == accessor.command) {
                    val authorization = accessor.getNativeHeader("login")
                    println("X-Authorization: {$authorization}")
                    val authToken = authorization!![0].split(" ")[1]

                    val username = jwtTokenUtil.getUsernameFromToken(authToken)

                    if (username != null) {
                        if(username.contains("@")) {
                            val userDetails = userDetailsService!!.loadUserByUsername(username)

                            if (jwtTokenUtil.validateToken(authToken, userDetails)) {
                                val authentication = jwtTokenUtil.getAuthentication(authToken, SecurityContextHolder.getContext().authentication, userDetails)
                                accessor.user = authentication
                            }
                        } else {
                            val authorities = jwtTokenUtil.getAuthoritiesFromToken(authToken)
                            val usernamePasswordAuthenticationToken = UsernamePasswordAuthenticationToken(username, "", authorities)
                            val authentication = authenticationManager.authenticate(usernamePasswordAuthenticationToken)
                            accessor.user = authentication
                        }
                    }


                }
                return message
            }
        })
    }
}

Here is the events handler

@Component
class WebSocketEvents  {


    @EventListener
    fun handleSessionConnected(event: SessionConnectEvent) {

        val headers = SimpMessageHeaderAccessor.wrap(event.message)
        if ( headers.getNativeHeader("chatRoomId") != null && headers.getNativeHeader("chatRoomId")!!.isNotEmpty()){
            val chatId = headers.getNativeHeader("chatRoomId")!![0]
            if (headers.sessionAttributes != null)
                headers.sessionAttributes!!["chatRoomId"] = chatId
        }
    }

    @EventListener
    fun handleSessionDisconnect(event: SessionDisconnectEvent) {
        val headers = SimpMessageHeaderAccessor.wrap(event.message)
        val chatRoomId = headers.sessionAttributes!!["chatRoomId"].toString()
    }
}

So far what I have tried: As you can see above when user first connecting to the websocket endpoint http://localhost:9600/wsss/messages it is sending token and chatroom id (headers) and I am handling this in events listener component by resetting chatroomid into header attributes. What I really need to do is take chatroom id while user subscribing to this specific destionation and apply validation whether he belongs to this chatroom and if so just give him permission | let him join the chat if not return error I really appreciate any thought or workarounds!


Solution

  • I have spent couple of day searching for an answer but did not find any so I have figured out by myself. Here is my solution for this problem, though it is not complete one.

    I have created separate interceptor class for handling all connection types as I was doing while catching subscribe command. So it came to my mind, why not to use Subscribe command to listen users actions and respond to it properly. For instance like this

        @Component
    class WebSocketTopicHandlerInterceptor constructor() : ChannelInterceptor {
    
    
        override fun preSend(message: Message<*>, channel: MessageChannel): Message<*>? {
            val accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor::class.java)
            if (StompCommand.CONNECT == accessor!!.command || StompCommand.STOMP == accessor.command) {
                val authorization = accessor.getNativeHeader("login").apply { if (isNullOrEmpty()) throw LoggedError(AuthorizationException()) }
                val authToken = authorization!![0].split(" ").apply { if (size <= 1) throw LoggedError(InvalidTokenException("Token is not valid")) }[1]
    
                val username = jwtTokenUtil.getUsernameFromToken(authToken)
    
                //DO YOUR AUTHENTICATION HERE
    
    
            }
    
    
    
            if (StompCommand.SUBSCRIBE == accessor.command) {
                val destination = accessor.destination
                if (destination.isNullOrBlank()) throw LoggedError(CustomBadRequestException("Subscription destionation cannot be null! U DUMB IDIOT!"))
                val chatPattern = "/user/queue/+[a-zA-Z0-9-]+.messages".toRegex()
                val notificationPattern = "/topic/notification".toRegex()
    
                if (chatPattern.matches(accessor.destination!!)) println("working")
    
                // FINDING OUT WHERE USER IS TRYING TO SUBSCRIBE ALL ROUTING LOGIC GOES HERE...
                when {
                    chatPattern.matches(destination) -> {
                       //do your all logic here
                    }
                    notificationPattern.matches(destination) -> {
                        //do your all logic here
                    }
                }
    
            }
            return message
        }
    }
    

    IS THERE ANYTHING DOES NOT MAKE SENSE JUST LET ME KNOW I WILL BE VERY HAPPY TO ELABORATE ON ANYTHING FURTHER.

    What I have done in my case is that I have figured out where the user is going and do my all validation there otherwise user cannot subscribe to any channel which means it is very secure.