Search code examples
calgorithmoptimizationprimessieve-algorithm

Counting primes up to 18 digits optimization


I have a task at school to count number of primes up to 10^18 in under 2 minutes and with no more than 2 GB of memory to use. For the first try I've implemented a segmented sieve with the following optimizations:

  • used 32 bit integers to store the primes in a segmented as bits
  • didn't store the odd numbers
  • segmented the sieve into parts of size sqrt(n)
  • counting the composite primes when marking them so I don't have to loop the sieve again
  • used dynamic memory allocation to store the first primes up to sqrt(n) (in this case I've created a Queue in C to store the primes)

The problem is that for counting the primes up to 10^9 on my computer(which has pretty decent specs) it takes 13 seconds, therefore for 10^18 would take me some days.

My question is, is there some optimization that I am missing, or is there a better and faster way to count primes up to a number? The code:

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <string.h>

typedef   signed          char  int8_t;
typedef   signed     short int int16_t;
typedef   signed           int int32_t;
typedef   signed long long int int64_t;

typedef unsigned          char  uint8_t;
typedef unsigned     short int uint16_t;
typedef unsigned           int uint32_t;
typedef unsigned long long int uint64_t;

#define  SIZE 32 
#define DEBUG

#define  KRED "\x1B[31m"
#define  KGRN "\x1B[32m"
#define  KYEL "\x1B[33m"
#define  KBLU "\x1B[34m"
#define  KMAG "\x1B[35m"
#define  KCYN "\x1B[36m"
#define  KWHT "\x1B[37m"
#define RESET "\033[0m"

struct node {
    uint64_t     data;
    struct node* next;
};

struct queue {
    struct node* first;
    struct node* last;
    uint32_t     size; 
};

typedef struct node  Node;
typedef struct queue Queue;

/* Queue model */
uint8_t enqueue(Queue* queue, int64_t value) {
    Node* node = (Node*)malloc(sizeof(Node*));

    if (node == NULL)
        return 0;

    node->data = value;
    if (queue->last)
        queue->last->next = node;

    queue->last = node;
    if (queue->first == NULL)
        queue->first = queue->last;

    queue->size++;
    return 1;
}

uint64_t dequeue(Queue* queue) {
    Node*         node = queue->first;
    uint64_t save_data = node->data;

    if (queue->size == 0)
        return 0;

    queue->first = queue->first->next;
    queue->size--;
    free(node);

    return save_data;
}

Node* queue_peek(Queue* queue) {
    return queue->first;
}

uint32_t queue_size(Queue* queue) {
    return queue->size;
}

Queue* init_queue() {
    Queue* queue = (Queue*)malloc(sizeof(Queue*));

    queue->first = queue->last = NULL;
    queue->size  = 0;

    return queue;
}

/* Working with bit arrays functions */
uint8_t count_set_bits(uint64_t nbr) {
    uint8_t count = 0;

    while (nbr) {
        count++;
        nbr &= (nbr - 1);
    }

    return count;
}

uint8_t get_bit(uint32_t array[], uint32_t position) {
    const uint64_t mask = 1U << (position % SIZE);

    return array[position / SIZE] & mask ? 1 : 0;
}

void clear_bit(uint32_t array[], uint32_t position) {
    const uint64_t mask = ~(1U << (position % SIZE));

    array[position / SIZE] &= mask;
}

void set_bit(uint32_t array[], uint32_t position) {
    array[position / SIZE] |= (1U << (position % SIZE));
}

/* Solve the problem */
Queue* initial_sieve(uint64_t limit) {
    Queue*   queue   = init_queue();
    uint64_t _sqrt   = (uint64_t)sqrt(limit);
    uint32_t *primes = (uint32_t*)calloc(_sqrt / SIZE + 1, sizeof(uint32_t));

    set_bit(primes, 0);
    // working with reversed logic, otherwise primes should all me initialized to max uiint64_t

    enqueue(queue, 2);
    for (uint64_t number = 3; number <= _sqrt; number += 2) {
        if (get_bit(primes, number / 2) == 0) {
            enqueue(queue, number);

            for (uint64_t position = number * number; position <= _sqrt; position += (number * 2)) {    
                set_bit(primes, position / 2);
            }
        }
        else
            set_bit(primes, number / 2);
    }

    return queue;
}

uint64_t count_primes(uint64_t limit) {
    uint64_t start, end, delta;
    uint64_t non_primes_counter, initial_size;
    uint32_t *current_sieve;
    Queue* queue;

    queue = initial_sieve(limit);
    initial_size = queue->size;
    start = delta = (uint64_t)sqrt(limit);
    end   = 2 * start;
    non_primes_counter = 0;

    printf("Limits: %llu -> %llu\n", start, end);
    while (start < limit) {
        Node*    prime = queue->first->next; // pass 2 since only odd maps are represented in the sieve
        uint64_t count = 0;

        current_sieve = (uint32_t*)calloc(delta / SIZE + 1, sizeof(uint32_t));
        // memset(current_sieve, 0, sizeof(uint32_t) * delta);

        while (prime != NULL) {
            uint64_t first_composite = start / prime->data * prime->data;

            // calculate the first multiple of the given prime in the interval
            if (first_composite < start)
                first_composite += prime->data;
            if ((first_composite & 1) == 0)
                first_composite += prime->data;

            // set all the composites of the current prime in the given interval
            for (uint64_t number = first_composite; number <= end; number += (prime->data) * 2) {
                const uint64_t position = (number - start) / 2;

                if (get_bit(current_sieve, position) == 0) {
                    set_bit(current_sieve, position);
                    count++;
                }
            }

            // free(current_sieve);
            prime = prime->next;
        }

        non_primes_counter += count;
        start += delta;
        end   += delta;

        if (end > limit)
            end = limit;
    }

    uint64_t total = (limit - delta + 1) / 2 - non_primes_counter;

    printf("%sTotal composites and initial size: %llu %llu %s\n", KCYN, non_primes_counter, initial_size, RESET);
    printf("%sTotal primes: %llu %s\n", KCYN, total, RESET);
    return queue->size + (limit  - delta + 1) / 2 - non_primes_counter;
}

/* Main */
int main(int argc, char **argv) {
    clock_t begin, end;
    double  time;

    if (argc < 2) {
        printf("Invalid number of parameters\n");
        printf("Program will exit now.\n");
        return 0;
    }

    begin = clock();
    printf("%sNumber of primes found up to %s%s: %s%llu.\n%s", KWHT, KCYN, argv[1], KYEL, count_primes(atoll(argv[1])), RESET);
    end     = clock();
    time    = (double)(end - begin) / CLOCKS_PER_SEC;

    printf("%sTotal time elapsed since the starting of the program: %s%lf seconds.\n%s", KWHT, KYEL, time, RESET);  
    return 0;
}

Thanks, Marcus


Solution

  • You need to count the number of primes, not to find them all (there are too many of them). This called a Prime-counting function.

    In mathematics, the prime-counting function is the function counting the number of prime numbers less than or equal to some real number x. It is denoted by π(x).

    There are plenty of methods that calculate this function. Look at this Wolfram page with methods comparison. It seems it would be a hard task to accomplish this in two minutes.

    As was mentioned in the comments there is also a great answer at math.stackexchange, I think it would be helpful.