Search code examples
jax

What is the recommended way to do embeddings in jax?


So I mean something where you have a categorical feature $X$ (suppose you have turned it into ints already) and say you want to embed that in some dimension using the features $A$ where $A$ is arity x n_embed.

What is the usual way to do this? Is using a for loop and vmap correct? I do not want something like jax.nn, something more efficient like

https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding

For example consider high arity and low embedding dim.

Is it jnp.take as in the flax.linen implementation here? https://github.com/google/flax/blob/main/flax/linen/linear.py#L624


Solution

  • Indeed the typical way to do this in pure jax is with jnp.take. Given array A of embeddings of shape (num_embeddings, num_features) and categorical feature x of integers shaped (n,) then the following gives you the embedding lookup.

    jnp.take(A, x, axis=0)  # shape: (n, num_features)
    

    If using Flax then the recommended way would be to use the flax.linen.Embed module and would achieve the same effect:

    import flax.linen as nn
    
    class Model(nn.Module): 
      @nn.compact
      def __call__(self, x):
        emb = nn.Embed(num_embeddings, num_features)(x)  # shape