如何解决在 PyTorch 中沿轴上的所有索引应用函数
我正在尝试在 PyTorch 中实现 Wasserstein 损失函数,为此我参考了 Scipy 实现。由于在 forward() 方法中使用 PyTorch 函数意味着不必编写 back() 函数,因此我已在我的代码中完成了此操作(Scipy 版本仅包含等效的 Numpy 函数)。这是我所拥有的:
class WassersteinLoss(nn.Module):
def __init__(self):
super(WassersteinLoss,self).__init__()
def forward(self,u,v):
result = torch.empty((len(u)))
for i in range(len(u)):
u_values,v_values = u[i],v[i]
u_sorter,v_sorter = torch.argsort(u_values),torch.argsort(v_values)
all_values = torch.cat((u_values,v_values))
all_values,idx = torch.sort(all_values)
# Compute the differences between pairs of successive values of u and v.
deltas = torch.sub(all_values[1:],all_values[:-1])
# Get the respective positions of the values of u and v among the values of
# both distributions.
u_cdf_indices = torch.searchsorted(u_values[u_sorter],all_values[:-1],right=True)
v_cdf_indices = torch.searchsorted(v_values[v_sorter],right=True)
# Calculate the CDFs of u and v
u_cdf = torch.div(u_cdf_indices,len(u_values))
v_cdf = torch.div(v_cdf_indices,len(v_values))
# Compute the value of the integral based on the CDFs.
result[i] = torch.sum(torch.multiply(torch.abs(u_cdf-v_cdf),deltas))
return result.mean()
在上面的函数中,u 和 v 是形状 (NxM) 的向量,其中 N 是批次中的样本数。由于我的 for 循环基本上遍历了所有样本,因此对每个样本的计算都是独立的,因为批次中的样本相互依赖。我相信如果我可以取消这个 for 循环,我会看到显着的加速。到目前为止,我已经尝试沿 dim=1 轴执行所有计算,但这不起作用。
from scipy.stats import wasserstein_distance
import torch
import torch.nn as nn
print(wasserstein_distance([0,1,3],[5,6,8]))
#Output is 5
criterion = WassersteinLoss()
print(criterion(torch.tensor([[0,3]]),torch.tensor([[5,8]])))
#Output is tensor(5.)
有关如何修改 forward() 函数以消除 for 循环的任何输入将不胜感激。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。