I am trying to put together a numerical simulation (specifically, Beta cell dynamics) based on Betram et al. 2007 (https://www.sciencedirect.com/science/article/pii/S0006349507709621). The model itself works fine but it is very slow since the simulation step must be around 0.1 ms and python is not the fastest language around. It takes approximately 12 real seconds for every simulation second with only 15 coupled beta cells in the system. In the end, I will need around 1000 beta cells to simulate an entire islet of Langerhans so you can see why I need to speed things up.
Each beta cell is implemented as a class instance which inherits from the CellParameters and ModelParameters class.
@jitclass(spec)
class BetaCell:
def __init__(self, cell_num: int, neighbours: list, G: float):
##sets initial conditions (23 parameters - floats and lists)).
def w_ijkl(self, ii, jj, kk, ll, f6p):
###calculates and returns a specific parameter
def run_model_step(self, Ge: float):
###runs one time step (dt=0.1 ms) for the cell.
###has to calculate/update around 55 parameters
class ModelParameters:
###Contains all model parameters
###time step, the intensity of glucose stimulation, the start of stimulation etc.
###also contains when to save a time step for later visualization
@staticmethod
def external_glucose(time):
###calculates and returns the current level of external glucose
###uses a simple equation
class CellParameters:
###Contains approx. 70 parameters (floats) that the the model needs for execution.
###Some of these parameters are changed (once) after initialization
###to introduce some cell heterogeneity
The simulation looks like this:
- some data is imported with cell parameters (locations, coupling, coupling weights).
- each cell is initialized with its cell number (0, 1, 2, 3...), neighbours and starting glucose Cells are stored into a list named "cells".
- if required, heterogeneity is introduced into cellular parameters
- each step of the simulation is executed
Simulation step execution:
def run_step(cell):
cell.run_model_step(glc)
if __name__ == '__main__':
for step, current_time in enumerate(time):
###time array is pre-calculated based on provided end_time and simulation step (dt)
glc = ModelParameters.external_glucose(current_time)
cells = calculate_gj_coupling(cells) #calculates gap-jounction coupling between connected cells
cells = list(map(run_step, cells))
The above for-loop is repeated until the end of the simulation is reached. Ofcourse this is a slow process taking around 10-12 seconds for each simulation second (10000 loop iterations @ 0.1 ms steps)
I really need to speed things up, preferably around 10-fold or more would be great.
Sofar I tried to use the Pool class from the multiprocessing module.
I created a pool: pool = Pool(processes=NUMBER_OF_WORKERS) I used the pools map function to run a simulation step for each cell
pool = Pool(processes=NUMBER_OF_WORKERS)
.
.
.
for step, current_time in enumerate(time):
###time array is pre-calculated based on provided end_time and simulation step (dt)
glc = ModelParameters.external_glucose(current_time)
cells = calculate_gj_coupling(cells) #calculates gap-jounction coupling between connected cells
cells = pool.map(run_step, cells)
pool.terminate()
The rest is the same as before, because the slow part is the calculation of individual time steps for every beta cell.
The problem with the above solution is that it makes things worse. I am guessing that the shifting of the class instances around in memory for each process is the culprit, because the same solution worked wonders for a simplyfied version of the problem (below)
def task_function(dummy_object):
dummy_object.sum_ab()
return dummy_object
class DummyObject:
def __init__(self, a, b):
self.a = a
self.b = b
self.ab = 0.0
def sum_ab(self):
time.sleep(2) #simulates long running task
self.ab += self.a + self.b
if __name__ == '__main__':
pool = Pool(processes=NUMBER_OF_WORKERS)
cells = [DummyObject(i, randint(1,20), randint(1,20)) for i in range(NUMBER_OF_CELLS)]
for i in range(NUMBER_OF_STEPS):
pool.map(task_function, cells)
pool.terminate()
The above simple example speeds things up quite a bit. If sequential execution is implemented (the standard way) the "simulation" takes 400 seconds @ NUMBER_OF_CELLS=200 for one iteration of the for-loop (each cell takes 2 seconds * 200 = 400 s). If I implement the above solution one iteration of the for-loop takes only 8 seconds with NUMBER_OF_CELLS=200 and NUMBER_OF_WORKERS=60. But these DummyObjects are ofcourse very small and simple so the shifting around in memory goes quickly.
Any ideas to implement some version of the above dummy solution would be greatly appreciated.
EDIT 16. 2. 2023 Thanks to Fanchen Bao I have found the remaining bottleneck in my code. It is the coupling function that calculated coupling currents between connected cells.
The coupling function looks like this:
@jit(nopython=True)
def calculate_gj_coupling(cells, cells_neighbours):
for i, cell in enumerate(cells):
ca_current = 0.0
voltage_current = 0.0
g6p_current = 0.0
adp_current = 0.0
for neighbour, weight in cells_neighbours[i]:
voltage_current += (cell.Cgjv*weight)*(cells[neighbour].V-cell.V)
ca_current += (cell.Cgjca*weight)*(cells[neighbour].C-cell.C)
g6p_current += (cell.Cgjg6p*weight)*(0.3*cells[neighbour].G6P-0.3*cell.G6P)
adp_current += (cell.Cgjadp*weight)*(cells[neighbour].ADPm - cell.ADPm)
cell.couplingV = voltage_current
cell.couplingCa = ca_current
cell.couplingG6P = g6p_current
cell.couplingADP = adp_current
return cells
It is basically a nested for-loop because each connection between two cells is weighted (weight parameter).
What would be a more pythonic (and faster) way of writing this up? Keep in mind that this function runs in every simulation step.
EDIT 18. 2 2023 I rewrote the BetaCell class. It now contains all cell parameters (instead of inheriting from the CellParameters class) and all necessary model parameters are provided at initialization (dt, save_step). This allowed me to add the Numba jitclass decorator with corresponding specifications. It threw an error before, because the appears to be a problem with inheritance during compilation, I guess. I also use Numba List() class instead of the Python built-in list.