0

I have trained a CNN model using Keras with the TensorFlow backend. After training the model. I am trying to get a subset of the output of layer n. I can access the layer n output by using:

Model.layers[n].output 

which is

<tf.Tensor 'dense_2_1/Identity:0' shape=(None, 64) dtype=float32>

and I can get the subset continuous range of it by a command like this:

Model.layers[n].output[...,1:5]

Now, I am trying to subset the tensor considering only a few indexes out of the 64 (for instance 1,5,10)

Any Idea how can I do that?

Here is the code for the reference :

n                   = 15   
sub_indexes         = [1,5,10]
final_fmap_index    = 10
penultimate_output  = Model.layers[final_fmap_index].output
layer_input         = Model.input
loss                = Model.layers[n].output[...,sub_indexes]
grad_wrt_fmap       = K.gradients(loss,penultimate_output)[0]
grad_wrt_fmap_fn    = K.function([layer_input,K.learning_phase()],
                                      [penultimate_output,grad_wrt_fmap])

which gives me this error:

TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got [1, 5, 10]
A_sh
  • 21
  • 8
  • Does this answer your question? [What is the tensorflow equivalent of numpy tuple/array indexing?](https://stackoverflow.com/questions/66999579/what-is-the-tensorflow-equivalent-of-numpy-tuple-array-indexing) – Lescurel Apr 13 '21 at 06:52
  • @Lescurel Thank you for the link, I have followed the instruction there and figure out how to get the subset for my case. `gather_nd() ` do the job. I will post the solution based on that. – A_sh Apr 18 '21 at 21:42

1 Answers1

1

Using gather_nd() I could get the subset of the tensor. Basically to subset a tensor for some indexes [a,b,c] It needs to get in the format [[0,a],[1,b],[2,c]] and then use gather_nd() to get the subset.

indexes = [[0,a],[1,b],[2,c]]
subset  = gather_nd(MyTensor, indexes, 0)

more details on the function https://www.tensorflow.org/api_docs/python/tf/gather_nd

A_sh
  • 21
  • 8