I find myself re-running a function over and over, just to test different parameters. The function i am using is the mmm.fit - found here. I am following their example with their data. The parameter i keep iterating over is called "custom priors". Everythin else remain constant
Let's say i want to test three scenarios:
1.
lag_weight_prior = numpyro.distributions.Beta( concentration1 = jnp.array([2,2,2]),
concentration0 = jnp.array([5,5,5]))
custom_priors = {'lag_weight': lag_weight_prior}
half_max_effective_concentration_prior = numpyro.distributions.Gamma( concentration = jnp.array([3,3,3]), rate = jnp.array([1,1,1])) custom_priors = {'half_max_effective_concentration': half_max_effective_concentration_prior}
lag_weight_prior = numpyro.distributions.Beta(
concentration1 = jnp.array([2,2,2]), concentration0 = jnp.array([5,5,1]))
slope_prior = numpyro.distributions.Gamma(
concentration = jnp.array([1,1,1]), rate = jnp.array([3,3,3]))
custom_priors = {'lag_weight': lag_weight_prior,
'slope': slope_prior}
then i use this function three times:
# Fit model.
mmm = lightweight_mmm.LightweightMMM()
mmm.fit(media=media_data,
extra_features=extra_features,
media_prior=costs,
target=target,
number_warmup=1000,
number_samples=1000,
number_chains=2,
custom_priors=custom_priors)
How can i loop this process? I was reading up on a map()
, would it be possible to do it here?
I was thinking of somethin like this:
# Define the list of custom priors to iterate over
custom_priors_list = [...]
# Define a function to run the mmm.fit function with a given set of custom priors
def run_mmm(custom_priors):
mmm = lightweight_mmm.LightweightMMM()
mmm.fit(media=media_data_train_scale,
extra_features=extra_features_train_scale,
media_prior=costs_train_scale,
target=target_train_scale,
number_warmup=2000,
number_samples=2000,
number_chains=2,
seasonality_frequency = 52,
media_names = media_names,
# weekday_seasonality = 7,
custom_priors = custom_priors,
seed = 8548654856)
return mmm.summary()
# Use the map function to iterate over the custom priors and run the mmm.fit function with each one
summaries = list(map(run_mmm, custom_priors_list))
# Print the summaries
for summary in summaries:
print(summary)
but i am struggeling with defining the custom_priors_list
and keep getting errors.