Search code examples
javac++g++treeset

Counting Inversions using TreeSet Java


I'm using TreeSet for solving counting inversions problem. I'm using following approach which uses gnu_pbds which works in O(logn) time.

Algorithm

  1. Insert the first element of the array in the Ordered_Set.
  2. For all the remaining element in arr[] do the following:
  • Insert the current element in the Ordered_Set.
  • Find the number of element strictly less than current element + 1 in Ordered_Set using function order_of_key(arr[i]+1).
  • The difference between size of Ordered_Set and order_of_key(current_element + 1) will given the inversion count for the current element.

You can read more about this algorithm here.

For order_of_key method I'm using TreeSet's headset(k) method for calculation.

My code

public class Solution {
    
    static int order_of_key(TreeSet<Integer> s, int k)
    {
                return s.headSet(k, true).size();
    }

    public int solve(ArrayList<Integer> a) {
    
        TreeSet<Integer> s = new TreeSet<>();
        s.add(a.get(0));
        int invcount = 0;
        for(int i = 1; i < a.size(); i++)
        {
            s.add(a.get(i));
            int key = order_of_key(s, a.get(i) + 1);
            // if (i + 1 == a.size()) key--;
            invcount += s.size() - key;
            
           // System.out.println(s+" " + (a.get(i) + 1) + " " + key + " " + invcount);
                  
        }
        return invcount;  
    }
}

Respective C++ Code

// Ordered set in GNU C++ based 
// approach for inversion count 
#include <bits/stdc++.h> 
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 
using namespace std; 

// Ordered Set Tree 
typedef tree<int, null_type, less_equal<int>, 
            rb_tree_tag, 
            tree_order_statistics_node_update> 
    ordered_set; 

// Returns inversion count in 
// arr[0..n-1] 


void print(ordered_set s, int n){
  
 for(int i = 0; i < n; i++)
 {
  // printf("%d ",s[i]); // <<endl; 
 cout << *(s.find_by_order(i))  
         << " ";
 }
  // cout << endl;
}

int getInvCount(int arr[], int n) 
{ 
    int key; 
    // Intialise the ordered_set 
    ordered_set set1; 

    // Insert the first 
    // element in set 
    set1.insert(arr[0]); 

    // Intialise inversion 
    // count to zero 
    int invcount = 0; 

    // Finding the inversion 
    // count for current element 
    for (int i = 1; i < n; i++) { 
        set1.insert(arr[i]); 
        // Number of elements strictly 
        // less than arr[i]+1 
        key = set1.order_of_key(arr[i] + 1); 

        // Difference between set size 
        // and key will give the 
        // inversion count 
        invcount += set1.size() - key; 
        print(set1, n);
      cout << arr[i] + 1 << " ";
      cout << key << " ";
      
        cout << " " << invcount << endl;
        
    } 
    return invcount; 
} 

// Driver's Code 
int main() 
{ 
    int arr[] = { 32, 35, 43, 1, 38, 39, 42 }; 
    int n = sizeof(arr) / sizeof(int); 
    cout << n << endl;
    // Function call to count 
    // inversion 
    cout << getInvCount(arr, n); 
    return 0; 
} 

Observations

  • my solution works for all test cases but for some trivial test cases it returns the 1- inversion count. If I changed the working of order_of_key to the following it breaks some other test cases.
int order_of_key(TreeSet<Integer> s, int k)
    {
            if(s.contains(k))
                return s.headSet(k).size();
            else
                return s.headSet(k, true).size();
    }
  • but the gnu_pbds works for all the time.

Please help me fix this code using TreeSet. Also if I'm missing the some another points respect to pbds order_of_key method then let me know.


Solution

  • The order_of_key function returns the number of elements less than the provided key. In your C++ code, you add the element to the set and then call order_of_key element plus one, in order to include the element you just added.

    Preserving this arrangement in the Java code, the way to implement something like order_of_key faithfully using a Java TreeSet is to call

    return s.headSet(k, false).size();
    

    or simply

    return s.headSet(k).size();
    

    because either of these returns the size of the head set containing elements strictly less than k. There's no need to do a contains-check first. When I run this using the input array from your C++ code, I get the same intermediate results and final result for the number of inversions: 6.

    Note that Java's TreeSet also includes the ability to create the tail-set, that is, the elements that are greater than the provided element. This turns out to be (in my opinion) a much simpler way to compute the number of inversions. Prior to adding an element, the size of the tail-set that is larger than the current element is the number of inversions relative to this element. Since we want the tail-set to be strictly larger than the current element, we use tailSet(k, false). We can then do this for each element in the input list:

    int inversions(List<Integer> a) {
        var s = new TreeSet<Integer>();
        int invcount = 0;
        for (int k : a) {
            invcount += s.tailSet(k, false).size();
            s.add(k);
        }
        return invcount;
    }
    
    inversions(List.of(32, 35, 43, 1, 38, 39, 42)) // result is 6
    

    UPDATE 2020-06-24

    The above code works only for input with unique values. If the input has duplicate values, it doesn't work. I note that in the C++ code, a tree that uses the less_equal<int> comparison function preserves duplicates, while using less<int> compresses out duplicates.

    The reason that keeping duplicates is important is that each element -- even if it's a duplicate -- can count as an inversion. Thus, input of [2, 2, 1] is considered to have two inversions. A Java TreeSet compresses out duplicates, so we have to do some additional work to preserve them.

    One way to allow "duplicate" int values is to make them unique somehow. This can be done by creating a new object to contain the int value paired with a counter that's always incremented. Duplicate values are made unique since they'll have different counter values. Here's a class that does that:

    static class UniqInt implements Comparable<UniqInt> {
        static int count = 0;
        final int value;
        final int uniq;
        UniqInt(int value) { this.value = value; uniq = count++; }
        public int compareTo(UniqInt other) {
            int c = Integer.compare(this.value, other.value);
            return c != 0 ? c : Integer.compare(this.uniq, other.uniq);
        }
    }
    

    Note that the compareTo method here compares both the value and the "uniq" counter value, so creating multiple UniqInt instances will differ from each other and will be totally ordered. Once we have this class, we essentially do the same thing, except that we keep track of UniqInt instead of Integer objects in the TreeSet.

    int inversions1(List<Integer> a) {
        var s = new TreeSet<UniqInt>();
        int invcount = 0;
        for (int k : a) {
            var u = new UniqInt(k);
            invcount += s.tailSet(u, false).size();
            s.add(u);
        }
        return invcount;
    }
    

    For the input provided in the comment (which includes duplicate values),

    84, 2, 37, 3, 67, 82, 19, 97, 91, 63, 27, 6, 13, 90, 63, 89, 100, 60, 47, 96, 54, 26, 64, 50, 71, 16, 6, 40, 84, 93, 67, 85, 16, 22, 60

    this gives the expected result of 290.

    This is a bit fragile, however. Note that creating a new UniqInt with the same value always creates an instance that is greater than any existing one because the counter is always incremented. (Until you've created 2^31 of them.) When the new instance is created, its tail-set will never include any of the duplicate values already in the TreeSet. This is probably reasonable for this small example, but if this were part of a larger system, I'd think more carefully about how to get the correct head-set or tail-set relative to some value, without having to rely on the most recently created UniqInt being greater than all previous ones.