微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

CNN-LSTM 无法与 CTC loss 收敛

如何解决CNN-LSTM 无法与 CTC loss 收敛

我想使用 resnet50-LSTM 训练 LPR(车牌识别)模型来识别中国车牌。首先,我随机使用了十个省的车牌,该模型在一个时期内可以很好地收敛,但是当我在训练数据集中再添加一个省时,无论哪个省,它都无法收敛。 每个省有2000张图片,看起来像下面这些图片train dataset pics

模型

class res50LPRModel(Module):

def __init__(self,n_classes,input_shape=(3,64,128),pretrained=False,LSTMbidirectional=True,lstmHidenSize=256*2):
    super(res50LPRModel,self).__init__()
    self.input_shape = input_shape
    self.resnet50 = models.resnet50(pretrained=pretrained)
    self.resnet50 = nn.Sequential(*list(self.resnet50.children())[:7])
    if LSTMbidirectional:
        self.fc = nn.Sequential(nn.Linear(in_features=lstmHidenSize*2,out_features=128),nn.Dropout(0.25),nn.Linear(in_features=128,out_features=n_classes))
    else:
        self.fc = nn.Sequential(nn.Linear(in_features=lstmHidenSize,out_features=(lstmHidenSize/2)),nn.Dropout(0.5),nn.Linear(in_features=(lstmHidenSize/2),out_features=n_classes))
    self.conv1 = nn.Sequential(nn.Conv2d(1024,512,kernel_size=3,stride=1,padding=1),nn.Batchnorm2d(512),nn.ReLU(inplace=True))
    self.conv2 = nn.Sequential(nn.Conv2d(512,256,nn.Batchnorm2d(256),nn.ReLU(inplace=True))
    self.conv3 = nn.Sequential(nn.Conv2d(256,128,nn.Batchnorm2d(128))
    self.lstm = nn.LSTM(input_size=self.infer_features(),hidden_size=lstmHidenSize,num_layers=2,bidirectional=LSTMbidirectional)


def infer_features(self):
    x = torch.zeros((1,) + self.input_shape)
    x = self.resnet50(x)
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = x.reshape(x.shape[0],-1,x.shape[-1])
    return x.shape[1]

def forward(self,x):
    x = self.resnet50(x)
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = x.reshape(x.shape[0],x.shape[-1])
    x = x.permute(2,1)
    x,_ = self.lstm(x)
    x = self.fc(x)
    return x

数据加载器:

class labelTestDataLoader(Dataset):
def __init__(self,img_dir,imgSize,provinces=["皖","沪","津","渝","冀","晋","蒙","辽","吉","黑","苏","浙","京","闽","赣","鲁","豫","鄂","湘","粤","桂","琼","川","贵","云","藏","陕","甘","青","宁","新","警","学"],alphabets=['A','B','C','D','E','F','G','H','J','K','L','M','N','P','Q','R','S','T','U','V','W','X','Y','Z'],ads=['A','Z','0','1','2','3','4','5','6','7','8','9'],transform=None):
    self.img_dir = img_dir
    # self.img_paths = []
    # for i in range(len(img_dir)):
    #     print(i)
    #     self.img_paths += [el for el in paths.list_images(img_dir[i])]
    self.img_paths = os.listdir(img_dir)
    # print(self.img_paths)
    self.img_size = imgSize
    self.transform = transform

    self.provinces = provinces
    self.alphabets = alphabets
    self.ads = ads
    self.all_class = self.provinces
    self.all_class.extend(self.alphabets)
    self.all_class.extend(self.ads)
    self.all_class = list(set(self.all_class))
    self.all_class.sort()
    characters = ['_']
    characters.extend(self.all_class)
    self.all_class = characters

def __len__(self):
    return len(self.img_paths)

def __getitem__(self,index):
    img_name = self.img_paths[index]
    img = cv2.imread(os.path.join(self.img_dir,img_name))
    # img = img.astype('float32')
    resizedImage_out = cv2.resize(img,self.img_size)
    # resizedImage = np.transpose(resizedImage,(2,1))

    resizedImage = self.transform(resizedImage_out)
    # resizedImage /= 255.0
    lbl = img_name.split('/')[-1].split('.')[0].split('-')[-3]
    new_label = []
    label_text =""
    # 01-0_1-249&528_393&586-392&584_249&586_250&530_393&528-0_0_25_27_7_26_29-131-21.jpg
    # 01-0_1-249&528_393&586-392&584_249&586_250&530_393&528-0_0_25_27_7_26_29-131-21.jpg
    # 01-0_1-249&528_393&586-392&584_249&586_250&530_393&528-0_0_2_4_5_23_22-131-21.jpg
    for label_index,label_ in enumerate(lbl.split('_')):
        # print(label_)
        if label_index == 0:
            new_label.append(self.all_class.index(self.provinces[int(label_)]))
            label_text=label_text+(self.provinces[int(label_)])
        elif label_index == 1:
            new_label.append(self.all_class.index(self.alphabets[int(label_)]))
            label_text=label_text+(self.alphabets[int(label_)])
        else:
            new_label.append(self.all_class.index(self.ads[int(label_)]))
            label_text=label_text+(self.ads[int(label_)])
    # label_text=torch.as_tensor(label_text)
    new_label = torch.tensor(new_label)
    # resizedImage = to_tensor(resizedImage)
    label_length = int(len(new_label))  # 预测的长度是多少
    input_length = int(label_length * 3)  # 一般是预测一个数据中间会放一个_来做为间隔数据
    input_length = torch.full(size=(1,),fill_value=input_length,dtype=torch.long)  # 后续中ctc会用到
    target_length = torch.full(size=(1,fill_value=label_length,dtype=torch.long)
    return resizedImage,new_label,img_name,input_length,target_length,resizedImage_out,label_text

训练代码

import torch.nn.functional as F
from tqdm import tqdm
from PIL import ImageFont,ImageDraw,Image
from torch.utils.data import Dataset,DataLoader
import numpy as np
import torch
from tensorboardX import SummaryWriter
from pylab import mpl
import setproctitle
from model.allModel import rawModel,res18LPRModel,res50LPRModel
from data import labelTestDataLoader
from torchvision.utils import make_grid
import cv2 as cv
import matplotlib.pylab as plt
import torchvision
from torchvision.transforms import ToTensor,normalize

setproctitle.setproctitle("lpr")

writer = SummaryWriter(log_dir="outlog")
mpl.rcParams['font.sans-serif'] = ['FangSong']  # 指定认字体
mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题


def decode(sequence):
  a = ''.join([characters[x] for x in sequence])
  s = ''.join([x for j,x in enumerate(a[:-1]) if x != characters[0] and x != a[j + 1]])
  if len(s) == 0:
      return ''
  if a[-1] != characters[0] and s[-1] != a[-1]:
      s += a[-1]
  return s

def decode_target(sequence):
  return ''.join([characters[x] for x in sequence]).replace(' ','')

def calc_acc(target,output):
  output_argmax = output.detach().permute(1,2).argmax(dim=-1)
  target = target.cpu().numpy()
  output_argmax = output_argmax.cpu().numpy()
  a = np.array([decode_target(true) == decode(pred) for true,pred in zip(target,output_argmax)])
  return a.mean()

def train(model,optimizer,epoch,DataLoader):
  model.train()
  loss_mean = 0
  acc_mean = 0
  with tqdm(DataLoader) as pbar:
      for batch_index,(data,target,input_lengths,target_lengths,imgs,label_text) in enumerate(pbar):

          data,target = data.cuda(),target.cuda()

          optimizer.zero_grad()
          output = model(data)
          # print(input_lengths)
          # print(target_lengths)
          output_log_softmax = F.log_softmax(output,dim=-1)
          loss = F.ctc_loss(output_log_softmax,target_lengths)

          loss.backward()
          optimizer.step()

          loss = loss.item()
          acc = calc_acc(target,output)

          if batch_index == 0:
              loss_mean = loss
              acc_mean = acc

          loss_mean = 0.1 * loss + 0.9 * loss_mean
          acc_mean = 0.1 * acc + 0.9 * acc_mean

          pbar.set_description(f'Epoch: {epoch} Loss: {loss_mean:.4f} Acc: {acc_mean:.4f} ')
          writer.add_scalar("train_loss",loss_mean,epoch)
          writer.add_scalar("train_acc",acc_mean,epoch)

          rawImgTensorList=[]
          for index,img in enumerate(imgs):
              tempImg=cv.putText(img.numpy(),label_text[index],(0,20),cv.FONT_HERShey_SIMPLEX,0.7,(255,255,255))
              rawImgTensorList.append(torch.from_numpy(tempImg).permute(2,1))
          showGrid=make_grid(rawImgTensorList,padding=50)
          writer.add_image("batchImg",showGrid,global_step=epoch)


def valid(model,DataLoader):
  model.eval()
  with tqdm(DataLoader) as pbar,torch.no_grad():
      loss_sum = 0
      acc_sum = 0
      for batch_index,target_lengths) in enumerate(pbar):
          data,target.cuda()

          output = model(data)

          output_argmax = output.detach().permute(1,2).argmax(dim=-1)
          # predict=decode(output_argmax[0])
          # print('pred:',decode(output_argmax[0]))

          output_log_softmax = F.log_softmax(output,target_lengths)

          loss = loss.item()
          acc = calc_acc(target,output)

          loss_sum += loss
          acc_sum += acc

          loss_mean = loss_sum / (batch_index + 1)
          acc_mean = acc_sum / (batch_index + 1)

          pbar.set_description(f'Test : {epoch} Loss: {loss_mean:.4f} Acc: {acc_mean:.4f} ')
          writer.add_scalar("test_loss",epoch)
          writer.add_scalar("test_acc",epoch)


def text_name(name,image,position,size,color,fontpath="msyh.ttf"):
  fontpath = fontpath  # 32为字体大小
  font = ImageFont.truetype(fontpath,size)
  img_pil = Image.fromarray(image)
  draw = ImageDraw.Draw(img_pil)
  # 绘制文字信息<br># (100,300/350)为字体的位置,(255,255)为白色,(0,0)为黑色
  draw.text(position,name,font=font,fill=color)
  bk_img = np.array(img_pil)
  return bk_img

if __name__ == '__main__':
  ##车牌标签
  provinces = ["皖","学"]
  alphabets = ['A','Z']
  ads = ['A','9']

  batch_size = 16

  ##车牌拍标签拼接
  all_class = provinces
  all_class.extend(alphabets)
  all_class.extend(ads)
  all_class = list(set(all_class))
  all_class.sort()
  characters = ['_']  # 增加预测不出来的数据结果,用_表示
  characters.extend(all_class)  # 所有预测结果的拼接
  ##预测车牌数据为7位数据,鲁K.12345
  target_len = 7
  model_output_len = 7 * 3  ##每个预测目标之间需要穿插一个占位符

  width,height = int(model_output_len * 16),168  ##每16个wid 输出一个1,预测7个数要21个字符,21个字符 *16 = 336
  n_classes = len(all_class)  # 预测类别

  transformImage = torchvision.transforms.Compose(
      [ToTensor(),normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])

  train_set = labelTestDataLoader(r'License plate recognition\datasets\eleven',(width,height),transform=transformImage)  # 训练数据集
  valid_set = labelTestDataLoader(r'License plate recognition\datasets\LicensePlaceCropped\valid',transform=transformImage)  # 测试数据集

  train_loader = DataLoader(train_set,batch_size=batch_size,shuffle=True,num_workers=0)  ##要调成0 否则报进程错还在修改中

  valid_loader = DataLoader(valid_set,num_workers=0)
  # 数据展示
  for batch_index,(resizedImage,new_labels,img,label_text) in enumerate(
          train_loader):  # torch.Size([64,3,168,336])
      if batch_index > 2:
          break

  model = res50LPRModel(n_classes,height,width),pretrained=True)
  # model=rawModel(n_classes,width))
  # model.load_state_dict(torch.load("ctc_LPR_model_dict.pt"))
  # model=torch.load('ctc_carCCPD.pth')
  # torch.save(model.state_dict(),"ctc_LPR_model_dict.pt")
  print(model)
  model = model.cuda()
  #
  # for param in model.parameters():
  #     param.requires_grad = False
  # for param in model.lstm.parameters():
  #     param.requires_grad = True
  # for param in model.fc.parameters():
  #     param.requires_grad = True
  """update parameters using different lr"""
  finetuneLr=8e-5
  lr=1e-4
  optimizer = torch.optim.Adam(
      [{'params': model.resnet50.parameters(),'lr': finetuneLr},{'params': model.lstm.parameters(),'lr': lr},{'params': model.conv1.parameters(),{'params': model.conv2.parameters(),{'params': model.conv3.parameters(),{'params': model.fc.parameters(),'lr': lr}],amsgrad=True,weight_decay=1e-5)
  # model_train 训练
  epochs = 1
  for epoch in range(1,epochs + 1):
      train(model,train_loader)
      # valid(model,valid_loader)
  ##训练调整参数
  finetuneLr = 1e-5
  lr = 1e-4
  optimizer = torch.optim.Adam(
      [{'params': model.resnet50.parameters(),weight_decay=1e-5)

  epochs = 20
  for epoch in range(1,valid_loader)
  torch.save(model.state_dict(),'ctc_car_FiveClass_20210409.pth')

这让我困惑了很长时间。有人知道为什么会这样吗? 非常感谢您的帮助!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。