微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在 tensorflow 中将图像列表numpy 数组转换为 TFRecords 文件?

如何解决如何在 tensorflow 中将图像列表numpy 数组转换为 TFRecords 文件?

在 Tensorflow 2.0 中,我想在写入 TFRecord 文件之前处理图像。我的图像是 numpy 数组。似乎我找不到读者和作者之间的良好组合,可以给我有效的结果。我的 tf.Train.example fct 是:

def image_example(image_string,w,h,c,value):
   feature = {
     'height': _int64_feature(h),'width': _int64_feature(w),'channels': _int64_feature(c),'value': _float_feature(value),'image_raw': _bytes_feature(image_string),}

return tf.train.Example(features=tf.train.Features(feature=feature))

虽然我的作者是(只写一个文件来说明):

with tf.io.TFRecordWriter(record_file) as writer:
   # Image.open is from PIL
   my_image = np.ascontiguousarray(Image.open('myfile.jpg').convert('RGB'))[:,:,::-1]
   # make here some transformations/augmentations to my_image
   b = my_image.tobytes() ### IS IT CORRECT?
   w,value = (224,224,3,0.32)
   tf_example = image_example(b,value)
   writer.write(tf_example.SerializetoString())

现在我将示例消息写入 tfrecord 文件,然后我构建数据集并恢复图像如下:

def _parse_image_function(example_proto):
   # Create a dictionary describing the features.
   image_feature_description = {
       'height': tf.io.FixedLenFeature([],tf.int64),'width': tf.io.FixedLenFeature([],'channels': tf.io.FixedLenFeature([],'value': tf.io.FixedLenFeature([],tf.float32),'image_raw': tf.io.FixedLenFeature([],tf.string),}

   # Parse the input tf.train.Example proto using the dictionary above.
   parsed_features = tf.io.parse_single_example(example_proto,image_feature_description)

   width = tf.cast(parsed_features['width'],tf.int64)
   height = tf.cast(parsed_features['height'],tf.int64)
   channels = tf.cast(parsed_features['channels'],tf.int64)

   image = tf.io.decode_raw(parsed_features['image_raw'],out_type='float') ### IS IT CORRECT?
   image_shape = [parsed_features['height'],parsed_features['width'],parsed_features['channels']]
   image = tf.reshape(image,image_shape)

   return width,height,channels,image

我的 load_dataset fct:

def load_dataset(input_path,batch_size,shuffle_buffer):
   dataset = tf.data.TFRecordDataset(input_path)
   dataset = dataset.map(_parse_image_function,num_parallel_calls=16)
   dataset = dataset.batch(batch_size).prefetch(1)  # batch and prefetch
   return dataset

现在定义了我的数据集:

ds = load_dataset('images.tfrecords',1,1000)

我从 TFRecord 文件中恢复图像:

for w,image in ds.take(3):
   # plot the image

问题是绘制的图像不是原始图像(根本不是!)。似乎我的处理中有一些错误,它可能在于用于读取和写入图像文件(numpy.tobytes 和 tf.image.decode_image)的函数的选择。有人可以帮我吗?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。