Search code examples
javaspring-bootjavabeansservlet-filtersmulti-tenant

Accessing UserDetails from Filter (Spring)


Specification

A filter that can call loadUserFromUsername() via userDetailsService in order to retrieve the tenant DB details from the custom UserDetails instance.

Problem

Regardless of what the filter precedence is set to, this custom filter runs before the security filter, and so the spring security context is unpopulated or null. I've confirmed that this context is populated when I access the principal object from a controller.

Attempts

I've set the spring security order in application.properties to 5, and when registering this filter I've used larger and smaller values, but it always runs before. I'm aware that the generic filter bean should allow me to set it to come after in security configuration, but I don't know how to move the configuration and filter into one generic filter bean.

TenantFilter.java

@Component
public class TenantFilter implements Filter {

    @Autowired
    private TenantStore tenantStore;

    @Autowired
    private UserService userService;

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain)
            throws IOException, ServletException {

        HttpServletRequest request = (HttpServletRequest) servletRequest;

        User user = null;
        try {
            user = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal();
        } catch (UsernameNotFoundException ignored) {}

        String tenantId = user != null ? user.getSchool().getCode() : "";

        try {
            this.tenantStore.setTenantId(tenantId);
            chain.doFilter(servletRequest, servletResponse);
        } finally {
            // Otherwise when a previously used container thread is used, it will have the old tenant id set and
            // if for some reason this filter is skipped, tenantStore will hold an unreliable value
            this.tenantStore.clear();
        }
    }

    @Override
    public void destroy() {

    }
}

TenantFilterConfig.java

@Configuration
public class TenantFilterConfig {

    @Bean
    public Filter tenantFilter() {
        return new TenantFilter();
    }

    @Bean
    public FilterRegistrationBean tenantFilterRegistration() {
        FilterRegistrationBean result = new FilterRegistrationBean();
        result.setFilter(this.tenantFilter());
        result.setUrlPatterns(Lists.newArrayList("/*"));
        result.setName("Tenant Store Filter");
        result.setOrder(Ordered.LOWEST_PRECEDENCE-1);
        return result;
    }

    @Bean(destroyMethod = "destroy")
    public ThreadLocalTargetSource threadLocalTenantStore() {
        ThreadLocalTargetSource result = new ThreadLocalTargetSource();
        result.setTargetBeanName("tenantStore");
        return result;
    }

    @Primary
    @Bean(name = "proxiedThreadLocalTargetSource")
    public ProxyFactoryBean proxiedThreadLocalTargetSource(ThreadLocalTargetSource threadLocalTargetSource) {
        ProxyFactoryBean result = new ProxyFactoryBean();
        result.setTargetSource(threadLocalTargetSource);
        return result;
    }

    @Bean(name = "tenantStore")
    @Scope(scopeName = "prototype")
    public TenantStore tenantStore() {
        return new TenantStore();
    }
}

Solution

  • Found a different way that works real nicely: Aspects!

    The pointcut expression used means that this runs around all method calls from all classes within the controllers package in that project.

    The tenant store is based of a safer usage of threadlocal to avoid memory leaks, as this way it is always cleared (due to the finally block)

    Happy coding!

    TenantAspect.java

    @Component
    @Aspect
    public class TenantAspect {
    
        private final
        TenantStore tenantStore;
    
        @Autowired
        public TenantAspect(TenantStore tenantStore) {
            this.tenantStore = tenantStore;
        }
    
        @Around(value = "execution(* com.things.stuff.controller..*(..))")
        public Object assignForController(ProceedingJoinPoint proceedingJoinPoint) throws Throwable {
            return assignTenant(proceedingJoinPoint);
        }
    
        private Object assignTenant(ProceedingJoinPoint proceedingJoinPoint) throws Throwable {
            try {
                User user = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal();
                if (user != null) tenantStore.setTenantId(user.getSchool().getCode());
            } finally {
                Object retVal;
                retVal = proceedingJoinPoint.proceed();
                tenantStore.clear();
                return retVal;
            }
        }
    }