微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

DL4J:二进制化自动编码器的中间层以进行语义哈希

如何解决DL4J:二进制化自动编码器的中间层以进行语义哈希

我正在尝试使用DL4J为MNISTAutoencoder示例实现语义哈希。如何对中间层激活进行二值化?在理想情况下,我正在寻找对网络设置进行一些更改的方法,以使(几乎)开箱即用的中间层二进制激活。另外,我对一些“收据”满意以使当前的RELU激活二进制化。就泛化能力而言,这两种方法中哪一种是有利的?

我当前的网络设置为:

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
    .seed(12345)
    .weightinit(Weightinit.XAVIER)
    .updater(new AdaGrad(0.05))
    .activation(Activation.RELU)
    .l2(0.0001)
    .list()
    .layer(new DenseLayer.Builder().nIn(784).nOut(250)
               .build())
    .layer(new DenseLayer.Builder().nIn(250).nOut(10)
               .build())
    .layer(new DenseLayer.Builder().nIn(10).nOut(250)
               .build())
    .layer(new OutputLayer.Builder().nIn(250).nOut(784)
               .activation(Activation.LEAKYRELU)
               .lossFunction(LossFunctions.LossFunction.MSE)
               .build())
    .build();

30个纪元后,典型的中间层激活如下:

[[   11.3044,12.3678,7.3547,1.6518,1.0068,5.4340,2.1388,2.0708,2.5764]]
[[    9.9051,12.5345,11.1941,4.7900,1.2935,7.9786,4.1915,3.1802,7.5659]]
[[    6.4629,11.1013,10.8903,5.4528,0.8009,9.4881,3.6684,6.4524,7.2334]]
[[    2.3953,0.2429,3.7125,4.1561,0.8607,11.2486,7.0178,2.8771,2.1996]]
[[         0,1.6378,0.8993,0.3347,0.7708,3.7053,1.6704,2.1380]]
[[         0,1.5158,0.7937,0.8190,4.7548,0.0655,1.4635,1.8173]]
[[    6.8344,5.9989,10.1286,2.8528,1.1178,9.1865,10.3677,5.3564,4.3420]]
[[    7.0942,7.0364,4.8538,0.5096,0.0442,8.4336,8.2783,5.6474,3.8944]]
[[    3.6895,14.9696,6.5351,8.0446,12.7816,12.7445,7.8495,3.8600]]

解决方法

这可以通过向中间层分配自定义IActivation函数来建立。例如:

public static class ActivationBinary extends BaseActivationFunction {
    public INDArray getActivation(INDArray in,boolean training) {
        in.replaceWhere(Nd4j.ones(in.length()).muli(-1),new LessThan(0));
        in.replaceWhere(Nd4j.ones(in.length()),new GreaterThanOrEqual(0));
        return in;
    }

    public org.nd4j.common.primitives.Pair<INDArray,INDArray> backprop(INDArray in,INDArray epsilon) {
        this.assertShape(in,epsilon);
        Nd4j.getExecutioner().execAndReturn(new TanhDerivative(in,epsilon,in)); // tanh's gradient is a reasonable approximation
        return new org.nd4j.common.primitives.Pair(in,(Object)null);
    }

    public int hashCode() {
        return 1;
    }

    public boolean equals(Object obj) {
        return obj instanceof ActivationBinary;
    }

    public String toString() {
        return "Binary";
    }
}

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。