0

I want to compute the KL divergence DKL( ∏q(zk|x) ‖ p(z) ) in tensorflow


where ∏q(zk|x) is the product of N distributions (k=1 to k=N), not independent, each of type tensorflow_probability.layers.MultivariateNormalTriL. p(z) is a tfp.distributions.MultivariateNormalDiag. The KL divergence between a single q(z|x) and p(z) : DKL(q(z|x) ‖ p(z)) works.


I've tried to subclass the tfp Distributions class to create a product distribution Q(z) = ∏q(zk|x) and then compute DKL(Q(z) ‖ p(z)) but i had to implement the _kl_divergence function, it took me back to start.


I would like to solve this with tensorflow-probability or at least tensorflow.

JulienBr
  • 53
  • 7

1 Answers1

1

I don't think there is an analytic KL for this case. But since KL[q||p] is q(x)*log_q(x)-q(x)*log_p(x) you could use a monte carlo approximation for KL. In TFP this probably looks like:

q = QDist(...)
p = MVNDiag(...)
x = q.sample()  # x ~ q(x)
kl = q.log_prob(x) - p.log_prob(x)

or, for lower variance

x = q.sample(10)  # x ~ q(x)
kl = tf.math.reduce_mean(q.log_prob(x) - p.log_prob(x))
Brian Patton
  • 1,046
  • 5
  • 4