如何解决同一批中的不同内核进行因果卷积
PyTorch(1.6)中是否有一种计算动态因果卷积的方法:每个样本的内核大小不同?可以使用F.conv1d的参数group
。但是,内核必须具有相同的形状,这会使填充变得复杂。
例如,我以5个输入和5个不同的内核为例,由于kernel_size最大,我得到了错误的输出:
import torch
import torch.nn as nn
import torch.nn.functional as F
inputs1 = [0 for _ in range(20)]
inputs1[8] = 1
inputs2 = [0 for _ in range(20)]
inputs2[13] = 1
inputs3 = [0 for _ in range(20)]
inputs3[1] = 1
inputs4 = [0 for _ in range(20)]
inputs4[0] = 1
inputs5 = [0 for _ in range(20)]
inputs5[18] = 1
inputs = torch.FloatTensor([inputs1,inputs2,inputs3,inputs4,inputs5])
kernel_size1 = 6
kernel_size2 = 4
kernel_size3 = 5
kernel_size4 = 8
kernel_size5 = 4
largest_kernel = max(kernel_size1,kernel_size2,kernel_size3,kernel_size4,kernel_size5)
kernel1 = torch.cat([torch.ones(kernel_size1),torch.zeros(kernel_size1 - 1 + max(largest_kernel + (largest_kernel - 1) - (kernel_size1 + (kernel_size1 - 1)),0))])
kernel2 = torch.cat([torch.ones(kernel_size2),torch.zeros(kernel_size2 - 1 + max(largest_kernel + (largest_kernel - 1) - (kernel_size2 + (kernel_size2 - 1)),0))])
kernel3 = torch.cat([torch.ones(kernel_size3),torch.zeros(kernel_size3 - 1 + max(largest_kernel + (largest_kernel - 1) - (kernel_size3 + (kernel_size3 - 1)),0))])
kernel4 = torch.cat([torch.ones(kernel_size4),torch.zeros(kernel_size4 - 1 + max(largest_kernel + (largest_kernel - 1) - (kernel_size4 + (kernel_size4 - 1)),0))])
kernel5 = torch.cat([torch.ones(kernel_size5),torch.zeros(kernel_size5 - 1 + max(largest_kernel + (largest_kernel - 1) - (kernel_size5 + (kernel_size5 - 1)),0))])
kernels = torch.cat([kernel1.unsqueeze(0),kernel2.unsqueeze(0),kernel3.unsqueeze(0),kernel4.unsqueeze(0),kernel5.unsqueeze(0)],axis=0)
inputs = inputs.unsqueeze(0) # Inverse channel and batch
kernels = kernels.unsqueeze(1)
print('Predicted')
conv_res = F.conv1d(inputs,kernels,padding=largest_kernel - 1,groups=inputs.size(1)).long()
for x in conv_res.squeeze().tolist():
print('({})'.format(len(x)),x)
print('Expected')
print('(20) [0,1,0]')
print('(20) [0,0]')
print('(20) [1,1]')
生产
Predicted
(20) [0,0]
(20) [0,1]
(20) [0,0]
(20) [1,0]
Expected
(20) [0,1]
我们可以看到由于内核最大,所以填充错误。知道如何正确进行填充吗?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。