微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Pyro:伯努利随机变量的样本有多个元素

如何解决Pyro:伯努利随机变量的样本有多个元素

我是 Pyro 的新手,正在尝试使我的第一个随机过程模型发挥作用。我修改here 中的代码以适合我的示例问题,该示例问题只是两个高斯分布,样本来自一个或另一个的离散概率。

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC,MCMC

# Actual data sample
observations = torch.tensor(
    [0.00528813,-0.00589001,-1.20608593,0.00190794,0.89052784,0.66690464,0.57295968,0.02605967]
)

# Define the process
def model(observations):
    
    a_prior = dist.Beta(2,2)
    a = pyro.sample("a",a_prior)
    c = pyro.sample('c',dist.Bernoulli(a))
    if c.item() == 1.0:
        my_dist = dist.normal(0.785,1.0)
    else:
        my_dist = dist.normal(0.0,0.01)
    
    for i,observation in enumerate(observations):
        measurement = pyro.sample(f'obs_{i}',my_dist,obs=observation)

# Clear parameters
pyro.clear_param_store()

# Define the MCMC kernel function
my_kernel = HMC(model)

# Define the MCMC algorithm
my_mcmc = MCMC(my_kernel,num_samples=5000,warmup_steps=50)

# Run the algorithm,passing the observations 
my_mcmc.run(observations)

引发的异常是:

<ipython-input-2-a668622a0fb9> in model(observations)
     11     a = pyro.sample("a",a_prior)
     12     c = pyro.sample('c',dist.Bernoulli(a))
---> 13     if c.item() == 1.0:
     14         my_dist = dist.normal(0.785,1.0)
     15     else:

ValueError: only one element tensors can be converted to Python scalars
Trace Shapes:    
 Param Sites:    
Sample Sites:    
       a dist   |
        value   |
       c dist   |
        value 2 |

我使用调试器查看了 c,出于某种原因,第二次调用 model() 时它有两个元素:

tensor([0.,1.])

这是什么原因造成的?我希望它是一个简单的标量,其值为 0 或 1。

作为进一步的测试,条件语句在以正常方式采样时可以正常工作:

# Conditional switch test
a_prior = dist.Beta(2,2)
a = pyro.sample("a",a_prior)
for i in range(5):
    c = pyro.sample('c',dist.Bernoulli(a))
    if c.item() == 1.0:
        print(1,end=' ')
    else:
        print(0,end=' ')

# 0 0 1 0 0 

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。