Search code examples
pythoncpython-3.xpython-c-apipython-extensions

Defining a Python enum in a C extension - am I doing this right?


I'm working on a Python C extension and I would like to expose a custom enum (as in: a class inheriting from enum.Enum) that would be entirely defined in C.

It turned out to not be a trivial task and the regular mechanism for inheritance using .tp_base doesn't work - most likely due to the Enum's meta class not being pulled in.

Basically I'm trying to do this:

import enum

class FooBar(enum.Enum):
    FOO = 1
    BAR = 2

in C.

After a lot of digging into cpython's internals, this is what I came up with, wrapped in an example buildable module:

#include <Python.h>

PyDoc_STRVAR(module_doc,
"C extension module defining a class inheriting from enum.Enum.");

static PyModuleDef module_def = {
    PyModuleDef_HEAD_INIT,
    .m_name = "pycenum",
    .m_doc = module_doc,
    .m_size = -1,
};

struct enum_descr {
    const char *name;
    long value;
};

static const struct enum_descr foobar_descr[] = {
    {
        .name = "FOO",
        .value = 1,
    },
    {
        .name  = "BAR",
        .value = 2,
    },
    { }
};

static PyObject *make_bases(PyObject *enum_mod)
{
    PyObject *enum_type, *bases;

    enum_type = PyObject_GetAttrString(enum_mod, "Enum");
    if (!enum_type)
        return NULL;

    bases = PyTuple_Pack(1, enum_type); /* Steals reference. */
    if (!bases)
        Py_DECREF(enum_type);

    return bases;
}

static PyObject *make_classdict(PyObject *enum_mod, PyObject *bases)
{
    PyObject *enum_meta_type, *classdict;

    enum_meta_type = PyObject_GetAttrString(enum_mod, "EnumMeta");
    if (!enum_meta_type)
        return NULL;

    classdict = PyObject_CallMethod(enum_meta_type, "__prepare__",
                                    "sO", "FooBarEnum", bases);
    Py_DECREF(enum_meta_type);
    return classdict;
}

static int fill_classdict(PyObject *classdict, PyObject *modname,
                          const struct enum_descr *descr)
{
    const struct enum_descr *entry;
    PyObject *key, *val;
    int ret;

    key = PyUnicode_FromString("__module__");
    if (!key)
        return -1;

    ret = PyObject_SetItem(classdict, key, modname);
    Py_DECREF(key);
    if (ret < 0)
        return -1;

    for (entry = descr; entry->name; entry++) {
        key = PyUnicode_FromString(entry->name);
        if (!key)
            return -1;

        val = PyLong_FromLong(entry->value);
        if (!val) {
            Py_DECREF(key);
            return -1;
        }

        ret = PyObject_SetItem(classdict, key, val);
        Py_DECREF(key);
        Py_DECREF(val);
        if (ret < 0)
            return -1;
    }

    return 0;
}

static PyObject *make_new_type(PyObject *classdict, PyObject *bases,
                               const char *enum_name)
{
    PyObject *name, *args, *new_type;
    int ret;

    name = PyUnicode_FromString(enum_name);
    if (!name)
        return NULL;

    args = PyTuple_Pack(3, name, bases, classdict);
    if (!args) {
        Py_DECREF(name);
        return NULL;
    }

    Py_INCREF(bases);
    Py_INCREF(classdict);
    /*
     * Reference to name was stolen by PyTuple_Pack(), no need to
     * increase it here.
     */

    new_type = PyObject_CallObject((PyObject *)&PyType_Type, args);
    Py_DECREF(args);
    if (!new_type)
        return NULL;

    ret = PyType_Ready((PyTypeObject *)new_type);
    if (ret < 0) {
        Py_DECREF(new_type);
        return NULL;
    }

    return new_type;
}

static PyObject *make_enum_type(PyObject *modname, const char *enum_name,
                                const struct enum_descr *descr)
{
    PyObject *enum_mod, *bases, *classdict, *new_type;
    int ret;

    enum_mod = PyImport_ImportModule("enum");
    if (!enum_mod)
        return NULL;

    bases = make_bases(enum_mod);
    if (!bases) {
        Py_DECREF(enum_mod);
        return NULL;
    }

    classdict = make_classdict(enum_mod, bases);
    if (!classdict) {
        Py_DECREF(bases);
        Py_DECREF(enum_mod);
        return NULL;
    }

    ret = fill_classdict(classdict, modname, descr);
    if (ret < 0) {
        Py_DECREF(bases);
        Py_DECREF(enum_mod);
        Py_DECREF(classdict);
        return NULL;
    }

    new_type = make_new_type(classdict, bases, enum_name);
    Py_DECREF(bases);
    Py_DECREF(enum_mod);
    Py_DECREF(classdict);
    return new_type;
}

PyMODINIT_FUNC PyInit_pycenum(void)
{
    PyObject *module, *modname, *sub_enum_type;
    int ret;

    module = PyModule_Create(&module_def);
    if (!module)
        return NULL;

    ret = PyModule_AddStringConstant(module, "__version__", "0.0.1");
    if (ret < 0) {
        Py_DECREF(module);
        return NULL;
    }

    modname = PyModule_GetNameObject(module);
    if (!modname) {
        Py_DECREF(module);
        return NULL;
    }

    sub_enum_type = make_enum_type(modname, "FooBar", foobar_descr);
    Py_DECREF(modname);
    if (!sub_enum_type) {
        Py_DECREF(module);
        return NULL;
    }

    ret = PyModule_AddObject(module, "FooBar", sub_enum_type);
    if (ret < 0) {
        Py_DECREF(sub_enum_type);
        Py_DECREF(module);
        return NULL;
    }

    return module;
}

Basically I'm calling the EnumMeta's __prepare__ method directly to create a correct classdict and then I call the PyType_Type object too to create the sub-type.

This works and AFAICT results in a class that behaves exactly as expected, but... am I doing this right? Any feedback is appreciated.


Solution

  • The metaclass in Enum is tricky yes.

    But you can see here that you can create an enum (in Python) like:

    FooBar = enum.Enum('FooBar', dict(FOO=1, BAR=2))
    

    So you can use this technique to easily create an enum class in Python C-API by doing something like:

    PyObject *key, *val, *name, *attrs, *args, *modname, *kwargs, *enum_type, *sub_enum_type;
    
    attrs = PyDict_New();
    key = PyUnicode_FromString("FOO");
    val = PyLong_FromLong(1);
    PyObject_SetItem(attrs, key, val);
    Py_DECREF(key);
    Py_DECREF(val);
    key = PyUnicode_FromString("BAR");
    val = PyLong_FromLong(2);
    PyObject_SetItem(attrs, key, val);
    Py_DECREF(key);
    Py_DECREF(val);
    name = PyUnicode_FromString("FooBar");
    args = PyTuple_Pack(3, name, attrs);
    Py_DECREF(attrs);
    Py_DECREF(name);
    
    // the module name might need to be passed as keyword argument
    PyDict_Type *kwargs = PyDict_New();
    key = PyUnicode_FromString("module");
    modname = PyModule_GetNameObject(module);
    PyObject_SetItem(kwargs, key, modname);
    Py_DECREF(key);
    Py_DECREF(modname);
    
    enum_type = PyObject_GetAttrString(enum_mod, "Enum");
    sub_enum_type = PyObject_Call(enum_type, args, kwargs)
    Py_DECREF(enum_type);
    Py_DECREF(args);
    Py_DECREF(kwargs);
    
    return sub_enum_type