зеркало из https://github.com/microsoft/EdgeML.git
Fixed for the right npy model saving
This commit is contained in:
Родитель
31230dd6fd
Коммит
da82f47a9c
|
@ -416,13 +416,13 @@ class LSTMLRCell(RNNCell):
|
|||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._num_weight_matrices = [1, 1]
|
||||
self._num_weight_matrices = [4, 4]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 4
|
||||
self._num_weight_matrices[0] += 1
|
||||
if uRank is not None:
|
||||
self._num_weight_matrices[1] += 4
|
||||
self._num_weight_matrices[1] += 1
|
||||
self._name = name
|
||||
self._reuse = reuse
|
||||
|
||||
|
@ -620,12 +620,12 @@ class LSTMLRCell(RNNCell):
|
|||
|
||||
def getVars(self):
|
||||
Vars = []
|
||||
if self._num_weight_matrices[0] == 1:
|
||||
if self._num_weight_matrices[0] == 4:
|
||||
Vars.extend([self.W1, self.W2, self.W3, self.W4])
|
||||
else:
|
||||
Vars.extend([self.W, self.W1, self.W2, self.W3, self.W4])
|
||||
|
||||
if self._num_weight_matrices[1] == 1:
|
||||
if self._num_weight_matrices[1] == 4:
|
||||
Vars.extend([self.U1, self.U2, self.U3, self.U4])
|
||||
else:
|
||||
Vars.extend([self.U, self.U1, self.U2, self.U3, self.U4])
|
||||
|
@ -672,13 +672,13 @@ class GRULRCell(RNNCell):
|
|||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._num_weight_matrices = [1, 1]
|
||||
self._num_weight_matrices = [3, 3]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 3
|
||||
self._num_weight_matrices[0] += 1
|
||||
if uRank is not None:
|
||||
self._num_weight_matrices[1] += 3
|
||||
self._num_weight_matrices[1] += 1
|
||||
self._name = name
|
||||
self._reuse = reuse
|
||||
|
||||
|
@ -849,12 +849,12 @@ class GRULRCell(RNNCell):
|
|||
|
||||
def getVars(self):
|
||||
Vars = []
|
||||
if self._num_weight_matrices[0] == 1:
|
||||
if self._num_weight_matrices[0] == 3:
|
||||
Vars.extend([self.W1, self.W2, self.W3])
|
||||
else:
|
||||
Vars.extend([self.W, self.W1, self.W2, self.W3])
|
||||
|
||||
if self._num_weight_matrices[1] == 1:
|
||||
if self._num_weight_matrices[1] == 3:
|
||||
Vars.extend([self.U1, self.U2, self.U3])
|
||||
else:
|
||||
Vars.extend([self.U, self.U1, self.U2, self.U3])
|
||||
|
@ -900,13 +900,13 @@ class UGRNNLRCell(RNNCell):
|
|||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._num_weight_matrices = [1, 1]
|
||||
self._num_weight_matrices = [2, 2]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 2
|
||||
self._num_weight_matrices[0] += 1
|
||||
if uRank is not None:
|
||||
self._num_weight_matrices[1] += 2
|
||||
self._num_weight_matrices[1] += 1
|
||||
self._name = name
|
||||
self._reuse = reuse
|
||||
|
||||
|
@ -1040,12 +1040,12 @@ class UGRNNLRCell(RNNCell):
|
|||
|
||||
def getVars(self):
|
||||
Vars = []
|
||||
if self._num_weight_matrices[0] == 1:
|
||||
if self._num_weight_matrices[0] == 2:
|
||||
Vars.extend([self.W1, self.W2])
|
||||
else:
|
||||
Vars.extend([self.W, self.W1, self.W2])
|
||||
|
||||
if self._num_weight_matrices[1] == 1:
|
||||
if self._num_weight_matrices[1] == 2:
|
||||
Vars.extend([self.U1, self.U2])
|
||||
else:
|
||||
Vars.extend([self.U, self.U1, self.U2])
|
||||
|
|
|
@ -236,12 +236,12 @@ class FastTrainer:
|
|||
'''
|
||||
if self.numMatrices[0] == 1:
|
||||
np.save(os.path.join(currDir, "W.npy"), self.FastParams[0].eval())
|
||||
elif self.numMatrices[0] == 2:
|
||||
np.save(os.path.join(currDir, "W1.npy"),
|
||||
self.FastParams[0].eval())
|
||||
np.save(os.path.join(currDir, "W2.npy"),
|
||||
self.FastParams[1].eval())
|
||||
elif self.FastObj.wRank is None:
|
||||
if self.numMatrices[0] == 2:
|
||||
np.save(os.path.join(currDir, "W1.npy"),
|
||||
self.FastParams[0].eval())
|
||||
np.save(os.path.join(currDir, "W2.npy"),
|
||||
self.FastParams[1].eval())
|
||||
if self.numMatrices[0] == 3:
|
||||
np.save(os.path.join(currDir, "W1.npy"),
|
||||
self.FastParams[0].eval())
|
||||
|
@ -259,6 +259,11 @@ class FastTrainer:
|
|||
np.save(os.path.join(currDir, "W4.npy"),
|
||||
self.FastParams[3].eval())
|
||||
elif self.FastObj.wRank is not None:
|
||||
if self.numMatrices[0] == 2:
|
||||
np.save(os.path.join(currDir, "W1.npy"),
|
||||
self.FastParams[0].eval())
|
||||
np.save(os.path.join(currDir, "W2.npy"),
|
||||
self.FastParams[1].eval())
|
||||
if self.numMatrices[0] == 3:
|
||||
np.save(os.path.join(currDir, "W.npy"),
|
||||
self.FastParams[0].eval())
|
||||
|
@ -290,12 +295,12 @@ class FastTrainer:
|
|||
idx = self.numMatrices[0]
|
||||
if self.numMatrices[1] == 1:
|
||||
np.save(os.path.join(currDir, "U.npy"), self.FastParams[idx + 0].eval())
|
||||
elif self.numMatrices[1] == 2:
|
||||
np.save(os.path.join(currDir, "U1.npy"),
|
||||
self.FastParams[idx + 0].eval())
|
||||
np.save(os.path.join(currDir, "U2.npy"),
|
||||
self.FastParams[idx + 1].eval())
|
||||
elif self.FastObj.uRank is None:
|
||||
if self.numMatrices[1] == 2:
|
||||
np.save(os.path.join(currDir, "U1.npy"),
|
||||
self.FastParams[idx + 0].eval())
|
||||
np.save(os.path.join(currDir, "U2.npy"),
|
||||
self.FastParams[idx + 1].eval())
|
||||
if self.numMatrices[1] == 3:
|
||||
np.save(os.path.join(currDir, "U1.npy"),
|
||||
self.FastParams[idx + 0].eval())
|
||||
|
@ -313,6 +318,11 @@ class FastTrainer:
|
|||
np.save(os.path.join(currDir, "U4.npy"),
|
||||
self.FastParams[idx + 3].eval())
|
||||
elif self.FastObj.uRank is not None:
|
||||
if self.numMatrices[1] == 2:
|
||||
np.save(os.path.join(currDir, "U1.npy"),
|
||||
self.FastParams[idx + 0].eval())
|
||||
np.save(os.path.join(currDir, "U2.npy"),
|
||||
self.FastParams[idx + 1].eval())
|
||||
if self.numMatrices[1] == 3:
|
||||
np.save(os.path.join(currDir, "U.npy"),
|
||||
self.FastParams[idx + 0].eval())
|
||||
|
|
Загрузка…
Ссылка в новой задаче