I define a custom loss function in Tensorflow 1.9.0 (can't upgrade due to project restrictions). I have the following variables, obtained after an eigenvalue decomposition:
# eigw.shape = (?, x)
# eigv.shape = (?, x, y)
Now, I want to calculate the argmax
of eigw
, such that
amax = tf.argmax(eigw, axis=1, output_type=tf.int32)
# amax.shape = (?,)
I want to index eigv
with the values given in amax
, such that
# result.shape = (?, y)
How do I achieve that? I tried accessing it directly but doing so I run into the issue of the shapes not having equal rank. Also, I tried using tf.while_loop
, but I'm new to tf, and thus I was not successful.
What other options do I have? How do I solve that problem most easily?
Thanks