如何解决张量流中的二进制搜索和内插
我不知道您的错误来源,但我可以告诉您,这tf.while_loop
很可能非常缓慢。您可以实现没有循环的线性插值,如下所示:
import numpy as np
import tensorflow as tf
xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis')
yaxis = tf.placeholder(tf.float32, shape=100, name='yaxis')
query = tf.placeholder(tf.float32, name='query')
# Add additional elements at the beginning and end for extrapolation
xaxis_pad = tf.concat([[tf.minimum(query - 1, xaxis[0])], xaxis, [tf.maximum(query + 1, xaxis[-1])]], axis=0)
yaxis_pad = tf.concat([yaxis[:1], yaxis, yaxis[-1:]], axis=0)
# Find the index of the interval containing query
cmp = tf.cast(query >= xaxis_pad, dtype=tf.int32)
diff = cmp[1:] - cmp[:-1]
idx = tf.argmin(diff)
# Interpolate
alpha = (query - xaxis_pad[idx]) / (xaxis_pad[idx + 1] - xaxis_pad[idx])
res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]
# Test with f(x) = 2 * x
q = 5.4
x = np.arange(100)
y = 2 * x
with tf.Session() as sess:
q_interp = sess.run(res, Feed_dict={xaxis: x, yaxis: y, query: q})
print(q_interp)
>>> 10.8
填充部分只是为了避免麻烦(如果您将值传递到范围之外),否则只是比较和查找值开始大于的问题query
。
解决方法
我正在尝试在张量流中插值一维张量(我实际上想要等效于np.interp)。由于找不到类似的tensorflow op,因此我必须自己执行插值。
第一步是在x值的排序列表中搜索y值中的相应索引,即执行二进制搜索。我尝试为此使用while循环,但出现了神秘的运行时错误。这是一些代码:
xaxis = tf.placeholder(tf.float32,shape=100,name='xaxis')
query = tf.placeholder(tf.float32,name='query')
with tf.name_scope("binsearch"):
up = tf.Variable(0,dtype=tf.int32,name='up')
mid = tf.Variable(0,name='mid')
down = tf.Variable(0,name='down')
done = tf.Variable(-1,name='done')
def cond(up,down,mid,done):
return tf.logical_and(done<0,up-down>1)
def body(up,done):
val = tf.gather(xaxis,mid)
done = tf.cond(val>query,tf.cond(tf.gather(xaxis,mid-1)<query,lambda:mid-1,lambda: -1),mid+1)>query,lambda:mid,lambda: -1) )
up = tf.cond(val>query,lambda: mid,lambda: up )
down = tf.cond(val<query,lambda: down )
with tf.control_dependencies([done,up,down]):
return up,(up+down)//2,done
up,done = tf.while_loop(cond,body,(xaxis.shape[0]-1,(xaxis.shape[0]-1)//2,-1))
这导致
AttributeError: 'int' object has no attribute 'name'
我正在Windows 7和tensorflow 1.1上使用具有gpu支持的Python 3.6。知道有什么问题吗?谢谢。
这是完整的堆栈跟踪:
AttributeError Traceback (most recent call last)
<ipython-input-185-693d3873919c> in <module>()
19 return up,done
20
---> 21 up,-1))
c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in while_loop(cond,loop_vars,shape_invariants,parallel_iterations,back_prop,swap_memory,name)
2621 context = WhileContext(parallel_iterations,name)
2622 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT,context)
-> 2623 result = context.BuildLoop(cond,shape_invariants)
2624 return result
2625
c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildLoop(self,pred,shape_invariants)
2454 self.Enter()
2455 original_body_result,exit_vars = self._BuildLoop(
-> 2456 pred,original_loop_vars,shape_invariants)
2457 finally:
2458 self.Exit()
c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in _BuildLoop(self,shape_invariants)
2404 structure=original_loop_vars,2405 flat_sequence=vars_for_body_with_tensor_arrays)
-> 2406 body_result = body(*packed_vars_for_body)
2407 if not nest.is_sequence(body_result):
2408 body_result = [body_result]
<ipython-input-185-693d3873919c> in body(up,done)
11 val = tf.gather(xaxis,mid)
12 done = tf.cond(val>query,---> 13 tf.cond(tf.gather(xaxis,14 tf.cond(tf.gather(xaxis,lambda: -1) )
15 up = tf.cond(val>query,lambda: up )
c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in cond(pred,fn1,fn2,name)
1746 context_f = CondContext(pred,pivot_2,branch=0)
1747 context_f.Enter()
-> 1748 _,res_f = context_f.BuildCondBranch(fn2)
1749 context_f.ExitResult(res_f)
1750 context_f.Exit()
c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildCondBranch(self,fn)
1666 real_v = sparse_tensor.SparseTensor(indices,values,dense_shape)
1667 else:
-> 1668 real_v = self._ProcessOutputTensor(v)
1669 result.append(real_v)
1670 return original_r,result
c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in _ProcessOutputTensor(self,val)
1624 """Process an output tensor of a conditional branch."""
1625 real_val = val
-> 1626 if val.name not in self._values:
1627 # Handle the special case of lambda: x
1628 self._values.add(val.name)
AttributeError: 'int' object has no attribute 'name'
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。