如何解决使用tf.scatter_nd使Keras'None'批处理大小不变
我需要将池化模块输入到LSTM解码器,然后使用定制层构造该模块,并将编码器LSTM状态和Keras输入层作为输入。在此自定义层中,我需要分散对索引的更新:
updates: <tf.Tensor --- shape=(None,225,5,32) dtype=float32>
indices: <tf.Tensor --- shape=(None,225) dtype=int32>
使用tf.scatter_nd
创建一个具有shape = {None,960,5,32)的张量,如下所示:
tf.scatter_nd(tf.expand_dims(indices,2),updates,shape=[None,960,32])
但是问题在于这样做会由于形状为nonetype而引起错误,并且我不想在其中声明batch_size
,因为它是Keras层并且只能在学习过程中确定。在这种状态下,代码的工作版本是这样的:
tf.scatter_nd(tf.expand_dims(indices,shape=[960,32])
>>> <tf.Tensor 'ScatterNd_4:0' shape=(960,32) dtype=float32>
已忽略输出中的batch_size。
有没有其他方法可以代替tf.scatter_nd
来构建所需的输出张量,还是可以使其正常工作?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。