Search code examples
spring-bootspring-data-jpaspring-datapageable

How to only allow specific fields to sort by in a Spring Data JPA Repository Pageable?


Using a Pageable parameter in a Spring Data JPA Repository allows for specifying fields to sort by like: PageRequest.of(0, 50, Sort.by("field1", "field2")), which would sort by field1 and field2 ascending.

It works by appending an ORDER BY clause directly by doing SQL injection which would result in a JPA query like SELECT a FROM SomeEntity a ORDER BY field1, field2. However, if a non-existing field name is passed in it would result in a org.springframework.dao.InvalidDataAccessApiUsageException as seen below.

How do you whitelist, only allow specific fields, or validate the sorting without adding boilerplate code in a service that wraps the repository? Same goes for in a @RestController ensuring that a 400 level HttpStatus.BAD_REQUEST is returned to the API?

org.springframework.dao.InvalidDataAccessApiUsageException: org.hibernate.QueryException: could not resolve property: field1 of: com.example.SomeEntity [SELECT a FROM com.example.SomeEntity a order by a.field1 asc, a.field2 asc]
        at org.springframework.orm.jpa.EntityManagerFactoryUtils.convertJpaAccessExceptionIfPossible(EntityManagerFactoryUtils.java:374)
        at org.springframework.orm.jpa.vendor.HibernateJpaDialect.translateExceptionIfPossible(HibernateJpaDialect.java:257)
        at org.springframework.orm.jpa.AbstractEntityManagerFactoryBean.translateExceptionIfPossible(AbstractEntityManagerFactoryBean.java:531)
        at org.springframework.dao.support.ChainedPersistenceExceptionTranslator.translateExceptionIfPossible(ChainedPersistenceExceptionTranslator.java:61)
        at org.springframework.dao.support.DataAccessUtils.translateIfNecessary(DataAccessUtils.java:242)
        at org.springframework.dao.support.PersistenceExceptionTranslationInterceptor.invoke(PersistenceExceptionTranslationInterceptor.java:154)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.data.jpa.repository.support.CrudMethodMetadataPostProcessor$CrudMethodMetadataPopulatingMethodInterceptor.invoke(CrudMethodMetadataPostProcessor.java:149)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.aop.interceptor.ExposeInvocationInterceptor.invoke(ExposeInvocationInterceptor.java:95)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.aop.framework.JdkDynamicAopProxy.invoke(JdkDynamicAopProxy.java:212)
        at com.sun.proxy.$Proxy340.searchPaged(Unknown Source)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:566)
        at org.zeroturnaround.jrebel.integration.springdata.RepositoryReloadingProxyFactoryBuilder$ReloadingMethodHandler.invoke(RepositoryReloadingProxyFactoryBuilder.java:80)
