如何解决如何重塑TFRecord数据集以训练RNN模型?
我正在尝试使用.tfrecords
数据集(训练,测试)来提供RNN模型以进行训练。但是我收到有关input0大小的错误。
我认为train_data
希望重塑为4个暗角,但我不确定。
波纹管函数从.tfrecords
返回具有特征和一个热编码标签的数据集。
def get_dataset(directory,num_classes=60,batch_size=32,drop_remainder=False,shuffle=False,shuffle_size=1000):
# dictionary describing the features.
feature_description = {
'features': tf.io.FixedLenFeature([],tf.string),'label': tf.io.FixedLenFeature([],tf.int64)
}
# parse each proto and,the features within
def _parse_feature_function(example_proto):
features = tf.io.parse_single_example(example_proto,feature_description)
data = tf.io.parse_tensor(features['features'],tf.float32)
label = tf.one_hot(features['label'],num_classes)
data = tf.reshape(data,(3,300,25,2))
return data,label
records = [os.path.join(directory,file) for file in os.listdir(directory) if file.endswith("tfrecord")]
dataset = tf.data.TFRecordDataset(records,num_parallel_reads=len(records))
dataset = dataset.map(_parse_feature_function)
dataset = dataset.batch(batch_size,drop_remainder=drop_remainder)
dataset = dataset.prefetch(batch_size)
if shuffle:
dataset = dataset.shuffle(shuffle_size)
return dataset
if __name__ == "__main__":
train_data = get_dataset('/TfRecords/xsub/')
test_data = get_dataset('/TfRecords/xview/')
print(train_data)
# create LSTM
verbose,epochs,batch_size = 1,100,32
n_features=60
n_length = 32
model = Sequential()
model.add(
Timedistributed(Conv1D(filters=64,kernel_size=3,activation='relu'),input_shape=(None,n_length,n_features)))
model.add(Timedistributed(Conv1D(filters=64,activation='relu')))
model.add(Timedistributed(Dropout(0.5)))
model.add(Timedistributed(MaxPooling1D(pool_size=2)))
model.add(Timedistributed(Flatten()))
model.add(LSTM(100))
model.add(Dropout(0.5))
model.add(Dense(100,activation='relu'))
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
model.summary()
history = model.fit(train_data,epochs=epochs,batch_size=batch_size,verbose=verbose,validation_data=test_data)
ValueError:层顺序的输入0与层不兼容:预期ndim = 4,找到的ndim = 5。收到完整的图形:[无,3、300、25、2]
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。