The short answer is: just do what you would do normally, Tensorflow takes care of the rest.
The answer is hidden in the docstring of the save_weights
method for tf.keras.Model
(emphasis added):
When saving in TensorFlow format, all objects referenced by the network are
saved in the same format as tf.train.Checkpoint
, including any Layer
instances or Optimizer
instances assigned to object attributes. For
networks constructed from inputs and outputs using tf.keras.Model(inputs,
outputs)
, Layer
instances used by the network are tracked/saved
automatically. For user-defined classes which inherit from tf.keras.Model
,
Layer
instances must be assigned to object attributes, typically in the
constructor.
The easiest way to accomplish your goal is to assign the layers to a Python object. In the following example, I'm using a dictionary to preserve the original names.
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.my_weight_dict = {}
self.my_weight_dict["dense1"] = tf.keras.layers.Dense(6, activation=tf.nn.relu)
self.my_weight_dict["dense2"] = tf.keras.layers.Dense(3, activation=tf.nn.softmax) # changed to fit the dataset
def call(self,inputs):
x = self.my_weight_dict["dense1"](inputs)
return self.my_weight_dict["dense2"](x)
This allows you to programmatically specify attributes that will change the property of your model - e.g. useful for automated hyperparameters tuning.
Here's a fully reproducible example that uses the class defined above:
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
# load the data and split it into train and test
iris_dataset = load_iris()
X = iris_dataset.data
y = iris_dataset.target
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3,stratify=y)
# normalize the features
X_train = normalize(X_train, axis=0,norm='max')
X_test = normalize(X_test, axis=0,norm='max')
# create, compile, and fit the model
model = MyModel()
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.05, momentum=0.9),
loss="sparse_categorical_crossentropy", #tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(X_train, y_train, epochs=50, verbose = 2, batch_size=128,
validation_data = (X_test, y_test))
# just call the save_weights
model.save_weights(filepath="path/to/your/weights/file")
# create a new model with the same structure
model_2 = MyModel()
model_2.load_weights("path/to/your/weights/file")
model_2.compile(optimizer=tf.keras.optimizers.SGD(lr=0.05, momentum=0.9),
loss="sparse_categorical_crossentropy", #tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model_2.evaluate(X_test,y_test)