如何解决使用 AWS SageMaker 将存储在 AWS s3 中的 jpeg 图像数据转换为 TFRecords,其中子目录是与这些图像关联的唯一标签
我在 AWS s3 中有一个 jpeg 图像目录,其中子目录是与这些图像关联的唯一标签。我正在尝试使用 AWS SageMaker 遵循此 example 并且我在对标志缺乏经验的情况下将输入和输出路径弄得一团糟。任何有关使用 s3 和 SageMaker 应用链接解决方案或其他方法来实现 TFRecords 输出然后保存回 s3 的指导将不胜感激。
解决方法
下面是一个精简的示例,可能会对您有所帮助。我根据我使用的一些代码进行了调整,其中包括对象检测标签(每个图像有多个边界框),因此添加了一些 TODO,您可以在其中调整标签。此外,根据您的问题不确定是否有替代方法是在本地处理文件。此示例在本地使用这些文件,并创建一个 TFRecord 文件,然后您将其上传到 S3。
def create_tfrecords(train_img_dir,tfrecord_file):
""" Create TFRecord file from images/labels.
train_img_dir: A directory that contains .jpg images
tfrecord_file: Name of a file where the tfrecords are written to
"""
from object_detection.utils import dataset_util
# these same functions are available here:
# https://www.tensorflow.org/tutorials/load_data/tfrecord
with tf.io.TFRecordWriter(tfrecord_file) as writer:
#TODO: Modify this to recursivley list files in subdirs of labeled images or pull from S3
train_img_files = [f for f in listdir(train_img_dir) if isfile(join(train_img_dir,f))]
for i,f in enumerate(train_img_files):
try:
file_path = os.path.join(train_img_dir,f)
name,ext = path.splitext(f)
#TODO extract labe name based on path
label = ...
#TODO: convert the label name to an index if you want to store the index value
label_index = ...
# Skip non-jpegs
if ext not in ['.jpg','.jpeg']:
continue
with tf.io.gfile.GFile(file_path,'rb') as fid:
encoded_jpg = fid.read()
# Pil image to extract h/w
im = Image.open(file_path)
image_w,image_h = im.size
# Create TFRecord
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(image_h),'image/width': dataset_util.int64_feature(image_w),'image/filename': dataset_util.bytes_feature(f.encode('utf8')),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),'image/label': dataset_util.int64_feature(label_index),}))
if tf_example:
writer.write(tf_example.SerializeToString())
except ValueError:
print('Invalid example,ignoring.')
pass
except IOError:
print("Can't read example,ignoring.")
pass
print('TFRecord file created: ',tfrecord_file)
然后这样称呼它:
create_tfrecords(train_image_dir,tfrecord_file)
编辑: 如果您没有安装对象检测框架,请使用这些方法source:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value,type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。