微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow概率-MCMC-过渡核中的双射手问题?

如何解决Tensorflow概率-MCMC-过渡核中的双射手问题?

我正在建立张量流概率模型的混合体。一个给定模型的联合分布为:

one_network_prior = tfd.JointdistributionNamed(
    model=dict(
        mu_g=tfb.Sigmoid(
            low=-1.0,high=1.0,validate_args=True,name="mu_g"
        )(
            tfd.normal(
                loc=tf.zeros((D,)),scale=0.5,validate_args=True
            )
        ),epsilon=tfd.Gamma(
            concentration=0.4,rate=1.0,name="epsilon"
        ),mu_s=lambda mu_g,epsilon: tfb.Sigmoid(
            low=-1.0,name="mu_s"
        )(
            tfd.normal(
                loc=tf.stack(
                    [
                        mu_g
                    ] * S
                ),scale=epsilon,sigma=tfd.Gamma(
            concentration=0.3,name="sigma"
        ),mu_s_t=lambda mu_s,sigma: tfb.Sigmoid(
            low=-1.0,name="mu_s_t"
        )(
            tfd.normal(
                loc=tf.stack(
                    [
                        mu_s
                    ] * T
                ),scale=sigma,validate_args=True
            )
        )
    )
)

然后,我需要“混合”模型,但是这种混合是相当自定义的,我可以通过自定义log_prob_fn对数概率函数手动完成:

def log_prob_fn(
    mu_g,epsilon,mu_s,sigma,mu_s_t,kappa,spatial,observed
):
    log_probs_per_network = []
    probs_per_network = []
    for l in range(L):
        log_probs_per_network.append(
            tf.reduce_sum(
                one_network_prior.log_prob(
                    {
                        "mu_g": mu_g[l],"epsilon": epsilon[l],"mu_s": mu_s[l],"sigma": sigma[l],"mu_s_t": mu_s_t[l]
                    }
                )
            )
        ) 

        dist = tfb.Sigmoid(
            low=-1.0,validate_args=True
        )(
            tfd.normal(
                loc=tf.stack(
                    [
                        mu_s_t[l]
                    ] * N
                ),scale=kappa
            )
        )

        probs_per_network.append(
            tf.reduce_prod(            
                dist.prob(
                    observed
                ),axis=-1
            )
        )
    
    kappa_log_prob = kappa_prior.log_prob(
        kappa
    )

    mixed_probs = (
        spatial
        *
        tf.stack(
            probs_per_network,axis=-1
        )
    )
    margin_prob = tf.reduce_sum(
        mixed_probs,axis=-1
    )

    mix_log_prob = tf.reduce_sum(
        tf.math.log(
            margin_prob
        )
    )
    
    return (
        tf.reduce_sum(
            log_probs_per_network
        )
        + kappa_log_prob
        + mix_log_prob
    )

(我知道此函数效率不高,但是我无法轻松地从以前的模型(批处理形状)中进行采样,因此我暂时不得不对模型进行迭代)

请注意,正在为每个网络动态创建分发dist

然后,目标是使用此模型并将其适合数据。我使用one_network_prior生成了初始状态,并手动混合了数据以获得(N,T,S,D)观察到的数据,该数据将被馈送到MCMC,如下所示:

hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
  target_log_prob_fn=lambda *params: log_prob_fn(
      *params,observed=observed
  ),step_size=0.065,num_leapfrog_steps=5
)

unconstraining_bijectors = [
    tfb.Sigmoid(
        low=-1.0,high=1.0
    ),tfb.softplus(),tfb.Sigmoid(
        low=-1.0,tfb.softmaxCentered()
]

transformed_kernel = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=hmc_kernel,bijector=unconstraining_bijectors
)

adapted_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
    inner_kernel=transformed_kernel,num_adaptation_steps=400,target_accept_prob=0.65
)

@tf.function
def run_chain(initial_state,num_results=1000,num_burnin_steps=500):
  return tfp.mcmc.sample_chain(
    num_results=num_results,num_burnin_steps=num_burnin_steps,current_state=initial_state,kernel=adapted_kernel
  )

samples,kernel_results = run_chain(
    initial_state=init_state,num_results=20000,num_burnin_steps=5000
)

但是当我运行run_chain函数时,经过几次迭代后,我得到一个错误

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-45-2ff348713067> in <module>
----> 1 samples,kernel_results = run_chain(
      2     initial_state=init_state,3     num_results=20000,4     num_burnin_steps=5000
      5 )

~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self,*args,**kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args,**kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self,**kwds)
    844               *args,**kwds)
    845       # If we did not create any variables the trace we have is good enough.
--> 846       return self._concrete_stateful_fn._filtered_call(canon_args,canon_kwds)  # pylint: disable=protected-access
    847 
    848     def fn_with_cond(*inner_args,**inner_kwds):

~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _filtered_call(self,args,kwargs,cancellation_manager)
   1841       `args` and `kwargs`.
   1842     """
-> 1843     return self._call_flat(
   1844         [t for t in nest.flatten((args,kwargs),expand_composites=True)
   1845          if isinstance(t,(ops.Tensor,~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _call_flat(self,captured_inputs,cancellation_manager)
   1921         and executing_eagerly):
   1922       # No tape is watching; skip to running the function.
-> 1923       return self._build_call_outputs(self._inference_function.call(
   1924           ctx,cancellation_manager=cancellation_manager))
   1925     forward_backward = self._select_forward_and_backward_functions(

~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in call(self,ctx,cancellation_manager)
    543       with _InterpolateFunctionError(self):
    544         if cancellation_manager is None:
--> 545           outputs = execute.execute(
    546               str(self.signature.name),547               num_outputs=self._num_outputs,~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name,num_outputs,inputs,attrs,name)
     57   try:
     58     ctx.ensure_initialized()
---> 59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle,device_name,op_name,60                                         inputs,num_outputs)
     61   except core._NotOkStatusException as e:

InvalidArgumentError:  assertion Failed: [Argument `scale` must be positive.] [Condition x > 0 did not hold element-wise:] [x (mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/strided_slice_1:0) = ] [-nan]
     [[{{node mcmc_sample_chain/trace_scan/while/body/_415/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/body/_2366/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/body/_3200/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/JointdistributionNamed/log_prob/normal/assert_positive/assert_less/Assert/AssertGuard/else/_3580/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/JointdistributionNamed/log_prob/normal/assert_positive/assert_less/Assert/AssertGuard/Assert}}]] [Op:__inference_run_chain_169987]

Function call stack:
run_chain

我的理解是,否定的kappa被馈送到dist,但是经过Softplus双射器,这不可能吗?并且当反转所有我的双射点时,该函数仍在运行,这很奇怪,因为尺寸应该由于softmaxCentered而损坏。

所以我有种感觉,我的射手正被忽略。我想念什么?

提前感谢您的帮助:)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。