如何解决PyTorch中复数的矩阵乘法
我试图在PyTorch中将两个复杂的矩阵相乘,看来the torch.matmul functions is not added yet to PyTorch library for complex numbers.
您有任何建议吗?还是有另一种方法可以在PyTorch中乘以复杂矩阵?
解决方法
torch.matmul
等复杂张量目前不支持ComplexFloatTensor
,但是您可以执行以下代码一样紧凑的操作:
def matmul_complex(t1,t2):
return torch.view_as_complex(torch.stack((t1.real @ t2.real - t1.imag @ t2.imag,t1.real @ t2.imag + t1.imag @ t2.real),dim=2))
在可能的情况下,避免使用for循环,因为这会导致实现速度大大降低。 通过使用我随附的代码中演示的内置方法来实现矢量化。 例如,对于2个尺寸为1000 X 1000的随机复杂矩阵,您的代码在CPU上花费大约6.1s,而矢量化版本仅花费101ms(快60倍)。
,我使用torch.mv为pytorch.matmul实现了此函数,以处理复数,并且在时间上运行良好:
def matmul_complex(t1,t2):
m = list(t1.size())[0]
n = list(t2.size())[1]
t = torch.empty((1,n),dtype=torch.cfloat)
t_total = torch.empty((m,dtype=torch.cfloat)
for i in range(0,n):
if i == 0:
t_total = torch.mv(t1,t2[:,i])
else:
t_total = torch.cat((t_total,torch.mv(t1,i])),0)
t_final = torch.reshape(t_total,(m,n))
return t_final
我是PyTorch的新手,所以如果我错了,请纠正我。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。