Search code examples
python-3.xinheritancetypessuper

Initializing an object in a parent class with the type of the child


I wrote a parent class where I define a few functions (examples __mul__, __truediv__, etc.) that all of the child classes should have. These functions should keep the type of the child class after they have been executed.

Here some code to explain what I mean:

class Magnet():
    
    def __init__(self, length, strength):
        
        self.length = length
        self.strength = strength
        
        return
    
    def __mul__(self, other):

        if np.isscalar(other):
            return Magnet(self.length, self.strength * other)

        else:
            return NotImplemented
        
class Quadrupole(Magnet):
    
    def __init__(self, length, strength, name):
        
        super().__init__(length, strength)
        
        self.name = name
        
        return

Now if I do this:

Quad1 = Quadrupole(2, 10, 'Q1')
Quad2 = Quad1 * 2

Then Quad1 is of type '__main__.Quadrupole' and Quad2 of type '__main__.Magnet'.

I would like to know how to do this such that the type of the child is kept and it is not recast into a parent type. One solution would be to redefine these functions in the child classes and change

        if np.isscalar(other):
            return Magnet(self.length, self.strength * other)

to

        if np.isscalar(other):
            return Quadrupole(self.length, self.strength * other)

but the main reason to do inheritance was to not copy paste code. Maybe something like super() but downwards or maybe some placeholder for the class type...

I'm thankful for any help.

Adopted solution

Using

return type(self)(self.length, self.strength * other)

does the charm. It raises an error because I forgot to add the 'name' argument in the Magnet.__init__() (my original code does, but messed up when simplifying for the example).

I also found the same question here: Returning object of same subclass in __add__ operator


Solution

  • Solution 1

    You can get the type using type(self) and create a new object out of that.

    def __mul__(self, other):
        if np.isscalar(other):
            return type(self)(self.length, self.strength * other)
        raise NotImplemented
    

    (Also raised the NotImplemented instead of returning it.)

    Now using your code will result in:

        return type(self)(self.length, self.strength * other)
    TypeError: __init__() missing 1 required positional argument: 'name'
    

    Which would require Quadrupole to have a default argument for name.

    class Quadrupole(Magnet):
        def __init__(self, length, strength, name='unknown'):
            super().__init__(length, strength)
            self.name = name
    

    And you code is happy, but you might not be happy. The reason is that you now lost information about the name of the Quadrupole class.

    Solution 2

    You are returning a new instance of the class, sometimes this is not required and you can just mutate the old class. This would simplify your code to:

    def __mul__(self, other):
        if np.isscalar(other):
            self.strength *= other.strength
            return self
        raise NotImplemented
    

    This would mutate your old instance.

    Solution 3

    The main issue with solution 1 is that you lose information because you are creating a new class. Now a possible alternative would be to just copy that class. Unfortunately copying a class is not always that straightforward.

    Based on this SO question, probably using deepcopy would work in this case, but if you have a complex class structure you might have to implement __copy__ to get what you want.

    def __mul__(self, other):
        if np.isscalar(other):
            class_copy = deepcopy(self)
            class_copy.strength *= other
            return class_copy
        raise NotImplemented
    

    Where optionally you can provide a __copy__ method. For the provided code snippet this was not necessary, but in more complex cases it might be necessary.

    def __copy__(self):
        return Quadrupole(self.length, self.strength, self.name)