微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow (Keras) U-Net Segmentation 训练失败

如何解决Tensorflow (Keras) U-Net Segmentation 训练失败

我正在尝试使用 Tensorflow 和 Keras 训练 U-Net。型号如下图。

def get_unet(IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS):
    inputs = Input((IMG_HEIGHT,IMG_CHANNELS))
    conv1 = Conv2D(32,3,activation='relu',padding='same',kernel_initializer='he_normal')(inputs)
    conv1 = Conv2D(32,kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2,2))(conv1)
    conv2 = Conv2D(64,kernel_initializer='he_normal')(pool1)
    conv2 = Conv2D(64,kernel_initializer='he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2,2))(conv2)
    conv3 = Conv2D(128,kernel_initializer='he_normal')(pool2)
    conv3 = Conv2D(128,kernel_initializer='he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2,2))(conv3)
    conv4 = Conv2D(256,kernel_initializer='he_normal')(pool3)
    conv4 = Conv2D(256,kernel_initializer='he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2,2))(drop4)

    conv5 = Conv2D(512,kernel_initializer='he_normal')(pool4)
    conv5 = Conv2D(512,kernel_initializer='he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(256,2,kernel_initializer='he_normal')(UpSampling2D(size=(2,2))(drop5))
    merge6 = concatenate([drop4,up6],axis=3)
    conv6 = Conv2D(256,kernel_initializer='he_normal')(merge6)
    conv6 = Conv2D(256,kernel_initializer='he_normal')(conv6)

    up7 = Conv2D(128,2))(conv6))
    merge7 = concatenate([conv3,up7],axis=3)
    conv7 = Conv2D(128,kernel_initializer='he_normal')(merge7)
    conv7 = Conv2D(128,kernel_initializer='he_normal')(conv7)

    up8 = Conv2D(64,2))(conv7))
    merge8 = concatenate([conv2,up8],axis=3)
    conv8 = Conv2D(64,kernel_initializer='he_normal')(merge8)
    conv8 = Conv2D(64,kernel_initializer='he_normal')(conv8)

    up9 = Conv2D(32,2))(conv8))
    merge9 = concatenate([conv1,up9],axis=3)
    conv9 = Conv2D(32,kernel_initializer='he_normal')(merge9)
    conv9 = Conv2D(32,kernel_initializer='he_normal')(conv9)
    conv9 = Conv2D(2,kernel_initializer='he_normal')(conv9)
    conv10 = Conv2D(1,1,activation='sigmoid')(conv9)

    model = Model(inputs=[inputs],outputs=[conv10])

    #model.compile(optimizer = Adam(lr = 1e-3),loss = [dice_coef_loss],metrics = [dice_coef])
    model.compile(optimizer = Adam(lr = 1e-4),loss = [jacard_coef_loss],metrics = [jaccard_distance])
    #model.compile(optimizer=Adam(lr=1e-4),loss='binary_crossentropy',metrics=['accuracy'])

    model.summary()

    return model

问题在于训练(或者可能是数据)。数据是jpg图像。掩码仅包含一类。因为一些小物体我引入了jaccard_distance。培训是通过以下方式完成的:

seed = 42
np.random.seed = seed

IMG_WIDTH = 512
IMG_HEIGHT = 512
IMG_CHANNELS = 1

TRAIN_PATH = 'data/train/'
TEST_PATH = 'data/test/'

train_ids = next(os.walk(TRAIN_PATH))[1]
test_ids = next(os.walk(TEST_PATH))[1]

X_train = np.zeros((len(train_ids),IMG_HEIGHT,IMG_CHANNELS),dtype=np.float32)
Y_train = np.zeros((len(train_ids),1),dtype=np.float32)

print('Resizing training images and masks')
for n,id_ in tqdm(enumerate(train_ids),total=len(train_ids)):
    path = TRAIN_PATH + id_
    print(path)
    img = imread(path + '/images/' + id_ + '.jpg',as_gray=True)
    img = img/255
    img = img[:,:,newaxis]
    img = resize(img,(IMG_HEIGHT,IMG_WIDTH),mode='constant',preserve_range=True)
    X_train[n] = img
    mask = np.zeros((IMG_HEIGHT,dtype=np.float32)
    for mask_file in next(os.walk(path + '/masks/'))[2]:
        mask_ = imread(path + '/masks/' + mask_file,as_gray=True)
        mask_ = np.expand_dims(resize(mask_,preserve_range=True),axis=-1)
        mask = np.maximum(mask,mask_)

    Y_train[n] = mask

print('Done!')

model = get_unet_middle(IMG_HEIGHT,IMG_CHANNELS)

################################

callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=5,monitor='val_loss'),tf.keras.callbacks.ModelCheckpoint('model_for_leak_middle.h5',verbose=1,monitor='val_loss',save_freq='epoch',save_best_only=True)]
results = model.fit(X_train,Y_train,validation_split=0.1,batch_size=4,epochs=50,callbacks=callbacks)

因为使用 jaccard_distance 我使用 np.float32 作为 X_trainY_train 的数据类型。 在减少损失的一些时期之后,突然损失和 jaccard_distance 改变了它们的值。

Train on 180 samples,validate on 21 samples
...
Epoch 20/50
176/180 [============================>.] - ETA: 0s - loss: 0.6334 - jaccard_distance: 63.7578
Epoch 00020: val_loss improved from 0.14278 to 0.14249,saving model to model_for_leak_middle.h5
180/180 [==============================] - 47s 260ms/sample - loss: 0.6345 - jaccard_distance: 63.6719 - val_loss: 0.1425 - val_jaccard_distance: 82.2346
Epoch 21/50
176/180 [============================>.] - ETA: 0s - loss: 0.6041 - jaccard_distance: 63.4807
Epoch 00021: val_loss improved from 0.14249 to 0.14220,saving model to model_for_leak_middle.h5
180/180 [==============================] - 47s 260ms/sample - loss: 0.6041 - jaccard_distance: 63.7875 - val_loss: 0.1422 - val_jaccard_distance: 82.2784
Epoch 22/50
176/180 [============================>.] - ETA: 0s - loss: 0.5860 - jaccard_distance: 63.4502
Epoch 00022: val_loss did not improve from 0.14220
180/180 [==============================] - 47s 259ms/sample - loss: 0.6741 - jaccard_distance: 54.1866 - val_loss: 0.3247 - val_jaccard_distance: 44.9204

完成训练后的历史看起来像......

enter image description here

我做错了什么?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。