微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

用于多个输入变量的 JAX 自定义 VJP 函数不适用于 NumPyro/HMC-NUTS

如何解决用于多个输入变量的 JAX 自定义 VJP 函数不适用于 NumPyro/HMC-NUTS

我正在尝试使用自定义 VJP(向量雅可比积)函数作为 numpyro 中 HMC-NUTS 的模型。我能够制作一个适用于 HMC-NUTS 的单一变量函数,如下所示:

{' Four Spaces': 4,' One Space': 1,' Three Spaces': 3,' Zero Spaces': 0}

这里,我手动定义了 h(x)=sin(x)。然后,我做了一个测试数据

import jax.numpy as jnp
from jax import custom_vjp

@custom_vjp
def h(x):
    return jnp.sin(x)

def h_fwd(x):
    return h(x),jnp.cos(x)

def h_bwd(res,u):
    cos_x  = res 
    return (cos_x * u,)

h.defvjp(h_fwd,h_bwd)

test data

在这种情况下,我能够在 NumPyro 中执行 HMC-NUTS

import numpy as np
np.random.seed(32)
sigin=0.3
N=20
x=np.sort(np.random.rand(N))*4*np.pi
data=hv(x)+np.random.normal(0,sigin,size=N)

它有效。

import numpyro
import numpyro.distributions as dist

def model(x,y):
    sigma = numpyro.sample('sigma',dist.Exponential(1.))
    x0 = numpyro.sample('x0',dist.Uniform(-1.,1.))
    #mu=jnp.sin(x-x0)
    #mu=hv(x-x0)
    mu=h(x-x0)
    numpyro.sample('y',dist.normal(mu,sigma),obs=y)

from jax import random
from numpyro.infer import MCMC,NUTS

rng_key = random.PRNGKey(0)
rng_key,rng_key_ = random.split(rng_key)
num_warmup,num_samples = 1000,2000
kernel = NUTS(model)
mcmc = MCMC(kernel,num_warmup,num_samples)
mcmc.run(rng_key_,x=x,y=data)
mcmc.print_summary()

但是,如果我将多变量函数定义为,

sample: 100%|██████████| 3000/3000 [00:15<00:00,193.84it/s,3 steps of size 7.67e-01. acc. prob=0.92]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
     sigma      0.35      0.06      0.34      0.26      0.45   1178.07      1.00
        x0      0.07      0.11      0.07     -0.11      0.26   1243.73      1.00

Number of divergences: 0

后执行 HMC-NUTS

@custom_vjp
def h(x,A):
    return A*jnp.sin(x)

def h_fwd(x,A):
    res = (A*jnp.cos(x),jnp.sin(x))
    return h(x,A),res

def h_bwd(res,u):
    A_cos_x,sin_x = res
    return (A_cos_x * u,sin_x * u)

h.defvjp(h_fwd,h_bwd)

然后我得到一个错误

def model(x,1.))
    A = numpyro.sample('A',dist.Exponential(1.))
    mu=h(x-x0,A)
    numpyro.sample('y',obs=y)

rng_key = random.PRNGKey(0)
rng_key,y=data)
mcmc.print_summary()

我怀疑函数中的输出形状是错误的。但是,经过各种改变形状的尝试后,我无法弄清楚出了什么问题。

解决方法

def model(x,y):
sigma = numpyro.sample('sigma',dist.Exponential(1.))
x0 = numpyro.sample('x0',dist.Uniform(-1.,1.))
A = numpyro.sample('A',dist.Exponential(1.))
hv=vmap(h,(0,None),0)
mu=hv(x-x0,A)
numpyro.sample('y',dist.Normal(mu,sigma),obs=y)

vmap 解决了这个问题。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。