There is this one problem in some online judge that I have no clue on how to get accepted.
The problem goes like this first line contained two number
N (0 < N < 2^18)
M (0 < M < 2^20)
The second line contained N
numbers
ai (0 < ai < 2^40)
The question is how many X
are there that satisfied:
M = floor(X/a1) + floor(X/a2) + ... + floor(X/an)
My naive solution:
#include<bits/stdc++.h>
using namespace std;
long long n,m,i,j,haha,sum;
int main()
{
cin >> n >> m;
haha = 0;
long long ar[n+5];
for(i = 0; i < n; i++) cin >> ar[i];
sort(ar,ar+n);
for(i = ar[0]+1; i < m*ar[0]; i++){
sum = 0;
for (j = 0; j < n; j++) sum += i/ar[j];
if (sum == m) haha += 1;
else if (sum >= m) break;
}
cout << haha << endl;
}
Update1: My binary search solution (still didn't pass the time limit):
#include<bits/stdc++.h>
using namespace std;
long long n,m,i,l,r,mid,ans,tmp,cnt,haha;
long long ar[2621440];
long long func(long long x){
haha = 0;
for (i = 0; i < n; i++) haha += x/ar[i];
return haha;
}
int main()
{
cin >> n >> m;
for(i = 0; i < n; i++) cin >> ar[i];
sort(ar,ar+n);
l = ar[0];
r = ar[0]*m;
mid = (l+r)/2;
tmp = func(mid);
while (tmp != m){
mid = (l+r)/2;
tmp = func(mid);
if (l == r) break;
if (tmp < m) l = mid+1;
else if (tmp > m) r = mid-1;
else break;
}
ans = 0;
if (tmp == m) ans += 1;
cnt = mid;
while (func(cnt-1) == m){
ans += 1;
cnt -= 1;
}
cnt = mid;
while (func(cnt+1) == m){
ans += 1;
cnt += 1;
}
cout << ans << endl;
}
Got accepted (finally) using two binary search (each for lower bound, and upper bound) with this code:
#include<bits/stdc++.h>
using namespace std;
long long n,m,i,l,r,mid1,mid2,ans,tmp,cnt,haha,k;
long long ar[26214400];
long long func(long long x){
haha = 0;
for (k = 0; k < n; k++) haha += x/ar[k];
return haha;
}
int main()
{
cin >> n >> m;
for(i = 0; i < n; i++) cin >> ar[i];
sort(ar,ar+n);
l = ar[0];
r = ar[0]*m;
mid1 = (l+r)/2;
tmp = func(mid1);
while (l < r){
mid1 = (l+r)/2;
tmp = func(mid1);
if (tmp < m) l = mid1+1;
else if (tmp > m) r = mid1-1;
else r = mid1-1;
}
mid1 = l; //lower bound
l = ar[0];
r = ar[0]*m;
mid2 = (l+r)/2;
tmp = func(mid2);
while (l < r){
mid2 = (l+r)/2;
tmp = func(mid2);
if (tmp < m) l = mid2+1;
else if (tmp > m) r = mid2-1;
else l = mid2+1;
}
mid2 = r; //upper bound
while (mid1 <= mid2 and func(mid1) != m) mid1 += 1;
while (mid2 >= mid1 and func(mid2) != m) mid2 -= 1;
ans = mid2-mid1+1;
cout << ans << endl;
}