如何解决如何在tf.data
对于一个项目,我正在使用tf.data.Dataset编写输入管道。 输入是图像RGB 标签是图像中用于生成热图的对象的2D坐标的列表
这里是MWE重现该问题。
def encode_images(image,label):
"""
Parameters
----------
image
label
Returns
-------
"""
# load image
# here the normal code
# img_contents = tf.io.read_file(image)
# # decode the image
# img = tf.image.decode_jpeg(img_contents,channels=3)
# img = tf.image.resize(img,(256,256))
# img = tf.cast(img,tf.float32)
# this is just for testing
image = tf.random.uniform(
(256,256,3),minval=0,maxval=255,dtype=tf.dtypes.float32,seed=None,name=None
)
return image,label
def generate_heatmap(image,label):
"""
Parameters
----------
image
label
Returns
-------
"""
start = 0.5
sigma=3
img_shape = (image.shape[0],image.shape[1] )
density_map = np.zeros(img_shape,dtype=np.float32)
for center_x,center_y in label[0]:
for v_y in range(img_shape[0]):
for v_x in range(img_shape[1]):
x = start + v_x
y = start + v_y
d2 = (x - center_x) * (x - center_x) + (y - center_y) * (y - center_y)
exp = d2 / (2.0 * sigma**2)
if exp > math.log(100):
continue
density_map[v_y,v_x] = math.exp(-exp)
return density_map
X = ["img1.png","img2.png","img3.png","img4.png","img5.png"]
y = [[[2,3],[100,120],120]],[[2,[2,1]],[10,10],[11,12]],12],2]],120]]
]
dataset = tf.data.Dataset.from_tensor_slices((X,tf.ragged.constant(y)))
dataset = dataset.map(encode_images,num_parallel_calls=8)
dataset = dataset.map(generate_heatmap,num_parallel_calls=8)
dataset = dataset.batch(1,drop_remainder=False)
问题是在generate_heatmap()
函数中,我使用numpy数组通过索引修改元素,这在tensorflow中是不可能的。我尝试遍历标签张量,到目前为止,在张量流中是不可能的。另一件事是,tf.data.Dataset
中似乎没有启用急切模式!有什么建议可以解决!我认为在pytorch中,这样的代码可以很快完成,而不会受苦:)!
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。