0

I have several image classifiers in Flax. For one of the models I have saved the state and for the two others I have saved the parameters as a frozendict with .flax extension. My question is, how could I convert whole models to Pytorch and use these weights to have the same identical model in Pytorch?

For example, one of the models is this:

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x, training = True):
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.Dropout(0.5, deterministic= not training)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

Another is a ResNet18.

Thanks.

m0ss
  • 334
  • 2
  • 4
  • 17

0 Answers0