微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在Pytorch中训练CNN时,火车突然重新初始化,验证丢失,覆盖速度变慢

如何解决在Pytorch中训练CNN时,火车突然重新初始化,验证丢失,覆盖速度变慢

我是深度学习的初学者,尤其是从头开始设计神经网络的我,我尝试使用Pytorch中的Autoencoder概念构建B / W图像着色器,同时使用数据集引用here中的Keras体系结构here。 Keras架构供参考:

#Design the neural network
model = Sequential()
model.add(InputLayer(input_shape=(256,256,1)))
model.add(Conv2D(64,(3,3),activation='relu',padding='same'))
model.add(Conv2D(64,padding='same',strides=2))
model.add(Conv2D(128,padding='same'))
model.add(Conv2D(128,strides=2))
model.add(Conv2D(256,padding='same'))
model.add(Conv2D(256,strides=2))
model.add(Conv2D(512,padding='same'))
model.add(UpSampling2D((2,2)))
model.add(Conv2D(64,2)))
model.add(Conv2D(32,padding='same'))
model.add(Conv2D(2,activation='tanh',2)))

# Finish model
model.compile(optimizer='rmsprop',loss='mse')

# Image transformer
datagen = ImageDataGenerator(
    shear_range=0.2,zoom_range=0.2,rotation_range=20,horizontal_flip=True)

# Generate training data
batch_size = 50

def image_a_b_gen(batch_size):
    for batch in datagen.flow(Xtrain,batch_size=batch_size):
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,0]
        Y_batch = lab_batch[:,1:] / 128
        yield (X_batch.reshape(X_batch.shape+(1,)),Y_batch)

# Train model
TensorBoard(log_dir='/output')
model.fit_generator(image_a_b_gen(batch_size),steps_per_epoch=10000,epochs=1)

这是我在PyTorch中构建的架构:

