Questions tagged [flax]

40 questions
1
vote
0 answers

access a submodule of a flax class/module without calling model.apply()

I have module/class of this kind: class autoencoder(nn.Module): hidden_dim: int z_dim: int output_dim: int def setup(self): self.encoder = encoder(self.hidden_dim, self.z_dim) self.decoder =…
Jabby
  • 43
  • 1
  • 7
1
vote
0 answers

How to convert flax saved checkpoints to model?

https://github.com/google-research/scenic/tree/main/scenic/projects/mbt I am trying to use the pretrained model present in the git, which is basically a Flax Checkpoint. I want to convert it back to the model. How can I do that?
1
vote
0 answers

How to build a Pytorch-like code in Jax Flax

I am trying to build a NN with a dropout layer in case to avoid overfitting. But I met some trouble when I wrote it in Jax Flax. Here is the original model I built in Pytorch: class MLPModel(nn.Module): def __init__(self, layer, dp_rate=0.1): …
Woody Wan
  • 11
  • 2
1
vote
2 answers

is there a way to trace grads through self.put_variable method in flax?

I would like to trace the grads through the self.put_variable. Is there anyway to make that possible? Or another way to update the param supplied to the module that is traced? import jax from jax import numpy as jnp from jax import…
hal9000
  • 222
  • 2
  • 12
1
vote
1 answer

Pickle changes type in jax

I have a flax struct dataclass containing a jax numpy array. When I pickle dump this object and load it again, the array is not anymore a jax numpy array and is converted to a numpy array, here is the code to reproduce it: import flax import…
Valentin Macé
  • 1,150
  • 1
  • 10
  • 25
0
votes
0 answers

What is the difference between state.apply and model.apply in flax (if any)?

state.apply and model.apply appear to have the same signature. Given that a state is loaded from a previous checkpoint and the want is to do inference from that state, which method is applicable?
Cola
  • 2,097
  • 4
  • 24
  • 30
0
votes
0 answers

How to load checkpoints to flax state in jax 0.4.14 when the save path is changed?

I just updated the jax to 0.4.14 on the server of our lab. But after training by the same file that had been run in jax 0.3.25, I found that the checkpoint file cannot be loaded by my test file, no matter in 0.4.14 or 0.3.25. Moreover, the save path…
Gerry
  • 1
0
votes
0 answers

Prefetching an iterator of 128-dim array to device

I'm having trouble using flax.jax_utils.prefetch_to_device for the simple function below. I'm loading the SIFT 1M dataset, and converting the array to jnp array. I then want to prefetch the iterator of 128-dim arrays. import tensorflow_datasets as…
jeffreyveon
  • 13,400
  • 18
  • 79
  • 129
0
votes
1 answer

Computing dot product of gradients with itself for a neural network model in JAX

I have the following piece of code of JAX with my neural network model -- model: (loss, (inner_state, logits)), grad = jax.value_and_grad( lambda m: forward_and_loss(m, true_gradient=True), has_aux=True)(model) so grad is actually a type of…
abc
  • 211
  • 1
  • 3
  • 10
0
votes
0 answers

How to input negative prompt with FlaxStableDiffusionImg2ImgPipeline when using diffusers?

prompt = "masterpiece, best quality, 1girl, at dusk" neg_prompt = "(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2)" num_samples =…
Aero Wang
  • 8,382
  • 14
  • 63
  • 99
0
votes
0 answers

Data download issue with official Flax Image Net example

I am still trying to understand this official Flax Example. For the convenience of the experiment, I have created my own copy. In the section on running locally, it seems that there is no download command. Therefore when I run python main.py…
RanWang
  • 310
  • 2
  • 12
0
votes
1 answer

How to unroll the training loop so that Jax can train multiple steps in GPU/TPU

When using powerful hardware, especially TPU, it is often preferable to train multiple steps. For example, in TensorFlow, this is possible. with strategy.scope(): model = create_model() optimizer_inner = AdamW(weight_decay=1e-6) …
RanWang
  • 310
  • 2
  • 12
0
votes
0 answers

Loading FlaxHybridCLIP trained model

System Info transformers version: 4.27.4 Platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.29 Python version: 3.8.10 Huggingface_hub version: 0.13.4 PyTorch version (GPU?): 1.9.0+cpu (False) Tensorflow version (GPU?): 2.9.1 (True) Flax version…
khaled
  • 9
  • 1
  • 4
0
votes
0 answers

pip install gives ResolutionImpossible: the user requested flax 0.6.8 but t5x depends on flax 0.6.8

I am trying to a requirements file that depends on versions of the packages flax and t5x at specific commits. The problem can be reproduced with the following command: pip install "flax @…
BioGeek
  • 21,897
  • 23
  • 83
  • 145
0
votes
0 answers

Is there a way to convert custom checkpoint from Tensorflow to Jax/Flax/Optax?

I have been searching for this question for hours. Obviously, in Huggingface Transformers, if we don't do anything about the model, we can directly load it into jax/flax/optax. However, what if I want to train a TensorFlow model utilizing its TPU…
RanWang
  • 310
  • 2
  • 12