Questions tagged [flax]
40 questions
1
vote
0 answers
access a submodule of a flax class/module without calling model.apply()
I have module/class of this kind:
class autoencoder(nn.Module):
hidden_dim: int
z_dim: int
output_dim: int
def setup(self):
self.encoder = encoder(self.hidden_dim, self.z_dim)
self.decoder =…

Jabby
- 43
- 1
- 7
1
vote
0 answers
How to convert flax saved checkpoints to model?
https://github.com/google-research/scenic/tree/main/scenic/projects/mbt
I am trying to use the pretrained model present in the git, which is basically a Flax Checkpoint. I want to convert it back to the model. How can I do that?

Verma Sushant
- 45
- 7
1
vote
0 answers
How to build a Pytorch-like code in Jax Flax
I am trying to build a NN with a dropout layer in case to avoid overfitting. But I met some trouble when I wrote it in Jax Flax.
Here is the original model I built in Pytorch:
class MLPModel(nn.Module):
def __init__(self, layer, dp_rate=0.1):
…

Woody Wan
- 11
- 2
1
vote
2 answers
is there a way to trace grads through self.put_variable method in flax?
I would like to trace the grads through the self.put_variable. Is there anyway to make that possible? Or another way to update the param supplied to the module that is traced?
import jax
from jax import numpy as jnp
from jax import…

hal9000
- 222
- 2
- 12
1
vote
1 answer
Pickle changes type in jax
I have a flax struct dataclass containing a jax numpy array.
When I pickle dump this object and load it again, the array is not anymore a jax numpy array and is converted to a numpy array, here is the code to reproduce it:
import flax
import…

Valentin Macé
- 1,150
- 1
- 10
- 25
0
votes
0 answers
What is the difference between state.apply and model.apply in flax (if any)?
state.apply and model.apply appear to have the same signature. Given that a state is loaded from a previous checkpoint and the want is to do inference from that state, which method is applicable?

Cola
- 2,097
- 4
- 24
- 30
0
votes
0 answers
How to load checkpoints to flax state in jax 0.4.14 when the save path is changed?
I just updated the jax to 0.4.14 on the server of our lab. But after training by the same file that had been run in jax 0.3.25, I found that the checkpoint file cannot be loaded by my test file, no matter in 0.4.14 or 0.3.25. Moreover, the save path…

Gerry
- 1
0
votes
0 answers
Prefetching an iterator of 128-dim array to device
I'm having trouble using flax.jax_utils.prefetch_to_device for the simple function below. I'm loading the SIFT 1M dataset, and converting the array to jnp array.
I then want to prefetch the iterator of 128-dim arrays.
import tensorflow_datasets as…

jeffreyveon
- 13,400
- 18
- 79
- 129
0
votes
1 answer
Computing dot product of gradients with itself for a neural network model in JAX
I have the following piece of code of JAX with my neural network model -- model:
(loss, (inner_state, logits)), grad = jax.value_and_grad(
lambda m: forward_and_loss(m, true_gradient=True), has_aux=True)(model)
so grad is actually a type of…

abc
- 211
- 1
- 3
- 10
0
votes
0 answers
How to input negative prompt with FlaxStableDiffusionImg2ImgPipeline when using diffusers?
prompt = "masterpiece, best quality, 1girl, at dusk"
neg_prompt = "(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2)"
num_samples =…

Aero Wang
- 8,382
- 14
- 63
- 99
0
votes
0 answers
Data download issue with official Flax Image Net example
I am still trying to understand this official Flax Example. For the convenience of the experiment, I have created my own copy.
In the section on running locally, it seems that there is no download command. Therefore when I run python main.py…

RanWang
- 310
- 2
- 12
0
votes
1 answer
How to unroll the training loop so that Jax can train multiple steps in GPU/TPU
When using powerful hardware, especially TPU, it is often preferable to train multiple steps. For example, in TensorFlow, this is possible.
with strategy.scope():
model = create_model()
optimizer_inner = AdamW(weight_decay=1e-6)
…

RanWang
- 310
- 2
- 12
0
votes
0 answers
Loading FlaxHybridCLIP trained model
System Info
transformers version: 4.27.4
Platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.29
Python version: 3.8.10
Huggingface_hub version: 0.13.4
PyTorch version (GPU?): 1.9.0+cpu (False)
Tensorflow version (GPU?): 2.9.1 (True)
Flax version…

khaled
- 9
- 1
- 4
0
votes
0 answers
pip install gives ResolutionImpossible: the user requested flax 0.6.8 but t5x depends on flax 0.6.8
I am trying to a requirements file that depends on versions of the packages flax and t5x at specific commits.
The problem can be reproduced with the following command:
pip install "flax @…

BioGeek
- 21,897
- 23
- 83
- 145
0
votes
0 answers
Is there a way to convert custom checkpoint from Tensorflow to Jax/Flax/Optax?
I have been searching for this question for hours. Obviously, in Huggingface Transformers, if we don't do anything about the model, we can directly load it into jax/flax/optax. However, what if I want to train a TensorFlow model utilizing its TPU…

RanWang
- 310
- 2
- 12