5

I am generating C code with sympy the using the Common Subexpression Elimination (CSE) routine and the ccode printer.

However, I would like to have powered expressions as (x*x) instead of pow(x,2).

Anyway to do this?

Example:

import sympy as sp
a= sp.MatrixSymbol('a',3,3)
b=sp.Matrix(a)*sp.Matrix(a)

res = sp.cse(b)

lines = []
     
for tmp in res[0]:
    lines.append(sp.ccode(tmp[1], tmp[0]))

for i,result in enumerate(res[1]):
    lines.append(sp.ccode(result,"result_%i"%i))

Will output:

x0[0] = a[0];
x0[1] = a[1];
x0[2] = a[2];
x0[3] = a[3];
x0[4] = a[4];
x0[5] = a[5];
x0[6] = a[6];
x0[7] = a[7];
x0[8] = a[8];
x1 = x0[0];
x2 = x0[1];
x3 = x0[3];
x4 = x2*x3;
x5 = x0[2];
x6 = x0[6];
x7 = x5*x6;
x8 = x0[4];
x9 = x0[7];
x10 = x0[5];
x11 = x0[8];
x12 = x10*x9;
result_0[0] = pow(x1, 2) + x4 + x7;
result_0[1] = x1*x2 + x2*x8 + x5*x9;
result_0[2] = x1*x5 + x10*x2 + x11*x5;
result_0[3] = x1*x3 + x10*x6 + x3*x8;
result_0[4] = x12 + x4 + pow(x8, 2);
result_0[5] = x10*x11 + x10*x8 + x3*x5;
result_0[6] = x1*x6 + x11*x6 + x3*x9;
result_0[7] = x11*x9 + x2*x6 + x8*x9;
result_0[8] = pow(x11, 2) + x12 + x7;

Best Regards

Manuel Oliveira
  • 527
  • 5
  • 19
  • 1
    Maybe this helps: https://stackoverflow.com/questions/39173019/converting-squared-and-cube-terms-into-multiplication – hpaulj Jan 02 '21 at 00:36

3 Answers3

3

There is a function called create_expand_pow_optimization that creates a wrapper to optimise your expressions in this respect. It takes as an argument the highest power it will replace by explicit multiplications.

The wrapper returns an UnevaluatedExpr that is protected against automatic simplifications that would revert this change.

import sympy as sp
from sympy.codegen.rewriting import create_expand_pow_optimization

expand_opt = create_expand_pow_optimization(3)

a = sp.Matrix(sp.MatrixSymbol('a',3,3))
res = sp.cse(a@a)

for i,result in enumerate(res[1]):
    print(sp.ccode(expand_opt(result),"result_%i"%i))

Finally, be aware that for sufficiently high optimisation levels, your compiler will take care of this (and is probably better at this).

Wrzlprmft
  • 4,234
  • 1
  • 28
  • 54
  • That function appears to be interesting! However, I was not able to make work in the `user_functions` of `ccode `. Additionally, it appears to only work for cases of symbol ^ integer. If I have: `sp.ccode(expand_opt (a+b)**2)` it will give me `'pow(a + b, 2)` – Manuel Oliveira Jan 06 '21 at 14:31
  • That should be `sp.ccode(expand_opt((a+b)**2))`, but even then it does not work. Seems like a bug. – Wrzlprmft Jan 06 '21 at 14:44
  • I created [an issue](https://github.com/sympy/sympy/issues/20753) on this. – Wrzlprmft Jan 06 '21 at 14:51
1

You can subclass the code printer, and only change the one function you want different. You'd need to investigate the original sympy code to find the correct function names and default implementation, so you can make sure you don't make errors. With a bit of care, the needed brackets can be generated automatically exact when and where they are needed.

Here is a minimal example:

import sympy as sp
from sympy.printing.c import C99CodePrinter
from sympy.printing.precedence import precedence
from sympy.abc import x

class CustomCodePrinter(C99CodePrinter):
    def _print_Pow(self, expr):
        PREC = precedence(expr)
        if expr.exp == 2:
            return '({0} * {0})'.format(self.parenthesize(expr.base, PREC))
        else:
            return super()._print_Pow(expr)

default_printer = C99CodePrinter().doprint
custom_printer = CustomCodePrinter().doprint

expressions = [x, (2 + x) ** 2, x ** 3, x ** 15, sp.sqrt(5), sp.sqrt(x)**4, 1 / x, 1 / (x * x)]
print("Default: {}".format(default_printer(expressions)))
print("Custom: {}".format(custom_printer(expressions)))

Output:

Default: [x, pow(x + 2, 2), pow(x, 3), pow(x, 15), sqrt(5), pow(x, 2), 1.0/x, pow(x, -2)]
Custom: [x, ((x + 2) * (x + 2)), pow(x, 3), pow(x, 15), sqrt(5), (x * x), 1.0/x, pow(x, -2)]

PS: To support a wider range of exponents, you could use e.g.

class CustomCodePrinter(C99CodePrinter):
    def _print_Pow(self, expr):
        PREC = precedence(expr)
        if expr.exp in range(2, 7):
            return '*'.join([self.parenthesize(expr.base, PREC)] * int(expr.exp))
        elif expr.exp in range(-6, 0):
            return '1.0/(' + ('*'.join([self.parenthesize(expr.base, PREC)] * int(-expr.exp))) + ')'
        else:
            return super()._print_Pow(expr)
JohanC
  • 71,591
  • 8
  • 33
  • 66
0

I think I will go with the user_function approach:

As suggested in the comment above I will be using the user_functions functionality of sp.ccode: Assuming we have a number like a^3

sp.ccode(a**3, user_functions={'Pow': [(lambda x, y: y.is_integer, lambda x, y: '*'.join(['('+x+')']*int(y))),(lambda x, y: not y.is_integer, 'pow')]})

should output: '(a)*(a)*(a)'

In the future, I will try to improve the function to only put parenthesis when needed.

Any improvements are welcome!

Manuel Oliveira
  • 527
  • 5
  • 19
  • You still have to check for negative powers. E.g. it goes wrong for `1/a`, `a**(-2)` or `1/(a**2+1)`. For high powers you probably want to use [exponentiation via squaring](https://en.wikipedia.org/wiki/Exponentiation_by_squaring), which can result in a lot of speedup. (Also note that most C compilers do this kind of optimizations automatically.) – JohanC Jan 06 '21 at 21:15