I have one-hot vector of shape 1 * n
v= [0.0, 1.0, 0.0] for n = 3
and a matrix of shape n * m * r as( m and r can be any number but the first dimention is as n)
m = [[[1,2,3,],[4,5,6]], [[5,6,7],[7,8,9]], [[2,4,7],[1,8,9]]]
I want to multiple a * b using a broadcasting mechanism such that only the sub-matrix corresponding to 1.0 element in vector v is kept in the multiplication of v * m and all other sub-matrices become zero ( because all other elements are zero in v) as:
prod = [[[0,0,0],[0,0,0]], [[5,6,7],[7,8,9]] , [[0,0,0],[0,0,0]]]