微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在tf.data

如何解决如何在tf.data

对于一个项目,我正在使用tf.data.Dataset编写输入管道。 输入是图像RGB 标签是图像中用于生成热图的对象的2D坐标的列表

这里是MWE重现该问题。

 def encode_images(image,label):
        """

        Parameters
        ----------
        image
        label

        Returns
        -------

        """
        # load image
        # here the normal code
        # img_contents = tf.io.read_file(image)
        # # decode the image
        # img = tf.image.decode_jpeg(img_contents,channels=3)
        # img = tf.image.resize(img,(256,256))
        # img = tf.cast(img,tf.float32)

        # this is just for testing
        image = tf.random.uniform(
            (256,256,3),minval=0,maxval=255,dtype=tf.dtypes.float32,seed=None,name=None
        )
        return image,label

    def generate_heatmap(image,label):
        """

        Parameters
        ----------
        image
        label

        Returns
        -------

        """

        start = 0.5
        sigma=3
        img_shape = (image.shape[0],image.shape[1] )
        density_map = np.zeros(img_shape,dtype=np.float32)
        for center_x,center_y in label[0]:
            for v_y in range(img_shape[0]):
                for v_x in range(img_shape[1]):
                    x = start + v_x
                    y = start + v_y
                    d2 = (x - center_x) * (x - center_x) + (y - center_y) * (y - center_y)
                    exp = d2 / (2.0 * sigma**2)
                    if exp > math.log(100):
                        continue
                    density_map[v_y,v_x] = math.exp(-exp)
        return density_map


    X = ["img1.png","img2.png","img3.png","img4.png","img5.png"]
    y = [[[2,3],[100,120],120]],[[2,[2,1]],[10,10],[11,12]],12],2]],120]]
         ]
    dataset = tf.data.Dataset.from_tensor_slices((X,tf.ragged.constant(y)))
    dataset = dataset.map(encode_images,num_parallel_calls=8)
    dataset = dataset.map(generate_heatmap,num_parallel_calls=8)
    dataset = dataset.batch(1,drop_remainder=False)

问题是在generate_heatmap()函数中,我使用numpy数组通过索引修改元素,这在tensorflow中是不可能的。我尝试遍历标签张量,到目前为止,在张量流中是不可能的。另一件事是,tf.data.Dataset中似乎没有启用急切模式!有什么建议可以解决!我认为在pytorch中,这样的代码可以很快完成,而不会受苦:)!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?