3

I have a function written in pure numpy, where i compute some statistic a huge number of time, and it's too long. This function contains a triple loop, but i cannot find how to translate it to broadcasting.

As my actual data is hard to understand but random data will not make senss here, i gave you a sample of this data in the following code, and i drawed a graph of the expected output at the end.

import numpy as np
import seaborn as sns

# Setting parameters and data
z_init = np.array([[5.61293390e-01, 9.97100450e-01, 4.23530180e-01, 6.08808896e-01],
       [1.22563280e-01, 1.72015130e-01, 8.71145720e-01, 5.40745844e-01],
       [8.51194500e-02, 1.18289130e-01, 8.90346540e-01, 7.22859351e-01],
       [9.83241310e-01, 9.57282690e-01, 7.22347100e-02, 5.43527399e-02],
       [5.49211550e-01, 3.97858250e-01, 6.86380990e-01, 7.91494336e-01],
       [9.94878920e-01, 6.39160920e-01, 2.01045170e-01, 9.86840712e-01],
       [5.04337540e-01, 5.69995040e-01, 3.99087430e-01, 4.32140476e-01],
       [9.28230540e-01, 9.32143440e-01, 1.02748280e-01, 9.92666867e-01],
       [1.77513660e-01, 1.83466350e-01, 7.99027540e-01, 6.30800256e-01],
       [5.14663640e-01, 6.34361690e-01, 5.33889110e-01, 7.90899958e-01],
       [2.58006640e-01, 2.88319290e-01, 7.09604700e-01, 9.02145588e-01],
       [2.04811730e-01, 1.58717810e-01, 8.41421970e-01, 8.84068574e-01],
       [3.11875950e-01, 2.46353420e-01, 7.58289460e-01, 9.65660849e-01],
       [9.36622730e-01, 8.01263020e-01, 9.83931900e-02, 5.44281251e-01],
       [3.45077880e-01, 2.88884330e-01, 6.99352220e-01, 9.71301027e-01],
       [8.18323020e-01, 8.42968360e-01, 2.68607890e-01, 1.52418342e-01],
       [9.07517590e-01, 7.41841580e-01, 3.94466860e-01, 3.33215046e-01],
       [3.84328210e-01, 2.85716010e-01, 7.73390420e-01, 4.79702183e-01],
       [4.74390660e-01, 7.16142340e-01, 4.52819790e-01, 7.98200958e-01],
       [1.91063150e-01, 3.11021770e-01, 8.17990970e-01, 3.65566296e-02],
       [3.84454330e-01, 5.31087070e-01, 6.46125720e-01, 4.26028784e-01],
       [7.76256310e-01, 9.82152950e-01, 2.24877600e-01, 9.03596071e-01],
       [5.21782070e-01, 6.09994810e-01, 5.49719820e-01, 3.77052128e-01],
       [2.44654800e-01, 2.22705930e-01, 6.92217350e-01, 8.04524950e-01],
       [6.26568380e-01, 6.85160450e-01, 5.05651420e-01, 1.06857236e-01],
       [4.57300600e-02, 3.19478800e-02, 9.56385030e-01, 5.62236853e-01],
       [9.85249100e-02, 9.00401200e-02, 9.15879990e-01, 9.37933793e-01],
       [1.13619610e-01, 1.08790380e-01, 8.86476490e-01, 9.07453097e-01],
       [8.61160760e-01, 9.26073490e-01, 1.51885700e-02, 5.61689264e-01],
       [4.72355650e-01, 8.90940390e-01, 4.79858960e-01, 1.31270153e-01],
       [3.29944340e-01, 3.87579980e-01, 5.10804930e-01, 5.19551698e-01],
       [5.29880000e-03, 7.01797000e-03, 9.93625440e-01, 1.85747216e-01],
       [6.23029300e-01, 4.84573370e-01, 2.68151590e-01, 5.60921564e-03],
       [8.03967200e-01, 7.84008540e-01, 2.37276020e-01, 9.47798098e-01],
       [4.58525700e-01, 6.45049070e-01, 5.76664620e-01, 5.75709041e-01],
       [8.65658770e-01, 8.99023760e-01, 2.40967370e-01, 5.56589158e-01],
       [9.67638170e-01, 8.83972000e-01, 3.44920000e-04, 9.17016810e-01],
       [7.24319710e-01, 8.63911350e-01, 2.64988220e-01, 2.13078474e-01],
       [9.91532810e-01, 9.85368470e-01, 6.58391400e-02, 4.58927178e-04],
       [1.17465420e-01, 1.14261700e-01, 9.05365980e-01, 1.80863318e-01],
       [8.63220080e-01, 7.91506140e-01, 5.18878970e-01, 5.49666344e-01],
       [3.12470000e-02, 3.19607200e-02, 9.72981330e-01, 8.60758375e-01],
       [4.41205810e-01, 4.87292710e-01, 5.39217100e-01, 5.65980037e-01],
       [1.39334500e-02, 1.40076300e-02, 9.84327060e-01, 9.38894626e-01],
       [2.43659030e-01, 2.09662260e-01, 7.81243000e-01, 1.14951150e-01],
       [9.45841720e-01, 9.45075580e-01, 1.37627400e-02, 5.39213927e-01],
       [7.92013050e-01, 6.55037130e-01, 2.08627580e-01, 1.50823215e-01],
       [6.04095200e-02, 4.57398400e-02, 9.53590740e-01, 6.32755639e-01],
       [5.67334500e-01, 2.75674320e-01, 6.47657510e-01, 4.68491101e-01],
       [1.58600060e-01, 1.22128390e-01, 8.47935330e-01, 4.94281577e-01],
       [8.26576000e-03, 4.07989000e-03, 9.95207870e-01, 8.02447365e-02],
       [7.25499790e-01, 7.05574910e-01, 3.94566850e-01, 8.90077195e-01],
       [8.30398180e-01, 7.65006390e-01, 1.04508490e-01, 9.44908637e-01],
       [6.84425700e-01, 8.13177160e-01, 2.55194010e-01, 1.71608600e-01],
       [3.69045830e-01, 4.10810940e-01, 6.39276590e-01, 9.22700243e-01],
       [1.42119190e-01, 1.45086850e-01, 8.87464990e-01, 2.35533293e-01],
       [8.51399930e-01, 8.63513030e-01, 4.28106000e-02, 7.49796027e-01],
       [7.26388760e-01, 9.23435870e-01, 3.21152410e-01, 5.59389176e-01],
       [2.68165680e-01, 2.21699530e-01, 7.13336850e-01, 8.28847266e-01],
       [4.67212100e-02, 6.18397600e-02, 9.08459550e-01, 1.73109978e-01],
       [8.12353540e-01, 6.14787930e-01, 2.36200930e-01, 6.70979632e-01],
       [3.56200600e-01, 2.86300900e-01, 6.87996620e-01, 7.68872468e-01],
       [4.27617260e-01, 4.08906890e-01, 4.65987670e-01, 1.67199623e-01],
       [6.63373240e-01, 9.66214910e-01, 1.39582640e-01, 9.85382902e-01],
       [5.51993350e-01, 4.93202560e-01, 5.63663960e-01, 1.69990831e-01],
       [8.04742160e-01, 7.23388830e-01, 1.97937550e-01, 5.06756753e-01],
       [1.07240370e-01, 1.15115720e-01, 9.07925810e-01, 3.46134208e-01],
       [3.61709450e-01, 2.16649010e-01, 7.91721970e-01, 5.22621049e-01],
       [9.83195600e-01, 9.35189250e-01, 1.09384140e-01, 4.87989100e-01],
       [1.07405620e-01, 1.05033440e-01, 8.76795260e-01, 2.44237928e-01],
       [6.75897130e-01, 6.50329960e-01, 3.04297580e-01, 3.60810270e-01],
       [7.02020600e-02, 4.96392100e-02, 9.33498520e-01, 7.17513612e-01],
       [4.84155500e-01, 6.88098980e-01, 3.46669530e-01, 2.16784063e-01],
       [6.04164790e-01, 7.48494480e-01, 9.49017500e-02, 2.69127829e-03],
       [5.92501140e-01, 7.18188940e-01, 4.79787090e-01, 4.72203718e-01],
       [6.47244640e-01, 9.12962170e-01, 3.94908800e-02, 1.89967176e-02],
       [7.52063710e-01, 8.36582980e-01, 2.56381510e-01, 1.82552057e-01],
       [7.33809600e-01, 5.88942430e-01, 3.17564930e-01, 4.83186793e-02],
       [6.37782580e-01, 7.91589180e-01, 3.08634220e-01, 1.83951279e-01],
       [7.32009020e-01, 9.14051250e-01, 1.80915920e-01, 2.45163585e-01],
       [1.53493780e-01, 1.90967590e-01, 8.19005590e-01, 7.55056039e-01],
       [5.36161820e-01, 5.13641150e-01, 5.01637010e-01, 3.47079632e-01],
       [6.06637230e-01, 6.67565790e-01, 3.33999130e-01, 2.51786198e-01],
       [7.25650010e-01, 8.41152620e-01, 2.36374270e-01, 2.61322095e-01],
       [6.52008490e-01, 8.66015010e-01, 1.90032370e-01, 5.14531432e-01],
       [2.59336300e-02, 3.60464100e-02, 9.42735970e-01, 8.76251330e-01],
       [3.91414850e-01, 3.16164320e-01, 6.36344310e-01, 2.11938819e-01],
       [6.43722130e-01, 5.38235890e-01, 1.13523690e-01, 3.54529909e-01],
       [7.90799970e-01, 7.44277280e-01, 3.24458070e-01, 1.60302427e-01],
       [7.10510700e-02, 8.50407900e-02, 9.08863250e-01, 9.18056054e-02],
       [8.27656880e-01, 7.68024600e-01, 6.35402600e-02, 1.39203186e-01],
       [4.22585470e-01, 4.66851210e-01, 5.36839920e-01, 9.51087042e-02],
       [8.74929100e-02, 9.12235300e-02, 8.91159090e-01, 3.88725280e-02],
       [9.36443830e-01, 8.16299420e-01, 2.90021130e-01, 2.89175878e-01],
       [9.26354550e-01, 9.51074570e-01, 8.49412000e-03, 1.76092602e-01],
       [5.72510800e-02, 3.56584500e-02, 9.67113730e-01, 7.74680782e-01],
       [1.10646890e-01, 1.10709490e-01, 8.85004310e-01, 8.08970193e-01],
       [9.30983140e-01, 9.85525760e-01, 6.47764700e-02, 3.51535913e-01],
       [9.16176650e-01, 8.13787300e-01, 2.80715970e-01, 7.69428516e-01],
       [9.34773620e-01, 8.73895270e-01, 1.23538120e-01, 5.72796569e-01]])
