Search code examples
javaspringjava-8spring-data-jpajpql

Avoid in-memory calculation for OneToMany ... when fetching MAX 1 Child


I am using Spring Data JPA and have the entities that map to their respective tables. I need to query the results so that I can fetch all the parent and one child per parent based on their strength.

@Entity
public class Parent {
  @Id
  Long id;

  @OneToMany(mappedBy = "parent", cascade = CascadeType.ALL)
  private List<Child> children;
}

@Entity
public class Child {
  @Id
  Long id;

  @Enumerated(EnumType.STRING)
  private Strength strength;

  @ManyToOne(fetch = FetchType.EAGER)
  @JoinColumn(name = "parent_id", nullable = false, insertable = true, updatable = false)
  private Parent parent;
} 

public Enum Strength {
  STRONG,
  NORMAL,
  WEAK
}

I have a basic crud repository as follows:

@Repository
public interface ParentRepository extends CrudRepository<Parent, Long>{}

Some rules and assumption:

  1. A child belongs to a single parent
  2. A parent can have many child objects in DB
  3. A parent can have 0 or 1 child that has strength = STRONG
  4. A parent will have 1 child that has strength = NORMAL
  5. A parent will have 0 or more child objects that has strength = WEAK
  6. A weak child is never returned
  7. The getParentAndStrongChildren method below should return a max of 1 child.

I can do a findAll query on Parent Repository method in Spring and then map the results in memory something like this

public List<Parent> getParentAndStrongChildren(){
    List<Parent> parents = parentRepository
        .findAll().stream()
        .map(p -> {
            if(p.getChildren() != null && p.getChildren.size() > 1){
               Child found = p.getChildren().stream()
                            .filter(c -> c.getStrength() == Strength.STRONG)
                            .findFirst()
                            .orElseGet(()-> p.getChildren().stream()
                                             .filter(c -> c.getStrength() == Strength.NORMAL)
                                             .findFirst()   
                                             .orElse(null));
                p.setChildren(found == null ? null : new Arrays.asList(found));
            }
        }
    return parents;
}

Q: Is there any way to not do the filters in memory and rely on JPQL and @Query annotation to achieve this?


Solution

  • This is a typical "top N per category" SQL query. I personally doubt this can be done with JPQL, but maybe someone else will provide an answer. Here's a standard SQL solution to your problem, using lateral:

    SELECT p.*
    FROM parent p
    LEFT JOIN LATERAL (
      SELECT c.*
      FROM child c
      WHERE p.id = c.parent_id
      AND c.strength != 'WEAK'
      ORDER BY CASE c.strength WHEN 'STRONG' THEN 1 WHEN 'NORMAL' THEN 2 END
      FETCH FIRST ROW ONLY
    ) c ON 1 = 1
    

    Alternatively, using window functions (also standard SQL):

    SELECT p.*
    FROM parent p
    LEFT JOIN (
      SELECT c.*, row_number() OVER (
        PARTITION BY c.parent_id 
        ORDER BY CASE c.strength WHEN 'STRONG' THEN 1 WHEN 'NORMAL' THEN 2 END
      ) AS rk
      FROM child c
      WHERE c.strength != 'WEAK'
    ) c ON p.id = c.parent_id AND c.rk = 1