Search code examples
spring-securityredisspring-session

How to work around SpringSessionBackedReactiveSessionRegistry bug? (Spring Security, Spring Session)


Any suggestions on how to work around the below bug, or am I doing something fundamentally wrong? Whenever I do something like val contextAttr = session.getAttribute<Map<String, Any>>(springAttribute), the contextAttr definitely comes back as a LinkedHashedMap.. These objects are stored in Redis as attributes to the main session, which I believe is all in JSON / Map <String, Any?> format.

Describe the bug

When my Spring logout handler calls this

fun invalidateSession(sessionId: String): Mono<Void> {
        logger.info("Invalidating sessionId: ${sessionId}")
        // handle the session invalidation process
        return reactiveSessionRegistry.getSessionInformation(sessionId)
            .flatMap { session ->
                // invalidate session
                session.invalidate()
                    .then(
                        // delete session
                        webSessionStore.removeSession(sessionId)
                    )
                    .doOnSuccess {
                        logger.info("Session invalidated and removed: ${sessionId}")
                    }
                    .doOnError { error ->
                        logger.error("Error invalidating session: ${sessionId}", error)
                    }
            }
    }

The following function inside SpringSessionBackedReactiveSessionRegistry gets called:

    @Override
    public Mono<ReactiveSessionInformation> getSessionInformation(String sessionId) {
        return this.sessionRepository.findById(sessionId).map(SpringSessionBackedReactiveSessionInformation::new);
    }

The inner class is implemented as follows:

class SpringSessionBackedReactiveSessionInformation extends ReactiveSessionInformation {

        SpringSessionBackedReactiveSessionInformation(S session) {
            super(resolvePrincipalName(session), session.getId(), session.getLastAccessedTime());
        }

        private static String resolvePrincipalName(Session session) {
            String principalName = session
                .getAttribute(ReactiveFindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME);
            if (principalName != null) {
                return principalName;
            }
            SecurityContext securityContext = session.getAttribute(SPRING_SECURITY_CONTEXT);
            if (securityContext != null && securityContext.getAuthentication() != null) {
                return securityContext.getAuthentication().getName();
            }
            return "";
        }

        @Override
        public Mono<Void> invalidate() {
            return super.invalidate()
                .then(Mono.defer(() -> SpringSessionBackedReactiveSessionRegistry.this.sessionRepository
                    .deleteById(getSessionId())));
        }

    }

But, here on this line:

SecurityContext securityContext = session.getAttribute(SPRING_SECURITY_CONTEXT);

SPRING_SECURITY_CONTEXT is received from Redis as a HashMap or LinkedHashMap, so cannot be cast to SecurityContext (w/o proper de-serialisation)

This is EXACTLY the error I see:

enter image description here

Two calls to get session?

Also, I'm not sure if this is calling Redis again to get the security context, but is it necessary?, given just before calling /logout endpoint, the session / security context is retrieved anyway, (see below.)

SessionId would come from the session, here, when this is called in the line just before fun invalidateSession(sessionId: String): Mono ) (to get the session ID), so calling getSessionInformation(String sessionId) and with it, this.sessionRepository.findById(sessionId), again, seems a bit wasteful...?

To Reproduce See above, just try the above, with sessions stored to redis, then try to invalidate a session calling the above functions

Expected behavior The casting should be properly deserialised. A linkedHashmap cannot be cast to a SecurityContext object directly

Sample

See above. Github code can be found here:

My implementations https://github.com/dreamstar-enterprises/docs/blob/master/Spring%20BFF/BFF/src/main/kotlin/com/frontiers/bff/auth/sessions/SessionRegistryConfig.kt

Spring implementation (where error is I believe) https://github.com/spring-projects/spring-session/blob/main/spring-session-core/src/main/java/org/springframework/session/security/SpringSessionBackedReactiveSessionRegistry.java

Appreciate any help or suggestions given!


