I'm trying to implement Hierarchical Dirichlet Process (HDP) topic model using PyMC3. The HDP graphical model is shown below:
I came up with the following code:
import numpy as np
import scipy as sp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pymc3 as pm
from theano import tensor as tt
np.random.seed(0)
def stick_breaking(beta):
portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])
return beta * portion_remaining
def main():
#load data
data = np.array([[1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 0, 0]])
Wd = [len(doc) for doc in data]
#HDP parameters
T = 10 # top-level truncation
K = 2 # group-level truncation
V = 4 # number of words
D = 3 # number of documents
with pm.Model() as model:
#top-level stick breaking
gamma = pm.Gamma('gamma', 1., 1.)
beta_prime = pm.Beta('beta_prime', 1., gamma, shape=T)
beta = pm.Deterministic('beta', stick_breaking(beta_prime))
#group-level stick breaking
alpha = pm.Gamma('alpha', 1., 1.)
pi_prime = pm.Beta("pi_prime", 1, alpha, shape=K) #Sethuraman's stick breaking
#pi_prime = [pm.Beta("pi_prime_%s_%s" %(j,k), alpha*(beta[k]), alpha*(1-np.sum(beta[:k+1])), shape=1)
# for j in range(K) for k in range(T)] #Teh's stick breaking
pi = pm.Deterministic('pi', stick_breaking(pi_prime))
#top-level DP
H = pm.Dirichlet("H", a=np.ones(V), shape=V)
phi_top = pm.Multinomial('phi_top', n=np.sum(Wd), p=H, shape=(T,V))
G0 = pm.Mixture('G0', w=beta, comp_dists=phi_top)
#group-level DP
phi_group = [pm.Multinomial('phi_group_%s' %j, n=Wd[j], p=G0) for j in range(D)]
Gj = [pm.Mixture('G_%s' %j, w=pi, comp_dists=phi_group[j]) for j in range(D)]
#likelihood
w = [pm.Categorical("w_%s_%s" %(j,n), p = Gj[j], observed=data[j][n]) for j in range(D) for n in range(Wd[j])]
with model:
trace = pm.sample(2000, n_init=1000, random_seed=42)
pm.traceplot(trace)
plt.show()
if __name__ == '__main__':
main()
However, I'm currently getting an AssertionError
that prevents me from debugging the rest of the model, it occurs at the following line:
phi_top = pm.Multinomial('phi_top', n=np.sum(Wd), p=H, shape=(T,V))
There's no additional information about the error. Does anyone know how to resolve this?