I disagree with strangepoop's answer, mostly in the idea that "If you understand how backprop works in those, you can desugar to understand backprop here."
einsum
is an elegant operation that is more fundamental than matmul
or any other tensor operation. Understanding backpropagation in matmul
only amounts to understanding a special case of einsum
and presents a very limited view.
In the case of a standard matmul
operation:
c = einsum("ij,jk->ik", a, b)
the gradient is of c
with respect to a
is computed in a very simple way:
dc/da = einsum("ik,jk->ij", np.ones_like(c), b).
What happened here is extremely simple: we flipped the operands and corresponding strings around. In the place of a
we put c
and and in the place of ik
we put ij
. That's it.
In the case of your operation:
C = tf.einsum('ijkm,ijkn>imn',A,B)
the gradient with respect to A
is just:
dc/da = tf.einsum('imn,ijkn>ijkm',np.ones_like(C),B)
The middle operand stayed the same, we just flipped the first and last operand and string.
So what's actually going on? It's a just a natural generalization of the normal multiplication operation to arbitrary tensors.
The same way in normal multiplication
e = a * b * c
and de/da = de/de * b * c
(where de/de
is just np.ones_like(e)
),
in einsum
it's the same thing, except np.ones_like(e)
is now a matrix of ones, instead of just being 1
and the *
operation is replaced with the specific einsum
string.
If you'd like to read more about it, great! I know exactly 0
resources that talk about it in this way. If you find some, please let me know :)