I'm trying to create a Gaussian HMM model in pyro to infer the parameters of a very simple Markov sequence. However, my model fails to infer the parameters and something wired happened during the training process. Using the same sequence, hmmlearn has successfully infer the true parameters.
Full code can be accessed in here:
https://colab.research.google.com/drive/1u_4J-dg9Y1CDLwByJ6FL4oMWMFUVnVNd#scrollTo=ZJ4PzdTUBgJi
My model is modified from the example in here:
https://github.com/pyro-ppl/pyro/blob/dev/examples/hmm.py
I manually created a first order Markov sequence where there are 3 states, the true means are [-10, 0, 10], sigmas are [1,2,1].
Here is my model
def model(observations, num_state):
assert not torch._C._get_tracing_state()
with poutine.mask(mask = True):
p_transition = pyro.sample("p_transition",
dist.Dirichlet((1 / num_state) * torch.ones(num_state, num_state)).to_event(1))
p_init = pyro.sample("p_init",
dist.Dirichlet((1 / num_state) * torch.ones(num_state)))
p_mu = pyro.param(name = "p_mu",
init_tensor = torch.randn(num_state),
constraint = constraints.real)
p_tau = pyro.param(name = "p_tau",
init_tensor = torch.ones(num_state),
constraint = constraints.positive)
current_state = pyro.sample("x_0",
dist.Categorical(p_init),
infer = {"enumerate" : "parallel"})
for t in pyro.markov(range(1, len(observations))):
current_state = pyro.sample("x_{}".format(t),
dist.Categorical(Vindex(p_transition)[current_state, :]),
infer = {"enumerate" : "parallel"})
pyro.sample("y_{}".format(t),
dist.Normal(Vindex(p_mu)[current_state], Vindex(p_tau)[current_state]),
obs = observations[t])
My model is compiled as
device = torch.device("cuda:0")
obs = torch.tensor(obs)
obs = obs.to(device)
torch.set_default_tensor_type("torch.cuda.FloatTensor")
guide = AutoDelta(poutine.block(model, expose_fn = lambda msg : msg["name"].startswith("p_")))
Elbo = Trace_ELBO
elbo = Elbo(max_plate_nesting = 1)
optim = Adam({"lr": 0.001})
svi = SVI(model, guide, optim, elbo)
As the training goes, the ELBO has decreased steadily as shown. However, the three means of the states converges.
I have tried to put the for loop of my model into a pyro.plate and switch pyro.param to pyro.sample and vice versa, but nothing worked for my model.