0

I want to calculate the mean square displacement for several particles, defined as: enter image description here

where i is the index for the particle, Dt is the time interval, t is the time, and vec(x) is the position of the particles in two dimensions. We do an average for all possible times t.

I have managed to implement it with numpy. Note that pos is a np.array with three axis: (particles, time, coordinate).

import numpy as np
import matplotlib.pyplot as plt
import time

#Initialize data
np.random.seed(1)
nTime = 10**4
nParticles = 3
pos = np.zeros((nParticles, nTime, 2)) #Axis: particles, times, coordinates
for t in range(1, nTime):
    pos[:, t, :] = pos[:, t-1, :] + ( np.random.random((nParticles, 2)) - 0.5)

#MSD calculation
def MSD_direct(pos):
    Dt_r = np.arange(1, pos.shape[1]-1)
    MSD = np.empty((nParticles, len(Dt_r)))
    dMSD = np.empty((nParticles,len(Dt_r)))
    for k, Dt in enumerate(Dt_r):
        SD = np.sum((pos[:, Dt:,:] - pos[:, 0:-Dt,:])**2, axis = -1)
        MSD[:,k] = np.mean( SD , axis = 1)
        dMSD[:,k] = np.std( SD, axis = 1 ) / np.sqrt(SD.shape[1])

    return Dt_r, MSD, dMSD

start_time = time.time()
Dt_r, MSD_d, dMSD_d = MSD_direct(pos)
print("MSD_direct -- Time: %s s\n" % (time.time() - start_time))

#Plots
plt.figure()
for i in range(nParticles):
    plt.plot(pos[i,:,0])    
plt.xlabel('t')
plt.ylabel('x')
plt.savefig('pos_x.png', dpi = 300)

plt.figure()
for i in range(nParticles):
    plt.plot(pos[i,:,1])    
plt.xlabel('t')
plt.ylabel('y')
plt.savefig('pos_y.png', dpi = 300)

plt.figure()
for i in range(nParticles):
    plt.fill_between(Dt_r, MSD_d[i,:]+dMSD_d[i,:], MSD_d[i,:] - dMSD_d[i,:], alpha = 0.5)
    plt.plot(Dt_r, MSD_d[i,:])
plt.xlabel('Dt')
plt.ylabel('MSD')
plt.savefig('MSD.png', dpi = 300)

Output:

MSD_direct -- Time: 7.793087720870972 s

enter image description here enter image description here enter image description here

However, I would like to optimize this code if possible. There is still a loop for Dt, I do not know how could I remove it and vectorize the program fully using numpy.


I also rewrote the calculation using numba, managing around a factor two of improvement from the previous code. I wonder if it is still possible to further improve it.

import numba as nb
@nb.jit(fastmath=True,parallel=True)
def MSD_numba(pos):
    Dt_r = np.arange(1, pos.shape[1]-1)
    MSD = np.empty((nParticles, len(Dt_r)))
    dMSD = np.empty((nParticles,len(Dt_r)))
    for i in nb.prange(nParticles):  
        for Dt in Dt_r:
            SD = (pos[i, Dt:, 0] - pos[i, 0:-Dt, 0])**2 + (pos[i, Dt:, 1] - pos[i, 0:-Dt, 1])**2
            MSD[i, Dt-1] = np.mean( SD )
            dMSD[i, Dt-1] = np.std( SD ) / np.sqrt(len(SD)) 
    return Dt_r, MSD, dMSD

start_time = time.time()
Dt_r, MSD_n, dMSD_n = MSD_numba(pos)
print("MSD_numba -- Time: %s s" % (time.time() - start_time))
print("MSD_numba -- All close to MSD_direct: %r\n" %(np.allclose(MSD_n, MSD_d) )  )

Output:

MSD_numba -- Time: 4.520232915878296 s
MSD_numba -- All close to MSD_direct: True

Note: this type of question has been asked in several posts already, but they use different definitions (Mean square displacement python, Mean squared displacement, Mean square displacement for n-dim matrix python), they do not have an answer (Mean square displacement in Python), they just use one particle (Computing mean square displacement using python and FFT, Mean square displacement of a 1d random walk in python), they use pandas (Vectorized calculation of Mean Square Displacement in Python, Speedup MSD calculation in Python), etc.

