微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

DeepLearning4J 批量索引

如何解决DeepLearning4J 批量索引

我正在尝试利用 DeepLearning4j 进行强化学习。为了速度,我需要将 Bellman Batch 更新代码转换为矢量化代码

这是当前版本,速度较慢。

for (int i = 0; i < BATCH_SIZE; i++) {
    var index = ndarrayIndex.point(i);
    var qPred = this.policy.predict(states.get(index).reshape(1,STATE_SPACE));
    var qNew = this.target.predict(newStates.get(index).reshape(1,STATE_SPACE));
    var max = qNew.max(1);
    var maxDones = max.mul(this.gamma).mul(dones.get(index));
    var maxDonesRewards = maxDones.add(rewards.get(index));
    var actionIndex = as.get(index).getNumber(0).intValue();

    qPred.putScalar(actionIndex,maxDonesRewards.getNumber(0).doubleValue());
    features.putRow(i,states.get(index));
    labels.putRow(i,qPred);
}

在 Keras、Numpy 和 Python 中,我会像下面的代码一样索引它。提醒一下,这是一个功能性的 java,用于显示与我如何用 Python 编写它的关系。

var indices = Nd4j.linspace(DataType.INT32,BATCH_SIZE,1);

var features = Nd4j.zeros(BATCH_SIZE,STATE_SPACE);
var labels = Nd4j.zeros(BATCH_SIZE,actions.length);

var pred = this.policy.predict(states.reshape(-1,9));
var prednew = this.target.predict(newStates.reshape(-1,9));
var qMax = prednew.max(1);
    
// in python
// targetQ = qMax * gamma * dones[batch_indices] + rewards[batch_indices]
// this is also wrong.
var bestQ = qMax.mul(this.gamma).mul(dones.get(indices)).add(rewards.get(indices));
    
var oldQValues = pred.dup();
    
// set the q value for the batch and action by this line of code in numpy/keras
// qs[indices][action_batch] = q_new
oldQValues.put(as.get(indices),qMax);

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。