如何解决如何在不带bmm的pytorch中执行批量乘法?
我正在上UMich的在线计算机视觉课程,对PyTorch还是陌生的。分配问题之一是关于批矩阵乘法,我们必须找到具有和不具有bmm函数的批矩阵乘积。这是代码。
def batched_matrix_multiply(x,y,use_loop=True):
"""
Perform batched matrix multiplication between the tensor x of shape (B,N,M)
and the tensor y of shape (B,M,P).
If use_loop=True,then you should use an explicit loop over the batch
dimension B. If loop=False,then you should instead compute the batched
matrix multiply without an explicit loop using a single PyTorch operator.
Inputs:
- x: Tensor of shape (B,M)
- y: Tensor of shape (B,P)
- use_loop: Whether to use an explicit Python loop.
Hint: torch.stack,bmm
Returns:
- z: Tensor of shape (B,P) where z[i] of shape (N,P) is the result of
matrix multiplication between x[i] of shape (N,M) and y[i] of shape
(M,P). It should have the same dtype as x.
"""
z = None
#############################################################################
# TODO: Implement this function #
#############################################################################
# Replace "pass" statement with your code
z = torch.zeros(x.shape[0],x.shape[1],y.shape[2])
if use_loop == True:
for i in range(x.shape[0]):
z[i] = torch.mm(x[i],y[i])
else:
z = torch.bmm(x,y)
#############################################################################
# END OF YOUR CODE #
#############################################################################
return z
我设法在没有bmm的情况下做到了这一点,但是却没有使用torch.stack提示。我用输出矩阵的尺寸初始化了一个零矩阵'z',并使用for循环对每一批执行常规矩阵乘法。
我想知道使用torch.stack更有效的答案是什么。
解决方法
很好的问题。我只是尝试自己解决这个问题两个小时。这是我的解决方案,它确实根据需要加快了计算速度。
if use_loop == False:
z = torch.bmm(x,y)
else:
z = torch.zeros(x.shape[0],x.shape[1],y.shape[2])
for i in range(x.shape[0],2):
z[i] = torch.stack([x[i] @ y[i],x[i+1] @ y[i+1]])
希望这有帮助!
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。