I am wondering if anyone here knows how to get FLAX LSTM layers to work in 2023. I have tried some of the code snippets on the actual Flax documentation, such as:
https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.scan.html
and, the first example provided there,
import flax.linen as nn
import jax
import jax.numpy as jnp
class LSTM(nn.Module):
features: int
@nn.compact
def __call__(self, x):
ScanLSTM = nn.scan(
nn.LSTMCell, variable_broadcast="params",
split_rngs={"params": False}, in_axes=1, out_axes=1)
lstm = ScanLSTM(self.features)
input_shape = x[:, 0].shape
carry = lstm.initialize_carry(jax.random.key(0), input_shape)
carry, x = lstm(carry, x)
return x
x = jnp.ones((4, 12, 7))
module = LSTM(features=32)
y, variables = module.init_with_output(jax.random.key(0), x)
throws an error. I have looked for other examples but it seems they have changed their API at some point in 2023, so what I could find online wasn't working anymore.
In short, what I am looking for is a simple example on how to pass a time series into an LSTM in FLAX.
Thank you for your help.
The snippet you provided runs correctly with the most recent version of flax (version 0.7.4). If you're using an older version of flax, you should change jax.random.key
to jax.random.PRNGKey
. For some information about this JAX PRNG key change, see JEP 9263: Typed Keys and Pluggable PRNGs.