微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow:从 tfrecord 文件读取图像后如何设置张量形状以进行数据增强?

如何解决Tensorflow:从 tfrecord 文件读取图像后如何设置张量形状以进行数据增强?

我有一个tf.data.Dataset 文件中读取的 tfrecords,如下所示:

import tensorflow as tf

# given an existing record_file

raw_dataset = tf.data.TFRecordDataset(record_file)
example_description = {
        "height": tf.io.FixedLenFeature([],tf.int64),"width": tf.io.FixedLenFeature([],"channels": tf.io.FixedLenFeature([],"image": tf.io.FixedLenFeature([],tf.string),}
dataset = raw_dataset.map(
    lambda example: tf.io.parse_single_example(example,example_description)
)

接下来,我将这些特征组合成一个图像,如下所示:

dataset = dataset.map(_extract_image_from_sample)

# and

def _extract_image_from_sample(sample):
    height = tf.cast(sample["height"],tf.int32) # always 1038
    width = tf.cast(sample["width"],tf.int32) # always 1366
    depth = tf.cast(sample["channels"],tf.int32) # always 3
    shape = [height,width,depth]

    image = sample["image"]
    image = decode_tf_image(image)
    image = tf.reshape(image,shape)

    return image

此时,数据集中的任何图像都具有 (None,None,None) 形状(这让我感到惊讶,因为我重塑了它们)。 当我尝试使用 tf.keras.preprocessing.image.ImageDataGenerator 扩充数据集时,我相信这是导致错误的原因:

augmented_dataset = dataset.map(random_image_augmentation)

# and

image_data_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range=45,width_shift_range=0.1,height_shift_range=0.1,shear_range=5.0,zoom_range=[0.9,1.2],fill_mode="reflect",horizontal_flip=True,vertical_flip=True,)

def random_image_augmentation(image: tf.Tensor) -> tf.Tensor:
    transform = image_data_generator.get_random_transform(img_shape=image.shape)
    image = image_data_generator.apply_transform(image,transform)
    return image

这会导致错误消息:

TypeError: in user code:
    # ...
    C:\Users\[PATH_TO_ENVIRONMENT]\lib\site-packages\keras_preprocessing\image\image_data_generator.py:778 get_random_transform  *
        tx *= img_shape[img_row_axis]

    TypeError: unsupported operand type(s) for *=: 'float' and 'nonetype'

但是,如果我不使用图形模式,而是使用急切模式,这就像一个魅力:

it = iter(dataset)
for i in range(3):
    image = it.next()
    image = random_image_augmentation(image.numpy())

这使我得出结论,主要错误是读入数据集后缺少形状信息。但我不知道如何比我已经做的更明确地定义它。有什么想法吗?

解决方法

使用 tf.py_function 包装预处理函数,该函数要求张量具有如下形状:

augmented_dataset = dataset.map(
    lambda x: tf.py_function(random_image_augmentation,inp=[x],Tout=tf.float32),num_parallel_calls=tf.data.experimental.AUTOTUNE
)

# and

def random_image_augmentation(image: tf.Tensor) -> tf.Tensor:
    image = image.numpy()  # now we can do this,because tensors have this function in eager mode
    transform = image_data_generator.get_random_transform(img_shape=image.shape)
    image = image_data_generator.apply_transform(image,transform)
    return image

这对我有用,但我不确定它是唯一的还是最好的解决方案。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。