如何解决由于使用pyro和pytorch的样品中存在多类分布,因此在svi步骤中出现错误
我正在研究一种因果变化型自动编码器,该编码器使用类分割掩码,类标签和因果关系(0或1)作为输入。
由于svi步骤,使用大于1的批处理大小时出现错误。我之所以使用bernoulling函数,是因为我希望它学习图像中多个类别的概率分布。我认为分类分布也很适合这里,但是我也遇到同样的错误。
当我尝试缩小造成问题的代码行时,我认为它在模型函数中:
one_vec2 = torch.ones([batch_size,self.lbl_shape[0]],**options)
class_labels = pyro.sample('class_labels',dist.Bernoulli(one_vec2*0.5),obs = lbls)
错误:
ValueError Traceback (most recent call last)
<ipython-input-19-8cbc046dd2c1> in <module>()
6 vae = Vae_Model1(lbl_sz,ch,img_sz).to(device)
7 svi = SVI(vae.model,vae.guide,optimizer,loss = Trace_ELBO())
----> 8 train(svi,train_loader,USE_CUDA)
6 frames
/usr/local/lib/python3.6/dist-packages/pyro/util.py in check_site_shape(site,max_plate_nesting)
320 '- enclose the batched tensor in a with plate(...): context',321 '- .to_event(...) the distribution being sampled',--> 322 '- .permute() data dimensions']))
323
324 # Check parallel dimensions on the left of max_plate_nesting.
ValueError: at site "class_labels",invalid log_prob shape
Expected [-1],actual [32,21]
Try one of the following fixes:
- enclose the batched tensor in a with plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions
当前批处理大小为32,lbl_shape [0]为21(VOC数据集(背景标签和其他标签))
有人可以帮我吗?非常感谢。 谢谢
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。