微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Eager_few_shot_od_training_tflite 检查点恢复

如何解决Eager_few_shot_od_training_tflite 检查点恢复

我正在浏览 TensorFlow 的 Eager fewshot training Colab 并希望使用先前输出的检查点恢复模型。然而,在恢复我早期训练生成的检查点后,损失又回到了训练前的水平。

下面列出了相关的代码块,整个 Colab 都在上面的链接中。

感谢您的帮助!

tf.keras.backend.clear_session()

print('Building model and restoring weights for fine-tuning...',flush=True)
num_classes = 1
pipeline_config = '/content/output/pipeline.config'
checkpoint_path = '/content/output/checkpoint/ckpt-1'

# This will be where we save checkpoint & config for TFLite conversion later.
output_directory = 'output/'
output_checkpoint_dir = os.path.join(output_directory,'checkpoint')

# Load pipeline config and build a detection model.
#
# Since we are working off of a COCO architecture which predicts 90
# class slots by default,we override the `num_classes` field here to be just
# one (for our new rubber ducky class).
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.ssd.num_classes = num_classes
model_config.ssd.freeze_batchnorm = True
detection_model = model_builder.build(
      model_config=model_config,is_training=True)
# Save new pipeline config

pipeline_proto = config_util.create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(pipeline_proto,output_directory)

# Set up object-based checkpoint restore --- SSD has two prediction
# `heads` --- one for classification,the other for Box regression.  We will
# restore the Box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_Box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._Box_predictor._base_tower_layers_for_heads,# _prediction_heads=detection_model._Box_predictor._prediction_heads,#    (i.e.,the classification head that we *will not* restore)
    _Box_prediction_head=detection_model._Box_predictor._Box_prediction_head,)
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,_Box_predictor=fake_Box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

# To save checkpoint for TFLite conversion.
exported_ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt_manager = tf.train.CheckpointManager(
    exported_ckpt,output_checkpoint_dir,max_to_keep=1)

# Run model through a dummy image so that variables are created
image,shapes = detection_model.preprocess(tf.zeros([1,320,3]))
prediction_dict = detection_model.predict(image,shapes)
_ = detection_model.postprocess(prediction_dict,shapes)
print('Weights restored!')
tf.keras.backend.set_learning_phase(True)

# These parameters can be tuned; since our training set has 5 images
# it doesn't make sense to have a much larger batch size,though we Could
# fit more examples in memory if we wanted to.
batch_size = 5
learning_rate = 0.15
num_batches = 1000

# Select variables in top layers to fine-tune.
trainable_variables = detection_model.trainable_variables
to_fine_tune = []
prefixes_to_train = [
  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead','WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead']
for var in trainable_variables:
  if any([var.name.startswith(prefix) for prefix in prefixes_to_train]):
    to_fine_tune.append(var)

# Set up forward + backward pass for a single train step.
def get_model_train_step_function(model,optimizer,vars_to_fine_tune):
  """Get a tf.function for training step."""

  # Use tf.function for a bit of speed.
  # Comment out the tf.function decorator if you want the inside of the
  # function to run eagerly.
  @tf.function
  def train_step_fn(image_tensors,groundtruth_Boxes_list,groundtruth_classes_list):
    """A single training iteration.

    Args:
      image_tensors: A list of [1,height,width,3] Tensor of type tf.float32.
        Note that the height and width can vary across images,as they are
        reshaped within this function to be 320x320.
      groundtruth_Boxes_list: A list of Tensors of shape [N_i,4] with type
        tf.float32 representing groundtruth Boxes for each image in the batch.
      groundtruth_classes_list: A list of Tensors of shape [N_i,num_classes]
        with type tf.float32 representing groundtruth Boxes for each image in
        the batch.

    Returns:
      A scalar tensor representing the total loss for the input batch.
    """
    shapes = tf.constant(batch_size * [[320,3]],dtype=tf.int32)
    model.provide_groundtruth(
        groundtruth_Boxes_list=groundtruth_Boxes_list,groundtruth_classes_list=groundtruth_classes_list)
    with tf.GradientTape() as tape:
      preprocessed_images = tf.concat(
          [detection_model.preprocess(image_tensor)[0]
           for image_tensor in image_tensors],axis=0)
      prediction_dict = model.predict(preprocessed_images,shapes)
      losses_dict = model.loss(prediction_dict,shapes)
      total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']
      gradients = tape.gradient(total_loss,vars_to_fine_tune)
      optimizer.apply_gradients(zip(gradients,vars_to_fine_tune))
    return total_loss

  return train_step_fn

optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,momentum=0.9)
train_step_fn = get_model_train_step_function(
    detection_model,to_fine_tune)

print('Start fine-tuning!',flush=True)
for idx in range(num_batches):
  # Grab keys for a random subset of examples
  all_keys = list(range(len(train_images_np)))
  random.shuffle(all_keys)
  example_keys = all_keys[:batch_size]

  # Note that we do not do data augmentation in this demo.  If you want a
  # a fun exercise,we recommend experimenting with random horizontal flipping
  # and random cropping :)
  gt_Boxes_list = [gt_Box_tensors[key] for key in example_keys]
  gt_classes_list = [gt_classes_one_hot_tensors[key] for key in example_keys]
  image_tensors = [train_image_tensors[key] for key in example_keys]

  # Training step (forward pass + backwards pass)
  total_loss = train_step_fn(image_tensors,gt_Boxes_list,gt_classes_list)

  if idx % 100 == 0:
    print('batch ' + str(idx) + ' of ' + str(num_batches)
    + ',loss=' +  str(total_loss.numpy()),flush=True)

print('Done fine-tuning!')

ckpt_manager.save()
print('Checkpoint saved!')

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?