如何解决在tensorflow数据集API中使用白化增强会产生此错误:形状不兼容[?,224,224,3]但得到了[8,1,224,224,3]
我在尝试使用使用tf.numpy_function来包装python函数以通过以下链接从Tensorflow中包装增强功能的Albumentation库来增强图像时遇到此错误:https://albumentations.ai/docs/examples/tensorflow-example/
我已经使用tensorflow数据集API加载了图像和目标标签的数据集。 代码:
img_paths = df['image_path'].values
target = df['target_label'].values
path_lis = tf.data.Dataset.from_tensor_slices(img_paths)
target_lis = tf.data.Dataset.from_tensor_slices(target)
list_ds = tf.data.Dataset.zip((path_lis,target_lis))
image_count = len(df)
val_size = int(image_count * 0.3)
train = list_ds.skip(val_size)
val = list_ds.take(val_size)
def process_path(file_path,target):
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img,channels=3)
return img,target
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_data = train.map(process_path,num_parallel_calls=AUTOTUNE)
val_data = val.map(process_path,num_parallel_calls=AUTOTUNE)
# Augmentation using albumentations library
transforms = A.Compose([
A.Rotate(limit=40),A.RandomBrightness(limit=0.1),A.RandomContrast(limit=0.9,p=1),A.HorizontalFlip(),A.Resize(224,224)
])
def aug_fn(image):
data = {"image": image}
aug_data = transforms(**data)
aug_img = aug_data["image"]
#target = aug_data["keypoints"][0]
aug_img = tf.cast(aug_img/255.0,tf.float32)
#aug_img = tf.image.resize(aug_img,size=[224,224])
return aug_img
def process_aug(img,label):
aug_img = tf.numpy_function(func=aug_fn,inp=[img],Tout=[tf.float32])
return aug_img,label
# create dataset
train_ds = train_data.map(process_aug,num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
val_ds = val_data.map(process_aug,num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
def set_shapes(img,label):
img.set_shape([224,224,3])
label.set_shape([])
return img,label
train_ds = train_ds.map(set_shapes,num_parallel_calls=AUTOTUNE).batch(8).prefetch(AUTOTUNE)
val_ds = val_ds.map(set_shapes,num_parallel_calls=AUTOTUNE).batch(8).prefetch(AUTOTUNE)
def view_image(ds):
image,label = next(iter(ds)) # extract 1 batch from the dataset
image = image.numpy()
label = label.numpy()
fig = plt.figure(figsize=(22,22))
for i in range(20):
ax = fig.add_subplot(4,5,i+1,xticks=[],yticks=[])
ax.imshow(image[i])
ax.set_title(f"Label: {label[i]}")
view_image(train_ds)
完整的错误消息:
Traceback (most recent call last):
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\context.py",line 2102,in execution_mode
yield
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py",line 758,in _next_internal
output_shapes=self._flat_output_shapes)
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_dataset_ops.py",line 2610,in iterator_get_next
_ops.raise_from_not_ok_status(e,name)
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py",line 6843,in raise_from_not_ok_status
six.raise_from(core._status_to_exception(e.code,message),None)
File "<string>",line 3,in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,3] but got [4,1,3]. [Op:IteratorGetNext]
During handling of the above exception,another exception occurred:
Traceback (most recent call last):
File "C:\Users\Arun\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py",line 2963,in run_code
exec(code_obj,self.user_global_ns,self.user_ns)
File "<ipython-input-20-23a37450bee7>",line 13,in <module>
view_image(train_ds)
File "<ipython-input-20-23a37450bee7>",in view_image
image,label = next(iter(ds)) # extract 1 batch from the dataset
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py",line 736,in __next__
return self.next()
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py",line 772,in next
return self._next_internal()
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py",line 764,in _next_internal
return structure.from_compatible_tensor_list(self._element_spec,ret)
File "C:\Users\Arun\Anaconda3\lib\contextlib.py",line 99,in __exit__
self.gen.throw(type,value,traceback)
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\context.py",line 2105,in execution_mode
executor_new.wait()
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\executor.py",line 67,in wait
pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,3] but got [8,3].
至少有人可以告诉我为什么会发生此错误吗?预先感谢!
解决方法
img_shape 应该是 (120,120,3) 而不是 [224,224,3]
例如:
def set_shapes(img,label,img_shape=(120,3)):
img.set_shape(img_shape)
label.set_shape([])
return img,label
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。