微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取

这篇文章主要介绍了Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

单一数据读取方式:

第一种:slice_input_producer()

# 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中,如[...] [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)

第二种:string_input_producer()

# 需要定义文件读取器,然后通过读取器中的 read()方法获取数据(返回值类型 key,value),再通过 Session.run(value)查看 file_queue = tf.train.string_input_producer(filename, num_epochs=None, shuffle=True) reader = tf.WholeFileReader() # 定义文件读取器 key, value = reader.read(file_queue) # key:文件名;value:文件中的内容

!!!num_epochs=None,不指定迭代次数,这样文件队列中元素个数也不限定(None*数据集大小)。

!!!如果它不是None,则此函数创建本地计数器 epochs,需要使用local_variables_initializer()初始化局部变量

!!!以上两种方法都可以生成文件名队列。

随机)批量数据读取方式:

batchsize=2# 每次读取的样本数量 tf.train.batch(tensors, batch_size=batchsize) tf.train.shuffle_batch(tensors, batch_size=batchsize, capacity=batchsize*10, min_after_dequeue=batchsize*5) # capacity > min_after_dequeue

!!!以上所有读取数据的方法,在Session.run()之前必须开启文件队列线程 tf.train.start_queue_runners()

 TFRecord文件的打包与读取

 一、单一数据读取方式

第一种:slice_input_producer()

def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None)

案例1:

import tensorflow as tf images = ['image1.jpg', 'image2.jpg', 'image3.jpg', 'image4.jpg'] labels = [1, 2, 3, 4] # [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True) # 当num_epochs=2时,此时文件队列中只有 2*4=8个样本,所有在取第9个样本时会出错 # [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=2, shuffle=True) data = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True) print(type(data)) # with tf.Session() as sess: # sess.run(tf.local_variables_initializer()) sess.run(tf.local_variables_initializer()) coord = tf.train.Coordinator() # 线程的协调器 threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器 for i in range(10): print(sess.run(data)) coord.request_stop() coord.join(threads) """

运行结果:

[b'image2.jpg', 2]

[b'image1.jpg', 1]

[b'image3.jpg', 3]

[b'image4.jpg', 4]

[b'image2.jpg', 2]

[b'image1.jpg', 1]

[b'image3.jpg', 3]

[b'image4.jpg', 4]

[b'image2.jpg', 2]

[b'image3.jpg', 3]

"""

!!!slice_input_producer() 中的第一个参数需要放在一个列表中,列表中的每个元素可以是 List 或 Tensor,如 [images,labels],

!!!num_epochs设置

 第二种:string_input_producer()

def string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None, cancel_op=None)

文件读取器

不同类型的文件对应不同的文件读取器,我们称为 reader对象;

该对象的 read 方法自动读取文件,并创建数据队列,输出key/文件名,value/文件内容

reader = tf.TextLineReader() ### 一行一行读取,适用于所有文本文件 reader = tf.TFRecordReader() ### A Reader that outputs the records from a TFRecords file reader = tf.WholeFileReader() ### 一次读取整个文件,适用图片

案例2:读取csv文件

import tensorflow as tf filename = ['data/A.csv', 'data/B.csv', 'data/C.csv'] file_queue = tf.train.string_input_producer(filename, shuffle=True, num_epochs=2) # 生成文件名队列 reader = tf.WholeFileReader() # 定义文件读取器(一次读取整个文件) # reader = tf.TextLineReader() # 定义文件读取器(一行一行的读) key, value = reader.read(file_queue) # key:文件名;value:文件中的内容 print(type(file_queue)) init = [tf.global_variables_initializer(), tf.local_variables_initializer()] with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: while not coord.should_stop(): for i in range(6): print(sess.run([key, value])) break except tf.errors.OutOfRangeError: print('read done') finally: coord.request_stop() coord.join(threads) """ reader = tf.WholeFileReader() # 定义文件读取器(一次读取整个文件) 运行结果: [b'data/C.csv', b'7.jpg,7n8.jpg,8n9.jpg,9n'] [b'data/B.csv', b'4.jpg,4n5.jpg,5n6.jpg,6n'] [b'data/A.csv', b'1.jpg,1n2.jpg,2n3.jpg,3n'] [b'data/A.csv', b'1.jpg,1n2.jpg,2n3.jpg,3n'] [b'data/B.csv', b'4.jpg,4n5.jpg,5n6.jpg,6n'] [b'data/C.csv', b'7.jpg,7n8.jpg,8n9.jpg,9n'] """ """ reader = tf.TextLineReader() # 定义文件读取器(一行一行的读) 运行结果: [b'data/B.csv:1', b'4.jpg,4'] [b'data/B.csv:2', b'5.jpg,5'] [b'data/B.csv:3', b'6.jpg,6'] [b'data/C.csv:1', b'7.jpg,7'] [b'data/C.csv:2', b'8.jpg,8'] [b'data/C.csv:3', b'9.jpg,9'] """

