Questions tagged [flax]
40 questions
0
votes
0 answers
Convert flax model to Pytorch
I have several image classifiers in Flax. For one of the models I have saved the state and for the two others I have saved the parameters as a frozendict with .flax extension. My question is, how could I convert whole models to Pytorch and use these…

m0ss
- 334
- 2
- 4
- 17
0
votes
1 answer
Getting incorrect output from the flax model's init call
I am trying to create a simple neural network using flax, as shown below.
However, the params frozen dict I receive as the output to of model.init is empty instead of having the parameters of the neural network. Also the the type(predictions) is…

Bunny Rabbit
- 8,213
- 16
- 66
- 106
0
votes
0 answers
Flax Memory Consumption in Backwards Pass
recently I build my first model in Flax. The forward pass worked fine, but i experienced OOM errors during the backward pass.
Originally I had split my model into several small classes, each of which implemented as its own flax model inheriting from…

Simon P.
- 105
- 7
0
votes
0 answers
Flax implementation of padding_idx from torch.nn.embedding
I have been rewriting some of my pytorch models in jax/flax and came across the issue of converting torch.nn.Embedding to flax.linen.Embed.
There does not appear to be a direct translation for pytorch's padding_idx. The keyword essentially 0's the…
0
votes
0 answers
Vanishing parameters in MAML JAX (Meta Learning)
I am working on an implementation of MAML (see https://arxiv.org/pdf/1703.03400.pdf) in Jax.
When training on a distribution of simple linear regression tasks it seems to perform fine (takes a while to converge but ultimately works).
However when…

Sefton de Pledge
- 19
- 4
0
votes
0 answers
How to change the batch size for a neural network in JAX
Using flax to create a network:
def create_train_state(rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return…

MichaelMMeskhi
- 659
- 8
- 26
0
votes
1 answer
How can I initialize the hidden state (carry) of a (flax linen) GRUCell as a learnable parameter (e.g. using model.init)
I create a GRU model in Jax using Flax and I initialize the model parameters using model.init as follows:
import jax.numpy as np
from jax import random
import flax.linen as nn
from jax.nn import initializers
class RNN(nn.Module):
n_RNN_units:…

Jabby
- 43
- 1
- 7
0
votes
1 answer
I am trying to assign a JAX Tracer object to a NumPy array that requires concrete values - work around needed please
I am new to Jax.
I am implementing a variational autoencoder (VAE) using Jax and Flax. During training, I sample a latent code (from the distribution inferred by the encoder, which I implement using compositions of flax.linen.nn modules). Crucially,…

Jabby
- 43
- 1
- 7
0
votes
1 answer
Can you update parameters of a module from inside the nn.compact of that module? (self modifying networks)
I'm quite new to flax and I was wondering what the correct way is to get this behavior:
param = f.init(key,x)
new_param, y = f.apply(param,x)
Where f is a nn.module instance.
Where f might go through multiple operations to get new_param and that…

hal9000
- 222
- 2
- 12
-1
votes
1 answer
Why JAX throws an unfiltered stack trace?
I need to jit the train step but when I do I get this error
import jax_resnet
import jax
import jax.numpy as jnp
from flax import linen as nn
import tensorflow_datasets as tfds
from flax.training import train_state
import optax
import numpy as…

Christopher Rae
- 41
- 1
- 5