Search code examples
pythonkeraspython-class

Problem with defining derived class in python


I am learning how to use classes in python to alter some Keras methods to create various forms of Generative Adversarial Networks, GANs. In this case, I am trying to implement the gradient penalty modification to the Wasserstein GAN architecture based on an example from the Keras website: https://keras.io/examples/generative/wgan_gp/ Since I am new to classes and inheritance, I decided to play around with some simple examples that try to match the Keras example. I am confused on this segment of code:

class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

I tried making my own simple class and an inherited class following this example similar to the example from the W3schools website https://www.w3schools.com/python/python_inheritance.asp :

class Person:
    def __init__(self, fname, lname):
        self.firstname = fname
        self.lastname = lname
    def printname(self):
        print(self.firstname, self.lastname)

class Student(Person):
    def __init__(self, age, height):
        super(Student, self).__init__()
        self.age = age
        self.height = height

I test it with:

s = Student(1,2)

I get the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_2168/1913336311.py in <module>
----> 1 s = Student(1,2)

~\AppData\Local\Temp/ipykernel_2168/376196858.py in __init__(self, age, height)
      8 class Student(Person):
      9     def __init__(self, age, height):
---> 10         super(Student, self).__init__()
     11         self.age = age
     12         self.height = height

TypeError: __init__() missing 2 required positional arguments: 'fname' and 'lname'

How is it possible to have the empty() after "super(WGAN, self).__ init __ in the Keras code but it is not working in mine. I feel like I am taking the same approach. Thanks


Solution

  • super refers to the parent class. Parent class of Student is Person. Person takes 2 arguments, so __init__() does not work. If you call Person class, you have to provide 2 inputs.

    Parent class of WGAN is keras.Model, that does not require any arguments thus __init__() works.