如何解决将Captum与Pytorch Lighting一起使用?
因此,我尝试将Captum与PyTorch Lightning一起使用。将模块传递给Captum时,我遇到了问题,因为它似乎对张量进行了奇怪的重塑。 例如,在下面的最小示例中,闪电代码容易工作。 但是,当我将IntegratedGradient与“ n_step> = 1”一起使用时,会遇到问题。 我想说LighningModule的代码不是那么重要,我想知道最底部的代码行有多少。
有人知道该如何解决吗?
from captum.attr import IntegratedGradients
from torch import nn,optim,rand,sum as tsum,reshape,device
import torch.nn.functional as F
from pytorch_lightning import seed_everything,LightningModule,Trainer
from torch.utils.data import DataLoader,Dataset
SAMPLE_DIM = 3
class CustomDataset(Dataset):
def __init__(self,samples=42):
self.dataset = rand(samples,SAMPLE_DIM).cuda().float() * 2 - 1
def __getitem__(self,index):
return (self.dataset[index],(tsum(self.dataset[index]) > 0).cuda().float())
def __len__(self):
return self.dataset.size()[0]
class OurModel(LightningModule):
def __init__(self):
super(OurModel,self).__init__()
# Network layers
self.linear = nn.Linear(SAMPLE_DIM,2048)
self.linear2 = nn.Linear(2048,1)
self.output = nn.Sigmoid()
# Hyper-parameters,that we will auto-tune using lightning!
self.lr = 0.001
self.batch_size = 512
def forward(self,x):
x = self.linear(x)
x = self.linear2(x)
output = self.output(x)
return reshape(output,(-1,))
def configure_optimizers(self):
return optim.Adam(self.parameters(),lr=self.lr)
def train_DataLoader(self):
loader = DataLoader(CustomDataset(samples=1000),batch_size=self.batch_size,shuffle=True)
return loader
def training_step(self,batch,batch_nb):
x,y = batch
loss = F.binary_cross_entropy(self(x),y)
return {'loss': loss,'log': {'train_loss': loss}}
if __name__ == '__main__':
seed_everything(42)
device = device("cuda")
model = OurModel().to(device)
trainer = Trainer(max_epochs=2,min_epochs=1,auto_lr_find=False,progress_bar_refresh_rate=10)
trainer.fit(model)
# ok Now the Problem
test_input = CustomDataset(samples=1).__getitem__(0)[0].requires_grad_()
ig = IntegratedGradients(model)
attr,delta = ig.attribute(test_input,target=1,return_convergence_delta=True)
解决方法
解决方案是包装转发功能。确保进入mode.foward()的形状正确!
# Solution is this wrapper function
def modified_f(in_vec):
# Shape here is wrong
print("IN:",in_vec.size())
x = torch.reshape(in_vec,(int(in_vec.size()[0]/SAMPLE_DIM),SAMPLE_DIM))
print("x:",x.size())
res = model.forward(x)
print("res:",res.size())
res = torch.reshape(res,(res.size()[0],1))
print("res2:",res.size())
return res
ig = IntegratedGradients(modified_f)
attr,delta = ig.attribute(test_input,return_convergence_delta=True,n_steps=STEP_AMOUNT)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。