2

In the examples here it mentions that one can subclass the class tf.keras.Model as follows:

class MyModel(tf.keras.Model):

    def __init__(self):
    super(MyModel, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

However, what happens if I want to have a variable number of layers and also variable type of layers? How do I store my layer objects in my class object?

From what I have understood the name that I give to the attributes (dense1, dense2) in the example above is significant because that will be used to refer to those layers and their variables when saving to a checkpoint, etc.? Is that correct?

My question is basically: How do I store my layers in my tf.keras.Model subclass if I don't know how many of them I have available? And then how do I save and restore the weights of those layers?

My first thought was to have lists of layer objects but then it is not obvious to me how those layer weights will be saved and restored since they will not correspond to distinct attribute names.

arabinelli
  • 1,006
  • 1
  • 8
  • 19
MattSt
  • 1,024
  • 2
  • 16
  • 35

1 Answers1

2

The short answer is: just do what you would do normally, Tensorflow takes care of the rest.

The answer is hidden in the docstring of the save_weights method for tf.keras.Model (emphasis added):

When saving in TensorFlow format, all objects referenced by the network are saved in the same format as tf.train.Checkpoint, including any Layer instances or Optimizer instances assigned to object attributes. For networks constructed from inputs and outputs using tf.keras.Model(inputs, outputs), Layer instances used by the network are tracked/saved automatically. For user-defined classes which inherit from tf.keras.Model, Layer instances must be assigned to object attributes, typically in the constructor.

The easiest way to accomplish your goal is to assign the layers to a Python object. In the following example, I'm using a dictionary to preserve the original names.

class MyModel(tf.keras.Model):

def __init__(self):
    super(MyModel, self).__init__()
    self.my_weight_dict = {}
    self.my_weight_dict["dense1"] = tf.keras.layers.Dense(6, activation=tf.nn.relu)
    self.my_weight_dict["dense2"] = tf.keras.layers.Dense(3, activation=tf.nn.softmax) # changed to fit the dataset

def call(self,inputs):
    x = self.my_weight_dict["dense1"](inputs)
    return self.my_weight_dict["dense2"](x)

This allows you to programmatically specify attributes that will change the property of your model - e.g. useful for automated hyperparameters tuning.

Here's a fully reproducible example that uses the class defined above:

import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize

# load the data and split it into train and test
iris_dataset = load_iris()
X = iris_dataset.data
y = iris_dataset.target
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3,stratify=y)

# normalize the features
X_train = normalize(X_train, axis=0,norm='max')
X_test = normalize(X_test, axis=0,norm='max')

# create, compile, and fit the model
model = MyModel()
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.05, momentum=0.9),
              loss="sparse_categorical_crossentropy",  #tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(X_train, y_train, epochs=50, verbose = 2, batch_size=128,
            validation_data = (X_test, y_test))

# just call the save_weights 
model.save_weights(filepath="path/to/your/weights/file")

# create a new model with the same structure
model_2 = MyModel()
model_2.load_weights("path/to/your/weights/file")
model_2.compile(optimizer=tf.keras.optimizers.SGD(lr=0.05, momentum=0.9),
              loss="sparse_categorical_crossentropy",  #tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model_2.evaluate(X_test,y_test)
arabinelli
  • 1,006
  • 1
  • 8
  • 19