I see that the question is already answered for sympy, but I am trying to write an implementation of chain rule for a educative purposes on a toy project with no third party libraries.
Basically chain rule is k'(x) = f'(g(x)) * g'(x) where k(x) = f(g(x))
I have the following functions:
def g(x):
return x**3 + 2
def f(x):
return x**2 + 7
def de(fn, x, step):
t1 = fn(x)
t2 = fn(x+step)
return (t2 - t1) / step
def chain(x):
return f(g(x))
def de_chain(x, step):
d_g = de(g, x, step)
gres = g(x)
d_f_g = de(f, gres, step)
return d_g * d_f_g
The problem is when I evaluate de_chain
and de(chain)
for x=1.2
and step=2.6
, I get de(chain) = 205.5446...
and de_chain = 1238.6639...
.
Something is wrong here, because same approach is applied to addition and subtraction like in k'(x) = g'(x) + f'(x) where k(x) = g(x) + f(x)
the result was very very close. What am I doing wrong ?
Thanks