微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 PyTorch 中沿轴上的所有索引应用函数

如何解决在 PyTorch 中沿轴上的所有索引应用函数

我正在尝试在 PyTorch 中实现 Wasserstein 损失函数,为此我参考了 Scipy 实现。由于在 forward() 方法中使用 PyTorch 函数意味着不必编写 back() 函数,因此我已在我的代码中完成了此操作(Scipy 版本仅包含等效的 Numpy 函数)。这是我所拥有的:

class WassersteinLoss(nn.Module): 
  def __init__(self):
    super(WassersteinLoss,self).__init__()
  def forward(self,u,v):
    result = torch.empty((len(u)))
    for i in range(len(u)):
      u_values,v_values = u[i],v[i]
      u_sorter,v_sorter = torch.argsort(u_values),torch.argsort(v_values)
      all_values = torch.cat((u_values,v_values))
      all_values,idx = torch.sort(all_values)

      # Compute the differences between pairs of successive values of u and v.
      deltas = torch.sub(all_values[1:],all_values[:-1])

      # Get the respective positions of the values of u and v among the values of
      # both distributions.
      u_cdf_indices = torch.searchsorted(u_values[u_sorter],all_values[:-1],right=True)
      v_cdf_indices = torch.searchsorted(v_values[v_sorter],right=True)

      # Calculate the CDFs of u and v
      u_cdf = torch.div(u_cdf_indices,len(u_values))
      v_cdf = torch.div(v_cdf_indices,len(v_values))

      # Compute the value of the integral based on the CDFs.
      result[i] = torch.sum(torch.multiply(torch.abs(u_cdf-v_cdf),deltas))
    return result.mean()

在上面的函数中,u 和 v 是形状 (NxM) 的向量,其中 N 是批次中的样本数。由于我的 for 循环基本上遍历了所有样本,因此对每个样本的计算都是独立的,因为批次中的样本相互依赖。我相信如果我可以取消这个 for 循环,我会看到显着的加速。到目前为止,我已经尝试沿 dim=1 轴执行所有计算,但这不起作用。

这是一个测试用例的代码

from scipy.stats import wasserstein_distance
import torch
import torch.nn as nn

print(wasserstein_distance([0,1,3],[5,6,8]))
#Output is 5

criterion = WassersteinLoss()
print(criterion(torch.tensor([[0,3]]),torch.tensor([[5,8]])))
#Output is tensor(5.)

有关如何修改 forward() 函数以消除 for 循环的任何输入将不胜感激。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。