Search code examples
pythoncctypes

Using C code in Python returns strange values


I wrote a function to get all prime numbers under n in C, and imported it into Python with ctypes. The function however doesn't return what I expected. How do I know the size of the returned int* pointer in Python:

primes.c file: It returns a int* pointer of all prime numbers under n

int* primes(int n)
{
  int i, j;
  int *primes = NULL;
  int *prime_numbers = NULL;
  int num_primes = 0;
  
  if (n < 3) {
    printf("n must be >= 3, you input n: %d\n", n);
    return NULL;
  }

  primes = (int *)calloc(n, sizeof(int));

  if (primes == NULL) {
    printf("Memory allocation failed\n");
    return NULL;
  }

  for (i = 0; i < n; i++) {
    primes[i] = 1;
  }

  primes[0] = 0;
  primes[1] = 0;
  
  for (i = 2; i < n; i++) {
    if (primes[i]) {
      for (j = i*2; j < n; j += i) {
        primes[j] = 0;
      }
      num_primes++;
    }
  }

  j = 0;
  prime_numbers = (int *)calloc(num_primes, sizeof(int));

  if (prime_numbers == NULL) {
    printf("Memory allocation failed\n");
    return NULL;
  }

  for (i = 0; i < n; i++) {
    if (primes[i]) {
      prime_numbers[j] = i;
      j++;
    }
  }
  free(primes);
  return prime_numbers;
}

In Python:

import ctypes
from time import perf_counter

library = ctypes.CDLL('./primes.so')
library.primes.argtypes = [ctypes.c_int]
library.primes.restype = ctypes.POINTER(ctypes.c_int)
libc = ctypes.CDLL("libc.so.6")


# ---------------------------------------------------------------------------- #
# --- Primes ----------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #


def primes_c(n: int) -> list[int]:
    assert isinstance(n, int), "n must be an integer."
    assert (n >= 3), "n must be >= 3."
    primes: list[int] = library.primes(n)
    return primes

def main():
    n: int = 10
    print(f"Getting prime numbers under {n:_}.")

    print("C implementation:")
    start = perf_counter()
    prime_numbers_c = primes_c(n)
    end = perf_counter()
    print(f"took {end - start:.2f} seconds.")

    for i in prime_numbers_c:
    print(i)

    libc.free(prime_numbers_c)
    

if __name__ == '__main__':
    main()

My output looks like this, segfaulting

.
.
.
0
0
0
0
0
0
0
[1]    28416 segmentation fault (core dumped)  python3 primes.py

