The solution took a bit of time to understand the rv_continuous
. Cobbling together knowledge from a bunch of examples (I should have documented them--sorry) I think I got a working solution.
The only issue is that the domain needs to be known in advance, but I can work with that. If someone has ideas for how to fix that, please let me know.
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import scipy as sp
interp1d = sp.interpolate.interp1d
trapz = sp.integrate.trapz
# Time domain vector - needed in class
dt = 0.01
t_max = 10
T = np.arange(dt, t_max + dt, dt)
# Distance domain vector - needed in class
dd = 0.01
d_max = 30
D = np.arange(0, d_max + dd, dd)
class MultiplicativeModel(stats.rv_continuous):
def __init__(self, Tmodel, Vmodel, *args, **kwargs):
super().__init__(*args, **kwargs)
self.Tmodel = Tmodel # The time-domain probability function
self.Vmodel = Vmodel # The velocity-domain probability function
# Create vectors for interpolation of distributions
self.pdf_vec = np.array([trapz(self.Tmodel.pdf(T) * \
self.Vmodel.pdf(_ / T) / T, dx = dt) \
for _ in D])
self.cdf_vec = np.cumsum(self.pdf_vec) * dd
self.sf_vec = 1 - self.cdf_vec
# define key functions for rv_continuous class
self._pdf = interp1d(D, self.pdf_vec, assume_sorted=True)
self._sf = interp1d(D, self.sf_vec, assume_sorted=True)
self._cdf = interp1d(D, self.cdf_vec, assume_sorted=True)
# Extraolation option below is necessary because sometimes rvs picks
# a number really really close to 1 or 0 and this spits out an error if it
# is outside of the interpolation range.
self._ppf = interp1d(self.cdf_vec, D, assume_sorted=True,
fill_value = 'extrapolate')
# Moments
self._munp = lambda n, *args: np.trapz(self.pdf_vec * D ** n, dx=dd)
With the above defined, we get results like:
dv = 0.01
v_max = 10
V = np.arange(0, v_max + dv, dv)
model = MultiplicativeModel(stats.norm(3, 1),
stats.uniform(loc=2, scale = 2))
# test moments and stats functions
print(f'median: {model.median()}')
# median: 8.700970199181763
print(f'moments: {model.stats(moments = "mvsk")}')
#moments: (array(9.00872026), array(12.2315612), array(0.44131568), array(0.16819043))
plt.figure(figsize=(6,4))
plt.plot(T, model.Tmodel.pdf(T), label = 'Time PDF')
plt.plot(V, model.Vmodel.pdf(V), label = 'Velocity PDF')
plt.plot(D, model.pdf(D), label = 'Distance PDF')
plt.plot(D, model.cdf(D), label = 'Distance CDF')
plt.plot(D, model.sf(D), label = 'Distance SF')
x = model.rvs(size=10**5)
plt.hist(x, bins = 50, density = True, alpha = 0.5, label = 'Sampled distribution')
plt.legend()
plt.xlim([0,30])
