微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

TFRecord 解析 3-D 特征

如何解决TFRecord 解析 3-D 特征

我对 this 有类似的问题,但如果我的特征形状是 3-D 呢?例如,它不是价格 (1,288),而是 (1,288,3)tf.io.FixedLenFeature()的形状我应该怎么写?是 tf.io.FixedLenFeature(shape=[288,3],tf.float32)tf.io.FixedLenFeature(shape=[864],tf.float32) 还是其他什么?谢谢!

解决方法

有几种方法可以做到这一点。一种是使用 BytesList 功能

def _bytes_feature(value):
  return tf.train.Feature(
    bytes_list=tf.train.BytesList(value=[value]))

另一个正在使用 FloatList 功能

def _float_feature(value):
  return tf.train.Feature(
    float_list=tf.train.FloatList(value=value))

示例

import numpy as np
import tensorflow as tf


# make some data
img = np.random.normal(size=(5,3))
img = img.astype(np.float32)

writer = tf.io.TFRecordWriter("/tmp/data.tfrec")

example = tf.train.Example(
  features=tf.train.Features(
    feature = {
      "img_b": _bytes_feature(img.tobytes()),"img_f": _float_feature(img.flatten()),}))

writer.write(example.SerializeToString())
writer.close()

def parse_fn(example):
  features = {
    "img_b": tf.io.FixedLenFeature([],tf.string),"img_f": tf.io.FixedLenFeature([5,3],tf.float32),}
  parsed_example = tf.io.parse_single_example(example,features)
  img_b = tf.io.decode_raw(
      parsed_example['img_b'],out_type=tf.float32)
  img_b = tf.reshape(img_b,(5,3))
  img_f = parsed_example['img_f']
  return img_b,img_f

让我们导入数据,看看它是否有效

dataset = tf.data.TFRecordDataset(["/tmp/data.tfrec"])
dataset = dataset.map(parse_fn).batch(1)

arr_b,arr_f = next(iter(dataset))

np.testing.assert_almost_equal(arr_b.numpy(),arr_f.numpy())
# passes

这假设您知道图像的形状并且它们都是相同的形状。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。