I want to compute the KL divergence DKL( ∏q(zk|x) ‖ p(z) ) in tensorflow
where ∏q(zk|x) is the product of N distributions (k=1 to k=N), not independent, each of type tensorflow_probability.layers.MultivariateNormalTriL. p(z) is a tfp.distributions.MultivariateNormalDiag. The KL divergence between a single q(z|x) and p(z) : DKL(q(z|x) ‖ p(z)) works.
I've tried to subclass the tfp Distributions class to create a product distribution Q(z) = ∏q(zk|x) and then compute DKL(Q(z) ‖ p(z)) but i had to implement the _kl_divergence function, it took me back to start.
I would like to solve this with tensorflow-probability or at least tensorflow.