Questions tagged [flax]

40 questions
0
votes
0 answers

Convert flax model to Pytorch

I have several image classifiers in Flax. For one of the models I have saved the state and for the two others I have saved the parameters as a frozendict with .flax extension. My question is, how could I convert whole models to Pytorch and use these…
m0ss
  • 334
  • 2
  • 4
  • 17
0
votes
1 answer

Getting incorrect output from the flax model's init call

I am trying to create a simple neural network using flax, as shown below. However, the params frozen dict I receive as the output to of model.init is empty instead of having the parameters of the neural network. Also the the type(predictions) is…
Bunny Rabbit
  • 8,213
  • 16
  • 66
  • 106
0
votes
0 answers

Flax Memory Consumption in Backwards Pass

recently I build my first model in Flax. The forward pass worked fine, but i experienced OOM errors during the backward pass. Originally I had split my model into several small classes, each of which implemented as its own flax model inheriting from…
Simon P.
  • 105
  • 7
0
votes
0 answers

Flax implementation of padding_idx from torch.nn.embedding

I have been rewriting some of my pytorch models in jax/flax and came across the issue of converting torch.nn.Embedding to flax.linen.Embed. There does not appear to be a direct translation for pytorch's padding_idx. The keyword essentially 0's the…
0
votes
0 answers

Vanishing parameters in MAML JAX (Meta Learning)

I am working on an implementation of MAML (see https://arxiv.org/pdf/1703.03400.pdf) in Jax. When training on a distribution of simple linear regression tasks it seems to perform fine (takes a while to converge but ultimately works). However when…
0
votes
0 answers

How to change the batch size for a neural network in JAX

Using flax to create a network: def create_train_state(rng, learning_rate, momentum): """Creates initial `TrainState`.""" cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(learning_rate, momentum) return…
MichaelMMeskhi
  • 659
  • 8
  • 26
0
votes
1 answer

How can I initialize the hidden state (carry) of a (flax linen) GRUCell as a learnable parameter (e.g. using model.init)

I create a GRU model in Jax using Flax and I initialize the model parameters using model.init as follows: import jax.numpy as np from jax import random import flax.linen as nn from jax.nn import initializers class RNN(nn.Module): n_RNN_units:…
Jabby
  • 43
  • 1
  • 7
0
votes
1 answer

I am trying to assign a JAX Tracer object to a NumPy array that requires concrete values - work around needed please

I am new to Jax. I am implementing a variational autoencoder (VAE) using Jax and Flax. During training, I sample a latent code (from the distribution inferred by the encoder, which I implement using compositions of flax.linen.nn modules). Crucially,…
Jabby
  • 43
  • 1
  • 7
0
votes
1 answer

Can you update parameters of a module from inside the nn.compact of that module? (self modifying networks)

I'm quite new to flax and I was wondering what the correct way is to get this behavior: param = f.init(key,x) new_param, y = f.apply(param,x) Where f is a nn.module instance. Where f might go through multiple operations to get new_param and that…
hal9000
  • 222
  • 2
  • 12
-1
votes
1 answer

Why JAX throws an unfiltered stack trace?

I need to jit the train step but when I do I get this error import jax_resnet import jax import jax.numpy as jnp from flax import linen as nn import tensorflow_datasets as tfds from flax.training import train_state import optax import numpy as…
1 2
3