Search code examples
algorithmdynamic-programmingpartitioninggreedy

Can I partition the array in K sizes?


I am trying to implement the algorithm from this question: Need idea for solving this algorithm puzzle, but i am missing some edge case which is causing my code to go in infinite loop. I can fix it by doing some cosmetic change but it shows that i didn't understand the algorithm.

Can someone help me out, what i am missing?

#include <stdio.h>

#define max(a, b) (((a)>(b))?(a):(b));
int get_max(int *a, int i, int size)
{
    if (i >= size)
        return 0;
    return max(a[i], get_max(a, i+1, size));
}

int get_sum(int *a, int i, int size)
{
    if (i >= size)
        return 0;
    return a[i] + get_sum(a, i+1, size);
}

int get_partition(int *a, int size, int bound) {
    int running_sum = 0;
    int partitions = 0, i;

    for (i=0;i<size;i++) {
        if (a[i] + running_sum <= bound) {
            running_sum += a[i];
        } else {
            running_sum = 0;
            running_sum += a[i];
            partitions++;
        }
    }
    return partitions;
}

int foo(int *a, int size, int k)
{
    int lower = get_max(a, 0, size);
    int higher = get_sum(a, 0, size);
    int partition;

    while (lower < higher) {
        int bound = (lower + (higher))/2;

        partition = get_partition(a, size, bound);
        printf("partition %d bound %d lower %d higher %d\n", partition, bound, lower, higher);
        if (partition >= k) 
            lower = bound;
        else
            higher = bound;
    }
    return partition;
}

#define SIZE(a) sizeof(a)/sizeof(a[0])
int main(void) {
    int a[] = {2, 3, 4, 5, 6};
    printf("%d\n", foo(a, SIZE(a), 3));
    return 0;
}

Output:

partition 1 bound 13 lower 6 higher 20
partition 2 bound 9 lower 6 higher 13
partition 3 bound 7 lower 6 higher 9
partition 3 bound 8 lower 7 higher 9
partition 3 bound 8 lower 8 higher 9
...last line keeps repeating.

Solution

  • You have couple of mistakes:

    • during the binary search, your while test should be while (lower+1 < higher) { and not while (lower < higher) {. You are entering infinite loop when lower = 8, higher = 9. At this stage, your bound would be (lower+higher)/2=8and you would update lower = bound which would not change anything.
    • at the end of foo you should return higher (not partitions) since your binary search invariant is that for having bound <= lower you can partition the array in more than k parts and forbound >= higher you can partition it in k or less.
    • your calculation of get_partition is wrong. You don't take into the account the last partition group since you only update partitions when you overflow running_sum. After the for-cycle you should have the statement :

      if (running_sum > 0) 
          partitions++; 
      

    Putting it all together:

    #include <stdio.h>
    
    #define max(a, b) (((a)>(b))?(a):(b));
    int get_max(int *a, int i, int size)
    {
        if (i >= size)
            return 0;
        return max(a[i], get_max(a, i+1, size));
    }
    
    int get_sum(int *a, int i, int size)
    {
        if (i >= size)
            return 0;
        return a[i] + get_sum(a, i+1, size);
    }
    
    int get_partition(int *a, int size, int bound) {
        int running_sum = 0;
        int partitions = 0, i;
    
        for (i=0;i<size;i++) {
            if (a[i] + running_sum <= bound) {
                running_sum += a[i];
            } else {
                running_sum = 0;
                running_sum += a[i];
                partitions++;
            }
        }
        if (running_sum > 0)
            partitions++;
        return partitions;
    }
    
    int foo(int *a, int size, int k)
    {
        int lower = get_max(a, 0, size);
        int higher = get_sum(a, 0, size);
        int partition;
    
        while (lower+1 < higher) {
            int bound = (lower + (higher))/2;
    
            partition = get_partition(a, size, bound);
            printf("partition %d bound %d lower %d higher %d\n", partition, bound, lower, higher);
            if (partition > k)
                lower = bound;
            else
                higher = bound;
        }
        printf("partition %dlower %d higher %d\n", partition, lower, higher);
        return higher;
    }
    
    #define SIZE(a) sizeof(a)/sizeof(a[0])
    int main(void) {
        int a[] = {2, 3, 4, 5, 6};
        printf("%d\n", foo(a, SIZE(a), 3));
        return 0;
    }