Puco4
  • 491
  • 5
  • 16
  • I am prette sure this line `np.linalg.norm(pos[:, Dt:] - pos[:, 0:-Dt], axis = -1)**2` has two errors: 1) why square norm? there is no squaring in the definition apart from norm that does the squaring. 2) `pos[:, Dt:] - pos[:, 0:-Dt]` should be something like `np.diff(pos[:, Dt:])` – dankal444 Oct 27 '21 at 17:50
  • 1
    @dankal444 I don't understand why do you believe I calculated wrong the MSD. The line `np.linalg.norm(pos[:, Dt:] - pos[:, 0:-Dt], axis = -1)**2` is literally the definition I showed at the beginning of my post, it is also numerically equal to `np.sum((pos[:, Dt:] - pos[:, 0:-Dt])**2, axis = -1)`, in case you find it clearer. Also, why use `np.diff`? We are not doing differences from consecutive elements of the array, but differences at intervals `Dt`. – Puco4 Oct 28 '21 at 08:07
  • Maybe what it confused you is that `pos` has three axis: (particle, time, coordinate). When I do the norm is respect the coordinate axis, because the vectors are in two dimensions. – Puco4 Oct 28 '21 at 08:12
  • I see now I was wrong. Nevertheless, as you said, `np.sum((pos[:, Dt:] - pos[:, 0:-Dt])**2, axis = -1)` is equivalent and **faster**, I found this strange that you take square root (in norm) and immediately square those numbers - that led me to thinking there must be some error. – dankal444 Oct 28 '21 at 11:26
  • If you run your code more than once, add `use_cache=True` option to `@nb.jit(fastmath=True,parallel=True)` so that compilation will not be repeated next run. When I run the same script second time I get like 6x faster run – dankal444 Oct 28 '21 at 11:38
  • Just FYI: multiprocessing but without numba gave me about the same timings as your numba code but without performance loss (no need for `fastmath`) – dankal444 Oct 28 '21 at 11:45
  • 1
    @dankal444 you are right, it is a bit faster using the squared sum, I changed now the code from the question. Thank you for your comments. – Puco4 Oct 29 '21 at 10:17

1 Answers1

1

Adapting the answer from Computing mean square displacement using python and FFT that uses FFT transforms, I managed to do this calculation faster by two orders of magnitude.


Generalized function for any n dimensional array

This function assumes pos to be any n dimensional array, where you just need to specify the axes of time and of the coordinates (x,y). It returns the mean square displacement associated to all particles in pos.

def MSD_fft_ax(pos, axis_time, axis_coord):
    nTime=pos.shape[axis_time]        

    S2 = np.sum (  np.fft.ifft( np.abs(np.fft.fft(pos, n=2*nTime, axis = axis_time))**2, axis = axis_time ).take(range(nTime), axis = axis_time).real, axis = axis_coord )
    
    D=np.square(pos).sum(axis=axis_coord)

    if axis_coord % pos.ndim < axis_time % pos.ndim: axis_time -= 1

    shape_t = [nTime if ax==axis_time % D.ndim else 1 for ax, s in enumerate(D.shape)]
    shape_non_t = [1 if ax==axis_time % D.ndim else s for ax, s in enumerate(D.shape)]

    D=np.append(D, np.zeros( shape_non_t ), axis = axis_time)
    S1 = ( 2 * np.sum(D, axis = axis_time).reshape(shape_non_t) - np.cumsum( np.insert(D.take(np.arange(0,nTime), axis=axis_time), 0, 0, axis = axis_time) + np.flip(D, axis = axis_time), axis = axis_time ) ).take(np.arange(0,nTime), axis = axis_time) 

    MSD = ( S1-2*S2 ) / ( nTime-np.arange(nTime).reshape(shape_t) )

    Dt_r = np.arange(1, nTime-1)
    MSD = MSD.take(Dt_r, axis = axis_time)
    return Dt_r, MSD

start_time = time.time()
Dt_r, MSD_fax = MSD_fft_ax(pos, axis_time = 1, axis_coord=-1)
print("MSD_fft_ax -- Time: %s s" % (time.time() - start_time))
print("MSD_fft_ax -- All close to MSD_direct: %r\n" %(np.allclose(MSD_fax, MSD_d) )  )

Output:

MSD_direct -- Time: 2.1434285640716553 s

MSD_numba -- Time: 1.532573938369751 s
MSD_numba -- All close to MSD_direct: True

MSD_fft_ax -- Time: 0.009054422378540039 s
MSD_fft_ax -- All close to MSD_direct: True

Function for an array with shape: (particles, time, coordinate).

For a better understanding I include the particular case where pos is an array of shape (particles, time, coordinate):

def MSD_fft(pos):
    nTime=pos.shape[1]        

    S2 = np.sum ( np.fft.ifft( np.abs(np.fft.fft(pos, n=2*nTime, axis = -2))**2, axis = -2  )[:,:nTime,:].real , axis = -1 ) / (nTime-np.arange(nTime)[None,:] )

    D=np.square(pos).sum(axis=-1)
    D=np.append(D, np.zeros((pos.shape[0], 1)), axis = -1)
    S1 = ( 2 * np.sum(D, axis = -1)[:,None] - np.cumsum( np.insert(D[:,0:-1], 0, 0, axis = -1) + np.flip(D, axis = -1), axis = -1 ) )[:,:-1] / (nTime - np.arange(nTime)[None,:] )

    MSD = S1-2*S2

    Dt_r = np.arange(1, pos.shape[1]-1)
    MSD = MSD[:,Dt_r]
    return Dt_r, MSD

start_time = time.time()
Dt_r, MSD_f = MSD_fft(pos)
print("MSD_fft -- Time: %s s" % (time.time() - start_time))
print("MSD_fft -- All close to MSD_direct: %r\n" %(np.allclose(MSD_f, MSD_d) )  )

Output:

MSD_fft -- Time: 0.007384061813354492 s
MSD_fft -- All close to MSD_direct: True
Puco4
  • 491
  • 5
  • 16