微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用表达式模板自动微分 c++

如何解决使用表达式模板自动微分 c++

简介

我正在尝试了解表达式模板,因为它似乎是一种适用于各种计算的非常强大的技术。我在网上看了不同的例子(例如wikipedia),我写了一堆做不同计算的小程序。然而,我发现了一个我无法解决的问题;即,将任意表达式赋值给变量,其“惰性求值”,以及将其重新赋值给另一个任意表达式。

可以使用auto,将一个表达式赋值给一个变量,但不能重新赋值给另一个表达式(你可以查看整个库sadET,以便充分理解我正在尝试做的)。此外,可以使用 CRTP 和重载 operator= 来完成分配和重新分配。然而,表达式是在赋值过程中计算的,我们基本上丢失了关于表达式是什么的所有信息。

因此,我尝试使用多态克隆 + CRTP(例如,参见 this),这是一种有效的方法,但是当我尝试将变量重新分配给包含相同变量的表达式时,出现了段错误

代码

在下面的代码中,我展示了如何实现加法表达式模板的简化版本。在随后的代码中,msgstd::cout<<typeid(*this).name()<<std::endl; 的宏,用于跟踪正在评估的内容

这是所有表达式的(纯)基类,它允许我将变量分配给通用表达式(使用多态克隆):

struct BaseExpression{
    BaseExpression()=default;
    virtual double evaluate()const =0 ;
};

这是一个所有表达式都继承自的类,将允许我使用CRTP

template<typename subExpr>
struct GenericExpression:BaseExpression{
    const subExpr& self() const {return static_cast<const subExpr&>(*this);}
    subExpr& self() {return static_cast<subExpr&>(*this);}    
    double evaluate() const { msg; return self().evaluate(); };
};

这个想法是所有表达式都由数字、基本函数和运算符组成。所以,我写了一个 Number 类,如下

class Number: public GenericExpression<Number>{
    double val;
    public:
    Number()=default;

    Number(const double &x):val(x){}
    Number(const Number &x):val(x.evaluate()){}
    
    double evaluate()const  { msg; return val;}
    double& evaluate() { msg; return val;}
};

按照表达式模板的思路,那么,添加的是

template<typename leftHand,typename rightHand>
class Addition:public GenericExpression<Addition<leftHand,rightHand>>{
    const leftHand &LH;
    const rightHand &RH;

    public:
    Addition(const leftHand &LH,const rightHand &RH):LH(LH),RH(RH){}

    double evaluate() const {msg; return LH.evaluate() + RH.evaluate();}
};

template<typename leftHand,typename rightHand>
Addition<leftHand,rightHand> 
operator+(const GenericExpression<leftHand> &LH,const GenericExpression<rightHand> &RH){
    return Addition<leftHand,rightHand>(LH.self(),RH.self()); 
}

为了能够使用 BaseExpression,我还编写了一个 Expression 类,用于将 GenericExpression 的实例分配给 Expression 变量。

class Expression: public GenericExpression<Expression>{
    public:
    BaseExpression *baseExpr;

    Expression()=default;
    Expression(const Expression &E){baseExpr = E.baseExpr;};

    double evaluate() const {msg;  return baseExpr->evaluate();}
    
    template<typename subExpr>
    void assign(const GenericExpression<subExpr> &RH){        
        baseExpr = new subExpr(RH.self());
    }

};

在这个类中,重要的一点是指针baseExpr,当我们调用assign时,它允许将evaluate函数更改为GenericExpression

一些例子

为了测试这是否有效,我声明了以下变量:

    Number x(3.2);
    Number y(-2.3);
    Expression z,w;

然后,我们可以看到下面的东西起作用了

    //assignment to Number
    z.assign(x);
    cout<<z.evaluate()<<endl;


    //assignment to Addition<Number,Number>
    z.assign(x+y);
    cout<<z.evaluate()<<endl;

    
    //assignment to Addition<Expression,Number>
    w.assign(z+y);
    cout<<w.evaluate()<<endl;

但是,当我执行以下操作时,当我运行 z.evaluate() 时会得到无限递归,因为 z.baseExpr 指向自身。

    // Segmentation fault of z.evaluate() due to infinite recursion between
    // LH.evaluate() (in Addition<Expression,Number>::evaluate()) and 
    // baseExpr->evaluate() (in Expression::evaluate())
    z.assign(z+x);
    cout<<z.evaluate()<<endl;

重现我描述的行为的完整代码

#include<iostream>
#include<cmath>


// message to print during evaluation,in order to track the evaluation path.
#define msg std::cout<<typeid(*this).name()<<std::endl;


struct BaseExpression{
    BaseExpression()=default;
    virtual double evaluate()const =0 ;
};

template<typename subExpr>
struct GenericExpression:BaseExpression{
    const subExpr& self() const {return static_cast<const subExpr&>(*this);}
    subExpr& self() {return static_cast<subExpr&>(*this);}
    
    double evaluate() const { msg; return self().evaluate(); };
};


class Number: public GenericExpression<Number>{
    double val;
    public:
    Number()=default;

    Number(const double &x):val(x){}
    Number(const Number &x):val(x.evaluate()){}
    
    double evaluate()const  { msg; return val;}
    double& evaluate() { msg; return val;}
};

template<typename leftHand,RH.self()); 
}


