資源簡介
c++版神經網絡的實現,內部實現了矩陣運算,前向傳播,反向傳播的基本邏輯,使用的是交叉熵損失函數,二分類問題,數據集是64*64*3的貓的圖片,已被展成12288向量存于txt文件中,由于數據量大訓練時間可能較長

代碼片段和文件信息
#include?“layer.h“
#include?
#include?
#include?
layer::layer(int?input_num?int?output_num?string?activation?=?“relu“)
:?Z()
?input()
{
//初始化權重矩陣和偏差矩陣
weight.init_matrix(input_num?output_num);
bias.init_matrix(1?output_num);
//使用Xavier算法初始化系數矩陣
//seed=2018?保證結果穩定性
default_random_engine?generator(2018);
uniform_real_distribution?dis(-sqrt(6.0?/?(input_num?+?output_num))?sqrt(6.0?/?(input_num?+?output_num)));
for?(int?i?=?0;?i? for?(int?j?=?0;?j? {
weight.write(i?j?dis(generator));
}
this->activation?=?activation;
}
layer::~layer()
{
}
_Matrix?layer::forward(const?_Matrix?&?input)
{
this->input?=?input;
Z?=?input*weight?+?bias;
//cout?<“Z:“?< //print1(Z);
_Matrix?A;
if?(activation?==?“relu“)
A?=?relu(Z);
else?if?(activation?==?“sigmoid“)
A?=?sigmoid(Z);
else
A?=?Z;
//cout?<“A:“?< //print1(A);
return?A;
}
/*
grads_A:?下一層的梯度值
返回這一層的梯度值
*/
const?_Matrix?layer::backward(const?_Matrix?&?grads_A?float?learning_rate?=?1.)
{
_Matrix?grads;
/*求梯度*/
_Matrix?grads_Z;
if?(activation?==?“relu“)
grads_Z?=?d_relu(Z);
else?if?(activation?==?“sigmoid“)
grads_Z?=?d_sigmoid(Z);
/*
前一層(離輸入近)的grads_A
d_A_L_1?=?d_Z_L?*?W_L
*/
grads?=?grads_A.multiply(grads_Z)?*?(weight.T());
/*參數更新*/
/*
更新公式:?W?=?W?-?lr/m*grads_A*grads_Z*X
??b?=?b?-?lr/m*grads_A*grads_Z
*/
//cout?<“grads_Z“?< //print1(grads_Z);
//cout?<“weight“?< //print1(weight);
assert(input.rows()?>?0?&&?input.cols()?>?0);
int?m?=?input.rows();
weight?=?weight?-?((input.T()?*?(grads_A.multiply(grads_Z))*learning_rate*(1.?/?m)));
bias?=?bias?-?(((grads_A.multiply(grads_Z))*learning_rate*(1.?/?m)).sum(0));
return?grads;
}
/*activation*/
_Matrix?layer::sigmoid(_Matrix?&?z)?const
{
return?((-z).calc_exp()?+?1).reciprocal();
}
_Matrix?layer::relu(_Matrix&?z)?const
{
_Matrix?C(z.rows()?z.cols());
for?(int?i?=?0;?i? for?(int?j?=?0;?j? {
C.write(i?j?max(0.0?z.read(i?j)));
}
return?C;
}
_Matrix?layer::d_sigmoid(_Matrix&?z)?const
{
_Matrix?C?=?sigmoid(z);
return?C.multiply(-C?+?1);
}
_Matrix?layer::d_relu(_Matrix&?z)?const
{
_Matrix?C(z.rows()?z.cols());
for?(int?i?=?0;?i? for?(int?j?=?0;?j? {
if?(z.read(ij)?>?0)
C.write(i?j?1.);
else
C.write(i?j?0.);
}
return?C;
}
?屬性????????????大小?????日期????時間???名稱
-----------?---------??----------?-----??----
?????目錄???????????0??2018-06-19?15:12??NN\
?????目錄???????????0??2018-06-19?15:12??NN\Debug\
?????文件??????185344??2018-06-18?22:33??NN\Debug\honework.exe
?????文件?????1761180??2018-06-18?22:33??NN\Debug\honework.ilk
?????文件?????2198528??2018-06-18?22:33??NN\Debug\honework.pdb
?????目錄???????????0??2018-06-19?15:12??NN\honework\
?????目錄???????????0??2018-06-19?15:12??NN\honework\Debug\
?????文件??????207350??2018-06-18?22:33??NN\honework\Debug\la
?????文件??????312373??2018-06-18?22:01??NN\honework\Debug\Network.obj
?????文件??????209718??2018-06-18?19:37??NN\honework\Debug\_Matrix.obj
?????文件????????2712??2018-06-18?22:33??NN\honework\Debug\honework.log
?????目錄???????????0??2018-06-19?15:12??NN\honework\Debug\honework.tlog\
?????文件???????61824??2018-06-18?22:33??NN\honework\Debug\honework.tlog\CL.read.1.tlog
?????文件????????6758??2018-06-18?22:33??NN\honework\Debug\honework.tlog\CL.write.1.tlog
?????文件????????2438??2018-06-18?22:33??NN\honework\Debug\honework.tlog\cl.command.1.tlog
?????文件?????????157??2018-06-18?22:33??NN\honework\Debug\honework.tlog\honework.lastbuildstate
?????文件????????3172??2018-06-18?22:33??NN\honework\Debug\honework.tlog\li
?????文件????????7124??2018-06-18?22:33??NN\honework\Debug\honework.tlog\li
?????文件?????????668??2018-06-18?22:33??NN\honework\Debug\honework.tlog\li
?????文件??????383549??2018-06-18?21:58??NN\honework\Debug\main.obj
?????文件??????453632??2018-06-18?22:33??NN\honework\Debug\vc120.idb
?????文件??????765952??2018-06-18?22:33??NN\honework\Debug\vc120.pdb
?????文件????????2600??2018-06-18?22:33??NN\honework\la
?????文件?????????779??2018-06-17?22:57??NN\honework\la
?????文件????????3312??2018-06-18?22:01??NN\honework\Network.cpp
?????文件?????????430??2018-06-18?18:45??NN\honework\Network.h
?????文件????????8531??2018-06-18?19:37??NN\honework\_Matrix.cpp
?????文件????????1539??2018-06-18?19:24??NN\honework\_Matrix.h
?????文件????????4399??2018-06-16?17:13??NN\honework\honework.vcxproj
?????文件????????1512??2018-06-16?17:13??NN\honework\honework.vcxproj.filters
?????文件?????????165??2018-06-16?15:54??NN\honework\honework.vcxproj.user
............此處省略131個文件信息
- 上一篇:C++編程有限元公式
- 下一篇:Essential C++中文版(全)
評論
共有 條評論