案例3:读取图片(每次读取全部图片内容,不是一行一行)

import tensorflow as tf filename = ['1.jpg', '2.jpg'] filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=1) reader = tf.WholeFileReader() # 文件读取器 key, value = reader.read(filename_queue) # 读取文件 key:文件名;value:图片数据,bytes with tf.Session() as sess: tf.local_variables_initializer().run() coord = tf.train.Coordinator() # 线程的协调器 threads = tf.train.start_queue_runners(sess, coord) for i in range(filename.__len__()): image_data = sess.run(value) with open('img_%d.jpg' % i, 'wb') as f: f.write(image_data) coord.request_stop() coord.join(threads)

 二、(随机)批量数据读取方式:

功能:shuffle_batch() 和 batch() 这两个API都是从文件队列中批量获取数据,使用方式类似;

案例4:slice_input_producer() 与 batch()

import tensorflow as tf import numpy as np images = np.arange(20).reshape([10, 2]) label = np.asarray(range(0, 10)) images = tf.cast(images, tf.float32)# 可以注释掉,不影响运行结果 label = tf.cast(label, tf.int32) # 可以注释掉,不影响运行结果 batchsize = 6 # 每次获取元素的数量 input_queue = tf.train.slice_input_producer([images, label], num_epochs=None, shuffle=False) image_batch, label_batch = tf.train.batch(input_queue, batch_size=batchsize) # 随机获取 batchsize个元素,其中,capacity:队列容量,这个参数一定要比 min_after_dequeue 大 # image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=batchsize, capacity=64, min_after_dequeue=10) with tf.Session() as sess: coord = tf.train.Coordinator() # 线程的协调器 threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器 for cnt in range(2): print("第{}次获取数据,每次batch={}...".format(cnt+1, batchsize)) image_batch_v, label_batch_v = sess.run([image_batch, label_batch]) print(image_batch_v, label_batch_v, label_batch_v.__len__()) coord.request_stop() coord.join(threads) """

运行结果:

第1次获取数据,每次batch=6...

[[ 0.  1.]

 [ 2.  3.]

 [ 4.  5.]

 [ 6.  7.]

 [ 8.  9.]

 [10. 11.]] [0 1 2 3 4 5] 6

第2次获取数据,每次batch=6...

[[12. 13.]

 [14. 15.]

 [16. 17.]

 [18. 19.]

 [ 0.  1.]

 [ 2.  3.]] [6 7 8 9 0 1] 6

