Search code examples
pythonnumpyoperator-overloadingnumpy-slicing

How to implement __getitem__ such that it can process the input args and then pass them into the underlying numpy array?


Let A be a simple python class that has one member self._mat which is a numpy array. The constructor of A gets an integer n and creates an inner n by n zeros numpy array and saves it as the private member self._mat.

How can I implement __getitem__ such that I can first preprocess the arguments passed to __getitem__ (for instance map certain strings, such as serialized IDs, into zero-based integer indices), and then pass them into the inner numpy array self._mat such that it will be handled the same way as if the user of the class simply passed the processed arguments straight into the inner numpy array?

An example (bad) implementation of A looks like this:

class A():
     def __init__(self, n: int=3):
         self._mat = np.zeros(shape=[n, n])
     def __getitem__(self, val):
         # This implementation is wrong and doesn't pass interpretation...
         val = _process_val(val)  # How do I process this correctly such that I look at the inner elements of each index in the POSSIBLE multi-index as well?
         return self._mat[val]

Example usages:

a = A(n=4)
print(a[0], a[0, 1], a[(0, 1)], a[:, 1:2])
# Lets say that the string 'deadbeef' is cast onto the integer 0
print(a['deadbeef'], a['deadbeef', 1], a[('deadbeef', 1)], a[:, 'deadbeef':2])

Solution

  • Basic idea can be solved with recursion - changing all the elements you want, while keeping same general structure of an argument.

    import numpy as np
    
    def process_arg(arg):
        if isinstance(arg, str):
            if arg == "deadbeef":
                return 0
        if isinstance(arg, tuple):
            return tuple(process_arg(a) for a in arg)
        else:
            return arg
    
    
    class A():
        def __init__(self, n: int=3):
            self._mat = np.zeros(shape=[n, n])
        def __getitem__(self, arg):
            arg = process_arg(arg)
            return self._mat.__getitem__(arg)
    
    a = A(n=4)
    print(a[0], a[0, 1], a[(0, 1)], a[:, 1:2])
    print(a['deadbeef'], a['deadbeef', 1], a[('deadbeef', 1)])
    

    Handling slices as in a[:, 'deadbeef':2] I'll leave as an exercise to you.