Search code examples
javasortingcollectionsjava-7

Sort a list of Java objects by multiple fields and group by a particular field


I am trying to sort a List of Java objects according to more than one field. The object is of type:

public class Employee {
    String empId;
    String groupId;
    String salary;
    ...
}

All the employees with same groupId must be grouped together. groupId can be null. The group with the highest total salary (sum of salaries of all the employees in a group) must be at the top of the list. The list must be in descending order. In each group employees must be sorted in the decreasing order of their salaries.

Example: Given data:

+-------+---------+--------+--+
| empId | groupId | salary |  |
+-------+---------+--------+--+
| emp1  | grp1    |    500 |  |
| emp2  | null    |    600 |  |
| emp3  | null    |    700 |  |
| emp4  | grp2    |    800 |  |
| emp5  | grp1    |    700 |  |
| emp6  | grp2    |   1000 |  |
| emp7  | grp1    |    800 |  |
| emp8  | null    |   1000 |  |
| emp9  | grp2    |    600 |  |
+-------+---------+--------+--+

Expected output:

+-------+---------+--------+
| empId | groupId | salary |
+-------+---------+--------+
| emp6  | grp2    |   1000 |
| emp4  | grp2    |    800 |
| emp9  | grp2    |    600 |
| emp8  | null    |   1000 |
| emp3  | null    |    700 |
| emp2  | null    |    600 |
| emp7  | grp1    |    800 |
| emp5  | grp1    |    700 |
| emp1  | grp1    |    500 |
+-------+---------+--------+

My solution:

public class Employee {
    String empId;
    String groupId;
    int salary;

    ...

    public String getEmpId() {
        return empId;
    }

    public void setEmpId(String empId) {
        this.empId = empId;
    }

    public String getGroupId() {
        return groupId;
    }

    public void setGroupId(String groupId) {
        this.groupId = groupId;
    }

    public int getSalary() {
        return salary;
    }

    public void setSalary(int salary) {
        this.salary = salary;
    }

    ...

}

class EmployeeChainedComparator implements Comparator<Employee> {

    private List<Comparator<Employee>> listComparators;

    public EmployeeChainedComparator(Comparator<Employee>... comparators) {
        this.listComparators = Arrays.asList(comparators);
    }

    @Override
    public int compare(Employee o1, Employee o2) {
        for (Comparator<Employee> comparator : listComparators) {
            int result = comparator.compare(o1, o2);
            if (result != 0)
                return result;
        }

        return 0;
    }

}

class EmployeeGroupComparator implements Comparator<Employee> {

    @Override
    public int compare(Employee o1, Employee o2) {
        if(o2.getGroupId() == null)
            return (o1.getGroupId() == null) ? 0 : -1;
        if(o1.getGroupId() == null)
            return 1;
        return o1.getGroupId().compareTo(o2.getGroupId());
    }

}


class EmployeeSalaryComparator implements Comparator<Employee> {

    @Override
    public int compare(Employee o1, Employee o2) {
        return o2.getSalary() - o1.getSalary();
    }

}

class Solution {
    void sortEmployees(List<Employee> employees) {
        Collections.sort(employees, new EmployeeChainedComparator(new EmployeeGroupComparator(), new EmployeeSalaryComparator()))
    }
}