...
Caused by: java.lang.IllegalArgumentException: org.hibernate.QueryException: could not resolve property: field1 of: com.example.SomeEntity [SELECT a FROM com.example.SomeEntity a order by a.field1 asc, a.field2 asc]
        at org.hibernate.internal.ExceptionConverterImpl.convert(ExceptionConverterImpl.java:138)
        at org.hibernate.internal.ExceptionConverterImpl.convert(ExceptionConverterImpl.java:181)
        at org.hibernate.internal.ExceptionConverterImpl.convert(ExceptionConverterImpl.java:188)
        at org.hibernate.internal.AbstractSharedSessionContract.createQuery(AbstractSharedSessionContract.java:725)
        at org.hibernate.internal.AbstractSharedSessionContract.createQuery(AbstractSharedSessionContract.java:113)
        at jdk.internal.reflect.GeneratedMethodAccessor749.invoke(Unknown Source)
        at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:566)
        at org.springframework.orm.jpa.ExtendedEntityManagerCreator$ExtendedEntityManagerInvocationHandler.invoke(ExtendedEntityManagerCreator.java:366)
        at com.sun.proxy.$Proxy265.createQuery(Unknown Source)
        at jdk.internal.reflect.GeneratedMethodAccessor749.invoke(Unknown Source)
        at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:566)
        at org.springframework.orm.jpa.SharedEntityManagerCreator$SharedEntityManagerInvocationHandler.invoke(SharedEntityManagerCreator.java:314)
        at com.sun.proxy.$Proxy265.createQuery(Unknown Source)
        at org.springframework.data.jpa.repository.query.AbstractStringBasedJpaQuery.createJpaQuery(AbstractStringBasedJpaQuery.java:150)
        at org.springframework.data.jpa.repository.query.AbstractStringBasedJpaQuery.doCreateQuery(AbstractStringBasedJpaQuery.java:86)
        at org.springframework.data.jpa.repository.query.AbstractJpaQuery.createQuery(AbstractJpaQuery.java:226)
        at org.springframework.data.jpa.repository.query.JpaQueryExecution$PagedExecution.doExecute(JpaQueryExecution.java:175)
        at org.springframework.data.jpa.repository.query.JpaQueryExecution.execute(JpaQueryExecution.java:88)
        at org.springframework.data.jpa.repository.query.AbstractJpaQuery.doExecute(AbstractJpaQuery.java:154)
        at org.springframework.data.jpa.repository.query.AbstractJpaQuery.execute(AbstractJpaQuery.java:142)
        at org.springframework.data.repository.core.support.RepositoryFactorySupport$QueryExecutorMethodInterceptor.doInvoke(RepositoryFactorySupport.java:618)
        at org.springframework.data.repository.core.support.RepositoryFactorySupport$QueryExecutorMethodInterceptor.invoke(RepositoryFactorySupport.java:605)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.data.projection.DefaultMethodInvokingMethodInterceptor.invoke(DefaultMethodInvokingMethodInterceptor.java:80)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.transaction.interceptor.TransactionAspectSupport.invokeWithinTransaction(TransactionAspectSupport.java:367)
        at org.springframework.transaction.interceptor.TransactionInterceptor.invoke(TransactionInterceptor.java:118)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.dao.support.PersistenceExceptionTranslationInterceptor.invoke(PersistenceExceptionTranslationInterceptor.java:139)
        ... 127 more
Caused by: org.hibernate.QueryException: could not resolve property: field1 of: com.example.SomeEntity [SELECT a FROM com.example.SomeEntity a order by a.field1 asc, a.field2 asc]
        at org.hibernate.QueryException.generateQueryException(QueryException.java:120)
        at org.hibernate.QueryException.wrapWithQueryString(QueryException.java:103)
        at org.hibernate.hql.internal.ast.QueryTranslatorImpl.doCompile(QueryTranslatorImpl.java:220)
        at org.hibernate.hql.internal.ast.QueryTranslatorImpl.compile(QueryTranslatorImpl.java:144)
        at org.hibernate.engine.query.spi.HQLQueryPlan.<init>(HQLQueryPlan.java:113)
        at org.hibernate.engine.query.spi.HQLQueryPlan.<init>(HQLQueryPlan.java:73)
        at org.hibernate.engine.query.spi.QueryPlanCache.getHQLQueryPlan(QueryPlanCache.java:162)
        at org.hibernate.internal.AbstractSharedSessionContract.getQueryPlan(AbstractSharedSessionContract.java:604)
        at org.hibernate.internal.AbstractSharedSessionContract.createQuery(AbstractSharedSessionContract.java:716)
        ... 154 more
Caused by: org.hibernate.QueryException: could not resolve property: field1 of: com.example.SomeEntity
        at org.hibernate.persister.entity.AbstractPropertyMapping.propertyException(AbstractPropertyMapping.java:77)
        at org.hibernate.persister.entity.AbstractPropertyMapping.toType(AbstractPropertyMapping.java:71)
        at org.hibernate.persister.entity.AbstractEntityPersister.toType(AbstractEntityPersister.java:2043)
        at org.hibernate.hql.internal.ast.tree.FromElementType.getPropertyType(FromElementType.java:412)
        at org.hibernate.hql.internal.ast.tree.FromElement.getPropertyType(FromElement.java:520)
        at org.hibernate.hql.internal.ast.tree.DotNode.getDataType(DotNode.java:694)
        at org.hibernate.hql.internal.ast.tree.DotNode.prepareLhs(DotNode.java:269)
        at org.hibernate.hql.internal.ast.tree.DotNode.resolve(DotNode.java:209)
        at org.hibernate.hql.internal.ast.HqlSqlWalker.resolve(HqlSqlWalker.java:1053)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.expr(HqlSqlBaseWalker.java:1303)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.orderExpr(HqlSqlBaseWalker.java:1887)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.orderExprs(HqlSqlBaseWalker.java:1681)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.orderClause(HqlSqlBaseWalker.java:1654)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.query(HqlSqlBaseWalker.java:666)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.selectStatement(HqlSqlBaseWalker.java:325)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.statement(HqlSqlBaseWalker.java:273)
        at org.hibernate.hql.internal.ast.QueryTranslatorImpl.analyze(QueryTranslatorImpl.java:276)
        at org.hibernate.hql.internal.ast.QueryTranslatorImpl.doCompile(QueryTranslatorImpl.java:192)
        ... 160 more



