So I mean something where you have a categorical feature $X$ (suppose you have turned it into ints already) and say you want to embed that in some dimension using the features $A$ where $A$ is arity x n_embed.
What is the usual way to do this? Is using a for loop and vmap correct? I do not want something like jax.nn
, something more efficient like
https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
For example consider high arity and low embedding dim.
Is it jnp.take
as in the flax.linen implementation here? https://github.com/google/flax/blob/main/flax/linen/linear.py#L624