如何解决在nn.Sequential的类模块中访问函数
运行nn.Sequential时,我包括了一个类模块列表(这将是神经网络的各层)。运行nn.Sequential时,它将调用模块的前向功能。但是,每个类模块都有一个函数,我想在nn.Sequential运行时访问它。运行nn.Sequential时如何访问和运行此功能?
解决方法
您可以为此使用钩子。让我们考虑在 VGG16 上演示的以下示例:
这是网络体系结构:
假设我们要监视功能 Sequential
(您在上面看到的那个Conv2d层)中第(2)层的输入和输出。
为此,我们注册了一个名为my_hook
的前向挂钩,它将在任何前向传递中被调用:
import torch
from torchvision.models import vgg16
def my_hook(self,input,output):
print('my_hook\'s output')
print('input: ',input)
print('output: ',output)
# Sample net:
net = vgg16()
#Register forward hook:
net.features[2].register_forward_hook(my_hook)
# Test:
img = torch.randn(1,3,512,512)
out = net(img) # Will trigger my_hook and the data you are looking for will be printed
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。