基于Resnet的钢材表面缺陷分类问题

目录

前言

一、数据集

二、数据预处理

1.将图片随机分到不同的文件夹

1.1 创建对应类别文件夹并且进行同类分类

1.2 将训练数据按比例分开训练集和验证集

2.加载数据集和验证集

二.模型加载

1.模型选择

2.模型搭建

3.实例化模型

三.损失函数

四.训练

五.训练结果

总结

参考​​​​​​​


前言

       本文是主要讲述在灰度图如何进行迁移学习的图像分类任务,采用的是钢材表面缺陷数据集,并且使用resnet迁移学习进行训练。


一、数据集

本文采用的是钢材表面缺陷数据集,该数据集的特点有:

  1. 该数据集主要分有七类,在训练中分别用0到6代替其中的类型
  2. 该数据集全都是单通道的灰度图,图片后缀名为.bmp
  3. 该数据集经过分类训练集有2000张照片,测试集有400张
  4. 该数据集有training.csv,包括了训练集的id和label

数据集类别如图所示,从左到右分别为0、1、2、3、4、5、6类别:

training.csv部分内容如图所示:

二、数据预处理

1.将图片随机分到不同的文件

1.1 创建对应类别文件夹并且进行同类分类

      - 创建一个TrainAll文件夹,其中包含0、1、2、3、4、5、6七个文件

      - 输入混合的训练集图片路径,将混合的照片按照training.csv分到对应标签文件

def integration():
    """根据csv文件构建文件夹"""
    #读取csv文件,计算长度
    data_x= pd.read_csv(filepath_or_buffer = 'training.csv', sep = ',')["Image_name	/(.bmp)"].values
    data_y= pd.read_csv(filepath_or_buffer = 'training.csv', sep = ',')["categories"].values
    num = len(data_x)
    #如果原路径存在该文件,先清空了,保证数据准确
    if os.path.exists('TrainAll') :
        shutil.rmtree('TrainAll')
    #创建训练分类文件夹
    for i in range(8):
        i = str(i)
        os.makedirs(os.path.join('TrainAll',i))
    #输入混合图片路径
    dir = input('请输入训练集的文件夹的路径:')
    root = ''
    for _root, _dirs, _files in os.walk(dir):  # root为输入的文件夹名,dirs为子文件夹的名字,files为图片
        root = _root
        print("训练集上有{}张照片".format(len(_files)))
    #对训练集上的数据进行分类
    for j in range(8):
        label = str(j)
        for i in range(num) :
            if data_y[i] == j:
                i = ''+ str(i+1)+'.bmp'
                shutil.copy(os.path.join(root,i), os.path.join('TrainAll',label))

1.2 将训练数据按比例分开训练集和验证集

      - 因为该数据集的测试集未公布,所以训练的过程中需要在训练集中划分出验证集

      - 输入训练集和验证集的比例,代码自动按照比例进行分配

def Classify ():
    """按比分类训练集和测试集"""
    dir = 'TrainAll'
    proportion = input ('请输入训练集占全集(验证集和训练集总和)的比例(小数):') #输入比例
    proportion  = float(proportion)
    if os.path.exists('data'):
        shutil.rmtree('data')  # 删除文件夹,保证数据准确
    print("正在分类")
    for _root,_dirs,_files in os.walk(dir):     # root为输入的文件夹名,dirs为子文件夹的名字,files为图片
        for _name in _dirs:                    # 遍历
            folder = os.path.join(_root,_name) # 将文件名和子文件名合成一个路径
            moveFile(folder,_name,proportion)     #划分新的数据集
        print("分类完成!")
        break #循环一次就行

2.加载数据集和验证集

      - 训练集和验证集分别进行不同的图像增强,分别进行加载数据

      - 必须有transforms.Grayscale(1)读入灰度图,否则为RGB

      - transforms.normalize中的0.485, 0.229是在训练集中灰度图中计算得到的均值和方差

    # 加载数据
    data_transform = {
        "train": transforms.Compose([transforms.Grayscale(1)
                                     transforms.RandomresizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.normalize(0.485, 0.229, inplace=True)]),
        "val": transforms.Compose([transforms.Grayscale(1)
                                   transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.normalize(0.485, 0.229, inplace=True)])}
    image_path = os.path.join("data")
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)
    # {'0':0, '1':1, '2':2, '3':3, '4':4, '5':5, '6':6, '7':7}
    list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in list.items())
    # 写json文件,方便对应类别
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    batch_size = 8
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=0)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=0)
    print("训练集有{}张照片, 测试集有{}张照片".format(train_num,val_num))

二.模型加载

1.模型选择

      - 在保证准确率的同时,提高分类速度,选择了resnet-18的模型结构

      - 因为分类任务类似,因此使用resnet-18 提供的预训练权重进行训练

2.模型搭建

      - 因为是灰度图,所以class resnet中的self.conv1中的输入通道应该为1

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                                kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.Batchnorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.Batchnorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out


class resnet(nn.Module):
    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(resnet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64
        self.groups = groups
        self.width_per_group = width_per_group
        self.conv1 = nn.Conv2d(1, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.Batchnorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.Batchnorm2d(channel * block.expansion))
        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion
        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
        return x

def resnet18(num_classes=1000, include_top=True):
    return resnet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)

