微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 Google Colab 和本地机器上训练 DeepLab ResNet V3 之间的巨大差异

如何解决在 Google Colab 和本地机器上训练 DeepLab ResNet V3 之间的巨大差异

我正在尝试训练 Deeplab resnet V3 以对自定义数据集执行语义分割。我一直在我的本地机器上工作,但我的 GPU 只是一个小型 Quadro T1000,所以我决定将我的模型移到 Google Colab 上,以利用他们的 GPU 实例并获得更好的结果。

虽然我获得了我希望的速度提升,但与我的本地机器相比,我在 colab 上的训练损失大不相同。我已经复制并粘贴了完全相同的代码,所以我能找到的唯一区别是在数据集中。我使用的是完全相同的数据集,除了 colab 上的一个是 Google Drive 上本地数据集的副本。我注意到 Drive 订单文件在 Windows 上有所不同,但我看不出这是一个问题,因为我随机打乱了数据集。我知道这些随机分裂可能会导致输出的微小差异,但是训练损失大约 10 倍的差异是没有意义的。

我也尝试过在 colab 上使用不同的随机种子、不同的批次大小、不同的 train_test_split 参数运行该版本,并将优化器从 SGD 更改为 Adam,但是,这仍然导致模型很早收敛,损失约为0.5.

这是我的代码

import torch
from torch.utils import data
from torchvision import transforms
from customdatasets import SegmentationDataSet
import pathlib
from sklearn.model_selection import train_test_split
from customtransforms import Compose,AlbuSeg2d,DenseTarget
from customtransforms import MoveAxis,normalize01,Resize
import albumentations
import matplotlib.pyplot as plt
import time
import GPUtil



def get_filenames_of_path(path: pathlib.Path,ext: str = '*'):
    """Returns a list of files in a directory/path. Uses pathlib."""
    filenames = [file for file in path.glob(ext) if file.is_file()]
    return filenames


if __name__ == '__main__':

    root = pathlib.Path.cwd() / 'train'
    inputs = get_filenames_of_path(root / 'input')
    targets = get_filenames_of_path(root / 'target')

    

# training transformations and augmentations
    transforms_training = Compose([
        Resize(input_size=(128,128,3),target_size=(128,128)),AlbuSeg2d(albu=albumentations.HorizontalFlip(p=0.5)),MoveAxis(),normalize01()
    ])
