微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用tf.scatter_nd使Keras'None'批处理大小不变

如何解决使用tf.scatter_nd使Keras'None'批处理大小不变

我需要将池化模块输入到LSTM解码器,然后使用定制层构造该模块,并将编码器LSTM状态和Keras输入层作为输入。在此自定义层中,我需要分散对索引的更新:

updates: <tf.Tensor --- shape=(None,225,5,32) dtype=float32>
indices: <tf.Tensor --- shape=(None,225) dtype=int32>

使用tf.scatter_nd创建一个具有shape = {None,960,5,32)的张量,如下所示:

tf.scatter_nd(tf.expand_dims(indices,2),updates,shape=[None,960,32])

但是问题在于这样做会由于形状为nonetype而引起错误,并且我不想在其中声明batch_size,因为它是Keras层并且只能在学习过程中确定。在这种状态下,代码的工作版本是这样的:

tf.scatter_nd(tf.expand_dims(indices,shape=[960,32])
        >>> <tf.Tensor 'ScatterNd_4:0' shape=(960,32) dtype=float32>

已忽略输出中的batch_size。 有没有其他方法可以代替tf.scatter_nd来构建所需的输出张量,还是可以使其正常工作?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。