微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用__init__中的参数定义python类方法

如何解决使用__init__中的参数定义python类方法

我想实现的目标

我想定义一个具有两种模式A和B的类,以便该类的forward方法相应地更改。

class MyClass():
    def __init__(self,constant):
        self.constant=constant
            
    def forward(self,x1,x2,function):
        if function=='A':
            return x1+self.constant
        
        elif function=='B':
            return x1*x2+self.constant
        
        else:
            print('please provide the correct function')
model1 = MyClass(2)
model1.forward(2,None,'A')
output>>>4
model2 = MyClass(2)
model2.forward(2,2,'B')
output>>>6

它有效,但不是最佳方法,因为每次调用forward方法时,它都会检查要使用的函数。但是,正向函数已经设置,并且一旦定义了类就永远不会更改,因此,在我的情况下,检查在forward中使用哪个函数是多余的。 (对于那些注意到这一点的人,我正在使用PyTorch编写我的神经网络模型,两个模型共享90%的网络体系结构,唯一的10%差异是它们的前馈方式。)

我想要的版本

我想在定义类时设置forward方法,以便实现这一目标

model1 = MyClass(2,'A')
model1.forward(2)
output>>>4

model2 = MyClass(2,'B')
model2.forward(2,2)
output>>>6

所以我改写了我的班级:

class MyClass():
    def __init__(self,constant,function):
        self.constant=constant # There would be a lot of shared parameters for the two methods
        self.function=function # This controls the Feedforward method of this class
        
    if self.function=='A':
        def forward(self,x1):
            return x1+self.constant
        
    elif self.function=='B':
        def forward(self,x2):
            return x1*x2+self.constant
        
    else:
        print('please provide the correct function')

但是,它给了我以下错误

NameError:名称“ self”未定义

如何为该类编写基于forward的args定义不同的__init__方法的类?

解决方法

您一直在尝试用您的代码重新定义 class ,以便每个新对象都会在 all 对象之前和之后更改forward的定义。 幸运的是,您没有弄清楚该怎么做。

相反,使所选函数成为对象的属性。编写所需的两个函数,然后在创建每个实例时分配所需的变体。

class MyClass():
    def __init__(self,constant,function):
        self.constant=constant
        if function == 'A':
            self.forward = self.forwardA
        elif function=='B':
            self.forward = self.forwardB
        else:
            print('please provide the correct function')
            
    def forwardA(self,x1):
        return x1+self.constant
        
    def forwardB(self,x1,x2):
        return x1*x2+self.constant

# Main
model1 = MyClass(2,'A')
print(model1.forward(2))

model2 = MyClass(2,'B')
print(model2.forward(2,2))

输出:

4
6
,

您也可以尝试分解出基类。 mypy 可能会更好玩,并且可以更轻松地避免混淆正在使用的任何类。

class MyClassBase():                                      
    def __init__(self,constant):                         
         self.constant=constant                           
                                                          
    def forward(self,*args,**kwargs):              
        raise NotImplementedError('use a derived class')  
                                                          
class MyClassA(MyClassBase):                              
    def __init__(self,constant):                         
        super().__init__(constant)                        
                                                          
    def forward(self,x1):                                 
        return x1 + self.constant                         
                                                          
class MyClassB(MyClassBase):                              
    def __init__(self,x2):                            
        return x1*x2 + self.constant                      
                                                          
a = MyClassA(2)                                           
b = MyClassB(2)                                           
                                                          
print(a.forward(2))                                       
print(b.forward(2,2))                                     
,

我想说的是,如果您打算在类方法的返回值之间使用常量变量名称,则可以通过这种方式定义它来调整逻辑。 (我们可以通过其他方式来完成此操作,仅与Params一起使用)。希望这看起来不错。

class MyClass():
    def __init__(self,function):
        self.constant=constant
        self.function=function
    
    def forward(self,x1 = None,x2 = None):
        if self.function=='A':
            return x1+self.constant
        elif self.function=='B':
            return x1*x2+self.constant
        else:
            print('please provide the correct function')

model1 = MyClass(2,'A')
model1.forward(2)

model2 = MyClass(2,'B')
model2.forward(2,2)
,

定义

在这种情况下,我也将继承。给出您想要的:

model2 = MyClass(2,'B')

这样做会容易得多(其他人也更容易理解):

model2 = MyClassB(2)

鉴于此,与@Nathan Chappell provided in his answer类似,但更短(例如,无需重新定义__init__):

import torch


class Base(torch.nn.Module):
    def __init__(self,constant):
        super().__init__()
        self.constant = constant


class MyClassA(Base):
    def forward(self,x1):
        return x1 + self.constant


class MyClassB(Base):
    def forward(self,x2):
        return x1 * x2 + self.constant

呼叫

您应该使用torch.nn.Module的{​​{1}}方法而不是__call__,因为它可以与钩子一起正常工作(请参阅this answer),因此应该是:

forward

代替:

model1 = MyClassA(5)
model1(torch.randn(10,5))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。