3

So I mean something where you have a categorical feature $X$ (suppose you have turned it into ints already) and say you want to embed that in some dimension using the features $A$ where $A$ is arity x n_embed.

What is the usual way to do this? Is using a for loop and vmap correct? I do not want something like jax.nn, something more efficient like

https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding

For example consider high arity and low embedding dim.

Is it jnp.take as in the flax.linen implementation here? https://github.com/google/flax/blob/main/flax/linen/linear.py#L624

mathtick
  • 6,487
  • 13
  • 56
  • 101
  • 1
    Can you clarify what you mean with `using a for loop and vmap` ? – Geoffrey Negiar Jul 08 '22 at 01:50
  • 1
    @GeoffreyNegiar I just meant instead of using jnp.take you would literally iterate over the indices. But I now think take is the correct way, it looks like that is what various libaries using jax are doing in their implementations. – mathtick Jul 08 '22 at 09:02

2 Answers2

5

Indeed the typical way to do this in pure jax is with jnp.take. Given array A of embeddings of shape (num_embeddings, num_features) and categorical feature x of integers shaped (n,) then the following gives you the embedding lookup.

jnp.take(A, x, axis=0)  # shape: (n, num_features)

If using Flax then the recommended way would be to use the flax.linen.Embed module and would achieve the same effect:

import flax.linen as nn

class Model(nn.Module): 
  @nn.compact
  def __call__(self, x):
    emb = nn.Embed(num_embeddings, num_features)(x)  # shape
Jon Deaton
  • 3,943
  • 6
  • 28
  • 41
2

Suppose that A is the embedding table and x is any shape of indices.

  1. A[x], which is like jnp.take(A, x, axis=0) but simpler.
  2. vmap-ed A[x], which parallelizes along axis 0 of x.
  3. nested vmap-ed A[x], which parallelizes along all axes of x.

Here are the source code for your reference.

import jax
import jax.numpy as jnp

embs = jnp.array([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]], dtype=jnp.float32)

x = jnp.array([[3, 1], [2, 0]], dtype=jnp.int32)

print("\ntake\n", jnp.take(embs, x, axis=0))
print("\nuse []\n", embs[x])
print(
    "\nvmap\n",
    jax.vmap(lambda embs, x: embs[x], in_axes=[None, 0], out_axes=0)(embs, x),
)

print(
    "\nnested vmap\n",
    jax.vmap(
        jax.vmap(lambda embs, x: embs[x], in_axes=[None, 0], out_axes=0),
        in_axes=[None, 0],
        out_axes=0,
    )(embs, x),
)

BTW, I learned the nested-vmap trick from the IREE GPT2 model code by James Bradbury.

cxwangyi
  • 653
  • 1
  • 8
  • 15