如何解决PyTorch中复杂矩阵的行列式
有没有一种方法可以计算PyTroch中复杂矩阵的行列式?
torch.det未针对“ ComplexFloat”实现
解决方法
很遗憾,当前尚未实现。一种方法是实现自己的版本或仅使用np.linalg.det
。
这是一个简短的函数,用于计算使用LU分解编写的复杂矩阵的行列式:
def complex_det(A):
def complex_diag(A):
return torch.view_as_complex(torch.stack((A.real.diag(),A.imag.diag()),dim=1))
#Perform LU decomposition to matrix A:
A_LU,pivots = A.lu()
P,A_L,A_U = torch.lu_unpack(A_LU,pivots)
#Det. of multiplied matrices is multiplcation of det.:
det = torch.prod(complex_diag(A_L)) * torch.prod(complex_diag(A_U)) * torch.det(P.real) #Could probably calculate det(P) [which is +-1] efficiently using Sylvester's determinant identity
return det
#Test it:
A = torch.view_as_complex(torch.randn(3,3,2))
complex_det(A)
,
从 1.8 版开始,PyTorch 原生支持 numpy 样式的 torch.linalg
操作。特别是,torch.linalg.det
支持 cfloat
和 cdouble
复数数据类型:
torch.linalg.det(input)
计算方阵 input
或批处理 input
中每个方阵的行列式。
此函数支持 float、double、cfloat 和 cdouble 数据类型。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。