0

I have this simple bayesian neural network that gets stuck on .fit(x, y)

def get_model(input_shape, loss, optimizer, metrics, kl_weight, output_shape):
        
    inputs = Input(shape=(input_shape))
    x = BatchNormalization()(inputs)
    x = tfpl.DenseVariational(units=128, activation='tanh', make_posterior_fn=get_posterior, make_prior_fn=get_prior, kl_weight=kl_weight)(x)
    count = Dense(1)(x)
    logits = Dense(output_shape, activation = 'sigmoid')(x)
    neg_binom = tfp.layers.DistributionLambda(
            lambda t: tfd.NegativeBinomial(total_count=t[..., 0:1], logits = t[..., 1:]))
    cat = Concatenate(axis=-1)([count, logits])
    outputs = neg_binom(cat)
    model = Model(inputs, outputs)
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    return model

I do not get an error, it compiles and when I call model.fit(x,y) I just get:

Epoch 1/500

and it's stuck here forever (about 20 minutes I waited for the longest).

When I use a Poisson Layer, which I did before it starts fitting instantly, an epoch runs about 1s.

What could be the cause of this?

Many thanks for your insights and tips of things to try and debug this behaviour.

Olli
  • 906
  • 10
  • 25

0 Answers0