Search code examples
javaalgorithmtime-complexity

Given a list and range, find sum based on new list in less time


Given an integer list of size n, and list of m ranges where each range indicates the start and end indices of input list.

First create a new list using this range, for example:

n=6, list = [1, 2, 3, 2, 4, 5]
m=4, ranges = [[0, 1], [3, 4], [0, 0], [3, 4]]

To create a new list we loop through the ranges, that is from i=0 to m, and pick items at those indices:

for i=0, [0,1], pick items from list[0] to list[1], and append to newList, so newList = [1,2]
for i=1, [3,4], pick items from list[3] to list[4], and append to newList, so newList = [1,2,2,4]
for i=2, [0,0], pick items from list[0] to list[0], so newList = [1,2,2,4,1]
for i=3, pick items list[3], list[4] so newList = [1,2,2,4,1,2,4]

Next loop i through 0 to n, if the index i is part of any of the range mentioned above then add 0 to result, if it is not part of any range then count how many items are there in newList whose values are less than list[i]

initialize result = 0
for i=0, it is part of range [0,1],[0,0] so adds 0 to result
for i=1, it is part of range [0,1] so adds 0 to result
for i=2, it is not part of any ranges so count how many items are there in newList [1,2,2,4,1,2,4], whose values are less than list[2] = 3, we get [1,2,2,1,2], so 5
for i=3, it is part of range [3,4] so adds 0 to result
for i=4, it is part of range [3,4] so adds 0 to result
for i=5, it is not part of any ranges so count how many items are there in newList [1,2,2,4,1,2,4], whose values are less than list[5] = 5, we get [1,2,2,4,1,2,4], so 7

Result = 0 + 0 + 5 + 0 + 0 + 7 = 12

I have implemented the code for this, but it takes more time to process.

Here is my code:

import java.util.*;

public class Main {

    public static long solution(List<Integer> list, List<List<Integer>> ranges) {
        // Create a set to keep track of all indices that contribute to new list
        Set<Integer> contributingIndices = new HashSet<>();
        
        // TreeMap to maintain the frequency of elements in the new list
        TreeMap<Integer, Integer> freqMap = new TreeMap<>();
        
        // Populate the frequency map and contributing indices
        for (List<Integer> range : ranges) {
            int start = range.get(0);
            int end = range.get(1);
            for (int i = start; i <= end; i++) {
                freqMap.put(list.get(i), freqMap.getOrDefault(list.get(i), 0) + 1);
                contributingIndices.add(i);
            }
        }

        long result = 0;

        for (int i = 0; i < list.size(); i++) {
            if (!contributingIndices.contains(i)) {
                // Calculate the number of elements in map that are smaller than list[i]
                result += freqMap.headMap(list.get(i), false)
                                        .values()
                                        .stream()
                                        .mapToInt(Integer::intValue)
                                        .sum();
            }
        }
    
        return result;
    }
    
    static void case1() {
        List<Integer> list = Arrays.asList(1, 2, 3, 4, 5);
        List<List<Integer>> ranges = Arrays.asList(
                Arrays.asList(0, 1),
                Arrays.asList(0, 2),
                Arrays.asList(1, 2)
        );
        System.out.println(solution(list, ranges)); // Output: 14
    }
    static void case2() {
        List<Integer> list = Arrays.asList(1, 2, 3, 4, 5);
        List<List<Integer>> ranges = Arrays.asList(
                Arrays.asList(1, 2),
                Arrays.asList(1, 1),
                Arrays.asList(2, 2),
                Arrays.asList(3, 3),
                Arrays.asList(4, 4)
        );
        System.out.println(solution(list, ranges)); // Output: 0
    }
    
    static void case3() {
        List<Integer> list = Arrays.asList(1, 2, 3, 2, 4, 5);
        List<List<Integer>> ranges = Arrays.asList(
                Arrays.asList(0,1),
                Arrays.asList(3,4),
                Arrays.asList(0,0),
                Arrays.asList(3,4)
        );
        System.out.println(solution(list, ranges)); // Output: 12
    }

    public static void main(String[] args) {
        case1();
        case2();
        case3();
    }
}

To improve time, I tried to merge the ranges, then started creating freqMap, but it gives me wrong frequency as the intervals are merged so we get wrong frequency count.

I looking for a better solution, that takes less time complexity


Solution

  • First, we can use a difference array to count how many times each element is contained by a range in O(n + m).

    Next, create a list of all elements contained in at least one range and sort them. To quickly answer range queries, also construct an array of cumulative (prefix) sums of frequencies of elements in that list.

    Finally, for every element in the original list not contained in any range, binary search for the last element contained in a range that is less than the current element. Get the number of elements less than the current element from the prefix sum array at the index found from binary search, and add that to the result.

    The overall time complexity is O(m + n log n).

    public static long solution(List<Integer> list, List<List<Integer>> ranges) {
        var diff = new int[list.size() + 1];
        for (var range : ranges) {
            ++diff[range.get(0)];
            --diff[range.get(1) + 1];
        }
        var inRange = new ArrayList<Integer>();
        for (int i = 0; i < list.size(); ++i) {
            if (i > 0) diff[i] += diff[i - 1];
            if (diff[i] > 0) inRange.add(i);
        }
        inRange.sort(Comparator.comparing(list::get));
        var prefSum = new long[inRange.size()];
        long currTot = 0;
        for (int i = 0; i < inRange.size(); ++i) prefSum[i] = currTot += diff[inRange.get(i)];
        long result = 0;
        for (int i = 0; i < list.size(); ++i)
            if (diff[i] == 0) {
                int low = 0, high = inRange.size() - 1;
                while (low <= high) {
                    int mid = low + high >>> 1;
                    if (list.get(i) > list.get(inRange.get(mid))) low = mid + 1;
                    else high = mid - 1;
                }
                if (high >= 0) result += prefSum[high];
            }
        return result;
    }