如何解决Tensorflow 2.x –具有周围单元格平均值的张量
我正在尝试在Tensorflow 2.x中编写自定义损失函数,以鼓励在输出空间(2D矩阵)中实现渐变。因此,作为损失函数的一个组成部分,我想获取一个Tensor并返回一个Tensor,其中每个像元代表原始张量中相应相邻像元的平均值。
例如,取左上方的单元格:6.3 =(7 + 9 + 3)/ 3。或取中间格:4.5 =(1 + 3 + 5 + 7 + 8 + 6 + 4 + 2)/ 8。
考虑以下代码:
def gradient_encouraging_loss(y_true: Tensor,y_pred: Tensor) -> Tensor:
gradient_loss: Tensor = tf.divide(tf.reduce_sum(tf.abs(tf.subtract(
y_pred,tensor_harmonic(y_pred)
))),tf.cast(tf.size(y_pred),tf.float32))
return gradient_loss
我将如何实施tensor_harmonic()
? y_pred
的形状为(None,X,Y)
,其中X和Y是输出矩阵尺寸。
解决方法
大多数情况下,您都可以使用2D卷积运算来完成此操作,但是随后您需要特别注意外部值。这是您的操作方法:
import tensorflow as tf
def surround_average(x):
x = tf.convert_to_tensor(x)
dt = x.dtype
# Compute surround sum
filter = tf.constant([[1,1,1],[1,1]],dtype=dt)
x2 = x[tf.newaxis,:,tf.newaxis]
filter2 = filter[:,tf.newaxis,tf.newaxis]
y2 = tf.nn.conv2d(x2,filter2,strides=1,padding='SAME')
y = y2[0,0]
# Make matrix of number of surrounding elements
s = tf.shape(x)
d = tf.fill(s - 2,tf.constant(8,dtype=dt))
d = tf.pad(d,[[0,0],constant_values=5)
top_row = tf.concat([[3],tf.fill([s[1] - 2],tf.constant(5,dtype=dt)),[3]],axis=0)
d = tf.concat([[top_row],d,[top_row]],axis=0)
# Return average
return y / d
# Test
x = tf.reshape(tf.range(24.),(4,6))
print(x.numpy())
# [[ 0. 1. 2. 3. 4. 5.]
# [ 6. 7. 8. 9. 10. 11.]
# [12. 13. 14. 15. 16. 17.]
# [18. 19. 20. 21. 22. 23.]]
print(surround_average(x).numpy())
# [[ 4.6666665 4.6 5.6 6.6 7.6 8.333333 ]
# [ 6.6 7. 8. 9. 10. 10.4 ]
# [12.6 13. 14. 15. 16. 16.4 ]
# [14.666667 15.4 16.4 17.4 18.4 18.333334 ]]
编辑:上面的代码可以进行一些小的改动即可用于批次矩阵:
import tensorflow as tf
def surround_average_batch(x):
x = tf.convert_to_tensor(x)
dt = x.dtype
# Compute surround sum
filter = tf.constant([[1,dtype=dt)
x2 = tf.expand_dims(x,axis=-1)
filter2 = filter[:,padding='SAME')
y = tf.squeeze(y2,axis=-1)
# Make matrix of number of surrounding elements
s = tf.shape(x)
d = tf.fill(s[1:] - 2,tf.fill([s[2] - 2],(2,4,3))
print(x.numpy())
# [[[ 0. 1. 2.]
# [ 3. 4. 5.]
# [ 6. 7. 8.]
# [ 9. 10. 11.]]
#
# [[12. 13. 14.]
# [15. 16. 17.]
# [18. 19. 20.]
# [21. 22. 23.]]]
print(surround_average_batch(x).numpy())
# [[[ 2.6666667 2.8 3.3333333]
# [ 3.6 4. 4.4 ]
# [ 6.6 7. 7.4 ]
# [ 7.6666665 8.2 8.333333 ]]
#
# [[14.666667 14.8 15.333333 ]
# [15.6 16. 16.4 ]
# [18.6 19. 19.4 ]
# [19.666666 20.2 20.333334 ]]]
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。