如何解决需要有关tensorflow_addons对象的更多信息
我正在尝试使用带有张量流的注意力机制构建编码器-解码器模型。 我正在使用tensorflow_addons存储库,试图重现和理解此模型:https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt
很遗憾,BasicDecoder
,Sampler
和AttentionWrapper
对象上没有足够的文档供我完美使用。在研究期间,我能找到的最明确的文档是https://medium.com/@dhirensk/tensorflow-addons-seq2seq-example-using-attention-and-beam-search-9f463b58bc6b。
最模糊的阶段是使用TrainingSampler()
和GreedyEmbeddingSampler()
时,但他没有更深入地了解采样器的上下文,我唯一需要了解的信息是在{ {3}}:
#Sampler instances are used by BasicDecoder. The normal usage of a sampler is like below:
sampler = Sampler(init_args)
(initial_finished,initial_inputs) = sampler.initialize(input_tensors)
cell_input = initial_inputs
cell_state = cell.get_initial_state(...)
for time_step in tf.range(max_output_length):
cell_output,cell_state = cell(cell_input,cell_state)
sample_ids = sampler.sample(time_step,cell_output,cell_state)
(finished,cell_input,cell_state) = sampler.next_inputs(
time_step,cell_state,sample_ids)
if tf.reduce_all(finished):
break
此外,我的模型不包含嵌入层,因为我的输入向量不需要它。因此,我想在测试/推断期间必须使用另一个采样器代替GreedyEmbeddingSampler()。
我希望我足够清楚,希望有人可以帮助我理解。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。