如何解决PyTorch RNN使用`batch_first = False`更有效率吗?
在机器翻译中,我们总是需要在注释和预测中切出第一步(SOS令牌)。
使用batch_first=False
时,切出第一时间步仍使张量保持连续。
import torch
batch_size = 128
seq_len = 12
embedding = 50
# Making a dummy output that is `batch_first=False`
batch_not_first = torch.randn((seq_len,batch_size,embedding))
batch_not_first = batch_first[1:].view(-1,embedding) # slicing out the first time step
但是,如果我们使用batch_first=True
,则在切片后,张量将不再是连续的。我们必须先使其连续,然后才能执行诸如view
之类的不同操作。
batch_first = torch.randn((batch_size,seq_len,embedding))
batch_first[:,1:].view(-1,embedding) # slicing out the first time step
output>>>
"""
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-8-a9bd590a1679> in <module>
----> 1 batch_first[:,embedding) # slicing out the first time step
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
"""
这是否意味着batch_first=False
至少在机器翻译的上下文中更好?因为这样可以避免我们执行contiguous()
步骤。有没有任何情况batch_first=True
更好?
解决方法
性能
batch_first=True
和batch_first=False
之间似乎没有太大的区别。请参见下面的脚本:
import time
import torch
def time_measure(batch_first: bool):
torch.cuda.synchronize()
layer = torch.nn.RNN(10,20,batch_first=batch_first).cuda()
if batch_first:
inputs = torch.randn(100000,7,10).cuda()
else:
inputs = torch.randn(7,100000,10).cuda()
start = time.perf_counter()
for chunk in torch.chunk(inputs,100000 // 64,dim=0 if batch_first else 1):
_,last = layer(chunk)
return time.perf_counter() - start
print(f"Time taken for batch_first=False: {time_measure(False)}")
print(f"Time taken for batch_first=True: {time_measure(True)}")
在我的设备(GTX 1050 Ti)上,PyTorch 1.6.0
和CUDA 11.0的结果如下:
Time taken for batch_first=False: 0.3275816479999776
Time taken for batch_first=True: 0.3159054920001836
(并且这两种方式都会有所不同,因此没有定论)。
代码可读性
如果您想使用其他需要batch_first=True
作为第batch
维度的PyTorch图层, 0
会更简单(几乎所有torch.nn
图层(例如{{ 3}})。
在这种情况下,如果指定了permute
,则无论如何都必须batch_first=False
返回张量。
机器翻译
最好这样做,因为tensor
一直都是连续的,并且不需要复制数据。使用[1:]
而不是[:,1:]
切片看起来也更干净。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。