0

I'm working my way through this tutorial: An Introduction to Inference in Pyro

What I don't understand is the following. In order to get (|,=9.5) we can use the pyro.condition function with

def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    print(weight)
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

and conditioned_scale = pyro.condition(scale, data={"measurement": 9.5})

I wrote the following script:

    pyro.set_rng_seed(101)
    scale(0.3) # tensor(-1.0905)
    pyro.set_rng_seed(101)
    conditioned_scale(0.3) # tensor(-1.0905)

For both functions we get the same sample for the weight. Isn't this tutorial saying that with conditioned_scale we're getting a sample from a weight distribution that is conditioned on measurement=9.5? If so, shouldn't the samples of the weight be different, because in the first call we don't observe any data but in the second we condition on data?

Thanks!

f_3464gh
  • 162
  • 3
  • 11
  • Since this is more of a conceptual question, I think this is not really on-topic for this forum. Although stats.stackexchange.com is the place that's usually appropriate for conceptual questions, this is really pretty specific to Pyro, so you might not get much traction there. My advice is to ask on a Pyro-specific forum. Sorry I can't be more helpful. – Robert Dodier Apr 30 '21 at 17:48

1 Answers1

2

Running the model will not produce samples from the posterior; you'll need to run inference (like SVI or MCMC).

condition replaces the sample site value with the value you specify. Since you specify values for measurement, weight is unaffected. The model you've written is equivalent to N(measurement;N(weight;guess,1),.75) and by conditioning, you've stated measurement=9.5. conditioned_scale = pyro.condition(scale, data={"weight": 9.5}) and same key will produce different measurements. Below I've written the same program in NumPyro. You should check out https://forum.pyro.ai/.

import numpyro
import numpyro.distributions as dist


def scale(rng_key, guess):
    w_key, m_key = random.split(rng_key)

    weight = numpyro.sample("weight", dist.Normal(guess, 1.0), rng_key=w_key)
    print(weight)
    return numpyro.sample("measurement", dist.Normal(weight, 0.75), rng_key=m_key)


if __name__ == '__main__':
    rng_key = random.PRNGKey(0)
    print(scale(rng_key, 0.3))  # -0.49476373

    conditioned_scale = numpyro.handlers.condition(scale, data={"weight": 9.5})
    print(conditioned_scale(rng_key, 0.3))  # 8.561346