微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

几次迭代后损失到NaN

如何解决几次迭代后损失到NaN

在我的模型中,输入是以 edge-index 和节点 features 形式的图形数据。经过几次图形数据训练后,损失( EDIT :这是MSELoss函数和负损失函数的组合,即 L1 + -L2 ))变为NaN。在大约40次迭代后, L1 -L2 都变为NaN。

学习率= 0.00001。我还检查了无效的输入数据,但没有找到。

from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import networkx as nx
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class Model(nn.Module):
    def __init__(self,nin,nhid1,nout,inp_l,hid_l,out_l=1):
        super(Model,self).__init__()

        self.g1 = GCNConv(in_channels= nin,out_channels= nhid1)
        self.g2 = GCNConv(in_channels= nhid1,out_channels= nout)
        self.dropout = 0.5
        self.lay1 = nn.Linear(inp_l,hid_l)
        self.lay2 = nn.Linear(hid_l,out_l)

    def forward(self,x,adj):
        x = F.relu(self.g1(x,adj))
        x = F.dropout(x,self.dropout,training=self.training)
        x = self.g2(x,adj)
        
        x = self.lay1(x)
        x = F.relu(x)
        x = self.lay2(x)
        x = F.relu(x)
        
        return x

模型的输入:

x (张量,可选)–形状为[num_nodes,num_node_features]的节点特征矩阵。

edge_index (LongTensor,可选)–图形连接为COO格式,形状为[2,num_edges]

这里num_nodes = 1000; num_node_features = 1; num_edges = 5000

GCNConv是图嵌入器返回的[num_nodes,dim]矩阵。它需要边缘列表和功能以返回矩阵。

编辑2:添加了损失的计算方式

def train_model(epoch):
    model= Model(nin = 1,nhid1=128,nout=128,inp_l=128,hid_l=64,out_l=1).to(device)
    optimizer = optim.Adam(model.parameters(),lr=0.00001)

    model.train()
    t = time.time()
    optimizer.zero_grad()
    Y = model(features,adjacency_list)

    Y1 = func(Y) #Y1 values are calculated from Y by passing through a function func to obtain a same sized vector as Y

    loss1 = ((Y1-Y)**2).mean()  #MSE Loss function
    
    loss2 = -Y.abs().mean() # This loss is implemented to prevent Y values going to 0. Notice the "-" sign
    
    loss_train = loss1 + loss2
    loss_train.backward(retain_graph=True)
    nn.utils.clip_grad_norm_(model.parameters(),0.5)

    optimizer.step()
    
    if epoch%20==0:
        print("MSE loss = ",loss1,"\t","Mean Loss = ",loss2)
        print('Epoch: {:04d}'.format(epoch+1),'loss_train: {:.4f}'.format(loss_train.item()),'time: {:.4f}s'.format(time.time() - t))
        print("\n\n")

    return Y

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。