class Expression: public GenericExpression<Expression>{
    public:
    BaseExpression *baseExpr;

    Expression()=default;
    Expression(const Expression &E){baseExpr = E.baseExpr;};
    // Expression(Expression *E){baseExpr = E->baseExpr;};

    double evaluate() const {msg;  return baseExpr->evaluate();}


    template<typename subExpr>
    void assign(const GenericExpression<subExpr> &RH){
        
        baseExpr = new subExpr(RH.self());
    }

};


using std::cout;
using std::endl;


int main(){
    Number x(3.2);
    Number y(-2.3);
    Expression z,w;

    // works fine!
    z.assign(x);
    cout<<z.evaluate()<<endl;
    // works fine!
    z.assign(x+y);
    cout<<z.evaluate()<<endl;

    
    // works fine!
    w.assign(z+y);
    cout<<w.evaluate()<<endl;

    // Segmentation fault of z.evaluate() infinite recursion in between
    // LH.evaluate() (in Addition<Expression,Number>::evaluate()) and 
    // baseExpr->evaluate() (in Expression::evaluate())
    z.assign(z+x);
    cout<<z.evaluate()<<endl;

    return 0;
}

转向自动差异

自动微分的实现与上面的例子没有什么不同。最重要的变化是向所有派生自 Constant 的类引入了 derivative 类和 GenericExpression 成员函数,返回任意表达式。也就是说,我们添加

class Constant: public GenericExpression<Constant>{
    double val;
    public:
    Constant()=default;

    Constant(const double &x):val(x){}
    Constant(const Constant &x):val(x.evaluate()){}
    
    double evaluate()const  { msg; return val;}
    double& evaluate() { msg; return val;}
    
    auto derivative(){return Constant(0);}
};

其他类中的 derivative 成员应如下所示:

auto Number::derivative(){return Constant(1);}

template<typename leftHand,typename rightHand>
auto Addition<left,right>::derivative(){return LH.derivative() + RH.derivative();}

问题

即使我们找到了避免(或忽略)我之前展示的段错误方法,我也看不出我可以让 derivative 成为 BaseExpression 的虚拟成员函数,因为它返回一个不同的原则上,我们只有在调用它时才知道这种表达式(因此是 auto)。

最后一个问题

在我试图描述的上下文中,有什么办法可以做到以下几点

    Number x(5);
    Expression z;
    
    z.assign(x);
    z.assign(z+x);
    
    cout<<z.evaluate()<<endl;
    cout<<z.derivative().evaluate()<<endl;

最好没有段错误

感谢社区的意见或见解!

编辑

我简化了代码,让事情变得更简单,并避免潜在的危险指针。

在此示例中,我将成员函数 evaluate 设为指向 lambda 的函数指针。这样,如果我理解正确,我会直接将 ...::evaluate 复制到 Expression::evaluate。但是,我仍然遇到段错误...

奇怪的是,我在使用递归函数对向量求和时得到了段错误

代码

#include<iostream>
#include<functional>
#include<vector>


// message to print during evaluation,in order to track the evaluation path.
#define msg std::cout<<typeid(*this).name()<<std::endl


class Expression{
    public:
    std::function<double(void)> evaluate;
    
    template<typename T>
    Expression(const T &RH){evaluate = RH.evaluate;}

    template<typename T>
    Expression& operator=(const T &RH){evaluate = RH.evaluate; return *this;}
};




class Number{ 
    double val;
    public:
    Number()=default;

    std::function<double()> evaluate;

    Number(const double &x):val(x){ evaluate = [this](){msg; return this->val;};  }
    Number(const Number &x):val(x.evaluate()){ evaluate = [this](){msg; return this->val;}; }
};

template<typename leftHand,typename rightHand>
class Addition{ 
    const leftHand &LH;
    const rightHand &RH;

    public:
    std::function<double()> evaluate;

    Addition(const leftHand &LH,RH(RH)
    {evaluate = [this](){msg; return this->LH.evaluate() + this->RH.evaluate();};}
};

template<typename leftHand,typename rightHand>
auto operator+(const leftHand &LH,const rightHand &RH){return Addition<leftHand,rightHand>(LH,RH); }



using std::cout;
using std::endl;


inline Expression func (std::vector<Expression> x,int i){
    cout<<i<<endl;

    if (i==0){return static_cast<Expression>(x[0]);}    

    return static_cast<Expression>( x[i] + func(x,i-1) ) ;
};


int main(){
    Number y(-2.);
    Number x(1.33);

    Expression z(y+x);
    Expression w(x+y+x);
    
    // works
    z =x;
    cout<<z.evaluate()<<endl;
    cout<<(z+z+z+z).evaluate()<<endl;

    // Segfault due to recusion 
    // z =z+x;
    // cout<<z.evaluate()<<endl;

    // Unkown Segfault 
    // z = x+y ;
    // cout<<(z+z).evaluate()<<endl;
    // cout<<typeid(z+z).name()<<endl;

    // Unkown Segfault 
    // z = w+y+x+x;
    // cout<<z.evaluate()<<endl;

    
    
    // Unkown Segfault 
    // std::vector<Expression> X={x,y,x,y};

    // cout << typeid(func(X,X.size()-1)).name()  << endl;
    // cout << (func(X,X.size()-1)).evaluate()  << endl;
    return 0;
}

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。