3

I am trying to multiply 3000 independent matrices and vectors on TPU in order to speed up calculations, but I have run into some issues. I am unable to get the final result and I am also interested if there is a better solution.

Here is the code:

import time

import numpy as np
import tensorflow as tf


n_dim = 100
num_matrices = 3000
a = np.random.random((num_matrices, n_dim, n_dim)).astype(np.float32)
b = np.random.random((num_matrices, n_dim)).astype(np.float32)

atf = tf.constant(a, dtype=tf.float32)
btf = tf.constant(b, dtype=tf.float32)

Here is the version on CPU:

result = []
tic = time.time()
for i in range(num_matrices):
  result.append(tf.linalg.matvec(atf[i, :, :], btf[i, :]))
toc = time.time()
print(f"Time simple tf elapsed {toc -tic}")

Time simple tf elapsed 0.92

Here is the version that I have tried on TPU (Google Colab)

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
# print("All devices: ", tf.config.list_logical_devices('TPU'))
strategy = tf.distribute.TPUStrategy(resolver)

@tf.function
def matmul_fn(x, y):
  def cond_loop(i, x, y, result):
    return tf.less(i, 3000)

  def body_loop(i, x, y, result):
    result = tf.linalg.matvec(x[i, :, :], y[i, :])
    return [tf.add(i, 1), x, y, result]

  i = tf.constant(0) 
  result = tf.constant(np.zeros(y[0,:].shape), dtype=tf.float32)

  final_result = tf.while_loop(cond_loop, body_loop, [i, x, y, result])
  return final_result

tic = time.time()
z = strategy.run(matmul_fn, args=(atf, btf))
toc = time.time()
print(f"First Time = {toc -tic}")

inference_time = 0.0
num_iterations = 100

for i in range(num_iterations): 
  tic = time.time()
  result = strategy.run(matmul_fn, args=(atf, btf))
  toc = time.time()
  inference_time += toc - tic
  
print(inference_time / num_iterations)

Here inference time is only 0.001 but I have the following problems:

  1. I can't get values of tensor from result. It returns the list of 4 values of type tensorflow.python.distribute.values.PerReplica I would need the last one.
  2. I would like to get results from all iteration from the while loop.
  3. I wonder if there is a more standard/elegant way how to do this.

Thanks for your suggestions!

UPDATE: I have read more on Distribute input on site https://www.tensorflow.org/tutorials/distribute/input and defined input as:

dataset_a = tf.data.Dataset.from_tensor_slices([a[i, :, :] for i in range(3000)]).batch(512)
dataset_b = tf.data.Dataset.from_tensor_slices([b[i, :] for i in range(3000)]).batch(512)
dist_dataset_a = strategy.experimental_distribute_dataset(dataset_a)
dist_dataset_b = strategy.experimental_distribute_dataset(dataset_b)

Then I tried to use TPUStrategy like this:

@tf.function
def multiplication(x, y):
  return tf.linalg.matvec(x, y)
result =[]

tic = time.time()
for (x, y) in zip(dist_dataset_a, dist_dataset_b):
  result.append(strategy.run(multiplication, args=(x,y)))
toc = time.time()

print(f"First time = {toc - tic}")

result =[]

tic = time.time()
for (x, y) in zip(dist_dataset_a, dist_dataset_b):
  result.append(strategy.run(multiplication, args=(x,y)))
toc = time.time()

print(f"Second time = {toc - tic}")

However, the inference is much slower cca 1.2s.

Wooster
  • 61
  • 4

1 Answers1

3

I believe I have solved it. I am posting a solution in case someone else will need it one day. The trick is that matmul can be used for batches of matrices and vectors as noted here How does tensorflow batch_matmul work?. However batch_matmul does not exist any more so it is enough to call matmul.

@tf.function
def multiply_fn(atf, btf, experimental_relax_shapes=True):
  return tf.matmul(atf, btf)

gtf = tf.expand_dims(btf, axis=-1)
tic = time.time()
result = strategy.run(multiply_fn, args=(atf, gtf))
toc = time.time()
print(f"{toc - tic}")

tic = time.time()
result = strategy.run(multiply_fn, args=(atf, gtf))
toc = time.time()

print(f"{toc - tic}")

This works fast and returns results that are easy to read. Unfortunately, it repeats computation on all workers so further optimizations should be possible. The alternative:

tic = time.time()
with tf.device('/TPU:0'):
  gtf = tf.linalg.matvec(atf, btf)
toc = time.time()
print(f"{toc - tic}")

seems to be slower for some reason cca 0.002s

Wooster
  • 61
  • 4