如何解决RNN变分自动编码器可产生良好的重建效果,但生成效果较差 #1 #2
我正在尝试通过训练基于RNN的变分自动编码器来重现此serial in postgres is being increased even though I added on conflict do nothing的结果。虽然原始文本的重建效果很好,但是新文本的生成却很糟糕。我在下面给出了我的模型架构。它大致基于此paper
class SentenceVAE(nn.Module):
def __init__(self,embedding_size,vocab_size,hidden_size,latent_dim,dropout,device,max_len = 50,pad_idx = 0,start_idx = 1,end_idx = 2,unk_idx = 3):
super(SentenceVAE,self).__init__()
self.tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
self.embed = nn.Embedding(vocab_size,pad_idx)
self.hidden_to_mu = nn.Linear(hidden_size,latent_dim)
self.hidden_to_logvar = nn.Linear(hidden_size,latent_dim)
self.dropout = nn.Dropout(dropout)
self.encoder_gru = nn.GRU(embedding_size,batch_first=True)
self.decoder_gru = nn.GRU(embedding_size,batch_first=True)
self.flow_fc = nn.Sequential(
nn.Linear(latent_dim,1024),nn.GELU(),nn.Linear(1024,hidden_size)
)
self.out = nn.Linear(hidden_size,vocab_size)
self.device = device
self.latent_dim = latent_dim
self.unk_idx = unk_idx
self.start_idx = start_idx
self.end_idx = end_idx
self.pad_idx = pad_idx
def reparameterize(self,mu,logvar):
eps = torch.randn_like(logvar)
std = torch.exp(0.5 * logvar)
return mu + eps * std
def decode(self,hidden,dec_in):
decoder_input = self.embed(dec_in)
if len(hidden.size()) < 3:
hidden = hidden.unsqueeze(0)
outputs,hidden = self.decoder_gru(decoder_input,hidden)
out = self.out(outputs)
return out,hidden
def sample_sentence(self,z = None):
max_len = 20
batch = 1
if z == None:
z = torch.randn((batch,self.latent_dim))
z = z.to(self.device)
hidden = self.flow_fc(z)
pred = [[self.start_idx]]
out_sent = []
for i in range(max_len):
pred_tensor = torch.tensor(pred)
pred_tensor = pred_tensor.to(device)
preds,hidden = self.decode(hidden,pred_tensor)
preds = preds[:,-1,:]
pred_index = torch.argmax(preds,dim = -1)
pred[0] = [pred_index.item()]
out_sent.append(pred_index.item())
if pred_index.item() == self.end_idx:
break
return out_sent
def forward(self,x):
enc_in,dec_in = x,x
encoder_input = self.embed(enc_in)
_,rnn_hidden = self.encoder_gru(encoder_input)
rnn_hidden = rnn_hidden.squeeze(0)
mu = self.hidden_to_mu(rnn_hidden)
logvar = self.hidden_to_logvar(rnn_hidden)
z = self.reparameterize(mu,logvar)
## Randomly replace words with <unk>
dec_in_copy = dec_in.clone()
prob = torch.rand(dec_in.size())
prob[(dec_in == self.start_idx) | (dec_in == self.end_idx) | (dec_in == self.pad_idx)] = 1
dec_in_copy[prob < dropout] = self.unk_idx
hidden = self.flow_fc(z)
out,_= self.decode(hidden,dec_in_copy)
return mu,logvar,out
#1
<SOS> gondry 's direction is adequate ... but what gives human nature its unique feel is kaufman 's script . <EOS>
<SOS> it 's direction is adequate ... but what gives human nature its unique feel is kaufman 's approach . <EOS>
#2
<SOS> there seems to be no clear path as to where the story 's going,or how long it 's going to take to get there . <EOS>
<SOS> there seems to be no amount path,to where the most 's going,or even long it 's going to take to get there . <EOS>
现在,如果我在类sample_sentence
中使用SentenceVAE
方法来生成新句子,则输出始终为:
<SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS>
在调试时,我注意到输出始终是pred = [[self.start_idx]]
内部的令牌,该令牌重复max_len
次。在上述情况下,<SOS>
是sample_sentence
中的输入令牌。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。