I'm using the mpmath library to compute a self defined function f(x)
and needed to compute its higher order derivatives.
I found the Automatic differentiation on Wikipedia and found it useful.
# Automatic Differentiation
import math
class Var:
def __init__(self, value, children=None):
self.value = value
self.children = children or []
self.grad = 0
print(children)
def __add__(self, other):
return Var(self.value + other.value, [(1, self), (1, other)])
def __mul__(self, other):
return Var(self.value * other.value, [(other.value, self), (self.value, other)])
def sin(self):
return Var(math.sin(self.value), [(math.cos(self.value), self)])
def calc_grad(self, grad=1):
self.grad += grad
for coef, child in self.children:
child.calc_grad(grad * coef)
# Example: f(x, y) = x * y + sin(x)
x = Var(2)
y = Var(3)
f = x * y + x.sin()
# Calculation of partial derivatives
f.calc_grad()
print("f =", f.value)
print("∂f/∂x =", x.grad)
print("∂f/∂y =", y.grad)
which returned
None
None
[(3, <__main__.Var object at 0x000001EDDCCA2260>), (2, <__main__.Var object at 0x000001EDDCCA1300>)]
[(-0.4161468365471424, <__main__.Var object at 0x000001EDDCCA2260>)]
[(1, <__main__.Var object at 0x000001EDDCCA0A30>), (1, <__main__.Var object at 0x000001EDDCCA0BB0>)]
f = 6.909297426825682
∂f/∂x = 2.5838531634528574
∂f/∂y = 2
However, I have some question about how the code was implemented. For example, I guessed that the children list was the derivatives with Chain rule, but I'm not sure how the statement such as [(1, self), (1, other)]
implemented this recursion.
Also, the Automatic Differentiation in this code seemed to be written as more of a "symbolic calculation" in disguise, and I'm not sure if it will work for arbitrary function and at arbitrary order. For example, I wanted to try the second order derivatives but f.calc_grad(grad=2).value
and f.calc_grad().calc_grad( ).value
didn't work, I suppose it was because the cos
function was not defined.
How to calculate the Automatic differentiation of an arbitrary function in python? Can I not define a class function but to work with mpmath directly?