Search code examples
javaspring-bootpaginationhibernate-criteriacriteria-api

How to iterate over every page returned by a JPA Criteria?


I need to build a background process that will periodically process all the elements of a database table. Since I can not load all the elements in memory, I need to divide the database in subportions. Unfortunately, I can not use JPQL Stream return type, since the query uses a complex filter based logic which is not achievable by writing a JPQL query.

Thus, I have built the following criteria query, to return only a page of Products. After I have processed a page, how can I iterate over next one until I have processed them all?

import org.springframework.data.jpa.repository.query.QueryUtils;

@Repository
@AllArgsConstructor
public class ProductRepositoryImpl implements CustomProductRepository {

    private final EntityManager entityManager;

    @Override
    public Page<Product> getProducts(Pageable pageable, /* fields used to filter the results */) {
        CriteriaBuilder criteriaBuilder = entityManager.getCriteriaBuilder();
        CriteriaQuery<Product> criteriaQuery = criteriaBuilder.createQuery(Product.class);
        Root<Product> root = criteriaQuery.from(Product.class);
        
        List<Predicate> predicates = getPredicates(/* fields used to filter the results */);

        criteriaQuery.where(combinePredicatesWithAndStatement(criteriaBuilder, predicates))
            .orderBy(QueryUtils.toOrders(pageable.getSort(), root, criteriaBuilder));

        List<Product> result = entityManager.createQuery(criteriaQuery)
            .setFirstResult((int) pageable.getOffset())
            .setMaxResults(pageable.getPageSize())
            .getResultList();

        CriteriaBuilder criteriaBuilderCount = entityManager.getCriteriaBuilder();
        CriteriaQuery<Long> countQuery = criteriaBuilderCount.createQuery(Long.class);
        Root<Product> rootCount = countQuery.from(Product.class);

        List<Predicate> predicatesCount = getPredicates(/* fields used to filter the results */);

        countQuery.select(criteriaBuilderCount.count(rootCount))
            .where(combinePredicatesWithAndStatement(criteriaBuilderCount, predicatesCount));

        Long totalElements = entityManager.createQuery(countQuery).getSingleResult();

        return new PageImpl<>(result, pageable, totalElements);
    }

    private List<Product> getPredicates(/* fields used to filter the results */) {
        List<Predicate> predicates = new ArrayList<Predicate>();
        // assemble predicates based on some complex conditions not replicable with JPQL
        return predicates;
    }

    private Predicate combinePredicatesWithAndStatement(CriteriaBuilder criteriaBuilder, List<Predicate> predicates) {
        return criteriaBuilder.and(predicates.stream().filter(Objects::nonNull).toArray(Predicate[]::new));
    }
}

Solution

  • To iterate over all the pages you can create a PageStreamer class as follows:

    import org.springframework.data.domain.Page;
    import org.springframework.data.domain.Pageable;
    import org.springframework.data.domain.Slice;
    
    import java.util.Objects;
    import java.util.concurrent.atomic.AtomicReference;
    import java.util.function.Function;
    import java.util.stream.Stream;
    
    public class PageStreamer<T> {
    
        private final Pageable pageable;
        private final Function<Pageable, Page<T>> getPage;
    
        public PageStreamer(Pageable pageable, Function<Pageable, Page<T>> getPage) {
            this.pageable = pageable;
            this.getPage = getPage;
        }
    
        public Stream<Page<T>> stream() {
            val currentPageableReference = new AtomicReference<>(pageable);
    
            return Stream.generate(currentPageableReference::get)
                .takeWhile(Objects::nonNull)
                .map(getPage)
                .takeWhile(Slice::hasContent)
                .map(page -> setNextPageable(page, currentPageableReference));
        }
    
        private Page<T> setNextPageable(Page<T> page, AtomicReference<Pageable> currentPageableReference) {
            currentPageableReference.set(page.hasNext() ? page.nextPageable() : null);
            return page;
        }
    }
    

    Then you can use it as follows:

    PageStreamer<Product> pageStreamer = new PageStreamer<>(
        // starts retrieving the products from the first page,
        // using a 50 elements window size
        PageRequest.of(0, 50),
        pageable -> productRepository.getProducts(pageable, /* fields used to filter the results */)
    );
    
    pageStreamer.stream().forEach(product -> {
        // perform some processing
    });