微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow 训练在每个 epoch 后减慢

如何解决Tensorflow 训练在每个 epoch 后减慢

我使用 Titan Xp GPU。代码在下面,但我不知道问题出在哪里。为什么每个epoch的训练时间不断增加?最初我每分钟可以处理大约 180 个批次,但在三个 epoch 之后我每分钟只能处理 5 个批次。

train_image = tf.data.Dataset.from_tensor_slices(X_train)
train_image_dataset = train_image.map(lambda x: tf.io.decode_png(tf.io.read_file(x)),num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_label_dataset = tf.data.Dataset.from_tensor_slices(tf.one_hot(Y_train,depth=num_classes))
train_dataset = tf.data.Dataset.zip((train_image_dataset,train_label_dataset)).shuffle(batch_size*3).batch(batch_size)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

val_image = tf.data.Dataset.from_tensor_slices(X_val)
val_image_dataset = val_image.map(lambda x: tf.io.decode_png(tf.io.read_file(x)),num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_label_dataset = tf.data.Dataset.from_tensor_slices(tf.one_hot(Y_val,depth=num_classes))
val_dataset = tf.data.Dataset.zip((val_image_dataset,val_label_dataset)).shuffle(batch_size*3).batch(batch_size)
val_dataset = val_dataset.prefetch(tf.data.experimental.AUTOTUNE)


##### CALLBACKS
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath,monitor='val_loss',# checkpoint callback
                                            verbose=1,save_weights_only=True,save_freq='epoch',save_best_only=True)

tensorboard_callbacks = tf.keras.callbacks.TensorBoard(log_dir=log_folder,update_freq='epoch')
reduce_lr = LearningRateScheduler(step_decay,verbose=1)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=10)
callback = [reduce_lr,early_stopping,checkpoint,tensorboard_callbacks]

opt = tf.keras.optimizers.Adam()

METRICS = [
tf.keras.metrics.CategoricalAccuracy(name='categorical_accuracy'),tf.keras.metrics.AUC(name='auc'),f1_score,]

model = tf.keras.applications.MobileNetV2(include_top=True,weights=None,input_shape=(224,224,1),classes=num_classes)
model.load_weights("/weights-05-val_loss-0.215.hdf5")
model.summary()
model.compile(loss='categorical_crossentropy',optimizer=opt,metrics=[METRICS])
model.fit(train_dataset,epochs=100,validation_data=val_dataset,callbacks=callback,verbose=1,shuffle=True,initial_epoch=5)

纪元 00006:77428/77428 [==============================] - 59041s 763ms/步

纪元 00007:77427/77428 [============================>.] - 68783s 888ms/步

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?