0

I find myself re-running a function over and over, just to test different parameters. The function i am using is the mmm.fit - found here. I am following their example with their data. The parameter i keep iterating over is called "custom priors". Everythin else remain constant

Let's say i want to test three scenarios:

1.

lag_weight_prior = numpyro.distributions.Beta( concentration1 = jnp.array([2,2,2]), 
    concentration0 = jnp.array([5,5,5])) 
    custom_priors = {'lag_weight': lag_weight_prior}
  1. half_max_effective_concentration_prior = numpyro.distributions.Gamma( concentration = jnp.array([3,3,3]), rate = jnp.array([1,1,1])) custom_priors = {'half_max_effective_concentration': half_max_effective_concentration_prior}

 lag_weight_prior = numpyro.distributions.Beta(
    concentration1 = jnp.array([2,2,2]), concentration0 = jnp.array([5,5,1]))
slope_prior = numpyro.distributions.Gamma(
    concentration = jnp.array([1,1,1]), rate = jnp.array([3,3,3]))
 custom_priors = {'lag_weight': lag_weight_prior, 
                 'slope': slope_prior}

then i use this function three times:

# Fit model.
mmm = lightweight_mmm.LightweightMMM()
mmm.fit(media=media_data,
        extra_features=extra_features,
        media_prior=costs,
        target=target,
        number_warmup=1000,
        number_samples=1000,
        number_chains=2,
        custom_priors=custom_priors)

How can i loop this process? I was reading up on a map(), would it be possible to do it here?

I was thinking of somethin like this:

# Define the list of custom priors to iterate over
custom_priors_list = [...]

# Define a function to run the mmm.fit function with a given set of custom priors
def run_mmm(custom_priors):
    mmm = lightweight_mmm.LightweightMMM()
    mmm.fit(media=media_data_train_scale,                           
        extra_features=extra_features_train_scale,              
        media_prior=costs_train_scale,
        target=target_train_scale,
        number_warmup=2000,
        number_samples=2000,
        number_chains=2,
        seasonality_frequency = 52,
        media_names = media_names,
        # weekday_seasonality = 7,
        custom_priors = custom_priors,
        seed = 8548654856)
    return mmm.summary()

# Use the map function to iterate over the custom priors and run the mmm.fit function with each one
summaries = list(map(run_mmm, custom_priors_list))

# Print the summaries
for summary in summaries:
    print(summary)

but i am struggeling with defining the custom_priors_list and keep getting errors.

Nneka
  • 1,764
  • 2
  • 15
  • 39
  • It might be useful using [`functools.partial`](https://docs.python.org/3/library/functools.html#functools.partial) to get a function that needs one only parameter, the one that changes. – Jorge Luis Apr 11 '23 at 11:06

0 Answers0