微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在tensorflow数据集API中使用白化增强会产生此错误:形状不兼容[?,224,224,3]但得到了[8,1,224,224,3]

如何解决在tensorflow数据集API中使用白化增强会产生此错误:形状不兼容[?,224,224,3]但得到了[8,1,224,224,3]

我在尝试使用使用tf.numpy_function来包装python函数以通过以下链接从Tensorflow中包装增强功能的Albumentation库来增强图像时遇到此错误https://albumentations.ai/docs/examples/tensorflow-example/

我已经使用tensorflow数据集API加载了图像和目标标签的数据集。 代码

img_paths = df['image_path'].values
target = df['target_label'].values

path_lis = tf.data.Dataset.from_tensor_slices(img_paths)
target_lis = tf.data.Dataset.from_tensor_slices(target)
list_ds = tf.data.Dataset.zip((path_lis,target_lis))

image_count = len(df)
val_size = int(image_count * 0.3)
train = list_ds.skip(val_size)
val = list_ds.take(val_size)


def process_path(file_path,target):

  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = tf.image.decode_jpeg(img,channels=3)

  return img,target

# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_data = train.map(process_path,num_parallel_calls=AUTOTUNE)
val_data = val.map(process_path,num_parallel_calls=AUTOTUNE)

# Augmentation using albumentations library
transforms = A.Compose([
            A.Rotate(limit=40),A.RandomBrightness(limit=0.1),A.RandomContrast(limit=0.9,p=1),A.HorizontalFlip(),A.Resize(224,224)
            ])

def aug_fn(image):

    data = {"image": image}
    aug_data = transforms(**data)
    aug_img = aug_data["image"]
    #target = aug_data["keypoints"][0]
    aug_img = tf.cast(aug_img/255.0,tf.float32)
    #aug_img = tf.image.resize(aug_img,size=[224,224])

    return aug_img

def process_aug(img,label):

    aug_img = tf.numpy_function(func=aug_fn,inp=[img],Tout=[tf.float32])
    return aug_img,label

# create dataset
train_ds = train_data.map(process_aug,num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
val_ds = val_data.map(process_aug,num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

def set_shapes(img,label):

    img.set_shape([224,224,3])
    label.set_shape([])

    return img,label

train_ds = train_ds.map(set_shapes,num_parallel_calls=AUTOTUNE).batch(8).prefetch(AUTOTUNE)
val_ds = val_ds.map(set_shapes,num_parallel_calls=AUTOTUNE).batch(8).prefetch(AUTOTUNE)


def view_image(ds):

    image,label = next(iter(ds)) # extract 1 batch from the dataset
    image = image.numpy()
    label = label.numpy()

    fig = plt.figure(figsize=(22,22))
    for i in range(20):
        ax = fig.add_subplot(4,5,i+1,xticks=[],yticks=[])
        ax.imshow(image[i])
        ax.set_title(f"Label: {label[i]}")

view_image(train_ds)

完整的错误消息:

Traceback (most recent call last):
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\context.py",line 2102,in execution_mode
    yield
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py",line 758,in _next_internal
    output_shapes=self._flat_output_shapes)
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_dataset_ops.py",line 2610,in iterator_get_next
    _ops.raise_from_not_ok_status(e,name)
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py",line 6843,in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code,message),None)
  File "<string>",line 3,in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,3] but got [4,1,3]. [Op:IteratorGetNext]

During handling of the above exception,another exception occurred:
Traceback (most recent call last):
  File "C:\Users\Arun\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py",line 2963,in run_code
    exec(code_obj,self.user_global_ns,self.user_ns)
  File "<ipython-input-20-23a37450bee7>",line 13,in <module>
    view_image(train_ds)
  File "<ipython-input-20-23a37450bee7>",in view_image
    image,label = next(iter(ds)) # extract 1 batch from the dataset
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py",line 736,in __next__
    return self.next()
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py",line 772,in next
    return self._next_internal()
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py",line 764,in _next_internal
    return structure.from_compatible_tensor_list(self._element_spec,ret)
  File "C:\Users\Arun\Anaconda3\lib\contextlib.py",line 99,in __exit__
    self.gen.throw(type,value,traceback)
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\context.py",line 2105,in execution_mode
    executor_new.wait()
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\executor.py",line 67,in wait
    pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,3] but got [8,3].

至少有人可以告诉我为什么会发生此错误吗?预先感谢!

解决方法

img_shape 应该是 (120,120,3) 而不是 [224,224,3]

例如:

def set_shapes(img,label,img_shape=(120,3)):
    img.set_shape(img_shape)
    label.set_shape([])
    return img,label

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?