如何解决如何在 tensorflow 中将图像列表numpy 数组转换为 TFRecords 文件?
在 Tensorflow 2.0 中,我想在写入 TFRecord 文件之前处理图像。我的图像是 numpy 数组。似乎我找不到读者和作者之间的良好组合,可以给我有效的结果。我的 tf.Train.example fct 是:
def image_example(image_string,w,h,c,value):
feature = {
'height': _int64_feature(h),'width': _int64_feature(w),'channels': _int64_feature(c),'value': _float_feature(value),'image_raw': _bytes_feature(image_string),}
return tf.train.Example(features=tf.train.Features(feature=feature))
with tf.io.TFRecordWriter(record_file) as writer:
# Image.open is from PIL
my_image = np.ascontiguousarray(Image.open('myfile.jpg').convert('RGB'))[:,:,::-1]
# make here some transformations/augmentations to my_image
b = my_image.tobytes() ### IS IT CORRECT?
w,value = (224,224,3,0.32)
tf_example = image_example(b,value)
writer.write(tf_example.SerializetoString())
现在我将示例消息写入 tfrecord 文件,然后我构建数据集并恢复图像如下:
def _parse_image_function(example_proto):
# Create a dictionary describing the features.
image_feature_description = {
'height': tf.io.FixedLenFeature([],tf.int64),'width': tf.io.FixedLenFeature([],'channels': tf.io.FixedLenFeature([],'value': tf.io.FixedLenFeature([],tf.float32),'image_raw': tf.io.FixedLenFeature([],tf.string),}
# Parse the input tf.train.Example proto using the dictionary above.
parsed_features = tf.io.parse_single_example(example_proto,image_feature_description)
width = tf.cast(parsed_features['width'],tf.int64)
height = tf.cast(parsed_features['height'],tf.int64)
channels = tf.cast(parsed_features['channels'],tf.int64)
image = tf.io.decode_raw(parsed_features['image_raw'],out_type='float') ### IS IT CORRECT?
image_shape = [parsed_features['height'],parsed_features['width'],parsed_features['channels']]
image = tf.reshape(image,image_shape)
return width,height,channels,image
我的 load_dataset fct:
def load_dataset(input_path,batch_size,shuffle_buffer):
dataset = tf.data.TFRecordDataset(input_path)
dataset = dataset.map(_parse_image_function,num_parallel_calls=16)
dataset = dataset.batch(batch_size).prefetch(1) # batch and prefetch
return dataset
现在定义了我的数据集:
ds = load_dataset('images.tfrecords',1,1000)
我从 TFRecord 文件中恢复图像:
for w,image in ds.take(3):
# plot the image
问题是绘制的图像不是原始图像(根本不是!)。似乎我的处理中有一些错误,它可能在于用于读取和写入图像文件(numpy.tobytes 和 tf.image.decode_image)的函数的选择。有人可以帮我吗?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。