Solution

  • This is a common problem when working with session serialization, particularly with Redis, as objects stored are often serialized to JSON or a similar format that, when retrieved, can become a Map instead of the expected object type. I have faced a similar problem for another spring project in my organization, and I had to write custom Serialization/Deserialization so here is code I think can work for your case too

    import com.fasterxml.jackson.databind.ObjectMapper
    import com.fasterxml.jackson.module.kotlin.KotlinModule
    import com.fasterxml.jackson.module.kotlin.readValue
    import org.springframework.data.redis.serializer.RedisSerializer
    import org.springframework.data.redis.serializer.SerializationException
    import org.springframework.security.core.context.SecurityContext
    import org.springframework.security.core.context.SecurityContextImpl
    import org.springframework.security.core.Authentication
    import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
    
    class SecurityContextSerializer : RedisSerializer<SecurityContext> {
    
        private val objectMapper: ObjectMapper = ObjectMapper().registerModule(KotlinModule())
    
        override fun serialize(securityContext: SecurityContext?): ByteArray? {
            return try {
                // Convert SecurityContext to JSON bytes
                objectMapper.writeValueAsBytes(securityContext)
            } catch (e: Exception) {
                throw SerializationException("Error serializing SecurityContext", e)
            }
        }
    
        override fun deserialize(bytes: ByteArray?): SecurityContext? {
            return try {
                if (bytes == null || bytes.isEmpty()) {
                    return null
                }
                // Convert JSON bytes back to SecurityContext
                val map: Map<String, Any> = objectMapper.readValue(bytes)
                mapToSecurityContext(map)
            } catch (e: Exception) {
                throw SerializationException("Error deserializing SecurityContext", e)
            }
        }
    
        private fun mapToSecurityContext(map: Map<String, Any>): SecurityContext {
            // Custom logic to convert map back to SecurityContext
            val authMap = map["authentication"] as? Map<String, Any>
            val authentication = authMap?.let { mapToAuthentication(it) }
            return SecurityContextImpl(authentication)
        }
    
        private fun mapToAuthentication(map: Map<String, Any>): Authentication {
            // Custom logic to convert map to Authentication, this example uses UsernamePasswordAuthenticationToken
            val principal = map["principal"]
            val credentials = map["credentials"]
            val authorities = emptyList<Any>() // Implement authority conversion if needed
            return UsernamePasswordAuthenticationToken(principal, credentials, authorities)
        }
    }
    

    You need to ensure Spring uses your custom serializer when reading and writing session data.

    import org.springframework.context.annotation.Bean
    import org.springframework.context.annotation.Configuration
    import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory
    import org.springframework.data.redis.core.RedisTemplate
    import org.springframework.session.data.redis.config.annotation.web.server.EnableRedisWebSession
    import org.springframework.session.data.redis.ReactiveRedisSessionRepository
    
    @Configuration
    @EnableRedisWebSession
    class RedisConfig {
    
        @Bean
        fun reactiveRedisSessionRepository(factory: LettuceConnectionFactory): ReactiveRedisSessionRepository {
            val redisTemplate = RedisTemplate<String, SecurityContext>()
            redisTemplate.setConnectionFactory(factory)
            redisTemplate.valueSerializer = SecurityContextSerializer()
            redisTemplate.afterPropertiesSet()
            return ReactiveRedisSessionRepository(redisTemplate)
        }
    }
    

    additionally,

    Avoid Redundant Redis Calls: As you pointed out, calling sessionRepository.findById(sessionId) twice is redundant, especially since you already have the session information. Instead of retrieving the session again, you can pass the session object directly to the invalidation method, thereby avoiding an unnecessary Redis call. You can modify your code to directly use the session object instead of calling getSessionInformation(sessionId) again:

    fun invalidateSession(session: ReactiveSessionInformation): Mono<Void> {
        logger.info("Invalidating sessionId: ${session.getSessionId()}")
        return session.invalidate()
            .then(
                // delete session
                webSessionStore.removeSession(session.getSessionId())
            )
            .doOnSuccess {
                logger.info("Session invalidated and removed: ${session.getSessionId()}")
            }
            .doOnError { error ->
                logger.error("Error invalidating session: ${session.getSessionId()}", error)
            }
    }
    

    Modify resolvePrincipalName to Handle Maps: Since the SecurityContext is being retrieved as a Map, modify the resolvePrincipalName function to detect this and handle deserialization accordingly:

    private static String resolvePrincipalName(Session session) {
        Object principalNameObj = session.getAttribute(ReactiveFindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME);
        if (principalNameObj instanceof String) {
            return (String) principalNameObj;
        }
        Object securityContextObj = session.getAttribute(SPRING_SECURITY_CONTEXT);
        if (securityContextObj instanceof LinkedHashMap) {
            // Deserialize LinkedHashMap to SecurityContext here
            SecurityContext securityContext = deserializeSecurityContext((LinkedHashMap) securityContextObj);
            if (securityContext != null && securityContext.getAuthentication() != null) {
                return securityContext.getAuthentication().getName();
            }
        }
        return "";
    }
    
    private static SecurityContext deserializeSecurityContext(LinkedHashMap map) {
        // Custom logic to convert the map back to SecurityContext
    }