微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

tensorflow学习笔记之tfrecord文件的生成与读取

这篇文章主要介绍了tensorflow学习笔记之tfrecord文件生成与读取,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

训练模型时,我们并不是直接将图像送入模型,而是先将图像转换为tfrecord文件,再将tfrecord文件送入模型。为进一步理解tfrecord文件,本例先将6幅图像及其标签转换为tfrecord文件,然后读取tfrecord文件,重现6幅图像及其标签

1、生成tfrecord文件

import os import numpy as np import tensorflow as tf from PIL import Image filenames = [ 'images/cat/1.jpg', 'images/cat/2.jpg', 'images/dog/1.jpg', 'images/dog/2.jpg', 'images/pig/1.jpg', 'images/pig/2.jpg',] labels = {'cat':0, 'dog':1, 'pig':2} def int64_feature(values): if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) with tf.Session() as sess: output_filename = os.path.join('images/train.tfrecords') with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: for filename in filenames: #读取图像 image_data = Image.open(filename) #图像灰度化 image_data = np.array(image_data.convert('L')) #将图像转化为bytes image_data = image_data.tobytes() #读取label label = labels[filename.split('/')[-2]] #生成protocol数据类型 example = tf.train.Example(features=tf.train.Features(feature={'image': bytes_feature(image_data), 'label': int64_feature(label)})) tfrecord_writer.write(example.SerializetoString())

2、读取tfrecord文件

import tensorflow as tf import matplotlib.pyplot as plt from PIL import Image # 根据文件生成一个队列 filename_queue = tf.train.string_input_producer(['images/train.tfrecords']) reader = tf.TFRecordReader() # 返回文件名和文件 _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) # 获取图像数据 image = tf.decode_raw(features['image'], tf.uint8) # 恢复图像原始尺寸[高,宽] image = tf.reshape(image, [60, 160]) # 获取label label = tf.cast(features['label'], tf.int32) with tf.Session() as sess: # 创建一个协调器,管理线程 coord = tf.train.Coordinator() # 启动QueueRunner, 此时文件名队列已经进队 threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(6): image_b, label_b = sess.run([image, label]) img = Image.fromarray(image_b, 'L') plt.imshow(img) plt.axis('off') plt.show() print(label_b) # 通知其他线程关闭 coord.request_stop() # 其他所有线程关闭之后,这一函数才能返回 coord.join(threads)

到此这篇关于tensorflow学习笔记之tfrecord文件生成与读取的文章就介绍到这了,更多相关tfrecord文件生成与读取内容搜索编程之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程之家!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