3.实例化模型

      - 将下载好的预训练权重放到与代码同以路径下,读入权重

      - 将全连接层1000改为8

    #加载模型
    net = resnet18()
    model_weight_path = "./resnet18-5c106cde.pth"
    pre_state_dict = torch.load(model_weight_path)
    new_state_dict = {}
    #遍历修改模型的各个层
    for k, v in net.state_dict().items():
        # 如果原模型的层也在新模型的层里面, 那新模型就加载原先训练好的权重
        if k in pre_state_dict.keys() and k != 'conv1.weight':
            new_state_dict[k] = pre_state_dict[k]  
    net.load_state_dict(new_state_dict, False)
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 8)
    net.to(device)

三.损失函数

        - 使用交叉熵损失函数

        - 使用Afam学习率为0.0001的优化方法

    #损失函数
    loss_function = nn.CrossEntropyLoss()
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)
    

四.训练

      - 运用了tqdm进行可视化进程

      - 方向梯度更新权重

      - 比较验证集的准确率,保存准确率高的权重

#训练
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()
            #可视化数据
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs, loss)
        #检测
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        #只保留最好的预测模型
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
    print('训练完成!')
    print('训练模型已保存到当前目录')

五.训练结果

以下为部分的训练数据:

可看出在原来的权重上进行训练可以更快得到更高的准确率

训练得到的模型在未知的测试集中的400张照片可以达到100%的准确率


总结

以上就是今天要讲的内容,本文仅仅简单介绍了如何在灰度图中进行迁移学习,在此之前试过将灰度图转为RGB进行迁移学习,但是在验证集的准确率一直稳定在50%,分析了之后,原因是灰度图转为RGB时,颜色是不可控的,然而图像的颜色也是图像进行分类的依据之一,因此只能用改模型的方法进行迁移学习。

参考

ResNet训练单通道图像分类网络(Pytorch)_望~的博客-CSDN博客_resnet train.py

GitHub - WZMIAOMIAO/deep-learning-for-image-processing: deep learning for image processing including classification and object-detection etc.

(55条消息) ResNet50修改网络适应灰度图片并加载预训练模型_吕大娟的博客-CSDN博客_resnet50修改
 

不当之处请多多指教
 

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


学习编程是顺着互联网的发展潮流,是一件好事。新手如何学习编程?其实不难,不过在学习编程之前你得先了解你的目的是什么?这个很重要,因为目的决定你的发展方向、决定你的发展速度。
IT行业是什么工作做什么?IT行业的工作有:产品策划类、页面设计类、前端与移动、开发与测试、营销推广类、数据运营类、运营维护类、游戏相关类等,根据不同的分类下面有细分了不同的岗位。
女生学Java好就业吗?女生适合学Java编程吗?目前有不少女生学习Java开发,但要结合自身的情况,先了解自己适不适合去学习Java,不要盲目的选择不适合自己的Java培训班进行学习。只要肯下功夫钻研,多看、多想、多练
Can’t connect to local MySQL server through socket \'/var/lib/mysql/mysql.sock问题 1.进入mysql路径
oracle基本命令 一、登录操作 1.管理员登录 # 管理员登录 sqlplus / as sysdba 2.普通用户登录
一、背景 因为项目中需要通北京网络,所以需要连vpn,但是服务器有时候会断掉,所以写个shell脚本每五分钟去判断是否连接,于是就有下面的shell脚本。
BETWEEN 操作符选取介于两个值之间的数据范围内的值。这些值可以是数值、文本或者日期。
假如你已经使用过苹果开发者中心上架app,你肯定知道在苹果开发者中心的web界面,无法直接提交ipa文件,而是需要使用第三方工具,将ipa文件上传到构建版本,开...
下面的 SQL 语句指定了两个别名,一个是 name 列的别名,一个是 country 列的别名。**提示:**如果列名称包含空格,要求使用双引号或方括号:
在使用H5混合开发的app打包后,需要将ipa文件上传到appstore进行发布,就需要去苹果开发者中心进行发布。​
+----+--------------+---------------------------+-------+---------+
数组的声明并不是声明一个个单独的变量,比如 number0、number1、...、number99,而是声明一个数组变量,比如 numbers,然后使用 nu...
第一步:到appuploader官网下载辅助工具和iCloud驱动,使用前面创建的AppID登录。
如需删除表中的列,请使用下面的语法(请注意,某些数据库系统不允许这种在数据库表中删除列的方式):
前不久在制作win11pe,制作了一版,1.26GB,太大了,不满意,想再裁剪下,发现这次dism mount正常,commit或discard巨慢,以前都很快...
赛门铁克各个版本概览:https://knowledge.broadcom.com/external/article?legacyId=tech163829
实测Python 3.6.6用pip 21.3.1,再高就报错了,Python 3.10.7用pip 22.3.1是可以的
Broadcom Corporation (博通公司,股票代号AVGO)是全球领先的有线和无线通信半导体公司。其产品实现向家庭、 办公室和移动环境以及在这些环境...
发现个问题,server2016上安装了c4d这些版本,低版本的正常显示窗格,但红色圈出的高版本c4d打开后不显示窗格,
TAT:https://cloud.tencent.com/document/product/1340