如何解决PyTorch 层的输入和输出
我如何知道 PyTorch 中某个层的输入节点或层的名称?
假设我有一个 torch.cat,我怎么知道从哪里获取输入的张量或层的名称?
对于来自 https://rosenfelder.ai/multi-input-neural-network-pytorch/
class LitClassifier(pl.LightningModule):
def __init__(
self,lr: float = 1e-3,num_workers: int = 4,batch_size: int = 32,):
super().__init__()
self.lr = lr
self.num_workers = num_workers
self.batch_size = batch_size
self.conv1 = conv_block(3,16)
self.conv2 = conv_block(16,32)
self.conv3 = conv_block(32,64)
self.ln1 = nn.Linear(64 * 26 * 26,16)
self.relu = nn.ReLU()
self.batchnorm = nn.BatchNorm1d(16)
self.dropout = nn.Dropout2d(0.5)
self.ln2 = nn.Linear(16,5)
self.ln4 = nn.Linear(5,10)
self.ln5 = nn.Linear(10,10)
self.ln6 = nn.Linear(10,5)
self.ln7 = nn.Linear(10,1)
def forward(self,img,tab):
img = self.conv1(img)
img = self.conv2(img)
img = self.conv3(img)
img = img.reshape(img.shape[0],-1)
img = self.ln1(img)
img = self.relu(img)
img = self.batchnorm(img)
img = self.dropout(img)
img = self.ln2(img)
img = self.relu(img)
tab = self.ln4(tab)
tab = self.relu(tab)
tab = self.ln5(tab)
tab = self.relu(tab)
tab = self.ln6(tab)
tab = self.relu(tab)
x = torch.cat((img,tab),dim=1)
x = self.relu(x)
return self.ln7(x)
所以如果我想知道 torch.cat 从哪一层接收输入。
对于 keras,我们有 model.get_layer(id=idx).input.name
,PyTorch 是否也有类似的东西?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。