зеркало из https://github.com/microsoft/EdgeML.git
C reference (#171)
fixed errors in backward rnn pass and changed naming convention of zeta and nu
This commit is contained in:
Родитель
fa04d004eb
Коммит
9dd584338c
|
@ -21,8 +21,8 @@
|
|||
* @var uRank rank of U matrix
|
||||
* @var Bg pointer to bias for sigmoid
|
||||
* @var Bh pointer to bias for tanh
|
||||
* @var zeta first weight parameter for update from input from next step
|
||||
* @var nu second weight parameter for update from input from next step
|
||||
* @var sigmoid_zeta first weight parameter for update from input from next step
|
||||
* @var sigmoid_nu second weight parameter for update from input from next step
|
||||
*/
|
||||
typedef struct FastGRNN_LR_Params {
|
||||
float* mean;
|
||||
|
@ -35,8 +35,8 @@ typedef struct FastGRNN_LR_Params {
|
|||
unsigned uRank;
|
||||
float* Bg;
|
||||
float* Bh;
|
||||
float zeta;
|
||||
float nu;
|
||||
float sigmoid_zeta;
|
||||
float sigmoid_nu;
|
||||
} FastGRNN_LR_Params;
|
||||
|
||||
/**
|
||||
|
@ -82,8 +82,8 @@ int fastgrnn_lr(float* const hiddenState, unsigned hiddenDims,
|
|||
* @var U pointer U matrix
|
||||
* @var Bg pointer to bias for sigmoid
|
||||
* @var Bh pointer to bias for tanh
|
||||
* @var zeta first weight parameter for update from input from next step
|
||||
* @var nu second weight parameter for update from input from next step
|
||||
* @var sigmoid_zeta first weight parameter for update from input from next step
|
||||
* @var sigmoid_nu second weight parameter for update from input from next step
|
||||
*/
|
||||
typedef struct FastGRNN_Params {
|
||||
float* mean;
|
||||
|
@ -92,8 +92,8 @@ typedef struct FastGRNN_Params {
|
|||
float* U;
|
||||
float* Bg;
|
||||
float* Bh;
|
||||
float zeta;
|
||||
float nu;
|
||||
float sigmoid_zeta;
|
||||
float sigmoid_nu;
|
||||
} FastGRNN_Params;
|
||||
|
||||
/**
|
||||
|
|
|
@ -36,7 +36,7 @@ void v_add(float scalar1, const float* const vec1,
|
|||
float scalar2, const float* const vec2,
|
||||
unsigned len, float* const ret);
|
||||
|
||||
// point-wise vector division ret = vec2 / vec1
|
||||
// point-wise vector multiplication ret = vec2 * vec1
|
||||
void v_mult(const float* const vec1, const float* const vec2,
|
||||
unsigned len, float* const ret);
|
||||
|
||||
|
@ -44,6 +44,10 @@ void v_mult(const float* const vec1, const float* const vec2,
|
|||
void v_div(const float* const vec1, const float* const vec2,
|
||||
unsigned len, float* const ret);
|
||||
|
||||
// Return squared Euclidean distance between vec1 and vec2
|
||||
float l2squared(const float* const vec1,
|
||||
const float* const vec2, unsigned dim);
|
||||
|
||||
// Return index with max value, if tied, return first tied index.
|
||||
unsigned argmax(const float* const vec, unsigned len);
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ int fastgrnn_lr(float* const hiddenState, unsigned hiddenDims,
|
|||
// #steps iterations of the RNN cell starting from hiddenState
|
||||
for (unsigned t = 0; t < steps; t++) {
|
||||
// Normalize the features
|
||||
unsigned offset = backward ? steps - t : t;
|
||||
unsigned offset = backward ? steps - 1 - t : t;
|
||||
if (normalize) {
|
||||
v_add(1.0f, input + offset * inputDims, -1.0f, tparams->mean + t * inputDims,
|
||||
inputDims, tbuffers->normFeatures);
|
||||
|
@ -45,7 +45,7 @@ int fastgrnn_lr(float* const hiddenState, unsigned hiddenDims,
|
|||
for (unsigned i = 0; i < hiddenDims; i++) {
|
||||
float gate = sigmoid(tbuffers->preComp[i] + tparams->Bg[i]);
|
||||
float update = tanh(tbuffers->preComp[i] + tparams->Bh[i]);
|
||||
hiddenState[i] = gate * hiddenState[i] + (tparams->zeta * (1.0 - gate) + tparams->nu) * update;
|
||||
hiddenState[i] = gate * hiddenState[i] + (tparams->sigmoid_zeta * (1.0 - gate) + tparams->sigmoid_nu) * update;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
|
@ -63,7 +63,7 @@ int fastgrnn(float* const hiddenState, unsigned hiddenDims,
|
|||
|
||||
for (unsigned t = 0; t < steps; t++) {
|
||||
// Normalize the features
|
||||
unsigned offset = backward ? steps - t : t;
|
||||
unsigned offset = backward ? steps - 1 - t : t;
|
||||
if (normalize) {
|
||||
v_add(1.0f, input + offset * inputDims, -1.0f, tparams->mean + t * inputDims,
|
||||
inputDims, tbuffers->normFeatures);
|
||||
|
@ -86,7 +86,7 @@ int fastgrnn(float* const hiddenState, unsigned hiddenDims,
|
|||
for (unsigned i = 0; i < hiddenDims; i++) {
|
||||
float gate = sigmoid(tbuffers->preComp[i] + tparams->Bg[i]);
|
||||
float update = tanh(tbuffers->preComp[i] + tparams->Bh[i]);
|
||||
hiddenState[i] = gate * hiddenState[i] + (tparams->zeta * (1.0 - gate) + tparams->nu) * update;
|
||||
hiddenState[i] = gate * hiddenState[i] + (tparams->sigmoid_zeta * (1.0 - gate) + tparams->sigmoid_nu) * update;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
|
|
|
@ -90,6 +90,14 @@ void v_div(const float* const vec1, const float* const vec2,
|
|||
ret[i] = vec2[i] / vec1[i];
|
||||
}
|
||||
|
||||
float l2squared(const float* const vec1,
|
||||
const float* const vec2, unsigned dim) {
|
||||
float sum = 0.0f;
|
||||
for (unsigned i = 0; i < dim; i++)
|
||||
sum += (vec1[i] - vec2[i]) * (vec1[i] - vec2[i]);
|
||||
return sum;
|
||||
}
|
||||
|
||||
unsigned argmax(const float* const vec, unsigned len) {
|
||||
unsigned maxId = 0;
|
||||
float maxScore = FLT_MIN;
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -16,51 +16,53 @@
|
|||
|
||||
int main() {
|
||||
FastGRNN_Params rnn1_params = {
|
||||
.mean = mean1,
|
||||
.stdDev = stdDev1,
|
||||
.mean = NULL,
|
||||
.stdDev = NULL,
|
||||
.W = W1,
|
||||
.U = U1,
|
||||
.Bg = Bg1,
|
||||
.Bh = Bh1,
|
||||
.zeta = zeta1,
|
||||
.nu = nu1
|
||||
.sigmoid_zeta = sigmoid_zeta1,
|
||||
.sigmoid_nu = sigmoid_nu1
|
||||
};
|
||||
|
||||
FastGRNN_Params rnn2_params = {
|
||||
.mean = mean2,
|
||||
.stdDev = stdDev2,
|
||||
.mean = NULL,
|
||||
.stdDev = NULL,
|
||||
.W = W2,
|
||||
.U = U2,
|
||||
.Bg = Bg2,
|
||||
.Bh = Bh2,
|
||||
.zeta = zeta2,
|
||||
.nu = nu2
|
||||
.sigmoid_zeta = sigmoid_zeta2,
|
||||
.sigmoid_nu = sigmoid_nu2
|
||||
};
|
||||
|
||||
float preComp1[HIDDEN_DIMS1] = { 0.0f };
|
||||
float normFeatures1[INPUT_DIMS] = { 0.0f };
|
||||
float preComp1[HIDDEN_DIMS1];
|
||||
float normFeatures1[INPUT_DIMS];
|
||||
memset(preComp1, 0, sizeof(float) * HIDDEN_DIMS1);
|
||||
memset(normFeatures1, 0, sizeof(float) * INPUT_DIMS);
|
||||
FastGRNN_Buffers rnn1_buffers = {
|
||||
.preComp = preComp1,
|
||||
.normFeatures = normFeatures1
|
||||
};
|
||||
|
||||
float preComp2[HIDDEN_DIMS2] = { 0.0f };
|
||||
float normFeatures2[HIDDEN_DIMS1] = { 0.0f };
|
||||
float preComp2[HIDDEN_DIMS2];
|
||||
float normFeatures2[HIDDEN_DIMS1];
|
||||
memset(preComp2, 0, sizeof(float) * HIDDEN_DIMS2);
|
||||
memset(normFeatures2, 0, sizeof(float) * HIDDEN_DIMS1);
|
||||
FastGRNN_Buffers rnn2_buffers = {
|
||||
.preComp = preComp2,
|
||||
.normFeatures = normFeatures2
|
||||
};
|
||||
|
||||
float output_test[4 * HIDDEN_DIMS2] = { 0.0f };
|
||||
float output_test[4 * HIDDEN_DIMS2];
|
||||
float buffer[HIDDEN_DIMS1 * PATCH_DIM];
|
||||
|
||||
memset(output_test, 0, sizeof(float) * 4 * HIDDEN_DIMS2);
|
||||
memset(buffer, 0, sizeof(float) * HIDDEN_DIMS1 * PATCH_DIM);
|
||||
rnnpool_block(input, INPUT_DIMS, PATCH_DIM, PATCH_DIM,
|
||||
fastgrnn, HIDDEN_DIMS1, (const void*)(&rnn1_params), (void*)(&rnn1_buffers),
|
||||
fastgrnn, HIDDEN_DIMS2, (const void*)(&rnn2_params), (void*)(&rnn2_buffers),
|
||||
output_test, buffer);
|
||||
|
||||
float error = 0.0f;
|
||||
for (unsigned d = 0; d < 4 * HIDDEN_DIMS2; ++d)
|
||||
error += (output[d] - output_test[d]) * (output[d] - output_test[d]);
|
||||
printf("Error: %f\n", error);
|
||||
printf("Error: %f\n", l2squared(output, output_test, 4 * HIDDEN_DIMS2));
|
||||
}
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
#define HIDDEN_DIMS1 8
|
||||
|
||||
static float mean1[INPUT_DIMS * PATCH_DIM * PATCH_DIM] = {0.0};
|
||||
static float stdDev1[INPUT_DIMS * PATCH_DIM * PATCH_DIM] = {1.0};
|
||||
|
||||
static float W1[INPUT_DIMS * HIDDEN_DIMS1] = { -6.669194251298904419e-02,4.241880401968955994e-02,1.779496856033802032e-02,2.434188500046730042e-02,2.040588408708572388e-01,-6.289862841367721558e-02,8.665277808904647827e-02,-4.211997613310813904e-02,-4.081934988498687744e-01,3.874431252479553223e-01,-9.131602197885513306e-02,-4.231898188591003418e-01,1.137487962841987610e-01,-7.153615355491638184e-02,1.098951250314712524e-01,-5.568345189094543457e-01,-9.548854231834411621e-01,6.743898242712020874e-02,1.308344118297100067e-02,-5.101327300071716309e-01,-2.034226208925247192e-01,-3.552019000053405762e-01,3.574696481227874756e-01,-1.339200735092163086e-01,-1.199822783470153809e+00,4.235469996929168701e-01,-2.176112309098243713e-02,-5.763933062553405762e-01,-4.363142326474189758e-02,-3.619972467422485352e-01,4.203165769577026367e-01,-1.205071806907653809e+00,3.212589919567108154e-01,7.616437077522277832e-01,7.672815918922424316e-01,2.576289772987365723e-01,1.962029747664928436e-02,6.437804698944091797e-01,1.115008115768432617e+00,8.542865514755249023e-02,-8.398564457893371582e-01,-3.183248341083526611e-01,-3.293828964233398438e-01,-1.621776819229125977e-01,5.834429860115051270e-01,3.947886824607849121e-01,-6.720532774925231934e-01,3.155398648232221603e-03,1.377462625503540039e+00,-6.704510450363159180e-01,8.417400717735290527e-02,6.603993773460388184e-01,-1.158756241202354431e-01,4.007513821125030518e-01,-4.624451696872711182e-01,1.793413400650024414e+00,1.293107271194458008e-01,3.916886448860168457e-01,-4.810366034507751465e-01,-2.212110720574855804e-02,-6.547657251358032227e-01,3.415162265300750732e-01,-3.584066927433013916e-01,3.155157715082168579e-02 };
|
||||
static float U1[HIDDEN_DIMS1 * HIDDEN_DIMS1] = { 4.257509708404541016e-01,-2.243996411561965942e-01,4.599139690399169922e-01,-3.530698418617248535e-01,2.927841432392597198e-02,1.921992599964141846e-01,7.122249808162450790e-03,2.478006668388843536e-02,2.090227305889129639e-01,6.178687140345573425e-02,-8.031798005104064941e-01,-9.047465324401855469e-01,1.150295138359069824e-01,-1.216369643807411194e-01,8.977604657411575317e-02,-3.528899326920509338e-02,-9.029741585254669189e-02,1.964072138071060181e-02,-2.055168338119983673e-02,5.608395114541053772e-02,-1.545069087296724319e-02,-3.085457533597946167e-02,7.987973280251026154e-03,-8.097411133348941803e-03,-1.581576466560363770e-01,-7.558400928974151611e-02,-1.338823318481445312e+00,1.076061371713876724e-02,7.246796786785125732e-02,-6.944761611521244049e-03,-8.065721951425075531e-03,-2.095492556691169739e-02,-2.624719589948654175e-02,-3.073738422244787216e-03,1.721778139472007751e-02,-3.143182024359703064e-03,1.096169114112854004e+00,-4.858187958598136902e-02,-8.912781253457069397e-03,-3.243596106767654419e-02,-1.679758429527282715e-01,-7.507130037993192673e-03,-1.409216374158859253e-01,1.100406423211097717e-01,-1.400157362222671509e-01,1.513722836971282959e-01,-1.210445910692214966e-02,-8.220742642879486084e-02,1.883537024259567261e-01,5.114595890045166016e-01,1.616048932075500488e+00,1.620839357376098633e+00,-1.330980509519577026e-01,9.179102256894111633e-03,7.160065323114395142e-02,2.354629104956984520e-03,8.153508603572845459e-02,9.405165910720825195e-03,1.780232340097427368e-01,-5.950598046183586121e-02,-5.966191366314888000e-02,1.019131317734718323e-01,2.581899240612983704e-02,1.182158961892127991e-01 };
|
||||
static float Bg1[HIDDEN_DIMS1] = { -1.213001132011413574e+00,-3.357241451740264893e-01,-1.219372987747192383e+00,-2.052456587553024292e-01,-7.755046486854553223e-01,-7.103578448295593262e-01,-8.887031674385070801e-01,-5.140971541404724121e-01};
|
||||
static float Bh1[HIDDEN_DIMS1] = { -2.761822193861007690e-02,3.516794741153717041e-01,7.828953266143798828e-01,7.023379206657409668e-01,5.146764516830444336e-01,7.880625128746032715e-01,1.113372668623924255e-01,6.815895438194274902e-02 };
|
||||
|
||||
static float zeta1 = 9.999979734420776367e-01;
|
||||
static float nu1 = 1.968302512977970764e-06;
|
||||
static float sigmoid_zeta1 = 9.999979734420776367e-01;
|
||||
static float sigmoid_nu1 = 1.968302512977970764e-06;
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
static float mean2[HIDDEN_DIMS1 * PATCH_DIM] = {0.0};
|
||||
static float stdDev2[HIDDEN_DIMS1 * PATCH_DIM] = {1.0};
|
||||
|
||||
static float W2[HIDDEN_DIMS1 * HIDDEN_DIMS2] = { 3.328921273350715637e-02,4.839901626110076904e-02,8.630067855119705200e-02,-3.623228371143341064e-01,3.617477416992187500e-02,-1.013135910034179688e-01,-1.616892337799072266e+00,9.920979291200637817e-02,-3.237796723842620850e-01,-2.648477256298065186e-01,2.549230158329010010e-01,1.056594774127006531e-01,-5.211604386568069458e-02,-5.569029450416564941e-01,7.901080884039402008e-03,4.476135373115539551e-01,-6.069063544273376465e-01,4.720874130725860596e-01,5.098583698272705078e-01,4.149395599961280823e-02,-2.221289835870265961e-02,2.892487943172454834e-01,1.637366861104965210e-01,-4.640549048781394958e-02,3.176147341728210449e-01,-2.042243480682373047e-01,3.439170718193054199e-01,2.087370865046977997e-02,2.785135433077812195e-02,3.186852931976318359e-01,5.935984104871749878e-02,-3.889196217060089111e-01,-8.808585256338119507e-02,1.615403443574905396e-01,-4.946086406707763672e-01,-1.174109801650047302e-02,2.072153845801949501e-03,4.303685724735260010e-01,-7.063516974449157715e-02,-3.901343047618865967e-01,4.995678067207336426e-01,4.996369481086730957e-01,-8.396717905998229980e-02,-3.785453736782073975e-01,1.230286993086338043e-02,-3.816023766994476318e-01,1.118507459759712219e-01,1.944433003664016724e-01,-8.141241967678070068e-02,-5.713146924972534180e-02,6.110856309533119202e-02,2.926556952297687531e-02,-1.405209898948669434e-01,-5.008732676506042480e-01,5.522043444216251373e-03,-8.167469501495361328e-01,-5.527251958847045898e-02,-6.109619140625000000e-02,2.462622970342636108e-01,-1.656581298448145390e-04,1.710500240325927734e+00,-6.598105430603027344e-01,2.383397147059440613e-02,-5.196015834808349609e-01 };
|
||||
static float U2[HIDDEN_DIMS2 * HIDDEN_DIMS2] = { 3.934212923049926758e-01,-4.295524582266807556e-02,1.002475544810295105e-01,-1.167084053158760071e-01,7.208147644996643066e-02,7.301057875156402588e-02,7.414811104536056519e-02,-1.667873002588748932e-02,2.627835571765899658e-01,2.479245215654373169e-01,-2.347197830677032471e-01,2.314671277999877930e-01,1.421915292739868164e-01,-4.386771023273468018e-01,4.338171705603599548e-02,1.125133633613586426e-01,-9.242479503154754639e-02,1.447719633579254150e-01,3.752615749835968018e-01,2.763805091381072998e-01,-7.362117618322372437e-02,-1.142739504575729370e-01,1.064843088388442993e-01,-8.647724986076354980e-03,1.917631626129150391e-01,8.766186982393264771e-02,-2.752942740917205811e-01,2.459737062454223633e-01,4.787993803620338440e-02,8.056888729333877563e-02,2.883537299931049347e-02,6.972444709390401840e-03,2.150612622499465942e-01,2.176617681980133057e-01,-9.746207296848297119e-02,3.679830208420753479e-02,2.094942629337310791e-01,-2.684409022331237793e-01,6.217004358768463135e-02,-3.187925368547439575e-02,7.392763346433639526e-02,4.104424268007278442e-03,-1.972307413816452026e-01,-2.362748086452484131e-01,3.649697601795196533e-01,3.465250432491302490e-01,7.446446269750595093e-02,5.720932036638259888e-02,-5.659309402108192444e-02,-4.538772255182266235e-02,6.283282488584518433e-02,-3.104292228817939758e-02,-1.333466079086065292e-02,-7.922663539648056030e-02,6.666561216115951538e-02,4.965405911207199097e-02,-4.473730921745300293e-02,-2.271021008491516113e-01,1.190942153334617615e-02,-6.096216291189193726e-02,2.375180423259735107e-01,-1.497552990913391113e-01,-4.494012892246246338e-01,-6.579961627721786499e-02 };
|
||||
|
||||
static float Bg2[HIDDEN_DIMS2] = { -1.537145614624023438e+00,-6.593755483627319336e-01,-8.165745735168457031e-01,-1.047435641288757324e+00,-1.003585577011108398e+00,-1.275580763816833496e+00,-9.717565178871154785e-01,-1.349884271621704102e+00 };
|
||||
static float Bh2[HIDDEN_DIMS2] = { 1.249589323997497559e+00,2.501939237117767334e-01,3.707601428031921387e-01,1.205096021294593811e-01,6.529558449983596802e-02,-2.186506539583206177e-01,-1.120083108544349670e-01,-1.578094959259033203e+00 };
|
||||
|
||||
static float zeta2 = 9.999979734420776367e-01;
|
||||
static float nu2 = 1.968388687600963749e-06;
|
||||
static float sigmoid_zeta2 = 9.999979734420776367e-01;
|
||||
static float sigmoid_nu2 = 1.968388687600963749e-06;
|
||||
|
|
Загрузка…
Ссылка в новой задаче