As said in the title, I am trying to create a mixture of multivariate normal distributions using tensorflow probability package.
In my original project, am feeding the weights of the categorical, the loc and the variance from the output of a neural network. However when creating the graph, I get the following error:
components[0] batch shape must be compatible with cat shape and other component batch shapes
I recreated the same problem using placeholders:
import tensorflow as tf
import tensorflow_probability as tfp # dist= tfp.distributions
tf.compat.v1.disable_eager_execution()
sess = tf.compat.v1.InteractiveSession()
l1 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_1')
l2 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_2')
log_std = tf.compat.v1.get_variable('log_std', [1, 2], dtype=tf.float32,
initializer=tf.constant_initializer(1.0),
trainable=True)
mix = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,1], name='weights')
cat = tfp.distributions.Categorical(probs=[mix, 1.-mix])
components = [
tfp.distributions.MultivariateNormalDiag(loc=l1, scale_diag=tf.exp(log_std)),
tfp.distributions.MultivariateNormalDiag(loc=l2, scale_diag=tf.exp(log_std)),
]
bimix_gauss = tfp.distributions.Mixture(
cat=cat,
components=components)
So, my question is, what am I doing wrong? I looked into the error and it seems tensorshape_util.is_compatible_with
is what raises the error but I don't see why.
Thanks!