Search code examples
c++algorithmdata-structuresbinary-search

Getting Overflow and TLE issues in this Atcoder question


I am solving this question: D-Base n (https://atcoder.jp/contests/abc192/tasks/abc192_d)

It states that:

Given are a string X consisting of 0 through 9, and an integer M . Let d be the greatest digit in X . How many different integers not greater than M can be obtained by choosing an integer n not less than d + 1 and seeing X as a base- n number?

I solved this question iteratively (which gave me correct answers for small inputs) but it gave me TLE and overflow issues (see my submission at https://atcoder.jp/contests/abc192/submissions/20651499). The editorial said that I must use Binary Search and I have implemented it as follows. But i still dont get the correct outputs. Any advice is welcomed.

#include<iostream>
#include<vector>
#include<algorithm>
#include<cmath>
using namespace std;
int M;
bool getInBase(unsigned long long int x, int base){
    int i=0;
    unsigned long long int ans = 0;
    while(x>0){
        int temp = x%10;
        ans += pow(base,i)*temp;
        i++;
        x/=10;
        if(ans>M) return false;
    }
    return (ans<M);
}

int main(){
    string X;
    cin>>X;
    cin>>M;
    int mx = 0;
    unsigned long long orig =0;
    for(char x: X){
        int sum = (int) x - (int)'0';
        if(sum>mx) mx = sum;
        orig = orig*10 + ((int)x - (int) '0');
    }
    int ans=0;
    unsigned long long int l = mx+1, r =M+1;
    unsigned long long int mid;
    while(l<r){
        mid = (l+r)/2;
        if(getInBase(orig,mid)){
            l = mid +1;
        }
        else{
            r = mid-1;
        }
    }
    
    cout<<l-mx;

}

Solution

  • ll convert(string x, ll base, ll m){
        ll ans=0;
        ll p=0;
        ll cur=1;
        for(int i=x.size()-1;i>=0;i--){
            if(cur<0 || ans<0)
                return m+1;
            ans+=(ll)(x[i]-'0')*cur;
            cur*=base;
            if(ans>m)
                return m+1;
        }
        return ans;
    }
    void solve(){
        ll i, j, m;
        string x;
        cin>>x>>m;
        int n=x.size();
        ll mx=0;
        for(auto i:x){
            mx=max(mx, (ll)(i-'0'));
        }
        if(n==1){
            cout<<(((ll)stoi(x))<=m)<<endl;
            return;
        }
        ll k=(ll)ceil(pow(m*1.0, 1.0/(n-1)));
        ll l=mx+1;
        ll r=k+1;
        ll ans=0;
        while(l<=r){
            ll mid=l+(r-l)/2;
            ll c=convert(x, mid, m);
            //cout<<"l = "<<l<<" mid ="<<mid<<" r= "<<r<<" c= "<<c<<endl;
            if(c<=m){
                ans=max(ans, mid);
                l=mid+1;
            }else{
                r=mid-1;
            }
        }
        cout<<max(0LL, ans-mx)<<endl;
    }