I have a matrix (a Tensorflow variable) whose shape is determined dynamically. So I need to re-assign matrix while creating the graph. In python, we have options of resetting the graph or reassigning the variable by setting validate_shape
to False.
Code Snippet:
Embedding matrix in tensorflow graph:
embedding_matrix = tf.get_variable("EMB_MATRIX",
shape=[vocab_size, 300],
dtype=tf.float32,
initializer=tf.random_uniform_initializer(-0.1, 0.1, dtype=tf.float32),
trainable=False)
Re-assigning the embedding matrix with new shape:
word_emb_matrix = np.insert(word_emb_matrix, vocab_size, np.arange(300), axis=0) # Inserting one row into embedding matrix
session.run(tf.assign(mtrain.embedding_matrix, word_emb_matrix, name="EMB_MATRIX", validate_shape=False)) #Assign embedding matrix with new shape
Now, when we use saved model/graph in Java, how can same functionality be achieved?