"""

 案例5:从本地批量的读取图片 --- string_input_producer() 与 batch()

import tensorflow as tf import glob import cv2 as cv def read_imgs(filename, picture_format, input_image_shape, batch_size=): """ 从本地批量的读取图片 :param filename: 图片路径(包括图片文件名),[] :param picture_format: 图片的格式,如 bmp,jpg,png等; string :param input_image_shape: 输入图像的大小; (h,w,c)或[] :param batch_size: 每次从文件队列中加载图片数量; int :return: batch_size张图片数据, Tensor """ global new_img # 创建文件队列 file_queue = tf.train.string_input_producer(filename, num_epochs=1, shuffle=True) # 创建文件读取器 reader = tf.WholeFileReader() # 读取文件队列中的文件 _, img_bytes = reader.read(file_queue) # print(img_bytes) # Tensor("ReaderReadV2_19:1", shape=(), dtype=string) # 对图片进行解码 if picture_format == ".bmp": new_img = tf.image.decode_bmp(img_bytes, channels=1) elif picture_format == ".jpg": new_img = tf.image.decode_jpeg(img_bytes, channels=3) else: pass # 重新设置图片的大小 # new_img = tf.image.resize_images(new_img, input_image_shape) new_img = tf.reshape(new_img, input_image_shape) # 设置图片的数据类型 new_img = tf.image.convert_image_dtype(new_img, tf.uint) # return new_img return tf.train.batch([new_img], batch_size) def main(): image_path = glob.glob(r'F:demoFaceRecognition人脸库ORL*.bmp') image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5) print(type(image_batch)) # image_path = glob.glob(r'.*.jpg') # image_batch = read_imgs(image_path, ".jpg", (313, 500, 3), 1) sess = tf.Session() sess.run(tf.local_variables_initializer()) tf.train.start_queue_runners(sess=sess) image_batch = sess.run(image_batch) print(type(image_batch)) # for i in range(image_batch.__len__()): cv.imshow("win_"+str(i), image_batch[i]) cv.waitKey() cv.destroyAllWindows() def start(): image_path = glob.glob(r'F:demoFaceRecognition人脸库ORL*.bmp') image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5) print(type(image_batch)) # with tf.Session() as sess: sess.run(tf.local_variables_initializer()) coord = tf.train.Coordinator() # 线程的协调器 threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器 image_batch = sess.run(image_batch) print(type(image_batch)) # for i in range(image_batch.__len__()): cv.imshow("win_"+str(i), image_batch[i]) cv.waitKey() cv.destroyAllWindows() # 若使用 with 方式打开 Session,且没加如下行语句,则会出错 # ERROR:tensorflow:Exception in QueueRunner: Enqueue operation was cancelled; # 原因:文件队列线程还处于工作状态(队列中还有图片数据),而加载完batch_size张图片会话就会自动关闭,同时关闭文件队列线程 coord.request_stop() coord.join(threads) if __name__ == "__main__": # main() start()

案列6:TFRecord文件打包与读取

 TFRecord文件打包案列

def write_TFRecord(filename, data, labels, is_shuffler=True): """ 将数据打包成TFRecord格式 :param filename: 打包后路径名,认在工程目录下创建该文件;String :param data: 需要打包的文件路径名;list :param labels: 对应文件标签;list :param is_shuffler:是否随机初始化打包后的数据,认:True;Bool :return: None """ im_data = list(data) im_labels = list(labels) index = [i for i in range(im_data.__len__())] if is_shuffler: np.random.shuffle(index) # 创建写入器,然后使用该对象写入样本example writer = tf.python_io.TFRecordWriter(filename) for i in range(im_data.__len__()): im_d = im_data[index[i]] # im_d:存放着第index[i]张图片的路径信息 im_l = im_labels[index[i]] # im_l:存放着对应图片标签信息 # # 获取当前的图片数据 方式一: # data = cv2.imread(im_d) # # 创建样本 # ex = tf.train.Example( # features=tf.train.Features( # feature={ # "image": tf.train.Feature( # bytes_list=tf.train.BytesList( # value=[data.tobytes()])), # 需要打包成bytes类型 # "label": tf.train.Feature( # int64_list=tf.train.Int64List( # value=[im_l])), # } # ) # ) # 获取当前的图片数据 方式二:相对于方式一,打包文件占用空间小了一半多 data = tf.gfile.FastGFile(im_d, "rb").read() ex = tf.train.Example( features=tf.train.Features( feature={ "image": tf.train.Feature( bytes_list=tf.train.BytesList( value=[data])), # 此时的data已经是bytes类型 "label": tf.train.Feature( int_list=tf.train.IntList( value=[im_l])), } ) ) # 写入将序列化之后的样本 writer.write(ex.SerializetoString()) # 关闭写入器 writer.close()

TFReord文件的读取案列

import tensorflow as tf import cv2 def read_TFRecord(file_list, batch_size=): """ 读取TFRecord文件 :param file_list: 存放TFRecord的文件名,List :param batch_size: 每次读取图片数量 :return: 解析后图片及对应的标签 """ file_queue = tf.train.string_input_producer(file_list, num_epochs=None, shuffle=True) reader = tf.TFRecordReader() _, ex = reader.read(file_queue) batch = tf.train.shuffle_batch([ex], batch_size, capacity=batch_size * 10, min_after_dequeue=batch_size * 5) feature = { 'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64) } example = tf.parse_example(batch, features=feature) images = tf.decode_raw(example['image'], tf.uint) images = tf.reshape(images, [-1, 32, 32, 3]) return images, example['label'] def main(): # filelist = ['data/train.tfrecord'] filelist = ['data/test.tfrecord'] images, labels = read_TFRecord(filelist, 2) with tf.Session() as sess: sess.run(tf.local_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: while not coord.should_stop(): for i in range(): image_bth, _ = sess.run([images, labels]) print(_) cv2.imshow("image_0", image_bth[0]) cv2.imshow("image_1", image_bth[1]) break except tf.errors.OutOfRangeError: print('read done') finally: coord.request_stop() coord.join(threads) cv2.waitKey(0) cv2.destroyAllWindows() if __name__ == "__main__": main()

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