如何解决Eager_few_shot_od_training_tflite 检查点恢复
我正在浏览 TensorFlow 的 Eager fewshot training Colab 并希望使用先前输出的检查点恢复模型。然而,在恢复我早期训练生成的检查点后,损失又回到了训练前的水平。
下面列出了相关的代码块,整个 Colab 都在上面的链接中。
感谢您的帮助!
tf.keras.backend.clear_session()
print('Building model and restoring weights for fine-tuning...',flush=True)
num_classes = 1
pipeline_config = '/content/output/pipeline.config'
checkpoint_path = '/content/output/checkpoint/ckpt-1'
# This will be where we save checkpoint & config for TFLite conversion later.
output_directory = 'output/'
output_checkpoint_dir = os.path.join(output_directory,'checkpoint')
# Load pipeline config and build a detection model.
#
# Since we are working off of a COCO architecture which predicts 90
# class slots by default,we override the `num_classes` field here to be just
# one (for our new rubber ducky class).
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.ssd.num_classes = num_classes
model_config.ssd.freeze_batchnorm = True
detection_model = model_builder.build(
model_config=model_config,is_training=True)
# Save new pipeline config
pipeline_proto = config_util.create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(pipeline_proto,output_directory)
# Set up object-based checkpoint restore --- SSD has two prediction
# `heads` --- one for classification,the other for Box regression. We will
# restore the Box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_Box_predictor = tf.compat.v2.train.Checkpoint(
_base_tower_layers_for_heads=detection_model._Box_predictor._base_tower_layers_for_heads,# _prediction_heads=detection_model._Box_predictor._prediction_heads,# (i.e.,the classification head that we *will not* restore)
_Box_prediction_head=detection_model._Box_predictor._Box_prediction_head,)
fake_model = tf.compat.v2.train.Checkpoint(
_feature_extractor=detection_model._feature_extractor,_Box_predictor=fake_Box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()
# To save checkpoint for TFLite conversion.
exported_ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt_manager = tf.train.CheckpointManager(
exported_ckpt,output_checkpoint_dir,max_to_keep=1)
# Run model through a dummy image so that variables are created
image,shapes = detection_model.preprocess(tf.zeros([1,320,3]))
prediction_dict = detection_model.predict(image,shapes)
_ = detection_model.postprocess(prediction_dict,shapes)
print('Weights restored!')
tf.keras.backend.set_learning_phase(True)
# These parameters can be tuned; since our training set has 5 images
# it doesn't make sense to have a much larger batch size,though we Could
# fit more examples in memory if we wanted to.
batch_size = 5
learning_rate = 0.15
num_batches = 1000
# Select variables in top layers to fine-tune.
trainable_variables = detection_model.trainable_variables
to_fine_tune = []
prefixes_to_train = [
'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead','WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead']
for var in trainable_variables:
if any([var.name.startswith(prefix) for prefix in prefixes_to_train]):
to_fine_tune.append(var)
# Set up forward + backward pass for a single train step.
def get_model_train_step_function(model,optimizer,vars_to_fine_tune):
"""Get a tf.function for training step."""
# Use tf.function for a bit of speed.
# Comment out the tf.function decorator if you want the inside of the
# function to run eagerly.
@tf.function
def train_step_fn(image_tensors,groundtruth_Boxes_list,groundtruth_classes_list):
"""A single training iteration.
Args:
image_tensors: A list of [1,height,width,3] Tensor of type tf.float32.
Note that the height and width can vary across images,as they are
reshaped within this function to be 320x320.
groundtruth_Boxes_list: A list of Tensors of shape [N_i,4] with type
tf.float32 representing groundtruth Boxes for each image in the batch.
groundtruth_classes_list: A list of Tensors of shape [N_i,num_classes]
with type tf.float32 representing groundtruth Boxes for each image in
the batch.
Returns:
A scalar tensor representing the total loss for the input batch.
"""
shapes = tf.constant(batch_size * [[320,3]],dtype=tf.int32)
model.provide_groundtruth(
groundtruth_Boxes_list=groundtruth_Boxes_list,groundtruth_classes_list=groundtruth_classes_list)
with tf.GradientTape() as tape:
preprocessed_images = tf.concat(
[detection_model.preprocess(image_tensor)[0]
for image_tensor in image_tensors],axis=0)
prediction_dict = model.predict(preprocessed_images,shapes)
losses_dict = model.loss(prediction_dict,shapes)
total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']
gradients = tape.gradient(total_loss,vars_to_fine_tune)
optimizer.apply_gradients(zip(gradients,vars_to_fine_tune))
return total_loss
return train_step_fn
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,momentum=0.9)
train_step_fn = get_model_train_step_function(
detection_model,to_fine_tune)
print('Start fine-tuning!',flush=True)
for idx in range(num_batches):
# Grab keys for a random subset of examples
all_keys = list(range(len(train_images_np)))
random.shuffle(all_keys)
example_keys = all_keys[:batch_size]
# Note that we do not do data augmentation in this demo. If you want a
# a fun exercise,we recommend experimenting with random horizontal flipping
# and random cropping :)
gt_Boxes_list = [gt_Box_tensors[key] for key in example_keys]
gt_classes_list = [gt_classes_one_hot_tensors[key] for key in example_keys]
image_tensors = [train_image_tensors[key] for key in example_keys]
# Training step (forward pass + backwards pass)
total_loss = train_step_fn(image_tensors,gt_Boxes_list,gt_classes_list)
if idx % 100 == 0:
print('batch ' + str(idx) + ' of ' + str(num_batches)
+ ',loss=' + str(total_loss.numpy()),flush=True)
print('Done fine-tuning!')
ckpt_manager.save()
print('Checkpoint saved!')
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。