Search code examples
c++numpystdmappybind11

insert dtype in std::map


I want to do a map that takes a pair of pybind11::dtype and int and maps it into an OpenCV format:

static std::map<std::pair<pybind11::dtype, int>, int> ocv_types;

So I inserted all combinations but there seems to be a problem when adding int32_t and float_t:

    ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::int32_t>() , 3), CV_32SC3));

    ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::float_t>() , 3), CV_32FC3));

When I do this, only the CV_32SC3 is really inserted , my guess that somewhere the program "thinks" that both elements are equal and therefore is not going to insert the second one.

How can I actually add these 2?

P.S. I did this check just to "prove" that the types are not equal:

    if(pybind11::dtype::of<std::int32_t>() == pybind11::dtype::of<std::float_t>())
    {
        std::cout << "std::int32_t == std::float_t" << std::endl;
    }
    else
    {
        std::cout << "std::int32_t != std::float_t" << std::endl;
    }

... And of course they are not.

EDIT

I added the < function for dtype and used it in the compare function for the map, but not all elements are present in the map:

int getVal(pybind11::dtype type)
{
    if(type.is(pybind11::dtype::of<std::uint8_t>()))
        return 1;
    if(type.is(pybind11::dtype::of<std::uint16_t>()))
        return 2;
    if(type.is(pybind11::dtype::of<std::int16_t>()))
        return 3;
    if(type.is(pybind11::dtype::of<std::int32_t>()))
        return 4;
    if(type.is(pybind11::dtype::of<std::float_t>()))
        return 5;
    if(type.is(pybind11::dtype::of<std::double_t>()))
        return 6;
}

inline bool operator <(const pybind11::dtype a, const pybind11::dtype b) //friend claim has to be here
{
    return getVal(a) < getVal(b);
}

auto comp = [](const std::pair<pybind11::dtype, int> a, const std::pair<pybind11::dtype, int> b)
{
    return a < b;
};
static std::map<std::pair<pybind11::dtype, int>, int, decltype(comp)> ocv_types(comp);

Solution

  • As you noted pybind11::dtype do not have any particular order. So IMO best approach is to use std::unordered_map and provide respective hashes. pybind11 already has some hash function, so it is needed to adopt it for std::hash.

    Here is test I've wrote (using Catch2) and it passes on my machine:

    main.cpp:

    #include "catch2/catch_all.hpp"
    #include <pybind11/embed.h>
    #include <pybind11/numpy.h>
    #include <unordered_map>
    
    template<>
    struct std::hash<pybind11::dtype>
    {
        size_t operator()(const pybind11::dtype &t) const
        {
            return pybind11::hash(t);
        }
    };
    
    template<>
    struct std::hash<std::pair<pybind11::dtype, int>>
    {
        size_t operator()(const std::pair<pybind11::dtype, int> &t) const
        {
            return std::hash<pybind11::dtype>{}(t.first) ^ static_cast<size_t>(t.second);
        }
    };
    
    
    TEST_CASE("map_with_dtype") {
        constexpr auto CV_32SC3 = 1;
        constexpr auto CV_32FC3 = 2;
    
        pybind11::scoped_interpreter guard{};
    
        std::unordered_map<std::pair<pybind11::dtype, int>, int> ocv_types;
        REQUIRE(ocv_types.empty());
    
        auto a = ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::int32_t>() , 3), CV_32SC3));
        REQUIRE(a.second);
    
        auto b = ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::float_t>() , 3), CV_32FC3));
        REQUIRE(b.second);
        CHECK(b.first->second == CV_32FC3);
    
        CHECK(ocv_types.size() == 2);
    }
    

    CMakeLists.txt:

    cmake_minimum_required(VERSION 3.16)
    
    # set the project name
    project(MapOfPyBind11)
    
    find_package(Catch2 REQUIRED)
    find_package(pybind11 REQUIRED)
    
    # add the executable
    add_executable(MapOfPyBind11Test main.cpp)
    target_link_libraries(MapOfPyBind11Test PRIVATE Catch2::Catch2 pybind11::module pybind11::embed)
    
    include(CTest)
    include(Catch)
    catch_discover_tests(MapOfPyBind11Test)