4

I've looked around on Stackoverflow, but could not find anything that would answer my question.

Problem Setup:

I am trying to solve a system of stiff ODEs using scipy.integrate.ode. I've reduced the code to the minimal working example:

import scipy as sp
from scipy import integrate
import matplotlib.pylab as plt
spiketrain =[0]
syn_inst = 0

def synapse(t, t0):
    tau_1 = 5.3
    tau_2 = 0.05
    tau_rise = (tau_1 * tau_2) / (tau_1 - tau_2)
    B = ((tau_2 / tau_1) ** (tau_rise / tau_1) - (tau_2 / tau_1) ** (tau_rise / tau_2)) ** -1
    return B*(sp.exp(-(t - t0) / tau_1) - sp.exp(-(t - t0) / tau_2)) #the culprit

def alpha_m(v, vt):
    return -0.32*(v - vt -13)/(sp.exp(-1*(v-vt-13)/4)-1)

def beta_m(v, vt):
    return 0.28 * (v - vt - 40) / (sp.exp((v- vt - 40) / 5) - 1)

def alpha_h(v, vt):
    return 0.128 * sp.exp(-1 * (v - vt - 17) / 18)

def beta_h(v, vt):
    return  4 / (sp.exp(-1 * (v - vt - 40) / 5) + 1)

def alpha_n(v, vt):
    return -0.032*(v - vt - 15)/(sp.exp(-1*(v-vt-15)/5) - 1)

def beta_n(v, vt):
    return 0.5* sp.exp(-1*(v-vt-10)/40)

def inputspike(t):
    if int(t) in a :
        spiketrain.append(t)

def f(t,X):
    V = X[0]
    m = X[1]
    h = X[2]
    n = X[3]

    inputspike(t)
    g_syn = synapse(t, spiketrain[-1])
    syn = 0.5* g_syn * (V - 0)
    global syn_inst
    syn_inst = g_syn 

    dydt = sp.zeros([1, len(X)])[0]
    dydt[0] = - 50*m**3*h*(V-60) - 10*n**4*(V+100) - syn - 0.1*(V + 70)
    dydt[1] = alpha_m(V, -45)*(1-m) - beta_m(V, -45)*m
    dydt[2] = alpha_h(V, -45)*(1-h) - beta_h(V, -45)*h
    dydt[3] = alpha_n(V, -45)*(1-n) - beta_n(V, -45)*n
    return dydt

t_start = 0.0
t_end = 2000
dt = 0.1

num_steps = int(sp.floor((t_end - t_start) / dt) + 1)

a = sp.zeros([1,int(t_end/100)])[0]
a[0] = 500 #so the model settles
sp.random.seed(0)
for i in range(1, len(a)):
a[i] = a[i-1] + int(round(sp.random.exponential(0.1)*1000, 0))

r = integrate.ode(f).set_integrator('vode', nsteps = num_steps,
                                          method='bdf')
X_start = [-70, 0, 1,0]
r.set_initial_value(X_start, t_start)

t = sp.zeros(num_steps)
syn = sp.zeros(num_steps)
X = sp.zeros((len(X_start),num_steps))
X[:,0] = X_start
syn[0] = 0
t[0] = t_start
k = 1

while r.successful() and k < num_steps:
    r.integrate(r.t + dt)
    # Store the results to plot later
    t[k] = r.t
    syn[k] = syn_inst
    X[:,k] = r.y
    k += 1

plt.plot(t,syn)
plt.show()

Problem:

I find that when I actually run the code, time t in the solver appears to go back and forth, which results in spiketrain[-1] being greater than t, and the value syn becoming very negative and significantly messing up my simulations (you can see the negative values in the plot if the code is run).

I am guessing it has something to do with variable time steps in the solver, so I was wondering if it is possible to restrict time to only forward (positive) propagation.

Thanks

Vasily
  • 73
  • 5

1 Answers1

2

The solver do actually go back and forth, and I think also because of the variable time stepping. But I think the difficulty comes from that the result of f(t, X) is not only a function of t and X but of the previous call made to this function, which is not a good idea.

