-2

I'd like to code in python a coupled system of differential equations : dF/dt=A(F) where F is a matrix and A(F) is a function of the matrix F.

When F and A(F) are vectors the equation is solved using scipy.integrate.odeint.

However, scipy.integrate.odeint doesn't work for matrices, and I get an error :

tmin, tmax, tstep = (0., 200., 1)
t_test=np.arange(tmin, tmax, tstep) #time vector

dydt_testm=np.array([[0.,1.],[2.,3.]])
Y0_test=np.array([[0,1],[0,1]])

def dydt_test(y,t):
    return dydt_testm

result = si.odeint(dydt_test, Y0_test,t_test)

ValueError: Initial condition y0 must be one-dimensional.

J.A
  • 285
  • 3
  • 12
  • 1
    What did you try so far? – GPhilo Jan 28 '20 at 16:50
  • 2
    Take a look at [`odeintw`](https://pypi.org/project/odeintw/). The source is on github at https://github.com/WarrenWeckesser/odeintw – Warren Weckesser Jan 28 '20 at 16:55
  • 1
    You could wrap the ode function with `F=F.reshape(n,n)` and `return dF.flatten()` – Lutz Lehmann Jan 28 '20 at 16:59
  • @WarrenWeckesser how should I import odeintw while working with IPython in a jupyter notebook ? Does it work on Python 2 ? – J.A Jan 28 '20 at 17:06
  • 1
    `odeintw` is on [PyPI](https://pypi.org/), so you can install it with the `pip` command. (If you are not familiar with `pip`, do a search for a tutorial.) I haven't updated the released version in a while, so it should still work with Python 2.7. – Warren Weckesser Jan 28 '20 at 17:15
  • @WarrenWeckesser it works thx – J.A Jan 28 '20 at 17:30

1 Answers1

1

As commented by Warren Weckesser in the comments, odeintw does the job.

from odeintw import odeintw
import numpy as np

Y0_test=np.array([[0,1],[0,1]])
tmin, tmax, tstep = (0., 200., 1)
t_test=np.arange(tmin, tmax, tstep) #time vector

dydt_testm=np.array([[0.,1.],[2.,3.]])

def dydt_test(y,t):
    return dydt_testm

result = odeintw(dydt_test, #Computes the derivative of y at t 
                                     Y0_test,               #Initial condition on y (can be a vector).
                                     t_test)
plt.plot(t_test,result[:,0,1])
plt.show()

J.A
  • 285
  • 3
  • 12