Search code examples
python-3.xkerastensorflow2.0tf.keraskeras-layer

Custom Layer in TensorFlow/Keras


I'm new to tensorflow and I'm trying to build a custom layer that takes multiple inputs (namely, x,y,A) and return z. In this layer, w is the trainable parameter. my code is:

import tensorflow as tf
import tensorflow.keras as tfK
from keras import backend as K
import numpy as np

class MyLayer(tfK.layers.Layer):
    def __init__(self, x, y, A):
        super().__init__()

    def build(self, x.shape, y.shape, A.shape):
        self.w= self.add_weight(
            shape=(x.shape[-1]),
            initializer="random_normal",
            trainable=True,
        )

    def call(self, x, y, A):
        alpha = K.abs(tf.matmul(A,x) - y)/tf.matmul(A,(x_w))
        beta = K.abs(tf.matmul(A,x) - y)/tf.matmul(A,(x-w))
        LHS = tf.matmul(A,x)
        cond  = K.abs(LHS) < y
        lowerProj = (1-beta) * x + beta * w
        upperProj  = (1-alpha) * x + alpha * w
        z = tf.where(cond, upperProj, lowerProj)
        return z

When I run the above code (just want to check if it works) as

A = tf.convert_to_tensor(np.array([[1,2],[2,-1]]), dtype=tf.float32)
y = tf.constant(1, shape= (2,1), dtype=tf.float32)
x = tf.constant(0.5, shape= (2,1), dtype=tf.float32)

inputs = x,y,A
NN = MyLayer(x.shape[0],y.shape[0],A.shape)
output = NN(inputs)
print(output)

I get the following error: TypeError: MyLayer.build() missing 2 required positional arguments: 'y_shape', 'A_shape'

How do I build this layer correctly? I'd highly appreciate any feedback.


Solution

  • the build method signature takes a list, not multiple args, when the call method takes multiple args. Try this.

    def build(self, shapes):
      assert len(shapes) == 
      x_shape, y_shape, A_shape = *shapes 
      self.w= self.add_weight(shape=(x_shape[-1]),
                              initializer="random_normal",
                              trainable=True,
                              )
    

    Also, don't import Keras separately - just use the tf.keras.