Search code examples
pythoncpython-c-api

Memory Leak in C-extension for Python


This is the first time I am writing a C-extension for python and you can see my ugly and probably super inefficient C++ implementation of a convolution. I have a problem with the memory management. Each time I call this function in python it consumes about 500MB of memory (for a batch of size 100x112x112x3 and a kernel of size 3x3x3x64) and doesn't free it afterwards. Do I have do take care about reference counting, even if this is not a class-method? Or do I have to free up memory manually somewhere in the code? Note that I excluded all the error checks for a better overview. Thanks.

PyObject* conv2d(PyObject*, PyObject* args)

{
    PyObject* data;
    PyObject* shape;
    PyObject* kernel;
    PyObject* k_shape;
    int stride;

    PyArg_ParseTuple(args, "OOOOi", &data, &shape, &kernel, &k_shape, &stride);

    Py_ssize_t dims = PyTuple_Size(shape);
    Py_ssize_t kernel_dims = PyTuple_Size(k_shape);

    int shape_c[3];
    int k_shape_c[4];

    for (int i = 0; i < kernel_dims; i++)
    {
        if (i < dims)
        {
            shape_c[i] = PyLong_AsLong(PyTuple_GetItem(shape, i));
        }
        k_shape_c[i] = PyLong_AsLong(PyTuple_GetItem(k_shape, i));
    }

    PyObject* data_item, kernel_item;
    PyObject* ret_array = PyList_New(0);
    double conv_val, channel_sum;

    for (int oc = 0; oc < k_shape_c[3]; oc++)
    {
        for (int row = 0; row < shape_c[0]; row += stride)
        {
            for (int col = 0; col < shape_c[1]; col += stride)
            {
                channel_sum = 0;
                for (int ic = 0; ic < shape_c[2]; ic++)
                {
                    conv_val = 0;
                    for (int k_row = 0; k_row < k_shape_c[0]; k_row++)
                    {
                        for (int k_col = 0; k_col < k_shape_c[1]; k_col++)
                        {
                            data_item = PyList_GetItem(data, row + k_row);
                            if (!data_item)
                            {
                                PyErr_Format(PyExc_IndexError, "Index out of bounds");
                                return NULL;
                            }
                            data_item = PyList_GetItem(data_item, col + k_col);
                            data_item = PyList_GetItem(data_item, ic);
                            kernel_item = PyList_GetItem(kernel, k_row);
                            kernel_item = PyList_GetItem(kernel_item, k_col);
                            kernel_item = PyList_GetItem(kernel_item, ic);
                            kernel_item = PyList_GetItem(kernel_item, oc);
                            conv_val += PyFloat_AsDouble(data_item) * PyFloat_AsDouble(kernel_item);
                        }
                    }
                    channel_sum += conv_val;
                }
                PyList_Append(ret_array, PyFloat_FromDouble(channel_sum));
            }
        }
    }
    return ret_array;
}

Solution

  • The leak comes from:

    PyList_Append(ret_array, PyFloat_FromDouble(channel_sum));
    

    PyFloat_FromDouble creates a new reference, PyList_Append takes shared ownership of the reference (it doesn't steal/consume the reference). When using PyList_Append and you want the list to take ownership of your own reference, you must explicitly release your reference after appending, e.g. (error checks omitted):

    PyObject *pychannel_sum = PyFloat_FromDouble(channel_sum);
    PyList_Append(ret_array, pychannel_sum);
    Py_DECREF(pychannel_sum);
    

    The alternative (and faster if suitable) solution is to preallocate the list to the correct size, and fill in the entries with PyList_SetItem/PyList_SET_ITEM, both of which steal a reference, rather than incrementing the reference count. In general, APIs that don't explicitly mention reference stealing won't, and you'll need to police your own reference counts.

    Note that memory-wise, individual PyFloats are quite a bit more expensive than C doubles (which they wrap); on a 64 bit system, each PyFloat in a list consumes 32 bytes (eight for the pointer in the list, 24 for the PyFloat itself), vs. eight for the raw C double.

    You may want to look into using Python's array module (creating an array of the correct size/type, using the buffer protocol to make a C level view of it, then filling in the buffer); the code will be a tad more complex, but the memory usage will drop by a factor of 4x. numpy types will provide the same advantage (and the result may be used more flexibly).