如何解决InvalidArgumentError:键:标签无法解析序列化示例:如何找到一种方法来解析来自 TFRecords 的单热编码标签?
我有 12 个包含图像的文件夹(它们是我的数据的类别)。此代码通过有效压缩将图像及其相应标签转换为 tfrecord 数据:
import tensorflow as tf
from pathlib import Path
from tensorflow.keras.utils import to_categorical
import cv2
from tqdm import tqdm
from os import listdir
import numpy as np
import matplotlib.image as mpimg
from tqdm import tqdm
labels = {v:k for k,v in enumerate(listdir('train/'))}
labels
class GenerateTFRecord:
def __init__(self,path):
self.path = Path(path)
self.labels = {v:k for k,v in enumerate(listdir(path))}
def convert_image_folder(self,tfrecord_file_name):
# Get all file names of images present in folder
img_paths = list(self.path.rglob('*.jpg'))
with tf.io.TFRecordWriter(tfrecord_file_name) as writer:
for img_path in tqdm(img_paths,desc='images converted'):
example = self._convert_image(img_path)
writer.write(example.SerializetoString())
def _convert_image(self,img_path):
label = self.labels[img_path.parent.stem]
img_shape = mpimg.imread(img_path).shape
# Read image data in terms of bytes
with tf.io.gfile.GFile(img_path,'rb') as fid:
image_data = fid.read()
example = tf.train.Example(features = tf.train.Features(feature = {
'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [3])),'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),'label': tf.train.Feature(int64_list = tf.train.Int64List(value = tf.one_hot(label,depth=len(labels),on_value=1,off_value=0))),}))
return example
t = GenerateTFRecord(path='train/')
t.convert_image_folder('data.tfrecord')
然后我在这里使用此代码读取 tfrecord 数据并创建我的 tf.data.Dataset
:
def _parse_function(tfrecord):
# Extract features using the keys set during creation
features = {
'rows': tf.io.FixedLenFeature([],tf.int64),'cols': tf.io.FixedLenFeature([],'channels': tf.io.FixedLenFeature([],'image': tf.io.FixedLenFeature([],tf.string),'label': tf.io.FixedLenFeature([],tf.int64)
}
# Extract the data record
sample = tf.io.parse_single_example(tfrecord,features)
image = tf.image.decode_image(sample['image'])
label = sample['label']
# label = tf.one_hot(label,off_value=0)
return image,label
def configure_for_performance(ds,buffer_size,batch_size):
ds = ds.cache()
ds = ds.batch(batch_size)
ds = ds.prefetch(buffer_size=buffer_size)
return ds
def generator(tfrecord_file,batch_size,n_data,validation_ratio,reshuffle_each_iteration=False):
reader = tf.data.TFRecordDataset(filenames=[tfrecord_file])
reader.shuffle(n_data,reshuffle_each_iteration=reshuffle_each_iteration)
AUTOTUNE = tf.data.experimental.AUTOTUNE
val_size = int(n_data * validation_ratio)
train_ds = reader.skip(val_size)
val_ds = reader.take(val_size)
train_ds = train_ds.map(_parse_function,num_parallel_calls=AUTOTUNE)
train_ds = configure_for_performance(train_ds,AUTOTUNE,batch_size)
val_ds = val_ds.map(_parse_function,num_parallel_calls=AUTOTUNE)
val_ds = configure_for_performance(val_ds,batch_size)
return train_ds,val_ds
在这里我创建我的模型:
from os.path import isdir,dirname,abspath,join
from os import makedirs
from tensorflow.keras import Sequential
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Dense,GlobalAveragePooling2D
from tensorflow.keras.optimizers import SGD,Adam
def create_model(optimizer,freeze_layer=False):
densenet = DenseNet121(weights='imagenet',include_top=False)
if freeze_layer:
for layer in densenet_model.layers:
if 'conv5' in layer.name:
layer.trainable = True
else:
layer.trainable = False
model = Sequential()
model.add(densenet)
model.add(GlobalAveragePooling2D())
model.add(Dense(12,activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer=optimizer,metrics=['accuracy'])
return model
if __name__ == '__main__':
optimizer = Adam(learning_rate=0.001,beta_1=0.9,beta_2=0.99,epsilon=1e-6)
densenet_model = create_model(optimizer)
tfrecord_file = 'data.tfrecord'
n_data = len(list(Path('train').rglob('*.jpg')))
train,val = generator(tfrecord_file,2,True)
validation_ratio = 0.2
val_size = int(n_data * validation_ratio)
train_size = n_data - val_size
batch_size = 32
n_epochs = 300
n_workers = 5
filename = '/content/drive/MyDrive/data.tfrecord'
train_ds,val_ds = generator(filename,batch_size=batch_size,n_data=n_data,validation_ratio=validation_ratio,reshuffle_each_iteration=True)
hist = densenet_model.fit(train_ds,validation_data=val_ds,epochs=n_epochs,workers=n_workers,steps_per_epoch=train_size//batch_size,validation_steps=val_size)
这是我每次得到的错误:
InvalidArgumentError: Key: label. Can't parse serialized Example. [[{{node ParseSingleExample/ParseExample/ParseExampleV2}}]] [[IteratorGetNext]] [Op:__inference_train_function_343514]
显然我的 tfrecord 数据中的 label
有问题。
我真的需要知道,基于我的模型输出形状 (12,),我怎样才能安全地在我的 tfrecord 中存储一个热编码标签并在 tf.data.Dataset
中解析?
谢谢大家。
解决方法
正如答案所暗示的那样 here 数据数组应该是固定大小的,所以我认为它可以解决您的问题。
,用特征大小初始化可能会解决。
'label': tf.FixedLenFeature([SIZE_OF_FEATURE],tf.int64,default_value=[0,0])
祝你好运:)
,在您的 _parse_function
函数中读取 tfrecord 文件时需要转换标签:
label = tf.cast(sample['label'],dtype=tf.int32)
我希望这能解决 InvalidArgumentError 消息。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。