如何解决使用 tf.gradienttape, loss = mse or huber or cross-entropy, y_true=constant, y_pred=my_network
1.演员评论模型
使用 tf.gradienttape,
loss_function = mse or huber or cross-entropy
y_true=constant
,y_pred=my_network_output
,例如y_pred = my_netword(input)
例如loss_actor = tf.losses.MSE(y_true,y_pred)
或其他类似的东西,比如
loss_actor = Huber(y_true,action_probs)
loss_actor = cross_entropy(y_true,y_pred)
意图,y_true = constant
,是我的网络收敛的
y_pred = my_network(input)
2.问题
最终,我将我的问题浓缩如下
y_true,我用的是人工数据(假数据)
if n < 130:
self.ret = 0.1
elif n >= 130:
self.ret = -0.1
其中,n从124开始,n最终到inf
这里,self.ret 是我的 y_true,我的标签
我想要,当我提供 self.ret,即 y_true,= 0.1
时,我的网络输出 [0.0,1.0] 代表 Invest
当我输入 self.ret,即 y_true,= (- 0.1)
时,我的网络输出 [1.0,0.0] 代表 Uninvest
但是当我将网络输入作为真实的股票数据提供时,这个模型会出错
- 当 n > 130 时,my_network 输出 [0.0,1.0],永远,代表投资,但我想要 [1.0,0.0],即取消投资
4. tf.gradienttape 的错误使用
我知道,问题在于我以错误的方式使用 tf.gradienttape
tf.gradienttape 没有正确计算梯度
我的代码是:
if n < 130:
self.ret = 0.1
elif n >= 130:
self.ret = -0.1
# n starts from 124,self.ret is y_true
with tf.GradientTape(persistent=True) as tape:
tape.watch(self.actor.trainable_variables)
#y_pred = action_probs = self.actor(self.get_input(n))[0]
action_probs = self.actor(self.get_input(n))[0] # i.e. y_pred
#''' # below use huber or mse as loss func
y_true = tf.nn.softmax([0.0,1e2] * tf.stop_gradient(self.ret))
#loss_actor = tf.losses.MSE(y_true,action_probs)
huber = tf.keras.losses.Huber()
loss_actor = huber(y_true,action_probs)
#'''
''' # below use cross-entropy as loss func
r_t = self.ret
delta_t = 1.0
prediction = tf.keras.backend.clip(tf.nn.softmax([0.0,1e2] * tf.stop_gradient(self.ret if NO_CRITIC else r_t)),eps,1 - eps)
log_probabilities = action_probs * tf.keras.backend.log(prediction)
# self.ret or r_t is y_true
loss_actor = tf.keras.backend.sum(-log_probabilities * tf.stop_gradient(delta_t))
#'''
loss_actor_gradients = tape.gradient(loss_actor,self.actor.trainable_variables)
self.opt_actor.apply_gradients(zip(loss_actor_gradients,self.actor.trainable_variables))
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。