如何解决使用__init__中的参数定义python类方法
我想实现的目标
我想定义一个具有两种模式A和B的类,以便该类的forward
方法相应地更改。
class MyClass():
def __init__(self,constant):
self.constant=constant
def forward(self,x1,x2,function):
if function=='A':
return x1+self.constant
elif function=='B':
return x1*x2+self.constant
else:
print('please provide the correct function')
model1 = MyClass(2)
model1.forward(2,None,'A')
output>>>4
model2 = MyClass(2)
model2.forward(2,2,'B')
output>>>6
它有效,但不是最佳方法,因为每次调用forward
方法时,它都会检查要使用的函数。但是,正向函数已经设置,并且一旦定义了类就永远不会更改,因此,在我的情况下,检查在forward
中使用哪个函数是多余的。 (对于那些注意到这一点的人,我正在使用PyTorch编写我的神经网络模型,两个模型共享90%的网络体系结构,唯一的10%差异是它们的前馈方式。)
我想要的版本
我想在定义类时设置forward
方法,以便实现这一目标
model1 = MyClass(2,'A')
model1.forward(2)
output>>>4
model2 = MyClass(2,'B')
model2.forward(2,2)
output>>>6
所以我改写了我的班级:
class MyClass():
def __init__(self,constant,function):
self.constant=constant # There would be a lot of shared parameters for the two methods
self.function=function # This controls the Feedforward method of this class
if self.function=='A':
def forward(self,x1):
return x1+self.constant
elif self.function=='B':
def forward(self,x2):
return x1*x2+self.constant
else:
print('please provide the correct function')
但是,它给了我以下错误。
NameError:名称“ self”未定义
如何为该类编写基于forward
的args定义不同的__init__
方法的类?
解决方法
您一直在尝试用您的代码重新定义 class ,以便每个新对象都会在 all 对象之前和之后更改forward的定义。 幸运的是,您没有弄清楚该怎么做。
相反,使所选函数成为对象的属性。编写所需的两个函数,然后在创建每个实例时分配所需的变体。
class MyClass():
def __init__(self,constant,function):
self.constant=constant
if function == 'A':
self.forward = self.forwardA
elif function=='B':
self.forward = self.forwardB
else:
print('please provide the correct function')
def forwardA(self,x1):
return x1+self.constant
def forwardB(self,x1,x2):
return x1*x2+self.constant
# Main
model1 = MyClass(2,'A')
print(model1.forward(2))
model2 = MyClass(2,'B')
print(model2.forward(2,2))
输出:
4
6
,
您也可以尝试分解出基类。 mypy 可能会更好玩,并且可以更轻松地避免混淆正在使用的任何类。
class MyClassBase():
def __init__(self,constant):
self.constant=constant
def forward(self,*args,**kwargs):
raise NotImplementedError('use a derived class')
class MyClassA(MyClassBase):
def __init__(self,constant):
super().__init__(constant)
def forward(self,x1):
return x1 + self.constant
class MyClassB(MyClassBase):
def __init__(self,x2):
return x1*x2 + self.constant
a = MyClassA(2)
b = MyClassB(2)
print(a.forward(2))
print(b.forward(2,2))
,
我想说的是,如果您打算在类方法的返回值之间使用常量变量名称,则可以通过这种方式定义它来调整逻辑。 (我们可以通过其他方式来完成此操作,仅与Params一起使用)。希望这看起来不错。
class MyClass():
def __init__(self,function):
self.constant=constant
self.function=function
def forward(self,x1 = None,x2 = None):
if self.function=='A':
return x1+self.constant
elif self.function=='B':
return x1*x2+self.constant
else:
print('please provide the correct function')
model1 = MyClass(2,'A')
model1.forward(2)
model2 = MyClass(2,'B')
model2.forward(2,2)
,
定义
在这种情况下,我也将继承。给出您想要的:
model2 = MyClass(2,'B')
这样做会容易得多(其他人也更容易理解):
model2 = MyClassB(2)
鉴于此,与@Nathan Chappell provided in his answer类似,但更短(例如,无需重新定义__init__
):
import torch
class Base(torch.nn.Module):
def __init__(self,constant):
super().__init__()
self.constant = constant
class MyClassA(Base):
def forward(self,x1):
return x1 + self.constant
class MyClassB(Base):
def forward(self,x2):
return x1 * x2 + self.constant
呼叫
您应该使用torch.nn.Module
的{{1}}方法而不是__call__
,因为它可以与钩子一起正常工作(请参阅this answer),因此应该是:
forward
代替:
model1 = MyClassA(5)
model1(torch.randn(10,5))
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。