微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Pytorch模型参数的grad属性在backward()之后为None

如何解决Pytorch模型参数的grad属性在backward()之后为None

我正在研究如何在图像字幕任务中使用注意力和开发工作代码

我有编码器,它返回图像矢量形状(批次,2048)。

解码器部分:

class CaptionNetAttention(nn.Module):
    def __init__(self,n_tokens=n_tokens,emb_size=128,lstm_units=256,cnn_feature_size=2048,attention_dim = 100):
        """ A recurrent 'head' network for image captioning. See scheme above. """
        super(self.__class__,self).__init__()
        
        # a layer that converts conv features to 
        self.cnn_to_h0 = nn.Linear(cnn_feature_size,lstm_units)
        self.cnn_to_c0 = nn.Linear(cnn_feature_size,lstm_units)
        
        #attention
        self.encoder_att = nn.Linear(encoder_dim,attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim,attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim,encoder_dim)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.softmax(dim=1)  # softmax layer to calculate weights
        

        # create embedding for input words. Use the parameters (e.g. emb_size).
        self.emb = nn.Embedding(n_tokens,emb_size)
            
        # lstm cell
        self.lstm = nn.LSTMCell(input_size=emb_size,hidden_size = lstm_units )
            
        # create logits: linear layer that takes lstm hidden state as input and computes one number per token
        self.logits = nn.Linear(lstm_units,n_tokens) 
        
    def forward(self,image_vectors,captions_ix):
       # initial cell state
        cx = self.cnn_to_c0(image_vectors)
        hx = self.cnn_to_h0(image_vectors)

        # compute embeddings for captions_ix
        input = self.emb(captions_ix).permute(1,2)
        input.retain_grad()
        
#ATTENTION PART
        output = []
        for i in range(input.size()[0]):
         
          att1 = self.encoder_att(image_vectors)  
          att1.retain_grad()
          
        
          att2 = self.decoder_att(hx)  
          att2.retain_grad()
        
          att_sum = att1 + att2
          att_sum.retain_grad()
        
          att = self.full_att(self.relu(att_sum))  # (batch_size,num_pixels)
          att.retain_grad()
        
          alpha = self.softmax(att)  # (batch_size,num_pixels)  
          alpha.retain_grad()
          
          hx = (image_vectors * alpha)  # (batch_size,encoder_dim)
          hx.retain_grad()
          
          hx = self.cnn_to_h0(image_vectors)
          hx,cx = self.lstm(input[i],(hx,cx))
        
          output.append(hx)
        
        outputs = torch.stack(output,dim=0)
        outputs.retain_grad()
        
        logits = self.logits(outputs)

        return F.log_softmax(logits,-1).permute(1,2)

Attention 部分单独检查并计算 grads。 代码确保:

class Attention(nn.Module):
    """
    Attention Network.
    """

    def __init__(self,encoder_dim,decoder_dim,attention_dim):
        """
        :param encoder_dim: feature size of encoded images
        :param decoder_dim: size of decoder's RNN
        :param attention_dim: size of the attention network
        """
        super(Attention,self).__init__()
        self.encoder_att = nn.Linear(encoder_dim,encoder_dim)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.softmax(dim=1)  # softmax layer to calculate weights

    def forward(self,encoder_out,decoder_hidden):
        """
        Forward propagation.
        :param encoder_out: encoded images,a tensor of dimension (batch_size,num_pixels,encoder_dim)
        :param decoder_hidden: prevIoUs decoder output,decoder_dim)
        :return: attention weighted encoding,weights
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size,attention_dim)
        
        att2 = self.decoder_att(decoder_hidden)  # (batch_size,attention_dim)
        
        att = self.full_att(self.relu(att1 + att2))  # (batch_size,num_pixels)
        
        alpha = self.softmax(att)  # (batch_size,num_pixels)
        
        attention_weighted_encoding = (encoder_out * alpha)  # (batch_size,encoder_dim)
        

        return attention_weighted_encoding,alpha

然后我将损失函数应用于 CaptionNetAttention 网络。

损失:

def compute_loss(network,captions_ix):
    """
    :param image_vectors: torch tensor containing inception vectors. shape: [batch,cnn_feature_size]
    :param captions_ix: torch tensor containing captions as matrix. shape: [batch,word_i]. 
        padded with pad_ix
    :returns: crossentropy (neg llh) loss for next captions_ix given prevIoUs ones. Scalar float tensor
    """
    # network.eval()
    # captions for input - all except last cuz we don't kNow next token for last one.
    criterion = nn.CrossEntropyLoss(ignore_index = pad_ix)
    captions_ix_inp = captions_ix[:,:-1].contiguous()
    captions_ix_next = captions_ix[:,1:].contiguous()
    

    logits_for_next = network.forward(image_vectors,captions_ix_inp)
    print(logits_for_next.shape,captions_ix_next.shape )
    print(logits_for_next.reshape((-1,logits_for_next.shape[2])).shape,captions_ix_next.reshape(-1).shape)
    loss = criterion(logits_for_next.reshape((-1,logits_for_next.shape[2])),captions_ix_next.reshape(-1)) 
    
    return loss

检查方式:

dummy_loss = compute_loss(network,dummy_img_vec,dummy_capt_ix)

dummy_loss.backward()

assert all(param.grad is not None for param in network.parameters()),\
        'loss should depend differentiably on all neural network weights'

我有断言错误

以下代码

for i,j in network.named_parameters():
  grad_is_none = ''
  if j.grad is None:
    grad_is_none = 'grad is None'
  print(i,grad_is_none)

显示输出

cnn_to_h0.weight 
cnn_to_h0.bias 
cnn_to_c0.weight 
cnn_to_c0.bias 
encoder_att.weight grad is None
encoder_att.bias grad is None
decoder_att.weight grad is None
decoder_att.bias grad is None
full_att.weight grad is None
full_att.bias grad is None
emb.weight 
lstm.weight_ih 
lstm.weight_hh 
lstm.bias_ih 
lstm.bias_hh 
logits.weight 
logits.bias

这种行为的原因是什么? grad graph 中断的时刻在哪里?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。