Search code examples
c++algorithmdata-structuressegment-treebinary-indexed-tree

Count "minimal" values


Problem:

I have an input of n vectors:

(x, y, z): x ∈ {1..n},y ∈ {1..n},z ∈ {1..n} (every "dimension" is set(1..n))
*I mean that in one vector x,y,z can be the same(x=y=z),
 but for ∀v1,v2 => x1≠x2, y1≠y2, z1≠z2

v1>v2 if and only if x1>x2,y1>y2,z1>z2. 
lets denote vector v1 "minimal" if and only if ∄v ∈ input: v>v1

The task is to count minimal vectors in input.

Source:

I found this problem in the task of local programming contest.

The (translated) formulation is:

n people participeted in competion. competion had three phases(every competitor 
took part in every stage). denote that the participant is better then 
participant b, if a ranks in all three stages is higher then participant b ranks. 
participant c is the best, if there is no such participant who is better 
than participant c. output the number of best participants.

1<=n<=100000 Time limit: 1 sec.

My attempts & thoughts

First idea was to create class Result(for competitors results), overload operator > (or <) just like:

bool operator > (const Score &s) const
{
    if (first_result > s.first_result)
        if (second_result > s.second_result)
            return third_result > s.third_result;
    return false;
}

and build whatever array based(min-heap for example) that allows to find min values(using <) and count them(i think i've just "recreated" a bad variant of heap-sort following this way). After i failed this attempt i've tried Fenwick tree(Binary indexed tree) for same task.

But then i've understood that my approach is incorrect (not ok class and < overload) and mb the idea of convert the task in 1d is not good at all.

Then i've found some info about BIT & segment tree for n-dimensions case, and i think that i can use them to solve this problem. But it's pretty hard for me to implement the working variant(and even understand the working principle of segment tree in more then 1d)

Maybe someone can help with the implementation (or find better solution and explain it)?


Solution

  • First, we'll need an ordered key/value data structure that you can insert, delete, and find the prev/last value less than or equal to your own in time O(log(n)). Think red-black tree or btree or skip list.

    I will use the following invented notation for that data structure. I am deliberately making it not look like any real language.

    by_order.prev(key) gives the k-v pair associated to the largest key <= to key. by_order.prev(key).k gives the largest key <= to key. This can be None. by_order.prev(key).v gives the value associated to the largest key <= to key. by_order.next(key) gives the k-v pair associated to the smallest key >= to key with .k and .v meaning what they did before. by_order.add(key, value) adds a k-v pair. by_order.del(key) removes the k-v pair with value key.

    The idea is this. We first sort by x then y then z. The first vector is minimal. Every vector after that is minimal if its value of z is less than the lowest value of z for any previous element with lower or equal y. We will use the by_order data structure to test that condition.

    Assuming that I made no mistakes, here is pseudocode:

    sort(vectors) by x then y then z
    Declare and initialize your empty ordered data structure by_order
    // NOTE: by_order[val] will be the value of the largest key <= val
    answer = [] // ie an empty list
    answer.push(vectors[0])
    by_order.add(vectors[0].y, by_order[vectors[0].z)
    for v in vectors:
        z_best = by_order.prev(v.y).v
        if z_best is None or v.z < z_best:
            answer.push(v) // Yay!
            // Clear anything in by_order that we are an improvement on
            while True:
                pair = by_order.next(v)
                if pair.v is not none and pair.k < v.z:
                    by_order.del(pair.v)
                else:
                    break
            // and now we add this one to by_order.
            by_order.add(v.y, v.z)
    

    The total time taken for the sort is O(n log(n)).

    Followed by for each of n vectors a O(log(n)) lookup to see whether to insert it, possibly followed by a O(1) insert into the answer, a O(log(n)) lookup what still follows it (don't worry, I didn't lose track of the ones that got deleted), followed by a O(log(n)) insert, followed by a O(log(n)) check that finds this one needs to be deleted, followed by a O(log(n)) delete.

    That's a lot of O(log(n)) terms, but the sum is still O(log(n)). n times.

    The result is a O(n log(n)) algorithm for the whole problem.