IT技术互动交流平台

LSTM神经网络的详细推导及C++实现

来源:IT165收集??发布日期:2016-10-18 20:56:54

LSTM隐层神经元结构:
这里写图片描述

LSTM隐层神经元详细结构:
这里写图片描述

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

//让程序自己学会是否需要进位,从而学会加法

#include 'iostream'
#include 'math.h'
#include 'stdlib.h'
#include 'time.h'
#include 'vector'
#include 'assert.h'
using namespace std;

#define innode  2       //输入结点数,将输入2个加数
#define hidenode  26    //隐藏结点数,存储“携带位”
#define outnode  1      //输出结点数,将输出一个预测数字
#define alpha  0.1      //学习速率
#define binary_dim 8    //二进制数的最大长度

#define randval(high) ( (double)rand() / RAND_MAX * high )
#define uniform_plus_minus_one ( (double)( 2.0 * rand() ) / ((double)RAND_MAX + 1.0) - 1.0 )  //均匀随机分布


int largest_number = ( pow(2, binary_dim) );  //跟二进制最大长度对应的可以表示的最大十进制数

//激活函数
double sigmoid(double x) 
{
    return 1.0 / (1.0 + exp(-x));
}

//激活函数的导数,y为激活函数值
double dsigmoid(double y)
{
    return y * (1.0 - y);  
}           

//tanh的导数,y为tanh值
double dtanh(double y)
{
    return 1.0 - y * y;  
}

//将一个10进制整数转换为2进制数
void int2binary(int n, int *arr)
{
    int i = 0;
    while(n)
    {
        arr[i++] = n % 2;
        n /= 2;
    }
    while(i < binary_dim)
        arr[i++] = 0;
}

class RNN
{
public:
    RNN();
    virtual ~RNN();
    void train();

public:
    double W_I[innode][hidenode];     //连接输入与隐含层单元中输入门的权值矩阵
    double U_I[hidenode][hidenode];   //连接上一隐层输出与本隐含层单元中输入门的权值矩阵
    double W_F[innode][hidenode];     //连接输入与隐含层单元中遗忘门的权值矩阵
    double U_F[hidenode][hidenode];   //连接上一隐含层与本隐含层单元中遗忘门的权值矩阵
    double W_O[innode][hidenode];     //连接输入与隐含层单元中遗忘门的权值矩阵
    double U_O[hidenode][hidenode];   //连接上一隐含层与现在时刻的隐含层的权值矩阵
    double W_G[innode][hidenode];     //用于产生新记忆的权值矩阵
    double U_G[hidenode][hidenode];   //用于产生新记忆的权值矩阵
    double W_out[hidenode][outnode];  //连接隐层与输出层的权值矩阵

    double *x;             //layer 0 输出值,由输入向量直接设定
    //double *layer_1;     //layer 1 输出值
    double *y;             //layer 2 输出值
};

void winit(double w[], int n) //权值初始化
{
    for(int i=0; i I_vector;      //输入门
    vector F_vector;      //遗忘门
    vector O_vector;      //输出门
    vector G_vector;      //新记忆
    vector S_vector;      //状态值
    vector h_vector;      //输出值
    vector y_delta;        //保存误差关于输出层的偏导

    for(epoch=0; epoch<11000; epoch++)  //训练次数
    {
        double e = 0.0;  //误差

        int predict[binary_dim];               //保存每次生成的预测值
        memset(predict, 0, sizeof(predict));

        int a_int = (int)randval(largest_number/2.0);  //随机生成一个加数 a
        int a[binary_dim];
        int2binary(a_int, a);                 //转为二进制数

        int b_int = (int)randval(largest_number/2.0);  //随机生成另一个加数 b
        int b[binary_dim];
        int2binary(b_int, b);                 //转为二进制数

        int c_int = a_int + b_int;            //真实的和 c
        int c[binary_dim];
        int2binary(c_int, c);                 //转为二进制数

        //在0时刻是没有之前的隐含层的,所以初始化一个全为0的
        double *S = new double[hidenode];     //状态值
        double *h = new double[hidenode];     //输出值

        for(i=0; i=0 ; p--)
        {
            x[0] = a[p];
            x[1] = b[p];

            //当前隐藏层
            double *in_gate = I_vector[p];     //输入门
            double *out_gate = O_vector[p];    //输出门
            double *forget_gate = F_vector[p]; //遗忘门
            double *g_gate = G_vector[p];      //新记忆
            double *state = S_vector[p+1];     //状态值
            double *h = h_vector[p+1];         //隐层输出值

            //前一个隐藏层
            double *h_pre = h_vector[p];   
            double *state_pre = S_vector[p];

            for(k=0; k=0; k--)
                cout << predict[k];
            cout << endl;

            cout << 'true:' ;
            for(k=binary_dim-1; k>=0; k--)
                cout << c[k];
            cout << endl;

            int out = 0;
            for(k=binary_dim-1; k>=0; k--)
                out += predict[k] * pow(2, k);
            cout << a_int << ' + ' << b_int << ' = ' << out << endl << endl;
        }

        for(i=0; i

这里写图片描述

参考:
http://lib.csdn.net/article/deeplearning/45380
http://www.open-open.com/lib/view/open1440843534638.html

延伸阅读:

Tag标签: C++?? 神经网络??
  • 专题推荐

About IT165 - 广告服务 - 隐私声明 - 版权申明 - 免责条款 - 网站地图 - 网友投稿 - 联系方式
本站内容来自于互联网,仅供用于网络技术学习,学习中请遵循相关法律法规