# validation transformations
    transforms_validation = Compose([
        Resize(input_size=(128,normalize01()
    ])
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    random_seed = 142
    train_size = 0.8

    inputs_train,inputs_valid = train_test_split(
        inputs,random_state=random_seed,train_size=train_size,shuffle=True)
    targets_train,targets_valid = train_test_split(
        targets,shuffle=True)

    dataset_train = SegmentationDataSet(inputs=inputs_train,targets=targets_train,transform=transforms_training,device=device)

    dataset_valid = SegmentationDataSet(inputs=inputs_valid,targets=targets_valid,transform=transforms_validation,device=device)


    DataLoader_training = data.DataLoader(dataset=dataset_train,batch_size=15,shuffle=True,num_workers=4,pin_memory=True)

    DataLoader_validation = data.DataLoader(dataset=dataset_valid,pin_memory=True)


    model = torch.hub.load('pytorch/vision:v0.6.0','deeplabv3_resnet101',pretrained=False)


    criterion = torch.nn.CrossEntropyLoss()

    model = model.to(device)

    optimizer = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.99)


    epochs = 10
    steps = 0
    running_loss = 0
    print_every = 10
    train_losses,valid_losses = [],[]

    start_time = time.time()
    prev_time = time.time()


    for epoch in range(epochs):
        #Training
        for inputs,labels in DataLoader_training:
            steps += 1
            inputs,labels = inputs.to(device,non_blocking=True),labels.to(device,non_blocking=True)
            optimizer.zero_grad()
            logps = model(inputs)
            loss = criterion(logps['out'],labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if steps % print_every == 0:
                train_losses.append(running_loss / len(DataLoader_training))
                epoch_time = time.time()
                elasped_time = epoch_time - prev_time
                prev_time = epoch_time
                print(f"Epoch {epoch + 1}/{epochs}.. "
                    f"Train loss: {running_loss / print_every:.3f}.. "
                    f"Elapsed time: {elasped_time}")

                running_loss = 0
                model.train()
        # Evaluation
        valid_loss = 0
        accuracy = 0
        model.eval()
        with torch.no_grad():
            for inputs,labels in DataLoader_validation:
                inputs,non_blocking=True)
                logps = model.forward(inputs)
                batch_loss = criterion(logps['out'],labels)
                valid_loss += batch_loss.item()

                ps = torch.exp(logps['out'])
                top_p,top_class = ps.topk(1,dim=1)
                equals = top_class == labels.view(*top_class.shape)
                accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
        valid_losses.append(valid_loss / len(DataLoader_validation))
        print(f"Epoch {epoch + 1}/{epochs}.. "
            f"Validation loss: {valid_loss / len(DataLoader_training):.3f}.. "
            f"Validation accuracy: {accuracy / len(DataLoader_training):.3f} ")
        model.train()
    torch.save(model,'model.pth')

    end_time = time.time()
    total_time = end_time - start_time
    print("Total Time: ",total_time)
    plt.plot(train_losses,label='Training loss')
    plt.plot(valid_losses,label='Validation loss')
    plt.legend(frameon=False)
    plt.show()

这是 Colab 上一个 epoch 的输出

Epoch 1/10.. Train loss: 2.080.. Elapsed time: 12.156640768051147
Epoch 1/10.. Train loss: 1.231.. Elapsed time: 8.76858925819397
Epoch 1/10.. Train loss: 1.051.. Elapsed time: 8.315532445907593
Epoch 1/10.. Train loss: 0.890.. Elapsed time: 8.249168634414673
Epoch 1/10.. Train loss: 0.839.. Elapsed time: 8.248667478561401
Epoch 1/10.. Train loss: 0.807.. Elapsed time: 8.120820999145508
Epoch 1/10.. Train loss: 0.742.. Elapsed time: 8.298616886138916
Epoch 1/10.. Train loss: 0.726.. Elapsed time: 8.170734167098999
Epoch 1/10.. Train loss: 0.677.. Elapsed time: 8.221246004104614
Epoch 1/10.. Train loss: 0.698.. Elapsed time: 8.124614000320435
Epoch 1/10.. Train loss: 0.675.. Elapsed time: 8.197462558746338
Epoch 1/10.. Train loss: 0.682.. Elapsed time: 8.263437509536743
Epoch 1/10.. Train loss: 0.626.. Elapsed time: 8.156179189682007
Epoch 1/10.. Train loss: 0.632.. Elapsed time: 8.268096446990967
Epoch 1/10.. Train loss: 0.616.. Elapsed time: 8.214547872543335
Epoch 1/10.. Train loss: 0.585.. Elapsed time: 8.31475019454956
Epoch 1/10.. Train loss: 0.598.. Elapsed time: 8.388074398040771
Epoch 1/10.. Train loss: 0.626.. Elapsed time: 8.179292440414429
Epoch 1/10.. Train loss: 0.612.. Elapsed time: 8.252359390258789
Epoch 1/10.. Train loss: 0.592.. Elapsed time: 8.284745693206787
Epoch 1/10.. Train loss: 0.597.. Elapsed time: 8.31213927268982
Epoch 1/10.. Train loss: 0.566.. Elapsed time: 8.164374113082886
Epoch 1/10.. Train loss: 0.556.. Elapsed time: 8.300082206726074
Epoch 1/10.. Train loss: 0.568.. Elapsed time: 8.26304841041565
Epoch 1/10.. Train loss: 0.572.. Elapsed time: 8.309881448745728
Epoch 1/10.. Train loss: 0.586.. Elapsed time: 8.211671352386475
Epoch 1/10.. Train loss: 0.586.. Elapsed time: 8.321797609329224
Epoch 1/10.. Train loss: 0.535.. Elapsed time: 8.318871021270752
Epoch 1/10.. Train loss: 0.543.. Elapsed time: 8.152915239334106
Epoch 1/10.. Train loss: 0.569.. Elapsed time: 8.251380205154419
Epoch 1/10.. Train loss: 0.526.. Elapsed time: 8.29153847694397
Epoch 1/10.. Train loss: 0.565.. Elapsed time: 8.15071702003479
Epoch 1/10.. Train loss: 0.542.. Elapsed time: 8.253364562988281
Epoch 1/10.. Validation loss: 0.182.. Validation accuracy: 0.271 

这是我本地机器上的输出

Epoch 1/10.. Train loss: 2.932.. Elapsed time: 32.148621797561646
Epoch 1/10.. Train loss: 1.852.. Elapsed time: 14.120505809783936
Epoch 1/10.. Train loss: 0.887.. Elapsed time: 14.210048198699951
Epoch 1/10.. Train loss: 0.618.. Elapsed time: 14.23294186592102
Epoch 1/10.. Train loss: 0.549.. Elapsed time: 14.212541103363037
Epoch 1/10.. Train loss: 0.519.. Elapsed time: 14.047481775283813
Epoch 1/10.. Train loss: 0.506.. Elapsed time: 14.060708045959473
Epoch 1/10.. Train loss: 0.347.. Elapsed time: 14.301624059677124
Epoch 1/10.. Train loss: 0.399.. Elapsed time: 13.9844491481781
Epoch 1/10.. Train loss: 0.361.. Elapsed time: 13.957871913909912
Epoch 1/10.. Train loss: 0.305.. Elapsed time: 14.164010763168335
Epoch 1/10.. Train loss: 0.296.. Elapsed time: 14.001536846160889
Epoch 1/10.. Train loss: 0.298.. Elapsed time: 14.019971132278442
Epoch 1/10.. Train loss: 0.271.. Elapsed time: 13.951345443725586
Epoch 1/10.. Train loss: 0.252.. Elapsed time: 14.037938594818115
Epoch 1/10.. Train loss: 0.283.. Elapsed time: 13.944657564163208
Epoch 1/10.. Train loss: 0.299.. Elapsed time: 13.977224826812744
Epoch 1/10.. Train loss: 0.219.. Elapsed time: 13.941975355148315
Epoch 1/10.. Train loss: 0.242.. Elapsed time: 13.936140060424805
Epoch 1/10.. Train loss: 0.244.. Elapsed time: 13.942122459411621
Epoch 1/10.. Train loss: 0.216.. Elapsed time: 13.960899114608765
Epoch 1/10.. Train loss: 0.186.. Elapsed time: 13.956881523132324
Epoch 1/10.. Train loss: 0.241.. Elapsed time: 13.944581985473633
Epoch 1/10.. Train loss: 0.203.. Elapsed time: 13.934357404708862
Epoch 1/10.. Train loss: 0.189.. Elapsed time: 13.938358306884766
Epoch 1/10.. Train loss: 0.181.. Elapsed time: 13.944468021392822
Epoch 1/10.. Train loss: 0.186.. Elapsed time: 13.946297407150269
Epoch 1/10.. Train loss: 0.164.. Elapsed time: 13.940366744995117
Epoch 1/10.. Train loss: 0.165.. Elapsed time: 13.938241720199585
Epoch 1/10.. Train loss: 0.176.. Elapsed time: 14.015569925308228
Epoch 1/10.. Train loss: 0.165.. Elapsed time: 14.019208669662476
Epoch 1/10.. Train loss: 0.175.. Elapsed time: 14.149503469467163
Epoch 1/10.. Train loss: 0.159.. Elapsed time: 14.128302097320557
Epoch 1/10.. Train loss: 0.155.. Elapsed time: 13.935027837753296
Epoch 1/10.. Train loss: 0.137.. Elapsed time: 13.937382221221924
Epoch 1/10.. Train loss: 0.127.. Elapsed time: 13.929635524749756
Epoch 1/10.. Train loss: 0.133.. Elapsed time: 13.935472011566162
Epoch 1/10.. Train loss: 0.152.. Elapsed time: 13.922808647155762
Epoch 1/10.. Validation loss: 0.032.. Validation accuracy: 0.239

我不会粘贴更多,因为它很长并且需要一段时间才能运行,但到第 3 个时期结束时,Colab 模型的损失仍在 0.5 左右反弹,而在本地达到 0.02。

如果有人能帮我解决这个问题,我将不胜感激。

解决方法

我通过将训练数据解压缩到 Google Drive 并从那里读取文件来解决这个问题,而不是使用 Colab 命令将文件夹直接解压缩到我的工作区。我完全不知道为什么会导致这个问题;对图像及其相应的张量进行快速目视检查看起来不错,但我无法逐一检查 6,000 个左右的图像来检查每一个。 如果有人知道这导致问题的原因,请告诉我!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?