如何解决Pyro:伯努利随机变量的样本有多个元素
我是 Pyro 的新手,正在尝试使我的第一个随机过程模型发挥作用。我修改了 here 中的代码以适合我的示例问题,该示例问题只是两个高斯分布,样本来自一个或另一个的离散概率。
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC,MCMC
# Actual data sample
observations = torch.tensor(
[0.00528813,-0.00589001,-1.20608593,0.00190794,0.89052784,0.66690464,0.57295968,0.02605967]
)
# Define the process
def model(observations):
a_prior = dist.Beta(2,2)
a = pyro.sample("a",a_prior)
c = pyro.sample('c',dist.Bernoulli(a))
if c.item() == 1.0:
my_dist = dist.normal(0.785,1.0)
else:
my_dist = dist.normal(0.0,0.01)
for i,observation in enumerate(observations):
measurement = pyro.sample(f'obs_{i}',my_dist,obs=observation)
# Clear parameters
pyro.clear_param_store()
# Define the MCMC kernel function
my_kernel = HMC(model)
# Define the MCMC algorithm
my_mcmc = MCMC(my_kernel,num_samples=5000,warmup_steps=50)
# Run the algorithm,passing the observations
my_mcmc.run(observations)
引发的异常是:
<ipython-input-2-a668622a0fb9> in model(observations)
11 a = pyro.sample("a",a_prior)
12 c = pyro.sample('c',dist.Bernoulli(a))
---> 13 if c.item() == 1.0:
14 my_dist = dist.normal(0.785,1.0)
15 else:
ValueError: only one element tensors can be converted to Python scalars
Trace Shapes:
Param Sites:
Sample Sites:
a dist |
value |
c dist |
value 2 |
我使用调试器查看了 c
,出于某种原因,第二次调用 model() 时它有两个元素:
tensor([0.,1.])
这是什么原因造成的?我希望它是一个简单的标量,其值为 0 或 1。
作为进一步的测试,条件语句在以正常方式采样时可以正常工作:
# Conditional switch test
a_prior = dist.Beta(2,2)
a = pyro.sample("a",a_prior)
for i in range(5):
c = pyro.sample('c',dist.Bernoulli(a))
if c.item() == 1.0:
print(1,end=' ')
else:
print(0,end=' ')
# 0 0 1 0 0
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。