So I mean something where you have a categorical feature $X$ (suppose you have turned it into ints already) and say you want to embed that in some dimension using the features $A$ where $A$ is arity x n_embed.
What is the usual way to do this? Is using a for loop and vmap correct? I do not want something like jax.nn
, something more efficient like
https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
For example consider high arity and low embedding dim.
Is it jnp.take
as in the flax.linen implementation here? https://github.com/google/flax/blob/main/flax/linen/linear.py#L624
Indeed the typical way to do this in pure jax is with jnp.take
. Given array A
of embeddings of shape (num_embeddings, num_features)
and categorical feature x
of integers shaped (n,)
then the following gives you the embedding lookup.
jnp.take(A, x, axis=0) # shape: (n, num_features)
If using Flax then the recommended way would be to use the flax.linen.Embed
module and would achieve the same effect:
import flax.linen as nn
class Model(nn.Module):
@nn.compact
def __call__(self, x):
emb = nn.Embed(num_embeddings, num_features)(x) # shape