I have several image classifiers in Flax. For one of the models I have saved the state and for the two others I have saved the parameters as a frozendict with .flax
extension. My question is, how could I convert whole models to Pytorch and use these weights to have the same identical model in Pytorch?
For example, one of the models is this:
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x, training = True):
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.Dropout(0.5, deterministic= not training)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
Another is a ResNet18.
Thanks.