Your code works by replacing:

inputspike(t)
g_syn = synapse(t, spiketrain[-1])

by:

last_spike_date = np.max( a[a<t] )
g_syn = synapse(t, last_spike_date)

And by setting an "old event" for the "settle time" with a = np.insert(a, 0, -1e4). This is needed to always have a last_spike_date defined (see the comment in the code below).

Here is a modified version of your code:

I modified how the time of the last spike if found (using this time the Numpy function searchsorted so that the function can be vectorized). I also modified the way the array a is created. This is not my field, so maybe I misunderstood the intent.

I used solve_ivp instead of ode but still with a BDF solver (However it's not the same implementation as in ode which is in Fortran).

import numpy as np  # rather than scipy 
import matplotlib.pylab as plt
from scipy.integrate import solve_ivp

def synapse(t, t0):
    tau_1 = 5.3
    tau_2 = 0.05
    tau_rise = (tau_1 * tau_2) / (tau_1 - tau_2)
    B = ((tau_2 / tau_1)**(tau_rise / tau_1) - (tau_2 / tau_1)**(tau_rise / tau_2)) ** -1
    return B*(np.exp(-(t - t0) / tau_1) - np.exp(-(t - t0) / tau_2))

def alpha_m(v, vt):
    return -0.32*(v - vt -13)/(np.exp(-1*(v-vt-13)/4)-1)

def beta_m(v, vt):
    return 0.28 * (v - vt - 40) / (np.exp((v- vt - 40) / 5) - 1)

def alpha_h(v, vt):
    return 0.128 * np.exp(-1 * (v - vt - 17) / 18)

def beta_h(v, vt):
    return  4 / (np.exp(-1 * (v - vt - 40) / 5) + 1)

def alpha_n(v, vt):
    return -0.032*(v - vt - 15)/(np.exp(-1*(v-vt-15)/5) - 1)

def beta_n(v, vt):
    return 0.5* np.exp(-1*(v-vt-10)/40)

def f(t, X):
    V = X[0]
    m = X[1]
    h = X[2]
    n = X[3]

    # Find the largest value in `a` before t:
    last_spike_date = a[ a.searchsorted(t, side='right') - 1 ]

    # Another simpler way to write this is:
    # last_spike_date = np.max( a[a<t] )
    # but didn't work with an array for t        

    g_syn = synapse(t, last_spike_date)
    syn = 0.5 * g_syn * (V - 0)

    dVdt = - 50*m**3*h*(V-60) - 10*n**4*(V+100) - syn - 0.1*(V + 70)
    dmdt = alpha_m(V, -45)*(1-m) - beta_m(V, -45)*m
    dhdt = alpha_h(V, -45)*(1-h) - beta_h(V, -45)*h
    dndt = alpha_n(V, -45)*(1-n) - beta_n(V, -45)*n
    return [dVdt, dmdt, dhdt, dndt]


# Define the spike events:
nbr_spike = 20
beta = 100
first_spike_date = 500

np.random.seed(0)
a = np.cumsum( np.random.exponential(beta, size=nbr_spike) ) + first_spike_date
a = np.insert(a, 0, -1e4)  # set a very old spike at t=-1e4
                           # it is a hack in order to set a t0  for t<first_spike_date (model settle time)
                           # so that `synapse(t, t0)` can be called regardless of t
                           # synapse(t, -1e4) = 0  for t>0

# Solve:
t_start = 0.0
t_end = 2000

X_start = [-70, 0, 1,0]

sol = solve_ivp(f, [t_start, t_end], X_start, method='BDF', max_step=1, vectorized=True)
print(sol.message)

# Graph
V, m, h, n = sol.y
plt.plot(sol.t, V);
plt.xlabel('time');  plt.ylabel('V');

which gives:

result for V

note: There is an events parameters in solve_ivp which could be useful.

xdze2
  • 3,986
  • 2
  • 12
  • 29