Search code examples
pythoninheritanceoverwrite

Python method that returns instance of class or subclass while keeping subclass attributes


I'm writing a Python class A with a method square() that returns a new instance of that class with its first attribute squared. For example:

class A:
    def __init__(self, x):
        self.x = x

    def square(self):
        return self.__class__(self.x**2)

I would like to use this method in a subclass B so that it returns an instance of B with x squared but all additional attributes of B unchanged (i. e. taken from the instance). I can get it to work by overwriting square() like this:

class B(A):
    def __init__(self, x, y):
        super(B, self).__init__(x)
        self.y = y

    def square(self):
        return self.__class__(self.x**2, self.y)

If I don't overwrite the square() method, this little code example will fail because I need to pass a value for y in the constructor of B:

#test.py

class A:
    def __init__(self, x):
        self.x = x 

    def square(self):
        return self.__class__(self.x**2)

class B(A):
    def __init__(self, x, y):
        super(B, self).__init__(x)
        self.y = y 

    #def square(self):
    #    return self.__class__(self.x**2, self.y)

a = A(3)
a2 = a.square()
print(a2.x)
b = B(4, 5)
b2 = b.square()
print(b2.x, b2.y)
$ python test.py
9
Traceback (most recent call last):
  File "test.py", line 20, in <module>
    b2 = b.square()
  File "test.py", line 6, in square
    return self.__class__(self.x**2)
TypeError: __init__() takes exactly 3 arguments (2 given)

Overwriting the method once isn't a problem. But A potentially has multiple methods similar to square() and there might be more sub(sub)classes. If possible, I would like to avoid overwriting all those methods in all those subclasses.

So my question is this: Can I somehow implement the method square() in A so that it returns a new instance of the current subclass with x squared and all other attributes it needs for the constructor taken from self (kept constant)? Or do I have to go ahead and overwrite square() for each subclass?

Thanks in advance!


Solution

  • I'd suggest implementing .__copy__() (and possibly .__deepcopy__ as well) methods for both classes.

    Then your squared can be simple method:

    def squared(self):
        newObj = copy(self)
        newObj.x = self.x **2
        return newObj
    

    It will work with inheritance, assuming all child classes have correctly implemented __copy__ method.

    EDIT: fixed typo with call to copy()

    Full working example:

    #test.py
    
    from copy import copy
    
    
    class A:
        def __init__(self, x):
            self.x = x 
    
        def square(self):
            newObj = copy(self)
            newObj.x = self.x **2
            return newObj
    
        def __copy__(self):
            return A(self.x)
    
    class B(A):
        def __init__(self, x, y):
            super(B, self).__init__(x)
            self.y = y 
    
        def __copy__(self):
            return B(self.x, self.y)
    
    a = A(3)
    a2 = a.square()
    print(a2.x)
    b = B(4, 5)
    b2 = b.square()
    print(b2.x, b2.y)