Search code examples
kerascallbackmetaclasscustom-function

Use custom function with custom parameters in keras callback


I am training a model in keras and I want to plot graphs of results after each epoch. I know that keras callbacks provide "on_epoch_end" function that can be overloaded if one wants to do some computations after each epoch but my function takes some additional parameters which when given, crashes code by the meta class error. The detail is given below:

Here is how I am doing it right now, which is working fine:-

class NewCallback(Callback):

def on_epoch_end(self, epoch, logs={}):  #working fine, printing epoch after each epoch
    print("EPOCH IS: "+str(epoch))


epochs=5
batch_size = 16
model_saved=False
if model_saved:
    vae.load_weights(args.weights)
else:
    # train the autoencoder
    vae.fit(x_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(x_test, None),
           callbacks=[NewCallback()])

But I want my callback function like this:-

class NewCallback(Callback,models,data,batch_size):
   def on_epoch_end(self, epoch, logs={}):
     print("EPOCH IS: "+str(epoch))
     x=models.predict(data)
     plt.plot(x)
     plt.savefig(epoch+".png")

If I call it like this in fit:

callbacks=[NewCallback(models, data, batch_size=batch_size)]

I get this error:

TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases 

I am looking for a simpler solution to call my function or get this error of meta class resolved, any help will be much appreciated!


Solution

  • I think that what you would like to do is to define a class that descends from callback and takes models, data, etc... as constructor arguments. So:

    class NewCallback(Callback):
        """ NewCallback descends from Callback
        """
        def __init__(self, models, data, batch_size):
            """ Save params in constructor
            """
            self.models = models
    
        def on_epoch_end(self, epoch, logs={}):
            x = self.models.predict(self.data)