微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何解析 tfrecords 的值,将其转换为其他内容,然后放回去

如何解决如何解析 tfrecords 的值,将其转换为其他内容,然后放回去

我有一个 tfrecords 文件,不幸的是它没有标签值。它有两个值: 图片ID

因此,要获取标签,我需要查看 Pandas DataFrame 中的 Id 以驱动其值,然后根据其值创建标签,例如:

if df[df['id'] == Id]['value'] > threshold_value:
   label = 1
else:
   label = 0

但是,我不知道如何将 Tensor("ParseSingleExample/ParseExample/ParseExampleV2:1",shape=(),dtype=string) 转换为 python 字符串。

在这里复制了解析 tfrecords 的代码

def parse_tf_records(example_input):
    feature_description_dict = {
        IMAGE_FIELD: tf.io.FixedLenFeature(IMAGE_SIZE,tf.float32),ID_FIELD: tf.io.FixedLenFeature([],tf.string)
    }
    parsed_example = tf.io.parse_single_example(example_input,feature_description_dict)
    return parsed_example

def read_tfrecord(example_input):
    parsed_example = parse_tf_records(example_input)
    image_data = parsed_example[IMAGE_FIELD]
    id_data = parsed_example[ID_FIELD]
    
    # label = Look for the value of id_data in a Pandas Dataframe and compare the value to threshold_value
    label_data = tf.cast(label,tf.int32)
    return image_data,label_data

我使用的是 tensorflow 2.4.1。如果有人可以帮助我,真的很感激。谢谢。

解决方法

好的,tf.py_fuction 就是答案。这是我的代码,它运行良好:

def get_label(tf_id):
    _id = tf_id.numpy().decode('utf-8')
    if df[df['id'] == _id]['value'] > threshold_value:
       label = 1
    else:
       label = 0
    return tf.cast(label,tf.int32)

def read_tfrecord(example_input):
    parsed_example = parse_tf_records(example_input)
    image_data = parsed_example[IMAGE_FIELD]
    id_data = parsed_example[ID_FIELD]
    label = tf.py_function(get_label,[id_data],tf.int32)
    return image_data,label

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。