微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Python - Scipy Multivariate normalized to 1维

如何解决Python - Scipy Multivariate normalized to 1维

运行 y = multivariate_normal(np.zeros(d),np.eye(d)).rvs() 时,我们获得维度为 (d,) 的样本。然而,当 d=1 获得一个标量时,这是有道理的,因为它是一维的。不幸的是,我有一些代码必须适用于任意数量的维度,包括 d=1,并且基本上采用 d 维向量 xy 的点积.这在 d=1 处中断。我该如何解决

import numpy as np
from scipy.stats import multivariate_normal as MVN

def mwe_function(d,x):
    """Minimal Working Example"""
    y = MVN(np.zeros(d),np.eye(d)).rvs()
    return x @ y

mwe_function(2,np.ones(2))  # This works
mwe_function(1,np.ones(1))  # This doesn't

重要提示:我想避免使用 if 语句。在这种情况下可以简单地使用 scipy.stats.norm,但我想避免使用 if 语句,因为它们会减慢代码速度。

解决方法

您可以使用 np.reshape 来固定样品的形状。通过使用 -1 指定第一维的长度,您将始终得到一个一维数组,没有标量。

import numpy as np
from scipy.stats import multivariate_normal as MVN

def mwe_function(d,x):
    """Minimal Working Example"""
    y = MVN(np.zeros(d),np.eye(d)).rvs().reshape([-1])
    return x @ y

v0 = mwe_function(2,np.ones(2))  # This works
print(v0)  # -0.5718013906409207
v1 = mwe_function(1,np.ones(1))  # This works as well :-)
print(v1)  # -0.20196038784485093

.reshape([-1]) 在哪里工作。

就个人而言,我更喜欢重塑而不是使用 np.atleast_1d,因为效果是直接可见的 - 但最终它是一个品味问题。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。