Search code examples
pythontensorflowkerasrecurrent-neural-network

Embed custom RNN cell with _init_ that takes more arguments (3 vs 1)


I am trying to create a model similar to the one proposed in this paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8738842

The custom cell code is available at: https://github.com/SungjoonPark/DenoisingRNN/blob/master/dgrud.py

However, I am not able to embed this custom cell into any RNN model and I am assuming it is because the init takes 3 arguments instead of the standard "num_units".

I tried following the example at https://keras.io/layers/recurrent/:

cell = MinimalRNNCell(32)

x = keras.Input((None, 5))

layer = RNN(cell)

y = layer(x)

but I get an error:

TypeError Traceback (most recent call last) in 2 x = keras.Input((None, 5)) 3 layer = RNN(cell) ----> 4 y = layer(x)

~/.local/lib/python3.5/site-packages/keras/layers/recurrent.py in call(self, inputs, initial_state, constants, **kwargs) 539 540 if initial_state is None and constants is None: --> 541 return super(RNN, self).call(inputs, **kwargs) 542 543 # If any of initial_state or constants are specified and are Keras

~/.local/lib/python3.5/site-packages/keras/engine/base_layer.py in call(self, inputs, **kwargs) 487 # Actually call the layer, 488 # collecting output(s), mask(s), and shape(s). --> 489 output = self.call(inputs, **kwargs) 490 output_mask = self.compute_mask(inputs, previous_mask) 491

~/.local/lib/python3.5/site-packages/keras/layers/recurrent.py in call(self, inputs, mask, training, initial_state, constants) 680 mask=mask, 681 unroll=self.unroll, --> 682 input_length=timesteps) 683 if self.stateful: 684 updates = []

~/.local/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in rnn(step_function, inputs, initial_states, go_backwards, mask, constants, unroll, input_length) 3101 constants=constants, 3102 unroll=unroll, -> 3103 input_length=input_length) 3104 reachable = tf_utils.get_reachable_from_inputs([learning_phase()], 3105 targets=[last_output])

~/.local/lib/python3.5/site-packages/tensorflow/python/keras/backend.py in rnn(step_function, inputs, initial_states, go_backwards, mask, constants, unroll, input_length, time_major, zero_output_for_mask) 3730 # the value is discarded. 3731 output_time_zero, _ = step_function( -> 3732 input_time_zero, tuple(initial_states) + tuple(constants)) 3733 output_ta = tuple( 3734 tensor_array_ops.TensorArray(

~/.local/lib/python3.5/site-packages/keras/layers/recurrent.py in step(inputs, states) 671 else: 672 def step(inputs, states): --> 673 return self.cell.call(inputs, states, **kwargs) 674 675 last_output, outputs, states = K.rnn(step,

TypeError: call() takes 2 positional arguments but 3 were given

Could you please help me figure out whether it is a init issue, a call issue or I need to define a custom layer for this custom cell?

I tried looking for answers all over the internet and I just can't get any clarity on how embedding a custom cell in a RNN model should be done.

Thank you in advance,

Sam


Solution

  • I was able to recreate your issue while I imported keras directly into the program. See below,

    %tensorflow_version 1.x
    import keras
    from keras import backend as K
    import tensorflow as tf
    from tensorflow.keras import layers
    from tensorflow.keras.layers import RNN
    
    class MinimalRNNCell(keras.layers.Layer):
    
        def __init__(self, units, **kwargs):
            self.units = units
            self.state_size = units
            super(MinimalRNNCell, self).__init__(**kwargs)
    
        def build(self, input_shape):
            self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                          initializer='uniform',
                                          name='kernel')
            self.recurrent_kernel = self.add_weight(
                shape=(self.units, self.units),
                initializer='uniform',
                name='recurrent_kernel')
            self.built = True
    
        def call(self, inputs, states):
            prev_output = states[0]
            h = K.dot(inputs, self.kernel)
            output = h + K.dot(prev_output, self.recurrent_kernel)
            return output, [output]
    
    # Let's use this cell in a RNN layer:
    
    cell = MinimalRNNCell(32)
    x = keras.Input((None, 5))
    layer = RNN(cell)
    y = layer(x)
    

    Output -

    TensorFlow is already loaded. Please restart the runtime to change versions.
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <ipython-input-3-0f3bed686a7d> in <module>()
         34 x = keras.Input((None, 5))
         35 layer = RNN(cell)
    ---> 36 y = layer(x)
    
    5 frames
    /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in symbolic_fn_wrapper(*args, **kwargs)
         73         if _SYMBOLIC_SCOPE.value:
         74             with get_graph().as_default():
    ---> 75                 return func(*args, **kwargs)
         76         else:
         77             return func(*args, **kwargs)
    
    TypeError: __call__() takes 2 positional arguments but 3 were given
    

    The error vanishes while you import keras from tensorflow import keras. The code runs successfully with tensorflow version 1.x and as well as 2.x. Modify your code as below -

    %tensorflow_version 2.x
    from keras import backend as K
    import tensorflow as tf
    from tensorflow.keras import layers
    from tensorflow import keras
    from tensorflow.keras.layers import RNN
    
    # First, let's define a RNN Cell, as a layer subclass.
    
    class MinimalRNNCell(keras.layers.Layer):
    
        def __init__(self, units, **kwargs):
            self.units = units
            self.state_size = units
            super(MinimalRNNCell, self).__init__(**kwargs)
    
        def build(self, input_shape):
            self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                          initializer='uniform',
                                          name='kernel')
            self.recurrent_kernel = self.add_weight(
                shape=(self.units, self.units),
                initializer='uniform',
                name='recurrent_kernel')
            self.built = True
    
        def call(self, inputs, states):
            prev_output = states[0]
            h = K.dot(inputs, self.kernel)
            output = h + K.dot(prev_output, self.recurrent_kernel)
            return output, [output]
    
    # Let's use this cell in a RNN layer:
    
    cell = MinimalRNNCell(32)
    x = keras.Input((None, 5))
    layer = RNN(cell)
    y = layer(x)
    
    print("I Ran Successfully")
    

    Output -

    I Ran Successfully
    

    Hope this answers your question. Happy Learning.