微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

训练感知器模型

如何解决训练感知器模型

The Coding Train awesome video 之后,我使用感知器模型在 XOR 门上训练神经网络。

我有 2 个输入和 1 个输出

我的大部分代码与视频非常相似,除了我使用了不同的异或门数据集。

我在训练模型后遇到了问题,即使经过 10 万个训练数据,它也没有猜出正确答案,我不知道为什么。

这是我的完整代码

感知器.h

#include <stdio.h>      /* printf,scanf,puts,NULL */
#include <stdlib.h>     /* srand,rand */
#include <time.h>       /* time */
#include <vector>

using std::vector;

class Perceptron {
private:
    vector<float> weights;
    float lr = 0.15;
public:
    Perceptron() {
        // initialize the weights randomly 
        srand(time(NULL));
        for (int i = 0; i < 2; i++)
        {
            int x = -1 + rand() % (3);
            while (x == 0)
                x = -1 + rand() % (3);
            weights.push_back(x);
        }
    }
    int sign(float n) // activation function
    {
        if (n >= 0)
            return 1;
        else
            return 0;
    }
    int guess(vector<float> inputs)
    {
        float sum = 0;
        for (int i = 0; i < weights.size(); i++)
        {
            sum += inputs[i] * weights[i];
        }

        int output = sign(sum);
        return output;
    }
    void train(vector<float> inputs,int target)
    {
        int guess1 = guess(inputs);
        int error = target - guess1;
        for (int i = 0; i < weights.size(); i++) {
            weights[i] += error * inputs[i] * lr;
        }
    }
};

训练.h

#include <stdio.h>      /* printf,rand */
#include <time.h>       /* time */
#include <iostream>  
#include <vector>

using std::vector;

class XOR {
private:
    float x1;
    float x2;
    float label;

public:
    XOR() {

        x1 = rand() % (2);
        x2 = rand() % (2);

        if (x1 == 0 && x2 == 0)
            label = 0;
        else if (x1 == 0 && x2 == 1)
            label = 1;
        else if (x1 == 1 && x2 == 0)
            label = 1;
        else
            label = 0;
    }

    float getX1(){ return x1; };
    float getX2() { return x2; };
    float getLabel() { return label; };

    vector<float> getInputs() {
        return vector<float> {x1,x2};
    }
    float getTarget() {
        return label;
    }
    
};

main.cpp

#include "Perceptron.h"
#include <iostream>
#include "Training.h"

using std::cout;
using std::endl;

int main()
{
    Perceptron brain;
    vector<XOR> trainingData(100);
    for (int i = 0; i < 100; i++)
    {
        brain.train(trainingData[i].getInputs(),trainingData[i].getTarget());
    }
    
    vector<float> inputs = { 0,0 };
    vector<float> inputs2 = { 0,1 };
    vector<float> inputs3 = { 1,0 };
    vector<float> inputs4 = { 1,1 };

    int guess1 = 0;

    guess1 = brain.guess(inputs);
    cout  << "guess: " << guess1 << endl;

    guess1 = brain.guess(inputs2);
    cout << "guess: " << guess1 << endl;

    guess1 = brain.guess(inputs3);
    cout << "guess: " << guess1 << endl;

    guess1 = brain.guess(inputs4);
    cout << "guess: " << guess1 << endl;

    return 0;
}

您可以在您的机器上运行此代码并自行测试,如果您注意到多次运行它会得到不同的输出,这更奇怪。

解决方法

答案是使用值为 1 的附加偏置权重,以防输入为 0,0。

另外,我用来学习 XOR 操作的模型太基础了,因为它不适用于 2 个以上的场景。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?