Search code examples
c++algorithmdivide-and-conquer

Divide and Conquer algorithm in C++


There is this one problem in some online judge that I have no clue on how to get accepted.

The problem goes like this first line contained two number

N (0 < N < 2^18) 
M (0 < M < 2^20)

The second line contained N numbers

ai (0 < ai < 2^40)

The question is how many X are there that satisfied:

M = floor(X/a1) + floor(X/a2) + ... + floor(X/an)

My naive solution:

#include<bits/stdc++.h>
using namespace std;

long long n,m,i,j,haha,sum;
int main()
{
    cin >> n >> m;
    haha = 0;
    long long ar[n+5];
    for(i = 0; i < n; i++) cin >> ar[i];
    sort(ar,ar+n);
    for(i = ar[0]+1; i < m*ar[0]; i++){
        sum = 0;
        for (j = 0; j < n; j++) sum += i/ar[j];
        if (sum == m) haha += 1;
        else if (sum >= m) break;
    }
    cout << haha << endl;
}

Update1: My binary search solution (still didn't pass the time limit):

#include<bits/stdc++.h>
using namespace std;

long long n,m,i,l,r,mid,ans,tmp,cnt,haha;
long long ar[2621440];
long long func(long long x){
    haha = 0;
    for (i = 0; i < n; i++) haha += x/ar[i];
    return haha;
}

int main()
{
    cin >> n >> m;
    for(i = 0; i < n; i++) cin >> ar[i];
    sort(ar,ar+n);
    l = ar[0];
    r = ar[0]*m;
    mid = (l+r)/2;
    tmp = func(mid);
    while (tmp != m){
        mid = (l+r)/2;
        tmp = func(mid);
        if (l == r) break;
        if (tmp < m) l = mid+1;
        else if (tmp > m) r = mid-1;
        else break;
    }
    ans = 0;
    if (tmp == m) ans += 1;
    cnt = mid;
    while (func(cnt-1) == m){
        ans += 1;
        cnt -= 1;
    }
    cnt = mid;
    while (func(cnt+1) == m){
        ans += 1;
        cnt += 1;
    }
    cout << ans << endl;
}

Solution

  • Got accepted (finally) using two binary search (each for lower bound, and upper bound) with this code:

    #include<bits/stdc++.h>
    using namespace std;
    
    long long n,m,i,l,r,mid1,mid2,ans,tmp,cnt,haha,k;
    long long ar[26214400];
    long long func(long long x){
        haha = 0;
        for (k = 0; k < n; k++) haha += x/ar[k];
        return haha;
    }
    
    int main()
    {
        cin >> n >> m;
        for(i = 0; i < n; i++) cin >> ar[i];
        sort(ar,ar+n);
        l = ar[0];
        r = ar[0]*m;
        mid1 = (l+r)/2;
        tmp = func(mid1);
        while (l < r){
            mid1 = (l+r)/2;
            tmp = func(mid1);
            if (tmp < m) l = mid1+1;
            else if (tmp > m) r = mid1-1;
            else r = mid1-1;
        }
        mid1 = l; //lower bound
        l = ar[0];
        r = ar[0]*m;
        mid2 = (l+r)/2;
        tmp = func(mid2);
        while (l < r){
            mid2 = (l+r)/2;
            tmp = func(mid2);
            if (tmp < m) l = mid2+1;
            else if (tmp > m) r = mid2-1;
            else l = mid2+1;
        }
        mid2 = r; //upper bound
        while (mid1 <= mid2 and func(mid1) != m) mid1 += 1;
        while (mid2 >= mid1 and func(mid2) != m) mid2 -= 1;
        ans = mid2-mid1+1;
        cout << ans << endl;
    }