1

I try to load a model within MirroredStategy. I find that the loaded model within MirroredStategy is not working correctly in that only one replica is found, while there are 4 visible devices specified actually. This does not happen for the model that is directly constructed within MirroredStategy.

It is worth mentioning that the subclassing tf.keras.models.Model and tf.keras.layers.Layer are used here, which I think may be the cause of this wrong behavior. I have confirmed that loading an saved tf.keras.Sequential model works well within MirroredStategy.

Reproducible code:

class Demo(tf.keras.models.Model):
    def __init__(self, **kwargs):
        super(Demo, self).__init__(**kwargs)
        
        self.test_layer = TestLayer()        
        self.dense_layer = tf.keras.layers.Dense(units=1, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros")
​
    def call(self, inputs):
        vector = self.test_layer(inputs)
        logit = self.dense_layer(vector)
        return logit, vector
​
    def summary(self):
        inputs = tf.keras.Input(shape=(10,), dtype=tf.int64)
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
​
@tf.function
def _step(inputs, labels, model):
    logit, vector = model(inputs)
    return logit, vector
​
def tf_dataset(keys, labels, batchsize, repeat):
    dataset = tf.data.Dataset.from_tensor_slices((keys, labels))
    dataset = dataset.repeat(repeat)
    dataset = dataset.batch(batchsize, drop_remainder=True)
    return dataset
​
def _dataset_fn(input_context):
    global_batch_size = 16384
    keys = np.ones((global_batch_size, 10))
    labels = np.random.randint(low=0, high=2, size=(global_batch_size, 1))
    replica_batch_size = input_context.get_per_replica_batch_size(global_batch_size)
    dataset = tf_dataset(keys, labels, batchsize=replica_batch_size, repeat=1)
    dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
    return dataset
​
# Save model within MirroredStrategy scope
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
with strategy.scope():
    model = Demo()
model.compile()
model.summary()
dataset = strategy.distribute_datasets_from_function(_dataset_fn)
for i, (key_tensors, replica_labels) in enumerate(dataset):
    print("-" * 30, "step ", str(i), "-" * 30)
    logit, vector = strategy.run(_step, args=(key_tensors, replica_labels, model))
# model(tf.keras.Input(shape=(10,), dtype=tf.int64))
model.save("demo")
​
# Load model within MirroredStrategy scope
with strategy.scope():
    model2 = tf.keras.models.load_model("demo")
dataset = strategy.distribute_datasets_from_function(_dataset_fn)
for i, (key_tensors, replica_labels) in enumerate(dataset):
    print("-" * 30, "step ", str(i), "-" * 30)
    logit, vector = strategy.run(_step, args=(key_tensors, replica_labels, model2))

Actual log:

------------------------------ step  0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)
2022-07-13 06:20:56.820402: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
------------------------------ step  0 ------------------------------
global_replica_id: 0

Expected log:

------------------------------ step  0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)
2022-07-13 06:20:56.820402: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
------------------------------ step  0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)

The log is from the line tf.print("global_replica_id: {}".format(global_replica_id)) within TestLayer.call.

0 Answers0