I am trying to learn using sympy to optimize the numerical evaluation of mathematical expressions in C. On one side I know that sympy can generate C code to evaluate one expression as follows:
from mpmath import *
from sympy.utilities.codegen import codegen
from sympy import *
x,y,z = symbols('x y z')
[(c_name, c_code), (h_name, c_header)] = codegen([('x', sin(x))], 'C')
and then you can print c_code to the destination file. On the other side, I know that cse can be used to simplify expressions as follows:
from mpmath import *
from sympy.utilities.codegen import codegen
from sympy import *
x,y,z, B1, B2, B3, B4 = symbols('x y z B1 B2 B3 B4 ')
cse([3.0*B2 + 8.0*B3*x**2 + 3.0*B3*x*y + 4.0*B3*x*z + B3*y**2 + B3*z**2 + B4*x**4 + B4*x**3*y + B4*x**3*z + B4*x**2*y**2 + B4*x**2*y*z + B4*x**2*z**2, 7.0*B3*x*y + 2*B3*x*z + B3*(x**2 + y**2) + B4*x**3*y + B4*x**2*y**2 + B4*x**2*y*z + B4*x*y**3 + B4*x*y**2*z + B4*x*y*z**2, B3*x*y + 8.0*B3*x*z + B3*(x**2 + z**2) + B4*x**3*z + B4*x**2*y*z + B4*x**2*z**2 + B4*x*y**2*z + B4*x*y*z**2 + B4*x*z**3, 3.0*B2 + B3*x**2 + 3.0*B3*x*y + B3*x*z + 8.0*B3*y**2 + 3.0*B3*y*z + B3*z**2 + B4*x**2*y**2 + B4*x*y**3 + B4*x*y**2*z + B4*y**4 + B4*y**3*z + B4*y**2*z**2, B3*x*y + 2*B3*x*z + 6.0*B3*y*z + B3*(y**2 + z**2) + B4*x**2*y*z + B4*x*y**2*z + B4*x*y*z**2 + B4*y**3*z + B4*y**2*z**2 + B4*y*z**3, 3.0*B2 + B3*x**2 + B3*x*y + 3.0*B3*x*z + B3*y**2 + 3.0*B3*y*z + 8.0*B3*z**2 + B4*x**2*z**2 + B4*x*y*z**2 + B4*x*z**3 + B4*y**2*z**2 + B4*y*z**3 + B4*z**4])
getting as output:
([(x0, z**2),
(x1, B3*x0),
(x2, B3*x),
(x3, x2*y),
(x4, 3.0*x3),
(x5, 3.0*B2),
(x6, y**2),
(x7, B3*x6),
(x8, x2*z),
(x9, x**2),
(x10, B3*x9),
(x11, B4*x**3),
(x12, x11*y),
(x13, x11*z),
(x14, B4*y),
(x15, x14*x9*z),
(x16, B4*x9),
(x17, x16*x6),
(x18, x0*x16),
(x19, 2*x8),
(x20, y**3),
(x21, B4*x),
(x22, x20*x21),
(x23, x0*x21*y),
(x24, x21*x6*z),
(x25, z**3),
(x26, x21*x25),
(x27, B3*y*z),
(x28, x10 + 3.0*x27),
(x29, B4*x20*z),
(x30, B4*x0*x6),
(x31, x14*x25)],
[B4*x**4 + x1 + 8.0*x10 + x12 + x13 + x15 + x17 + x18 + x4 + x5 + x7 + 4.0*x8,
B3*(x6 + x9) + x12 + x15 + x17 + x19 + x22 + x23 + x24 + 7.0*x3,
B3*(x0 + x9) + x13 + x15 + x18 + x23 + x24 + x26 + x3 + 8.0*x8,
B4*y**4 + x1 + x17 + x22 + x24 + x28 + x29 + x30 + x4 + x5 + 8.0*x7 + x8,
B3*(x0 + x6) + x15 + x19 + x23 + x24 + 6.0*x27 + x29 + x3 + x30 + x31,
B4*z**4 + 8.0*x1 + x18 + x23 + x26 + x28 + x3 + x30 + x31 + x5 + x7 + 3.0*x8])
My question is how to transform properly the former result in C code? sometimes can be useful to transform the reduced expressions in strings and operate on such strings, how it can be done? The aim is automatize the process of generation of code after CSE in order to apply it on many expressions.
EDIT:
Based in the answer below, thanks to Wrzlprmft, the code to produce the corresponding C code snippet is:
from sympy.printing import ccode
from sympy import symbols, cse, numbered_symbols
x,y,z, B1, B2, B3, B4 = symbols('x y z B1 B2 B3 B4 ')
results = [3.0*B2 + 8.0*B3*x**2 + 3.0*B3*x*y + 4.0*B3*x*z + B3*y**2 + B3*z**2 + B4*x**4 + B4*x**3*y + B4*x**3*z + B4*x**2*y**2 + B4*x**2*y*z + B4*x**2*z**2, 7.0*B3*x*y + 2*B3*x*z + B3*(x**2 + y**2) + B4*x**3*y + B4*x**2*y**2 + B4*x**2*y*z + B4*x*y**3 + B4*x*y**2*z + B4*x*y*z**2, B3*x*y + 8.0*B3*x*z + B3*(x**2 + z**2) + B4*x**3*z + B4*x**2*y*z + B4*x**2*z**2 + B4*x*y**2*z + B4*x*y*z**2 + B4*x*z**3, 3.0*B2 + B3*x**2 + 3.0*B3*x*y + B3*x*z + 8.0*B3*y**2 + 3.0*B3*y*z + B3*z**2 + B4*x**2*y**2 + B4*x*y**3 + B4*x*y**2*z + B4*y**4 + B4*y**3*z + B4*y**2*z**2, B3*x*y + 2*B3*x*z + 6.0*B3*y*z + B3*(y**2 + z**2) + B4*x**2*y*z + B4*x*y**2*z + B4*x*y*z**2 + B4*y**3*z + B4*y**2*z**2 + B4*y*z**3, 3.0*B2 + B3*x**2 + B3*x*y + 3.0*B3*x*z + B3*y**2 + 3.0*B3*y*z + 8.0*B3*z**2 + B4*x**2*z**2 + B4*x*y*z**2 + B4*x*z**3 + B4*y**2*z**2 + B4*y*z**3 + B4*z**4]
CSE_results = cse(results,numbered_symbols("helper_"))
with open("snippet.c", "w") as output:
for helper in CSE_results[0]:
output.write("double ")
output.write(ccode(helper[1],helper[0]))
output.write("\n")
for i,result in enumerate(CSE_results[1]):
output.write(ccode(result,"result_%d"%i))
output.write("\n")