Solution

  • The solution you have posted seems to be too complex for this problem. Given below is a clean approach for solving it:

    1. Sort employees using the comparators defined in the class, Solution.
    2. Group employees by group ID with the sum of salary as the grouping function. In other words, create a Map in which groupId will be the key and the sum of salaries pertaining to the groupId will be the value.
    3. Iterate the sorted entry set of map created in step#2 and put the records corresponding to each entry into the result list.

      Given below is the code implementing the above-mentioned algorithm:

      // Sort employees using the comparators defined in the class, Solution
      new Solution().sortEmployees(empList);
      
      // Group employees by group ID with the sum of salary as the grouping function
      Map<String, Integer> map = new HashMap<>();
      for (Employee e : empList) {
          String grp = e.getGroupId();
          if (grp == null) {
              grp = "null";
          }
          Integer salary = map.get(grp);
          map.put(grp, salary == null ? e.getSalary() : e.getSalary() + salary);
      }
      
      // Result list
      List<Employee> result = new ArrayList<>();
      
      // Iterate the sorted entry set of `map` and put the records corresponding to
      // an entry into the result list
      for (Entry<String, Integer> entry : entriesSortedByValues(map)) {
          String grp = entry.getKey();
          int i;
      
          // Find the starting index of `grp` in empList
      
          if ("null".equals(grp)) {// Special handling for employees with `null` group
              // Find the index in `empList` where employees with the group as `null` starts
              for (i = 0; i < empList.size() && empList.get(i).getGroupId() != null; i++)
                  ;
      
              // Add elements before a different group is encountered
              for (int j = i; j < empList.size() && empList.get(j).getGroupId() == null; j++) {
                  result.add(empList.get(j));
              }
          } else {
              // Find the index in `empList` where employees with the group as `grp` starts
              for (i = 0; i < empList.size() && !grp.equals(empList.get(i).getGroupId()); i++)
                  ;
      
              // Add elements before a different group is encountered
              for (int j = i; j < empList.size() && grp.equals(empList.get(j).getGroupId()); j++) {
                  result.add(empList.get(j));
              }
          }
      }
      

      Demo

      import java.util.ArrayList;
      import java.util.Arrays;
      import java.util.Collections;
      import java.util.Comparator;
      import java.util.HashMap;
      import java.util.List;
      import java.util.Map;
      import java.util.Map.Entry;
      import java.util.Objects;
      import java.util.SortedSet;
      import java.util.TreeSet;
      
      class Employee {
          String empId;
          String groupId;
          int salary;
      
          public Employee(String empId, String groupId, int salary) {
              this.empId = empId;
              this.groupId = groupId;
              this.salary = salary;
          }
      
          public String getEmpId() {
              return empId;
          }
      
          public String getGroupId() {
              return groupId;
          }
      
          public int getSalary() {
              return salary;
          }
      
          @Override
          public boolean equals(Object obj) {
              Employee other = (Employee) obj;
              return Objects.equals(empId, other.empId) && Objects.equals(groupId, other.groupId)
                      && Objects.equals(salary, other.salary);
          }
      
          @Override
          public String toString() {
              return "Employee [empId=" + empId + ", groupId=" + groupId + ", salary=" + salary + "]";
          }
      }
      
      class EmployeeChainedComparator implements Comparator<Employee> {
      
          private List<Comparator<Employee>> listComparators;
      
          public EmployeeChainedComparator(Comparator<Employee>... comparators) {
              this.listComparators = Arrays.asList(comparators);
          }
      
          @Override
          public int compare(Employee o1, Employee o2) {
              for (Comparator<Employee> comparator : listComparators) {
                  int result = comparator.compare(o1, o2);
                  if (result != 0)
                      return result;
              }
      
              return 0;
          }
      
      }
      
      class EmployeeGroupComparator implements Comparator<Employee> {
      
          @Override
          public int compare(Employee o1, Employee o2) {
              if (o2.getGroupId() == null)
                  return (o1.getGroupId() == null) ? 0 : -1;
              if (o1.getGroupId() == null)
                  return 1;
              return o1.getGroupId().compareTo(o2.getGroupId());
          }
      
      }
      
      class EmployeeSalaryComparator implements Comparator<Employee> {
      
          @Override
          public int compare(Employee o1, Employee o2) {
              return o2.getSalary() - o1.getSalary();
          }
      
      }
      
      class Solution {
          void sortEmployees(List<Employee> employees) {
              Collections.sort(employees,
                      new EmployeeChainedComparator(new EmployeeGroupComparator(), new EmployeeSalaryComparator()));
          }
      }
      
      public class Q62447064 {
          public static void main(String[] args) {
              List<Employee> empList = new ArrayList<>(List.of(new Employee("emp1", "grp1", 500),
                      new Employee("emp2", null, 600), new Employee("emp3", null, 700), new Employee("emp4", "grp2", 800),
                      new Employee("emp5", "grp1", 700), new Employee("emp6", "grp2", 1000),
                      new Employee("emp7", "grp1", 800), new Employee("emp8", null, 1000),
                      new Employee("emp9", "grp2", 600)));
      
              // Sort employees using the comparators defined in the class, Solution
              new Solution().sortEmployees(empList);
      
              // Group employees by group ID with the sum of salary as the grouping function
              Map<String, Integer> map = new HashMap<>();
              for (Employee e : empList) {
                  String grp = e.getGroupId();
                  if (grp == null) {
                      grp = "null";
                  }
                  Integer salary = map.get(grp);
                  map.put(grp, salary == null ? e.getSalary() : e.getSalary() + salary);
              }
      
              // Result list
              List<Employee> result = new ArrayList<>();
      
              // Iterate the sorted entry set of `map` and put the records corresponding to
              // an entry into the result list
              for (Entry<String, Integer> entry : entriesSortedByValues(map)) {
                  String grp = entry.getKey();
                  int i;
      
                  // Find the starting index of `grp` in empList
      
                  if ("null".equals(grp)) {// Special handling for employees with `null` group
                      // Find the index in `empList` where employees with the group as `null` starts
                      for (i = 0; i < empList.size() && empList.get(i).getGroupId() != null; i++)
                          ;
      
                      // Add elements before a different group is encountered
                      for (int j = i; j < empList.size() && empList.get(j).getGroupId() == null; j++) {
                          result.add(empList.get(j));
                      }
                  } else {
                      // Find the index in `empList` where employees with the group as `grp` starts
                      for (i = 0; i < empList.size() && !grp.equals(empList.get(i).getGroupId()); i++)
                          ;
      
                      // Add elements before a different group is encountered
                      for (int j = i; j < empList.size() && grp.equals(empList.get(j).getGroupId()); j++) {
                          result.add(empList.get(j));
                      }
                  }
              }
      
              // Display result list
              for (Employee e : result) {
                  System.out.println(e);
              }
          }
      
          private static <K, V extends Comparable<? super V>> SortedSet<Map.Entry<K, V>> entriesSortedByValues(
                  Map<K, V> map) {
              SortedSet<Map.Entry<K, V>> sortedEntries = new TreeSet<Map.Entry<K, V>>(new Comparator<Map.Entry<K, V>>() {
                  @Override
                  public int compare(Map.Entry<K, V> e1, Map.Entry<K, V> e2) {
                      int res = e2.getValue().compareTo(e1.getValue());
                      return res != 0 ? res : 1;
                  }
              });
              sortedEntries.addAll(map.entrySet());
              return sortedEntries;
          }
      }
      

      Output:

      Employee [empId=emp6, groupId=grp2, salary=1000]
      Employee [empId=emp4, groupId=grp2, salary=800]
      Employee [empId=emp9, groupId=grp2, salary=600]
      Employee [empId=emp8, groupId=null, salary=1000]
      Employee [empId=emp3, groupId=null, salary=700]
      Employee [empId=emp2, groupId=null, salary=600]
      Employee [empId=emp7, groupId=grp1, salary=800]
      Employee [empId=emp5, groupId=grp1, salary=700]
      Employee [empId=emp1, groupId=grp1, salary=500]
      

      Note: The method, entriesSortedByValues has been copied from this post.