如何解决Tensorflow概率-MCMC-过渡核中的双射手问题?
我正在建立张量流概率模型的混合体。一个给定模型的联合分布为:
one_network_prior = tfd.JointdistributionNamed(
model=dict(
mu_g=tfb.Sigmoid(
low=-1.0,high=1.0,validate_args=True,name="mu_g"
)(
tfd.normal(
loc=tf.zeros((D,)),scale=0.5,validate_args=True
)
),epsilon=tfd.Gamma(
concentration=0.4,rate=1.0,name="epsilon"
),mu_s=lambda mu_g,epsilon: tfb.Sigmoid(
low=-1.0,name="mu_s"
)(
tfd.normal(
loc=tf.stack(
[
mu_g
] * S
),scale=epsilon,sigma=tfd.Gamma(
concentration=0.3,name="sigma"
),mu_s_t=lambda mu_s,sigma: tfb.Sigmoid(
low=-1.0,name="mu_s_t"
)(
tfd.normal(
loc=tf.stack(
[
mu_s
] * T
),scale=sigma,validate_args=True
)
)
)
)
然后,我需要“混合”模型,但是这种混合是相当自定义的,我可以通过自定义log_prob_fn
对数概率函数手动完成:
def log_prob_fn(
mu_g,epsilon,mu_s,sigma,mu_s_t,kappa,spatial,observed
):
log_probs_per_network = []
probs_per_network = []
for l in range(L):
log_probs_per_network.append(
tf.reduce_sum(
one_network_prior.log_prob(
{
"mu_g": mu_g[l],"epsilon": epsilon[l],"mu_s": mu_s[l],"sigma": sigma[l],"mu_s_t": mu_s_t[l]
}
)
)
)
dist = tfb.Sigmoid(
low=-1.0,validate_args=True
)(
tfd.normal(
loc=tf.stack(
[
mu_s_t[l]
] * N
),scale=kappa
)
)
probs_per_network.append(
tf.reduce_prod(
dist.prob(
observed
),axis=-1
)
)
kappa_log_prob = kappa_prior.log_prob(
kappa
)
mixed_probs = (
spatial
*
tf.stack(
probs_per_network,axis=-1
)
)
margin_prob = tf.reduce_sum(
mixed_probs,axis=-1
)
mix_log_prob = tf.reduce_sum(
tf.math.log(
margin_prob
)
)
return (
tf.reduce_sum(
log_probs_per_network
)
+ kappa_log_prob
+ mix_log_prob
)
(我知道此函数效率不高,但是我无法轻松地从以前的模型(批处理形状)中进行采样,因此我暂时不得不对模型进行迭代)
请注意,正在为每个网络动态创建分发dist
。
然后,目标是使用此模型并将其适合数据。我使用one_network_prior
生成了初始状态,并手动混合了数据以获得(N,T,S,D)
观察到的数据,该数据将被馈送到MCMC,如下所示:
hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=lambda *params: log_prob_fn(
*params,observed=observed
),step_size=0.065,num_leapfrog_steps=5
)
unconstraining_bijectors = [
tfb.Sigmoid(
low=-1.0,high=1.0
),tfb.softplus(),tfb.Sigmoid(
low=-1.0,tfb.softmaxCentered()
]
transformed_kernel = tfp.mcmc.TransformedTransitionKernel(
inner_kernel=hmc_kernel,bijector=unconstraining_bijectors
)
adapted_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=transformed_kernel,num_adaptation_steps=400,target_accept_prob=0.65
)
@tf.function
def run_chain(initial_state,num_results=1000,num_burnin_steps=500):
return tfp.mcmc.sample_chain(
num_results=num_results,num_burnin_steps=num_burnin_steps,current_state=initial_state,kernel=adapted_kernel
)
samples,kernel_results = run_chain(
initial_state=init_state,num_results=20000,num_burnin_steps=5000
)
但是当我运行run_chain
函数时,经过几次迭代后,我得到一个错误:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-45-2ff348713067> in <module>
----> 1 samples,kernel_results = run_chain(
2 initial_state=init_state,3 num_results=20000,4 num_burnin_steps=5000
5 )
~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self,*args,**kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args,**kwds)
781
782 new_tracing_count = self._get_tracing_count()
~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self,**kwds)
844 *args,**kwds)
845 # If we did not create any variables the trace we have is good enough.
--> 846 return self._concrete_stateful_fn._filtered_call(canon_args,canon_kwds) # pylint: disable=protected-access
847
848 def fn_with_cond(*inner_args,**inner_kwds):
~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _filtered_call(self,args,kwargs,cancellation_manager)
1841 `args` and `kwargs`.
1842 """
-> 1843 return self._call_flat(
1844 [t for t in nest.flatten((args,kwargs),expand_composites=True)
1845 if isinstance(t,(ops.Tensor,~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _call_flat(self,captured_inputs,cancellation_manager)
1921 and executing_eagerly):
1922 # No tape is watching; skip to running the function.
-> 1923 return self._build_call_outputs(self._inference_function.call(
1924 ctx,cancellation_manager=cancellation_manager))
1925 forward_backward = self._select_forward_and_backward_functions(
~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in call(self,ctx,cancellation_manager)
543 with _InterpolateFunctionError(self):
544 if cancellation_manager is None:
--> 545 outputs = execute.execute(
546 str(self.signature.name),547 num_outputs=self._num_outputs,~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name,num_outputs,inputs,attrs,name)
57 try:
58 ctx.ensure_initialized()
---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle,device_name,op_name,60 inputs,num_outputs)
61 except core._NotOkStatusException as e:
InvalidArgumentError: assertion Failed: [Argument `scale` must be positive.] [Condition x > 0 did not hold element-wise:] [x (mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/strided_slice_1:0) = ] [-nan]
[[{{node mcmc_sample_chain/trace_scan/while/body/_415/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/body/_2366/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/body/_3200/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/JointdistributionNamed/log_prob/normal/assert_positive/assert_less/Assert/AssertGuard/else/_3580/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/JointdistributionNamed/log_prob/normal/assert_positive/assert_less/Assert/AssertGuard/Assert}}]] [Op:__inference_run_chain_169987]
Function call stack:
run_chain
我的理解是,否定的kappa
被馈送到dist
,但是经过Softplus
双射器,这不可能吗?并且当反转所有我的双射点时,该函数仍在运行,这很奇怪,因为尺寸应该由于softmaxCentered
而损坏。
所以我有种感觉,我的射手正被忽略。我想念什么?
提前感谢您的帮助:)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。