class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
        nn.Conv2d(1,64,padding = (1,1)),nn.ReLU(),nn.Conv2d(64,128,1),stride = (2,2)),nn.Conv2d(128,nn.Conv2d(256,512,nn.Conv2d(512,nn.Upsample(scale_factor= (2,32,nn.Conv2d(32,2,)

  def forward(self,input):
       output = self.model(input)
       return output

并像这样初始化:

model = AutoEncoder()
loss_func = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.001)

这是训练功能(在Google Colab上进行了训练),其中包含验证功能

  def train(model,optimizer,loss_func,train_loader,val_loader,epochs,device = 'cuda'):
  model = model.to(device)
  loss_func = loss_func.to(device)
  training_loss = 0
  valid_loss = 0
  save_path = {'grayscale': '/content/outputs/gray/','colorized': '/content/outputs/color/','target': '/content/outputs/target/'}

  for epoch in range(e+1,e+1+epochs):
    print('Starting epoch {}'.format(epoch))
    epoch_start = time.time()
    training_loss = 0.0
    valid_loss = 0.0

    model.train()
    b = 1
    for batch in train_loader:
      batch_start = time.time()
      optimizer.zero_grad()
      img_gray,img_ab,og_image = batch
      input,target = img_gray,img_ab
      input,target = input.to(device),target.to(device)
      output = model(input)
      loss = loss_func(output,target)
      loss.backward()
      optimizer.step()

      training_loss += loss.data.item()
      if b % 20 == 0:
        print('Epoch: {},Batch: {} Training Loss: {:.4f},Time: {:.2f}'.format(
        epoch,b,loss.data.item(),time.time() - batch_start))
      b += 1

    training_loss /= len(train_loader)

    model.eval()
    b = 1
    for batch in val_loader:
      batch_start = time.time()
      img_gray,target)

      valid_loss += loss.data.item()

      j = random.randint(1,9)
      save_name = 'epoch-{}batch-{}img-{}.jpg'.format(epoch,j)
      save_image(input[j].detach().cpu(),output[j].detach().cpu(),target[j].detach().cpu(),og_image[j].detach().cpu(),save_path=save_path,save_name=save_name)
        
      print('Epoch: {},Batch: {} Valid Loss: {:.4f},Time: {:.2f}'.format(
      epoch,time.time() - batch_start))

      b += 1

    valid_loss /= len(val_loader)

    # model saving code here
  
    print('Epoch: {},Training Loss: {:.4f},Validation Loss: {:.4f},Time: {:.2f}'
    .format(epoch,training_loss,valid_loss,time.time() - epoch_start))

但是,在训练该模型时,我发现收敛速度比我预期的要慢,并且训练和验证损失在某个时期(第40阶段)突然激增(重新初始化为第1阶段的损失)。如下面的日志所示,然后增加。即使确实在网上搜索,也无法确定发生这种情况的原因。我确实在单个图像上对模型进行了过度拟合,以检查该模型是否能够学习图像,并且在该单个图像上经过约200个纪元后,它给出了很好的结果。

也欢迎任何其他提高培训速度的建议。

日志:

Starting epoch 35
Epoch: 35,Batch: 20 Training Loss: 106.5381,Time: 0.68
Epoch: 35,Batch: 40 Training Loss: 164.4004,Time: 0.67
Epoch: 35,Batch: 60 Training Loss: 187.2683,Batch: 80 Training Loss: 125.9881,Batch: 100 Training Loss: 178.2198,Batch: 120 Training Loss: 145.3067,Batch: 140 Training Loss: 172.2734,Batch: 160 Training Loss: 161.0511,Batch: 180 Training Loss: 154.7270,Batch: 200 Training Loss: 105.7089,Batch: 220 Training Loss: 136.7364,Batch: 240 Training Loss: 144.3123,Batch: 1 Valid Loss: 150.3127,Time: 0.36
Epoch: 35,Batch: 2 Valid Loss: 136.5788,Batch: 3 Valid Loss: 145.5104,Time: 0.37
Epoch: 35,Batch: 4 Valid Loss: 174.9318,Time: 0.35
Epoch: 35,Batch: 5 Valid Loss: 169.0320,Batch: 6 Valid Loss: 104.9420,Batch: 7 Valid Loss: 139.7437,Batch: 8 Valid Loss: 142.1164,Batch: 9 Valid Loss: 162.5985,Batch: 10 Valid Loss: 195.1854,Batch: 11 Valid Loss: 134.6287,Batch: 12 Valid Loss: 136.0625,Batch: 13 Valid Loss: 210.8838,Batch: 14 Valid Loss: 171.1289,Batch: 15 Valid Loss: 158.6969,Time: 0.34
Epoch: 35,Batch: 16 Valid Loss: 245.2010,Batch: 17 Valid Loss: 163.9438,Batch: 18 Valid Loss: 148.4983,Batch: 19 Valid Loss: 153.7577,Batch: 20 Valid Loss: 180.0079,Batch: 21 Valid Loss: 123.2093,Batch: 22 Valid Loss: 125.0577,Batch: 23 Valid Loss: 167.9037,Batch: 24 Valid Loss: 142.5038,Batch: 25 Valid Loss: 204.0228,Batch: 26 Valid Loss: 138.4986,Time: 0.31
Epoch: 35,Batch: 27 Valid Loss: 245.4016,Batch: 28 Valid Loss: 147.9031,Batch: 29 Valid Loss: 86.9302,Batch: 30 Valid Loss: 144.6022,Batch: 31 Valid Loss: 171.0132,Batch: 32 Valid Loss: 171.6215,Time: 0.26

Epoch: 35,Training Loss: 158.2567,Validation Loss: 159.1384,Time: 419.94

Starting epoch 36
Epoch: 36,Batch: 20 Training Loss: 105.9032,Time: 0.68
Epoch: 36,Batch: 40 Training Loss: 166.2677,Time: 0.67
Epoch: 36,Batch: 60 Training Loss: 187.5138,Batch: 80 Training Loss: 126.4907,Batch: 100 Training Loss: 177.6801,Batch: 120 Training Loss: 145.7679,Batch: 140 Training Loss: 172.1092,Batch: 160 Training Loss: 161.7235,Batch: 180 Training Loss: 154.2764,Batch: 200 Training Loss: 105.0589,Batch: 220 Training Loss: 137.1901,Batch: 240 Training Loss: 143.7092,Batch: 1 Valid Loss: 151.7165,Time: 0.36
Epoch: 36,Batch: 2 Valid Loss: 137.6912,Time: 0.35
Epoch: 36,Batch: 3 Valid Loss: 148.1911,Batch: 4 Valid Loss: 178.2061,Batch: 5 Valid Loss: 171.2887,Batch: 6 Valid Loss: 105.7988,Batch: 7 Valid Loss: 139.2858,Batch: 8 Valid Loss: 143.3132,Batch: 9 Valid Loss: 164.0428,Batch: 10 Valid Loss: 197.8514,Batch: 11 Valid Loss: 135.2920,Batch: 12 Valid Loss: 136.8798,Time: 0.37
Epoch: 36,Batch: 13 Valid Loss: 213.7531,Batch: 14 Valid Loss: 173.2047,Time: 0.34
Epoch: 36,Batch: 15 Valid Loss: 159.9836,Batch: 16 Valid Loss: 250.1139,Batch: 17 Valid Loss: 166.5989,Batch: 18 Valid Loss: 149.8430,Batch: 19 Valid Loss: 156.4111,Time: 0.32
Epoch: 36,Batch: 20 Valid Loss: 182.4751,Batch: 21 Valid Loss: 125.2163,Batch: 22 Valid Loss: 127.7892,Batch: 23 Valid Loss: 171.6647,Batch: 24 Valid Loss: 145.8683,Batch: 25 Valid Loss: 205.0257,Batch: 26 Valid Loss: 141.0681,Batch: 27 Valid Loss: 248.7428,Batch: 28 Valid Loss: 148.8966,Batch: 29 Valid Loss: 88.4646,Batch: 30 Valid Loss: 143.6590,Batch: 31 Valid Loss: 173.3025,Batch: 32 Valid Loss: 174.2860,Time: 0.24

Epoch: 36,Training Loss: 158.3800,Validation Loss: 161.1226,Time: 419.45

Starting epoch 37
Epoch: 37,Batch: 20 Training Loss: 105.6317,Time: 0.68
Epoch: 37,Batch: 40 Training Loss: 165.0392,Time: 0.67
Epoch: 37,Batch: 60 Training Loss: 188.5125,Batch: 80 Training Loss: 126.0623,Batch: 100 Training Loss: 177.5551,Batch: 120 Training Loss: 145.3477,Batch: 140 Training Loss: 175.2243,Batch: 160 Training Loss: 160.9049,Batch: 180 Training Loss: 153.3031,Batch: 200 Training Loss: 104.0302,Batch: 220 Training Loss: 135.7108,Batch: 240 Training Loss: 142.8567,Batch: 1 Valid Loss: 150.1673,Time: 0.36
Epoch: 37,Batch: 2 Valid Loss: 136.7532,Batch: 3 Valid Loss: 144.9570,Batch: 4 Valid Loss: 173.6571,Time: 0.32
Epoch: 37,Batch: 5 Valid Loss: 166.9340,Time: 0.35
Epoch: 37,Batch: 6 Valid Loss: 104.1471,Batch: 7 Valid Loss: 138.5468,Batch: 8 Valid Loss: 141.9434,Batch: 9 Valid Loss: 163.1604,Batch: 10 Valid Loss: 195.1403,Batch: 11 Valid Loss: 134.0441,Batch: 12 Valid Loss: 135.6687,Batch: 13 Valid Loss: 210.3593,Batch: 14 Valid Loss: 170.5062,Batch: 15 Valid Loss: 157.5619,Batch: 16 Valid Loss: 243.9199,Batch: 17 Valid Loss: 163.0798,Batch: 18 Valid Loss: 147.0687,Batch: 19 Valid Loss: 152.7307,Batch: 20 Valid Loss: 179.4877,Time: 0.34
Epoch: 37,Batch: 21 Valid Loss: 122.9611,Batch: 22 Valid Loss: 124.0443,Batch: 23 Valid Loss: 165.3088,Batch: 24 Valid Loss: 142.3539,Batch: 25 Valid Loss: 202.2605,Batch: 26 Valid Loss: 137.7361,Batch: 27 Valid Loss: 242.8415,Batch: 28 Valid Loss: 148.3895,Batch: 29 Valid Loss: 86.0574,Batch: 30 Valid Loss: 144.1157,Batch: 31 Valid Loss: 170.8322,Batch: 32 Valid Loss: 171.6427,Time: 0.25

Epoch: 37,Training Loss: 158.1220,Validation Loss: 158.3868,Time: 417.97

Starting epoch 38
Epoch: 38,Batch: 20 Training Loss: 105.7252,Time: 0.68
Epoch: 38,Batch: 40 Training Loss: 163.4334,Time: 0.67
Epoch: 38,Batch: 60 Training Loss: 185.7129,Batch: 80 Training Loss: 125.1632,Batch: 100 Training Loss: 178.0677,Batch: 120 Training Loss: 145.0643,Batch: 140 Training Loss: 172.1782,Batch: 160 Training Loss: 160.3347,Batch: 180 Training Loss: 151.9679,Batch: 200 Training Loss: 103.2668,Batch: 220 Training Loss: 134.6251,Batch: 240 Training Loss: 141.8546,Batch: 1 Valid Loss: 150.4089,Time: 0.35
Epoch: 38,Batch: 2 Valid Loss: 136.9565,Time: 0.36
Epoch: 38,Batch: 3 Valid Loss: 144.5437,Batch: 4 Valid Loss: 172.9728,Batch: 5 Valid Loss: 166.5804,Batch: 6 Valid Loss: 103.4765,Batch: 7 Valid Loss: 138.3381,Batch: 8 Valid Loss: 141.9496,Batch: 9 Valid Loss: 163.2548,Batch: 10 Valid Loss: 194.6338,Batch: 11 Valid Loss: 134.2889,Batch: 12 Valid Loss: 136.0695,Batch: 13 Valid Loss: 209.8143,Batch: 14 Valid Loss: 170.3180,Batch: 15 Valid Loss: 157.3526,Batch: 16 Valid Loss: 243.4948,Batch: 17 Valid Loss: 162.8065,Batch: 18 Valid Loss: 146.0056,Batch: 19 Valid Loss: 152.3638,Batch: 20 Valid Loss: 179.5833,Time: 0.34
Epoch: 38,Batch: 21 Valid Loss: 122.8715,Batch: 22 Valid Loss: 123.2680,Batch: 23 Valid Loss: 164.3962,Batch: 24 Valid Loss: 142.2115,Time: 0.33
Epoch: 38,Batch: 25 Valid Loss: 202.0091,Batch: 26 Valid Loss: 137.4782,Batch: 27 Valid Loss: 241.4298,Batch: 28 Valid Loss: 148.7050,Batch: 29 Valid Loss: 86.0582,Batch: 30 Valid Loss: 144.0733,Batch: 31 Valid Loss: 170.6064,Batch: 32 Valid Loss: 171.6329,Time: 0.25

Epoch: 38,Training Loss: 157.1817,Validation Loss: 158.1235,Time: 418.16

Starting epoch 39
Epoch: 39,Batch: 20 Training Loss: 105.1230,Time: 0.68
Epoch: 39,Batch: 40 Training Loss: 165.4860,Time: 0.67
Epoch: 39,Batch: 60 Training Loss: 185.0349,Batch: 80 Training Loss: 124.9884,Batch: 100 Training Loss: 176.9023,Batch: 120 Training Loss: 144.6198,Batch: 140 Training Loss: 170.2020,Batch: 160 Training Loss: 158.3938,Batch: 180 Training Loss: 151.8574,Batch: 200 Training Loss: 103.9747,Batch: 220 Training Loss: 134.2870,Batch: 240 Training Loss: 141.7696,Batch: 1 Valid Loss: 151.1508,Time: 0.34
Epoch: 39,Batch: 2 Valid Loss: 136.8730,Time: 0.35
Epoch: 39,Batch: 3 Valid Loss: 145.3472,Time: 0.36
Epoch: 39,Batch: 4 Valid Loss: 173.3746,Batch: 5 Valid Loss: 167.3553,Batch: 6 Valid Loss: 103.7990,Batch: 7 Valid Loss: 139.3748,Batch: 8 Valid Loss: 142.6575,Batch: 9 Valid Loss: 163.5934,Batch: 10 Valid Loss: 194.6525,Batch: 11 Valid Loss: 135.1157,Batch: 12 Valid Loss: 136.4958,Batch: 13 Valid Loss: 209.0710,Batch: 14 Valid Loss: 170.3045,Batch: 15 Valid Loss: 157.5657,Batch: 16 Valid Loss: 243.3302,Batch: 17 Valid Loss: 163.2397,Time: 0.37
Epoch: 39,Batch: 18 Valid Loss: 146.7547,Batch: 19 Valid Loss: 153.3413,Batch: 20 Valid Loss: 179.2610,Batch: 21 Valid Loss: 123.0379,Batch: 22 Valid Loss: 123.5644,Batch: 23 Valid Loss: 165.4950,Batch: 24 Valid Loss: 142.4183,Batch: 25 Valid Loss: 203.7861,Batch: 26 Valid Loss: 137.3413,Batch: 27 Valid Loss: 241.8771,Batch: 28 Valid Loss: 148.8467,Batch: 29 Valid Loss: 85.9060,Batch: 30 Valid Loss: 145.1030,Time: 0.31
Epoch: 39,Batch: 31 Valid Loss: 171.1640,Batch: 32 Valid Loss: 172.1705,Time: 0.26

Epoch: 39,Training Loss: 156.5736,Validation Loss: 158.5428,Time: 417.88

Starting epoch 40
Epoch: 40,Batch: 20 Training Loss: 104.7642,Time: 0.68
Epoch: 40,Batch: 40 Training Loss: 199.1840,Time: 0.64
Epoch: 40,Batch: 60 Training Loss: 203.7560,Batch: 80 Training Loss: 164.9527,Batch: 100 Training Loss: 186.1097,Batch: 120 Training Loss: 155.4772,Batch: 140 Training Loss: 192.1701,Time: 0.65
Epoch: 40,Batch: 160 Training Loss: 193.4585,Batch: 180 Training Loss: 171.5463,Batch: 200 Training Loss: 117.4412,Batch: 220 Training Loss: 152.0421,Batch: 240 Training Loss: 143.3178,Batch: 1 Valid Loss: 156.7238,Time: 0.35
Epoch: 40,Batch: 2 Valid Loss: 147.8145,Time: 0.36
Epoch: 40,Batch: 3 Valid Loss: 159.7137,Batch: 4 Valid Loss: 191.3246,Time: 0.32
Epoch: 40,Batch: 5 Valid Loss: 178.6363,Batch: 6 Valid Loss: 109.2589,Batch: 7 Valid Loss: 144.4912,Batch: 8 Valid Loss: 150.4129,Batch: 9 Valid Loss: 171.8096,Batch: 10 Valid Loss: 210.0710,Batch: 11 Valid Loss: 142.5819,Batch: 12 Valid Loss: 140.7213,Batch: 13 Valid Loss: 221.1351,Batch: 14 Valid Loss: 182.1034,Batch: 15 Valid Loss: 172.3420,Batch: 16 Valid Loss: 264.6909,Batch: 17 Valid Loss: 177.7539,Batch: 18 Valid Loss: 159.2539,Batch: 19 Valid Loss: 167.3184,Batch: 20 Valid Loss: 188.2323,Batch: 21 Valid Loss: 134.6375,Batch: 22 Valid Loss: 137.9939,Batch: 23 Valid Loss: 188.5730,Batch: 24 Valid Loss: 157.5786,Batch: 25 Valid Loss: 215.2611,Batch: 26 Valid Loss: 146.4763,Batch: 27 Valid Loss: 257.7494,Batch: 28 Valid Loss: 156.5165,Batch: 29 Valid Loss: 94.2914,Time: 0.34
Epoch: 40,Batch: 30 Valid Loss: 146.4372,Batch: 31 Valid Loss: 180.1848,Batch: 32 Valid Loss: 178.5307,Time: 0.24

Epoch: 40,Training Loss: 176.2027,Validation Loss: 169.7069,Time: 414.79

Starting epoch 41
Epoch: 41,Batch: 20 Training Loss: 128.6025,Time: 0.64
Epoch: 41,Batch: 80 Training Loss: 164.9526,Batch: 120 Training Loss: 155.4771,Time: 0.65
Epoch: 41,Batch: 180 Training Loss: 171.5462,Batch: 220 Training Loss: 152.0420,Time: 0.35
Epoch: 41,Time: 0.36
Epoch: 41,Time: 0.34
Epoch: 41,Batch: 11 Valid Loss: 142.5818,Batch: 15 Valid Loss: 172.3419,Batch: 17 Valid Loss: 177.7538,Batch: 18 Valid Loss: 159.2538,Batch: 19 Valid Loss: 167.3183,Batch: 20 Valid Loss: 188.2322,Batch: 21 Valid Loss: 134.6374,Batch: 24 Valid Loss: 157.5785,Batch: 30 Valid Loss: 146.4371,Time: 0.26
Epoch: 41,Training Loss: 178.2872,Validation Loss: 169.7068,Time: 414.65

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。