z_init[:,3] = np.random.uniform(size=100)

d_init = z_init.shape[1]
n_init = z_init.shape[0]
N = 999
bp = np.array([0.5,0.2,0.55,0.7])

The actual loop that needs to be simplified goes as follows :

random_numbers = np.random.uniform(size=(N,n_init,d_init))
vals = np.zeros(shape=(N,d_init))
for index in range(2 ** d_init):
    binary_repr = np.array(list(np.binary_repr(index, width=d_init)), dtype="float64")
    min = bp * binary_repr
    max = bp ** (1 - binary_repr)
    lambda_l = np.prod(max - min)

    for d in np.arange(d_init):

        mask = np.arange(d_init) != d
        min_without_d = np.array(min)[mask]
        max_without_d = np.array(max)[mask]
        lambda_k = np.prod(max_without_d - min_without_d)

        z_rep = np.repeat(z_init[None,:,:],N,axis=0)
        z_rep[:,:,d] = random_numbers[:,:,d]

        f_l = np.mean(np.all(np.logical_and(z_rep >= min, z_rep < max), axis=2),axis=1)
        f_k = np.mean(np.all(np.logical_and(z_rep[:,:, mask] >= min_without_d, z_rep[:,:, mask] < max_without_d), axis=2),axis=1)

        vals[:,d] += 1 / 2 * f_k ** 2 / lambda_k + f_l ** 2 / lambda_l - 2 * f_k * f_l / lambda_k

