微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tf.gather_nd param.rank和索引不起作用

如何解决Tf.gather_nd param.rank和索引不起作用

num_class_range = tf.range(0,self.num_class,delta=1,dtype=tf.float32,name='range')
doc_type_indices = num_class_range+(self.num_class*self.doc_type)
doc_type_indices = tf.dtypes.cast(doc_type_indices,tf.int32)

y_pred = tf.reshape(y_pred,tf.stack([self.batch_size,-1,self.num_class*self.num_doc_type]))
#add dims (20,1,25)
doc_type_expanded = tf.expand_dims(doc_type_indices,axis=1)
#repeat (20,25) to 128*64
doc_type_indices = tf.repeat(doc_type_expanded,num_rows*num_cols,axis=1)
doc_type_indices = tf.expand_dims(doc_type_indices,axis=3)     

print("=================",self.doc_type.shape,"=============")

class_y_pred = tf.gather_nd(y_pred,doc_type_indices,batch_dims=0,name=None)

我正在尝试根据索引对参数进行切片,但是上面提到的代码向我显示了在线错误

indice.shape [-1]

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。