Search code examples
pythondeep-learningmodulerecurrent-neural-network

Implementing RNN with flax.nn.Module


I am trying to implement a basic RNN cell with flax.nn.Module. the equations to implement the RNN cell are quite simple:

a_t = W * h_{t-1} + U * x_t + b

h_t = tanh(a_t)

o_t = V * h_t + c

where h_t is the updated state at time t, x_t is the input and o_t is the output and Tanh is our activation function.

My code uses flax.nn.Module,

class ElmanCell(nn.Module):
  @nn.compact
  def __call__(self, h, x):
    nextState = jnp.tanh(jnp.dot(W, h) * jnp.dot(U, x) + b)
    return nextState

I don't know hoe to implement the parameters W, U and b. Are they supposed to be attributes of nn.Module?


Solution

  • Try something like:

    class RNNCell(nn.Module):
      @nn.compact
      def __call__(self, state, x):
        # Wh @ h + Wx @ x + b can be efficiently computed
        # by concatenating the vectors and then having a single dense layer
        x = np.concatenate([state, x])
        new_state = np.tanh(nn.Dense(state.shape[0])(x))
        return new_state
    

    This way the parameters will be learned. See https://schmit.github.io/jax/2021/06/20/jax-language-model-rnn.html