2

I've been implementing VAE and IWAE models on the caltech silhouettes dataset and am having an issue where the VAE outperforms IWAE by a modest margin (test LL ~120 for VAE, ~133 for IWAE!). I don't believe this should be the case, according to both theory and experiments produced here.

I'm hoping someone can find some issue in how I'm implementing that's causing this to be the case.

The network I'm using to approximate q and p is the same as that detailed in the appendix of the paper above. The calculation part of the model is below:

data_k_vec = data.repeat_interleave(K,0) # Generate K samples (in my case K=50 is producing this behavior)

mu, log_std = model.encode(data_k_vec)
z = model.reparameterize(mu, log_std) # z = mu + torch.exp(log_std)*epsilon (epsilon ~ N(0,1))
decoded = model.decode(z) # this is the sigmoid output of the model

log_prior_z = torch.sum(-0.5 * z ** 2, 1)-.5*z.shape[1]*T.log(torch.tensor(2*np.pi))
log_q_z = compute_log_probability_gaussian(z, mu, log_std) # Definitions below
log_p_x = compute_log_probability_bernoulli(decoded,data_k_vec) 

if model_type == 'iwae':
    log_w_matrix = (log_prior_z + log_p_x  - log_q_z).view(-1, K)
elif model_type =='vae':
    log_w_matrix = (log_prior_z + log_p_x  - log_q_z).view(-1, 1)*1/K

log_w_minus_max = log_w_matrix - torch.max(log_w_matrix, 1, keepdim=True)[0]
ws_matrix = torch.exp(log_w_minus_max)
ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

loss = -torch.sum(ws_sum_per_datapoint) # value of loss that gets returned to training function. loss.backward() will get called on this value

Here are the likelihood functions. I had to fuss with the bernoulli LL in order to not get nan during training

def compute_log_probability_gaussian(obs, mu, logstd, axis=1):
    return torch.sum(-0.5 * ((obs-mu) / torch.exp(logstd)) ** 2 - logstd, axis)-.5*obs.shape[1]*T.log(torch.tensor(2*np.pi)) 

def compute_log_probability_bernoulli(theta, obs, axis=1): # Add 1e-18 to avoid nan appearances in training
    return torch.sum(obs*torch.log(theta+1e-18) + (1-obs)*torch.log(1-theta+1e-18), axis)

In this code there's a "shortcut" being used in that the row-wise importance weights are being calculated in the model_type=='iwae' case for the K=50 samples in each row, while in the model_type=='vae' case the importance weights are being calculated for the single value left in each row, so that it just ends up calculating a weight of 1. Maybe this is the issue?

Any and all help is huge - I thought that addressing the nan issue would permanently get me out of the weeds but now I have this new problem.

EDIT: Should add that the training scheme is the same as that in the paper linked above. That is, for each of i=0....7 rounds train for 2**i epochs with a learning rate of 1e-4 * 10**(-i/7)

magikarp
  • 21
  • 2

1 Answers1

1

The K-sample importance weighted ELBO is

$$ \textrm{IW-ELBO}(x,K) = \log \sum_{k=1}^K \frac{p(x \vert z_k) p(z_k)}{q(z_k;x)}$$

For the IWAE there are K samples originating from each datapoint x, so you want to have the same latent statistics mu_z, Sigma_z obtained through the amortized inference network, but sample multiple z K times for each x.

So its computationally wasteful to compute the forward pass for data_k_vec = data.repeat_interleave(K,0), you should compute the forward pass once for each original datapoint, then repeat the statistics output by the inference network for sampling:

mu = torch.repeat_interleave(mu,K,0)
log_std = torch.repeat_interleave(log_std,K,0)

Then sample z_k. And now repeat your datapoints data_k_vec = data.repeat_interleave(K,0), and use the resulting tensor to efficiently evaluate the conditional p(x |z_k) for each importance sample z_k.

Note you may also want to use the logsumexp operation when calculating the IW-ELBO for numerical stability. I can't quite figure out what's going on with the log_w_matrix calculation in your post, but this is what I would do:

log_pz = ...
log_qzCx = ....
log_pxCz = ...

log_iw = log_pxCz + log_pz - log_qzCx
log_iw = log_iw.reshape(-1, K)
iwelbo = torch.logsumexp(log_iw, dim=1) - np.log(K)

EDIT: Actually after thinking about it a bit and using the score function identity, you can interpret the IWAE gradient as an importance weighted estimate of the standard single-sample gradient, so the method in the OP for calculation of the importance weights is equivalent (if a bit wasteful), provided you place a stop_gradient operator around the normalized importance weights, which you call w_norm. So I the main problem is the absence of this stop_gradient operator.

Eweler
  • 407
  • 5
  • 14