微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何选择Numpy张量轴

如何解决如何选择Numpy张量轴

我有两个形状为(436,1024,2)的numpy数组。最后一个维度(2)代表2D向量。我想逐个元素比较两个numpy数组的2D向量,以求出平均角度误差。

为此,我想使用点积,该点积在遍历数组的两个第一维时效果很好(python中的for循环可能很慢)。因此,我想使用一个numpy函数

我发现np.tensordot允许逐元素执行点积。但是,我无法成功使用其axes参数:

import numpy as np

def average_angular_error_vec(estimated_oc : np.array,target_oc : np.array):
    estimated_oc = np.float64(estimated_oc)
    target_oc = np.float64(target_oc)

    norm1 = np.linalg.norm(estimated_oc,axis=2)
    norm2 = np.linalg.norm(target_oc,axis=2)
    norm1 = norm1[...,np.newaxis]
    norm2 = norm2[...,np.newaxis]

    unit_vector1 = np.divide(estimated_oc,norm1)
    unit_vector2 = np.divide(target_oc,norm2)

    dot_product = np.tensordot(unit_vector1,unit_vector2,axes=2)
    angle = np.arccos(dot_product)

    return np.mean(angle)

我遇到以下错误

ValueError: shape-mismatch for sum

下面是我的函数,它可以正确计算平均角度误差:

def average_angular_error(estimated_oc : np.array,target_oc : np.array):
    h,w,c = target_oc.shape
    r = np.zeros((h,w),dtype="float64")

    estimated_oc = np.float64(estimated_oc)
    target_oc = np.float64(target_oc)

    for i in range(h):
        for j in range(w):

            unit_vector_1 = estimated_oc[i][j] / np.linalg.norm(estimated_oc[i][j])
            unit_vector_2 = target_oc[i][j] / np.linalg.norm(target_oc[i][j])
            dot_product = np.dot(unit_vector_1,unit_vector_2)

            angle = np.arccos(dot_product)

            r[i][j] = angle
       
    return np.mean(r)

解决方法

问题可能比您要解决的要简单得多。如果将np.tensordot沿最后一个轴应用于一对形状为(w,h,2)的数组,您将得到形状为(w,w,h)的结果。这不是您想要的。这里有三种简单的方法。除了显示选项之外,我还展示了一些提示和技巧,可以在不更改任何基本功能的情况下简化代码:

  1. 手动进行总和减少(使用+*

    def average_angular_error(estimated_oc : np.ndarray,target_oc : np.ndarray):
        # If you want to do in-place normalization,do x /= ... instead of x = x / ...
        estimated_oc = estimated_oc / np.linalg.norm(estimated_oc,axis=-1,keepdims=True)
        target_oc = target_oc / np.linalg.norm(target_oc,keepdims=True)
        # Use plain element-wise multiplication
        dots = np.sum(estimated_oc * target_oc,axis=-1)
        return np.arccos(dots).mean()
    
  2. 使用np.matmul(又称@)和正确广播的尺寸:

    def average_angular_error(estimated_oc : np.ndarray,target_oc : np.ndarray):
        estimated_oc = estimated_oc / np.linalg.norm(estimated_oc,keepdims=True)
        # Matrix multiplication needs two dimensions to operate on
        dots = estimated_oc[...,None,:] @ target_oc[...,:,None]
        return np.arccos(dots).mean()
    

    np.matmulnp.dot都要求第一个数组的最后一个维度与第二个数组的最后一个匹配,就像普通矩阵乘法一样。 Nonenp.newaxis的别名,它在您选择的位置引入了一个大小为1的新轴。在这种情况下,我制作了第一个数组(w,1,2)和第二个数组(w,2,1)。这样可以确保最后两个维度在每个对应元素处分别作为转置向量和正则向量相乘。

  3. 使用np.einsum

    def average_angular_error(estimated_oc : np.ndarray,keepdims=True)
        # Matrix multiplication needs two dimensions to operate on
        dots = np.einsum('ijk,ijk->ik',estimated_oc,target_oc)
        return np.arccos(dots).mean()
    

您不能为此使用np.dotnp.tensordotdottensordot保持两个数组的原始尺寸,如前所述。 matmul一起广播,这就是您想要的。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?