如何解决pytorch中的分批光束搜索
我正在尝试在文本生成模型中实现波束搜索解码策略。这是我用来解码输出概率的函数。
def beam_search_decoder(data,k):
sequences = [[list(),0.0]]
# walk over each step in sequence
for row in data:
all_candidates = list()
for i in range(len(sequences)):
seq,score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j],score - torch.log(row[j])]
all_candidates.append(candidate)
# sort candidates by score
ordered = sorted(all_candidates,key=lambda tup:tup[1])
sequences = ordered[:k]
return sequences
现在您可以看到该功能是在考虑batch_size 1的情况下实现的。为批处理大小添加另一个循环将使算法O(n^4)
。现在很慢。有什么方法可以提高此功能的速度。我的模型输出通常大小为(32,150,9907)
,其格式为(batch_size,max_len,vocab_size)
解决方法
下面是我的实现,可能比for循环的实现快一点。
import torch
def beam_search_decoder(post,k):
"""Beam Search Decoder
Parameters:
post(Tensor) – the posterior of network.
k(int) – beam size of decoder.
Outputs:
indices(Tensor) – a beam of index sequence.
log_prob(Tensor) – a beam of log likelihood of sequence.
Shape:
post: (batch_size,seq_length,vocab_size).
indices: (batch_size,beam_size,seq_length).
log_prob: (batch_size,beam_size).
Examples:
>>> post = torch.softmax(torch.randn([32,20,1000]),-1)
>>> indices,log_prob = beam_search_decoder(post,3)
"""
batch_size,_ = post.shape
log_post = post.log()
log_prob,indices = log_post[:,:].topk(k,sorted=True)
indices = indices.unsqueeze(-1)
for i in range(1,seq_length):
log_prob = log_prob.unsqueeze(-1) + log_post[:,i,:].unsqueeze(1).repeat(1,k,1)
log_prob,index = log_prob.view(batch_size,-1).topk(k,sorted=True)
indices = torch.cat([indices,index.unsqueeze(-1)],dim=-1)
return indices,log_prob
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。