使用 Sympy 生成 C 代码用 x*x 替换 Pow(x,2)

如何解决使用 Sympy 生成 C 代码用 x*x 替换 Pow(x,2)

我正在使用通用子表达式消除 (CSE) 例程和 ccode 打印机生成带有 sympy 的 C 代码

但是,我希望将幂表达式设为 (x*x) 而不是 pow(x,2)。

无论如何要这样做?

示例:

import sympy as sp
a= sp.MatrixSymbol('a',3,3)
b=sp.Matrix(a)*sp.Matrix(a)

res = sp.cse(b)

lines = []
     
for tmp in res[0]:
    lines.append(sp.ccode(tmp[1],tmp[0]))

for i,result in enumerate(res[1]):
    lines.append(sp.ccode(result,"result_%i"%i))

输出

x0[0] = a[0];
x0[1] = a[1];
x0[2] = a[2];
x0[3] = a[3];
x0[4] = a[4];
x0[5] = a[5];
x0[6] = a[6];
x0[7] = a[7];
x0[8] = a[8];
x1 = x0[0];
x2 = x0[1];
x3 = x0[3];
x4 = x2*x3;
x5 = x0[2];
x6 = x0[6];
x7 = x5*x6;
x8 = x0[4];
x9 = x0[7];
x10 = x0[5];
x11 = x0[8];
x12 = x10*x9;
result_0[0] = pow(x1,2) + x4 + x7;
result_0[1] = x1*x2 + x2*x8 + x5*x9;
result_0[2] = x1*x5 + x10*x2 + x11*x5;
result_0[3] = x1*x3 + x10*x6 + x3*x8;
result_0[4] = x12 + x4 + pow(x8,2);
result_0[5] = x10*x11 + x10*x8 + x3*x5;
result_0[6] = x1*x6 + x11*x6 + x3*x9;
result_0[7] = x11*x9 + x2*x6 + x8*x9;
result_0[8] = pow(x11,2) + x12 + x7;

最好的问候

解决方法

您可以将 code printer 子类化,并且只更改您想要不同的一个函数。您需要调查 the original sympy code 以找到正确的函数名称和默认实现,以确保不会出错。稍加注意,所需的括号就可以准确地在需要的时间和地点自动生成。

这是一个最小的例子:

import sympy as sp
from sympy.printing.c import C99CodePrinter
from sympy.printing.precedence import precedence
from sympy.abc import x

class CustomCodePrinter(C99CodePrinter):
    def _print_Pow(self,expr):
        PREC = precedence(expr)
        if expr.exp == 2:
            return '({0} * {0})'.format(self.parenthesize(expr.base,PREC))
        else:
            return super()._print_Pow(expr)

default_printer = C99CodePrinter().doprint
custom_printer = CustomCodePrinter().doprint

expressions = [x,(2 + x) ** 2,x ** 3,x ** 15,sp.sqrt(5),sp.sqrt(x)**4,1 / x,1 / (x * x)]
print("Default: {}".format(default_printer(expressions)))
print("Custom: {}".format(custom_printer(expressions)))

输出:

Default: [x,pow(x + 2,2),pow(x,3),15),sqrt(5),1.0/x,-2)]
Custom: [x,((x + 2) * (x + 2)),(x * x),-2)]

PS:要支持更广泛的指数,您可以使用例如

class CustomCodePrinter(C99CodePrinter):
    def _print_Pow(self,expr):
        PREC = precedence(expr)
        if expr.exp in range(2,7):
            return '*'.join([self.parenthesize(expr.base,PREC)] * int(expr.exp))
        elif expr.exp in range(-6,0):
            return '1.0/(' + ('*'.join([self.parenthesize(expr.base,PREC)] * int(-expr.exp))) + ')'
        else:
            return super()._print_Pow(expr)
,

有一个名为 create_expand_pow_optimization 的函数可以创建一个包装器来优化您在这方面的表达式。它将用显式乘法替换的最高幂作为参数。

包装器返回一个 UnevaluatedExpr,该 import sympy as sp from sympy.codegen.rewriting import create_expand_pow_optimization expand_opt = create_expand_pow_optimization(3) a = sp.Matrix(sp.MatrixSymbol('a',3,3)) res = sp.cse(a@a) for i,result in enumerate(res[1]): print(sp.ccode(expand_opt(result),"result_%i"%i)) 可以防止自动简化,从而恢复此更改。

mmc.exe gpedit.msc

最后,请注意,对于足够高的优化级别,您的编译器会处理这个问题(并且可能在这方面做得更好)。

,

我想我会采用 user_function 方法:

正如上面评论中所建议的,我将使用 sp.ccode 的 user_functions 功能: 假设我们有一个像 a^3

这样的数字

sp.ccode(a**3,user_functions={'Pow': [(lambda x,y: y.is_integer,lambda x,y: '*'.join(['('+x+')']*int(y))),(lambda x,y: not y.is_integer,'pow')]})

应该输出: '(a)*(a)*(a)'

以后会尽量改进功能,只在需要的时候加括号。

欢迎任何改进!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?