微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

自定义 DistributionLambda 层中的变量未更新

如何解决自定义 DistributionLambda 层中的变量未更新

我想在 tensorflow-probability 中构建一个自定义层,然后我可以使用它来构建 DenseVariational 层的后部。

作为第一步,我构建了以下后验,它相当于 tutorial 中使用的平均场后验,但不是学习正态分布的参数,而是学习两个双射器的参数。

def posterior_trainable_bijector(kernel_size,bias_size=0,dtype=None):
    n = kernel_size + bias_size
    c = np.log(np.expm1(1.0))

    return tf.keras.Sequential(
        [
            tfp.layers.VariableLayer(2 * n,dtype=dtype),tfp.layers.distributionLambda(
                lambda t: tfp.distributions.Transformeddistribution(
                    tfd.Independent(
                        tfd.normal(loc=tf.zeros(n),scale=tf.ones(n)),reinterpreted_batch_ndims=1,),tfp.bijectors.Chain(
                        bijectors=[
                            tfp.bijectors.Shift(t[...,n:]),tfp.bijectors.Scale(
                                1e-5 + 0.01 * tf.math.softplus(c + t[...,:n])
                            ),]
                    ),)
            ),]
    )

作为下一步,我认为将 distributionLambda 层子类化是一个好主意,因为这样可以设置更复杂的双射器。 不幸的是,我的初稿似乎不起作用。更具体地说,我仍然能够运行我的代码,但似乎 loc_params/scale_params 在训练期间没有更新,但我不知道为什么会这样。 有什么建议吗?

class LocScaleBijectorLayer(tfp.layers.distributionLambda):
    def __init__(
        self,event_shape=(),convert_to_tensor_fn=tfd.distribution.sample,validate_args=False,name="LocScaleBijectorLayer",**kwargs,):

        c = np.log(np.expm1(1.0))
        with tf.name_scope(name) as name:
            loc_params = tf.Variable(
                tf.zeros(event_shape),name="loc_var",trainable=True
            )
            scale_params = tf.Variable(
                tf.ones(event_shape),name="scale_var",trainable=True
            )

            self.base_distribution = tfd.Independent(
                tfd.normal(loc=tf.zeros(event_shape),scale=tf.ones(event_shape)),reinterpreted_batch_ndims=-1,)

            self.bijector = tfp.bijectors.Chain(
                bijectors=[
                    tfp.bijectors.Shift(loc_params),tfp.bijectors.Scale(
                        1e-5 + 0.01 * tf.math.softplus(c + scale_params)
                    ),]
            )

            super(LocScaleBijectorLayer,self).__init__(
                lambda t: tfp.distributions.Transformeddistribution(
                    self.base_distribution,self.bijector
                ),convert_to_tensor_fn,name=name,)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。