1

In tensorflow, how can I derive the number of rows and columns of tensor?

If there is a weight tensor as follows:

W1 = tf.get_variable("W1", shape=[784, 512], initializer=tf.contrib.layers.xavier_initializer())

I hope to know the number of rows and columns of W1 and save those in some parameters automatically.

for example, W1_rows <= the number of rows of W1 W1_columns <= the number of columns of W1

(I know that tf.shape or get_shape() can print the shape, but this method can not be used to save the shape in some parameters (like W1_rows, W1_columns)

Hyunseung Kim
  • 493
  • 1
  • 6
  • 17

1 Answers1

2
W1 = tf.get_variable("W1", shape=[784, 512],
                 initializer=tf.contrib.layers.xavier_initializer())
W1_rows, W1_columns = W1.get_shape()
print(W1_rows, W1_columns)

784 512

It works. Thank you, Oswald.

Hyunseung Kim
  • 493
  • 1
  • 6
  • 17