Solution

  • I ended up using JSR-303 validations on the repository methods to whitelist the sort fields.

    Enable method validation post processor to run JSR-303 validation annotations at the method level.

    ValidationConfig.java

    @Configuration
    public class ValidationConfig {
    
        @Bean
        public MethodValidationPostProcessor methodValidationPostProcessor() {
            return new MethodValidationPostProcessor();
        }
    }
    

    Create a validation that takes in a list of sort fields to validate against.

    AllowSortFields.java

    @Documented
    @Constraint(validatedBy = {AllowSortFieldsValidator.class})
    @Target({ANNOTATION_TYPE, TYPE, FIELD, PARAMETER})
    @Retention(RUNTIME)
    public @interface AllowSortFields {
    
        String message() default "Sort field values provided are not within the allowed fields that are sortable.";
    
        Class<?>[] groups() default {};
    
        Class<? extends Payload>[] payload() default {};
    
        /**
         * Specify an array of fields that are allowed.
         *
         * @return the allowed sort fields
         */
        String[] value() default {};
    
    }
    

    AllowSortFieldsValidator.java

    /**
     * Validates a list of sort fields within a Pageable against an allowed list.
     */
    public class AllowSortFieldsValidator implements ConstraintValidator<AllowSortFields, Pageable> {
    
        private List<String> allowedSortFields;
    
        static final String PROPERTY_NOT_FOUND_MESSAGE = "The following sort fields [%s] are not within the allowed fields. "
                + "Allowed sort fields are: [%s]";
    
        @Override
        public void initialize(AllowSortFields constraintAnnotation) {
            allowedSortFields = Arrays.asList(constraintAnnotation.value());
        }
    
        @Override
        public boolean isValid(Pageable value, ConstraintValidatorContext context) {
            if (value == null) {
                return true;
            }
    
            if (CollectionUtils.isEmpty(allowedSortFields)) {
                return true;
            }
    
            // ignore unsorted
            Sort sort = value.getSort();
            if (sort.isUnsorted()) {
                return true;
            }
    
            String fieldsNotFound = fieldsNotFoundAsCommaDelimited(sort);
    
            // all found fields are allowed
            if (StringUtils.isEmpty(fieldsNotFound)) {
                return true;
            }
    
            context.disableDefaultConstraintViolation();
            context.buildConstraintViolationWithTemplate(String.format(PROPERTY_NOT_FOUND_MESSAGE, fieldsNotFound, String.join(",", allowedSortFields)))
                    .addConstraintViolation();
            return false;
    
        }
    
        private String fieldsNotFoundAsCommaDelimited(Sort sort) {
            String fieldsNotFound = sort.stream()
                    .map(order -> order.getProperty())
                    .filter(property -> !allowedSortFields.contains(property))
                    .collect(joining(","));
            return fieldsNotFound;
        }
    }
    

    AllowSortFieldsValidatorSmallTest.java

    public class AllowSortFieldsValidatorSmallTest {
    
        private static final String[] ALLOWED_SORT_FIELDS = new String[]{"allowed1", "allowed2"};
    
        private static final String ALLOWED_SORT_FIELDS_DELIMITED = String.join(",", Arrays.asList(ALLOWED_SORT_FIELDS));
    
        private static Validator validator;
    
        @BeforeClass
        public static void setupValidator() throws Exception {
            ValidatorFactory factory = Validation.buildDefaultValidatorFactory();
            validator = factory.getValidator();
        }
    
        @Test
        public void isValid_TwoOfFourFieldsAllowed_FalseWithExpectedMessageExplainingDisallowedFields() {
    
            List<String> sortFields = List.of("allowed1", "allowed2|desc", "notfound1", "not.found2");
    
            AllowedSortFields toValidate = newAllowedSortFields(sortFields);
    
            Set<ConstraintViolation<AllowedSortFields>> constraintViolations = validator.validate(toValidate, Default.class);
    
            String expected = String.format(AllowSortFieldsValidator.PROPERTY_NOT_FOUND_MESSAGE, "notfound1,not.found2", ALLOWED_SORT_FIELDS_DELIMITED);
            String actual = getConstraintMessages(constraintViolations);
    
            assertEquals(expected, actual);
    
        }
    
        @Test
        public void isValid_NoSortFields_True() {
    
            List<String> sortFields = null;
    
            AllowedSortFields toValidate = newAllowedSortFields(sortFields);
    
            Set<ConstraintViolation<AllowedSortFields>> constraintViolations = validator.validate(toValidate, Default.class);
    
            assertTrue(constraintViolations.isEmpty());
    
        }
    
        @Test
        public void isValid_EmptyAllowedSortFields_True() {
    
            List<String> sortFields = List.of("allowed1", "allowed2|desc", "notfound1", "not.found2");
    
            EmptyAllowedSortFields toValidate = newEmptyAllowedSortFields(sortFields);
    
            Set<ConstraintViolation<EmptyAllowedSortFields>> constraintViolations = validator.validate(toValidate, Default.class);
    
            assertTrue(constraintViolations.isEmpty());
    
        }
    
        @Test
        public void isValid_AllSortFieldsFoundAsAllowed_True() {
    
            List<String> sortFields = Arrays.asList(ALLOWED_SORT_FIELDS);
    
            AllowedSortFields toValidate = newAllowedSortFields(sortFields);
    
            Set<ConstraintViolation<AllowedSortFields>> constraintViolations = validator.validate(toValidate, Default.class);
    
            assertTrue(constraintViolations.isEmpty());
    
        }
    
        @Test
        public void isValid_NullValue_True() {
    
            AllowedSortFields toValidate = new AllowedSortFields();
            toValidate.pageable = null;
    
            Set<ConstraintViolation<AllowedSortFields>> constraintViolations = validator.validate(toValidate, Default.class);
    
            assertTrue(constraintViolations.isEmpty());
    
        }
    
        private String getConstraintMessages(Set<ConstraintViolation<AllowedSortFields>> constraintViolations) {
            String actual = constraintViolations.stream()
                    .map(c -> c.getMessage())
                    .collect(joining(","));
            return actual;
        }
    
        private AllowedSortFields newAllowedSortFields(List<String> sortFields) {
            AllowedSortFields toValidate = new AllowedSortFields();
            toValidate.pageable = new CustomPageable().sort(sortFields);
            return toValidate;
        }
    
        private EmptyAllowedSortFields newEmptyAllowedSortFields(List<String> sortFields) {
            EmptyAllowedSortFields toValidate = new EmptyAllowedSortFields();
            toValidate.pageable = new CustomPageable().sort(sortFields);
            return toValidate;
        }
    
        public class AllowedSortFields {
    
            @AllowSortFields({"allowed1", "allowed2"})
            public Pageable pageable;
    
        }
    
        public class EmptyAllowedSortFields {
    
            @AllowSortFields
            public Pageable pageable;
    
        }
    }
    

    Finally the usage within the repository. Be sure to put @Validated at the top of the class.

    ExampleSearchRepository.java

    public interface ExampleSearchRepository extends JpaRepository<ExampleSearch, Integer>,
        JpaSpecificationExecutor<ExampleSearch>, PagingAndSortingRepository<ExampleSearch, Integer> {
    
        public Page<ExampleSearch> search(
            @Param("searchCriteria") ExampleSearchCriteria searchCriteria, 
            @AllowSortFields({"field1","subfield.name"}) Pageable pageable);