Solution

  • Listing [Python.Docs]: ctypes - A foreign function library for Python.

    There are a bunch of errors in the code:

    • Showstopper: SegFault (Access Violation). You return a pointer from C, but in Python you go beyond its boundaries (when creating the list out of it), accessing memory that's not "yours" (Undefined Behavior)

    • free:

      • Calling it without "preparation". Check [SO]: C function called from Python via ctypes returns incorrect value (@CristiFati's answer) for a common pitfall when working with CTypes (calling functions)

      • A subtle one: You're crossing .dll boundary. Allocating a pointer in one side (the .dll) and freeing it in another (the app) yields UB if the 2 are built with different C runtimes (C runtime is statically linked in one of them). When allocating memory in a .dll, also provide a function that deallocates it

    • Memory leak if the 2nd array allocation fails

    • Some minor ones (including using the name primes for 2 separate things)

    Now, In order to fix the (primary) error, you must tell the primes function caller how large the returned memory block is (for more details on this topic, check [SO]: Pickling a ctypes.Structure with ct.Pointer (@CristiFati's answer)). There is a number of ways to do that:

    • Add a (NULL) sentinel at the end of the returned block (as @interjay pointed)

    • Add an (output) argument to the function that will receive the number of primes

    • Encapsulate the pointer and its size in a structure. Although it requires writing more code, that's the one I prefer

    I prepared an example.

    • primes.c:

      #include <stdio.h>
      #include <stdlib.h>
      
      #if defined(_WIN32)
      #  define DLL00_EXPORT_API __declspec(dllexport)
      #else
      #  define DLL00_EXPORT_API
      #endif
      
      
      typedef unsigned int uint;
      
      typedef struct Buffer_ {
          uint count;
          uint *data;
      } Buffer, *PBuffer;
      
      #if defined(__cplusplus)
      extern "C" {
      #endif
      
      DLL00_EXPORT_API PBuffer primesEratosthenes(uint n);
      DLL00_EXPORT_API void freePtr(PBuffer buf);
      
      #if defined(__cplusplus)
      }
      #endif
      
      
      PBuffer primesEratosthenes(uint n)
      {
          uint i = 0, j = 0, count = 0;
          uint *sieve = NULL, *primes = NULL;
          PBuffer ret = NULL;
      
          if (n < 3) {
              printf("C - n must be >= 3, you input n: %u\n", n);
              return NULL;
          }
      
          sieve = malloc(n * sizeof(uint));
          if (sieve == NULL) {
              printf("C - Memory allocation failed 0\n");
              return NULL;
          }
      
          sieve[0] = 0;
          sieve[1] = 0;
      
          for (i = 2; i < n; ++i) {
              sieve[i] = 1;
          }
      
          for (i = 2; i < n; ++i) {
              if (sieve[i]) {
                  for (j = i * 2; j < n; j += i) {
                      sieve[j] = 0;
                  }
                  ++count;
              }
          }
      
          primes = malloc(count * sizeof(uint));
          if (primes == NULL) {
              printf("C - Memory allocation failed 1\n");
              free(sieve);
              return NULL;
          }
      
          ret = malloc(sizeof(Buffer));
          if (ret == NULL) {
              printf("C - Memory allocation failed 2\n");
              free(primes);
              free(sieve);
              return NULL;
          }
      
          for (i = 2, j = 0; i < n; ++i) {
              if (sieve[i]) {
                  primes[j] = i;
                  ++j;
              }
          }
      
          free(sieve);
          ret->count = count;
          ret->data = primes;
          return ret;
      }
      
      
      void freePtr(PBuffer buf)
      {
          if (buf == NULL) {
              return;
          }
          free(buf->data);
          free(buf);
      }
      
    • code00.py:

      #!/usr/bin/env python
      
      import ctypes as cts
      import sys
      from time import perf_counter
      
      
      DLL_NAME = "./primes.{:s}".format("dll" if sys.platform[:3].lower() == "win" else "so")
      
      UIntPtr = cts.POINTER(cts.c_uint)
      
      
      class Buffer(cts.Structure):
          _fields_ = (
              ("count", cts.c_uint),
              ("data", UIntPtr),
          )
      
      BufferPtr = cts.POINTER(Buffer)
      
      
      def main(*argv):
          dll = cts.CDLL(DLL_NAME)
          primes_eratosthenes = dll.primesEratosthenes
          primes_eratosthenes.argtypes = (cts.c_int,)
          primes_eratosthenes.restype = BufferPtr
          free_ptr = dll.freePtr
          free_ptr.argtypes = (BufferPtr,)
      
          ns = (
              100,
              50000000,
          )
      
          funcs = (
              primes_eratosthenes,
          )
      
          for n in ns:
              print(f"Searching for primes until {n:d}")
              for func in funcs:
                  start = perf_counter()
                  pbuf = func(n)
                  buf = pbuf.contents
                  if not buf:
                      print("NULL ptr")
                      continue
                  primes = buf.data[:buf.count]
                  print(f"\n  Function {func.__name__} took {perf_counter() - start:.3f} seconds")
                  print(f"  Found {buf.count:d} primes")
                  if buf.count <= 100:
                      print(end="  ")
                      for i in primes:
                          print(i, end=" ")
                      print("\n")
                  free_ptr(pbuf)
      
      
      if __name__ == "__main__":
          print("Python {:s} {:03d}bit on {:s}\n".format(" ".join(elem.strip() for elem in sys.version.split("\n")),
                                                         64 if sys.maxsize > 0x100000000 else 32, sys.platform))
          rc = main(*sys.argv[1:])
          print("\nDone.\n")
          sys.exit(rc)
      

    Output:

    (py_pc064_03.10_test0) [cfati@cfati-5510-0:/mnt/e/Work/Dev/StackExchange/StackOverflow/q076283276]> ~/sopr.sh
    ### Set shorter prompt to better fit when pasted in StackOverflow (or other) pages ###
    
    [064bit prompt]>
    [064bit prompt]> ls
    code00.py  primes.c
    [064bit prompt]>
    [064bit prompt]> gcc -fPIC -shared -o primes.so primes.c
    [064bit prompt]>
    [064bit prompt]> ls
    code00.py  primes.c  primes.so
    [064bit prompt]>
    [064bit prompt]> python ./code00.py
    Python 3.10.11 (main, Apr  5 2023, 14:15:10) [GCC 9.4.0] 064bit on linux
    
    Searching for primes until 100
    
      Function primesEratosthenes took 0.000 seconds
      Found 25 primes
      2 3 5 7 11 13 17 19 23 29 31 37 41 43 47 53 59 61 67 71 73 79 83 89 97
    
    Searching for primes until 50000000
    
      Function primesEratosthenes took 2.301 seconds
      Found 3001134 primes
    
    Done.