Search code examples
algorithmlanguage-agnosticdynamic-programmingsegment-tree

How to solve weighted Activity selection with use of Segment Trees and Binary search?


Given N jobs where every job is represented by following three elements of it.

1) Start Time

2) Finish Time.

3) Profit or Value Associated.

Find the maximum profit subset of jobs such that no two jobs in the subset overlap.

I know a dynamic programming solution which has a complexity of O(N^2) (close to LIS where we have to just check the previous elements with which we can merge the current interval and take the interval whose merging gives maximum till the i th element ).This solution can be further improved to O(N*log N ) using Binary search and simple sorting!

But my friend was telling me that it can be even solved by using Segment Trees and binary search! I have no clue as to where I am going to use Segment Tree and how .??

Can you help?

On request,sorry not commented

What I am doing is sorting on the basis of the starting index, storing the maximum obtainable value till i at DP[i] by merging previous intervals and their maximum obtainable value !

 void solve()
    {
        int n,i,j,k,high;
        scanf("%d",&n);
        pair < pair < int ,int>, int > arr[n+1];// first pair represents l,r and int alone shows cost
        int dp[n+1]; 
        memset(dp,0,sizeof(dp));
        for(i=0;i<n;i++)
            scanf("%d%d%d",&arr[i].first.first,&arr[i].first.second,&arr[i].second);
        std::sort(arr,arr+n); // by default sorting on the basis of starting index
        for(i=0;i<n;i++)
        {
            high=arr[i].second;
            for(j=0;j<i;j++)//checking all previous mergable intervals //Note we will use DP[] of the mergable interval due to optimal substructure
            {
                if(arr[i].first.first>=arr[j].first.second)  
                        high=std::max(high , dp[j]+arr[i].second);
            }
            dp[i]=high;
        }
        for(i=0;i<n;i++)
                dp[n-1]=std::max(dp[n-1],dp[i]);
        printf("%d\n",dp[n-1]);
    }

int main()
{solve();return 0;}

EDIT: My working code finally took me 3 hours to debug it though! Morover this code is slower than the binary search and sorting one due to a larger constant and bad implementation :P (just for reference)

#include<stdio.h>
#include<algorithm>
#include<vector>
#include<cstring>
#include<iostream>
#include<climits>
#define lc(idx) (2*idx+1)
#define rc(idx) (2*idx+2)
#define mid(l,r) ((l+r)/2)
using namespace std;
int Tree[4*2*10000-1];
void update(int L,int R,int qe,int idx,int value)
{

    if(value>Tree[0])
        Tree[0]=value;
    while(L<R)
    {
        if(qe<= mid(L,R))
        {
            idx=lc(idx);
            R=mid(L,R);
        }
        else
        {
            idx=rc(idx);
            L=mid(L,R)+1;
        }
        if(value>Tree[idx])
            Tree[idx]=value;

    }
    return ;
}
int Get(int L,int R,int idx,int q)
{
    if(q<L )
        return 0;
    if(R<=q)
        return Tree[idx];

    return max(Get(L,mid(L,R),lc(idx),q),Get(mid(L,R)+1,R,rc(idx),q));

}
bool cmp(pair < pair < int , int > , int > A,pair < pair < int , int > , int > B)
{
    return A.first.second< B.first.second;
}
int main()
{

        int N,i;
        scanf("%d",&N);
        pair < pair < int , int  > , int > P[N];
        vector < int > V;
        for(i=0;i<N;i++)
        {
            scanf("%d%d%d",&P[i].first.first,&P[i].first.second,&P[i].second);
            V.push_back(P[i].first.first);
            V.push_back(P[i].first.second);
        }
        sort(V.begin(),V.end());
        for(i=0;i<N;i++)
        {
            int &l=P[i].first.first,&r=P[i].first.second;
            l=lower_bound(V.begin(),V.end(),l)-V.begin();
            r=lower_bound(V.begin(),V.end(),r)-V.begin();
        }
        sort(P,P+N,cmp);
        int ans=0;
        memset(Tree,0,sizeof(Tree));
        for(i=0;i<N;i++)
        {
            int aux=Get(0,2*N-1,0,P[i].first.first)+P[i].second;
            if(aux>ans)
                ans=aux;
            update(0,2*N-1,P[i].first.second,0,ans);
        }
        printf("%d\n",ans);

    return 0;
}

Solution

  • high=arr[i].second;
    for(j=0;j<i;j++)//checking all previous mergable intervals //Note we will use DP[] of the mergable interval due to optimal substructure
    {
        if(arr[i].first.first>=arr[j].first.second)  
        high=std::max(high, dp[j]+arr[i].second);
    }
    dp[i]=high;
    

    This can be done in O(log n) with a segment tree.

    First of all, let's rewrite it a bit. The max you are taking is a bit complicated, because it takes the maximum of a sum involving both i and j. But i is constant in this part, so let's take it out.

    high=dp[0];
    for(j=1;j<i;j++)//checking all previous mergable intervals //Note we will use DP[] of the mergable interval due to optimal substructure
    {
        if(arr[i].first.first>=arr[j].first.second)  
        high=std::max(high, dp[j]);
    }
    dp[i]=high + arr[i].second;
    

    Great, now we have reduced the problem to determining the maximum in [0, i - 1] out of the values that satisfy your if condition.

    If we didn't have the if, it would be a simple application of segment trees.

    Now there are two choices.

    1. Deal with O(log V) query time and O(V) memory for the segment tree

    Where V is the maximum size of an interval's endpoint.

    You can build a segment tree to which you insert interval start points as you move your i. Then you query over the range of values. Something like this, where the segment tree is initialized to -infinity and of size O(V).

    Update(node, index, value):
      if node.associated_interval == [index, index]:
        node.max = value
        return
    
      if index in node.left.associated_interval:
        Update(node.left, index, value)
      else:
        Update(node.right, index, value)
    
      node.max = max(node.left.max, node.right.max)
    
    Query(node, left, right):
      if [left, right] does not intersect node.associated_interval:
        return -infinity
    
      if node.associated_interval included in [left, right]:
        return node.max
    
      return max(Query(node.left, left, right),
                 Query(node.right, left, right))
    
    [...]
    
    high=Query(tree, 0, arr[i].first.first)
    dp[i]=high + arr[i].second;
    Update(tree, arr[i].first.first, dp[i])
    

    2. Reducing to O(log n) query time and O(n) memory for the segment tree

    Since the number of intervals might be significantly less than their length, it's reasonable to think that we might be able to encode them better somehow, so that their length is also O(n). Indeed, we can.

    This involves normalizing your intervals in the range [1, 2*n]. Consider the following intervals

    8 100
    3 50
    90 92
    

    Let's plot them on a line. They'd look like this:

    3 8 50 90 92 100
    

    Now replace each of them with their index:

    1 2 3  4  5  6
    3 8 50 90 92 100
    

    And write your new intervals:

    2 6
    1 3
    4 5
    

    Note that they retain the properties of your initial intervals: the same ones overlap, the same ones are included in each other etc.

    This can be done with a sort. You can now apply the same segment tree algorithm, except you declare the segment tree for the size 2*n.