plotting the result :

sns.distplot(vals[:,0],hist=False, rug=True)
sns.distplot(vals[:,1],hist=False, rug=True)
sns.distplot(vals[:,2],hist=False, rug=True)
sns.distplot(vals[:,3],hist=False, rug=True)

The problem is that the loop uses a nasty binary representation of 2^d cases, which gives me troubles to simplify. But maybe the second loop can still be vectorised ?

Thanks for the help :)

Edit : As the people make relevant comments, i edited the code to include them.

lrnv
  • 1,038
  • 8
  • 19
  • 2
    `min = bp ** (binary_repr) * 0 ** (1 - binary_repr)` this is very strange. `0 ** (1 - binary_repr)` is just two bitwise ways to do `not` nested together, and `not(not(binary_repr)` is just `binary_repr` `1**(binary_repr)` in the next line is just `np.ones_like(binary_repr)`. Is this code intentionally obfuscated? Or what the heck is it trying to do? – Daniel F Dec 19 '19 at 12:32
  • I used this bitwise notation of nested `not` to obtain, as a result, the minimum and maximum point of each hyperboxes when you split the [0,1]^d hyperboxe in one simple point `bp` into 2^d hyperboxes : The first one yield `min = 0,...,0` and `max = bp`, and the last one yield `min = bp` and `max = 1,...,1`. But in the middle, each iteration correspond to one of all the other 2^d -2 hyperboxes that result from splitting the unit cube in the simple breakpoint `bp`, with split lines parralel to the axis. Is it clear enough ? You are totaly right though, this might not be the easiest way to do it. – lrnv Dec 19 '19 at 12:40
  • You are right : the `1**(binary_repr)` can just be deleted. – lrnv Dec 19 '19 at 12:54
  • what is the maximum number of columns `d_init`? – Daniel F Dec 19 '19 at 12:57
  • Let's say 100. but it varies from dataset to dataset – lrnv Dec 19 '19 at 13:05
  • Couldn't you make a more minimal example? – norok2 Dec 19 '19 at 13:10
  • Unfortunately, i dont think i can.. Sorry. – lrnv Dec 19 '19 at 13:14
  • that's . . . a lot. `2**100` is ~ `10**30`. You won't fit anything in RAM that size so the outer loop will probably have to stay a loop. – Daniel F Dec 19 '19 at 13:26
  • And since you're calling `random.uniform` in the inner function which will be different for every outer loop, that can't be vectorized either. Would it be ok to make an `np.random.uniform(shape = z_init.shape)` array and replace columns from that? – Daniel F Dec 19 '19 at 13:46
  • Well actualy, the random generation should be **outside** the loop on `index`... it's corrected now. – lrnv Dec 19 '19 at 13:57
  • Why couldn't you just generate the input array as random numbers? – Mad Physicist Dec 19 '19 at 14:03
  • Because it would make the graphs at the end less interesting since all values in `vals` will be approximately the same value. But yeah, can be done at least on any `(n_init,d_init)`-shaped dataset whose values are inside [0,1] – lrnv Dec 19 '19 at 14:06
  • 1
    Isn't `min = bp ** binary_repr * binary_repr` the same as `min = bp * binary_repr` (`bp^1*1 = bp` and `bp^0*0 = 0`)? – 9mat Dec 19 '19 at 14:20
  • Yes it is. Thanks. – lrnv Dec 19 '19 at 14:24

0 Answers0