Fixed for the right npy model saving

This commit is contained in:
Aditya Kusupati 2019-05-30 23:42:23 +05:30
Родитель 31230dd6fd
Коммит da82f47a9c
2 изменённых файлов: 35 добавлений и 25 удалений

Просмотреть файл

@ -416,13 +416,13 @@ class LSTMLRCell(RNNCell):
self._hidden_size = hidden_size
self._gate_non_linearity = gate_non_linearity
self._update_non_linearity = update_non_linearity
self._num_weight_matrices = [1, 1]
self._num_weight_matrices = [4, 4]
self._wRank = wRank
self._uRank = uRank
if wRank is not None:
self._num_weight_matrices[0] += 4
self._num_weight_matrices[0] += 1
if uRank is not None:
self._num_weight_matrices[1] += 4
self._num_weight_matrices[1] += 1
self._name = name
self._reuse = reuse
@ -620,12 +620,12 @@ class LSTMLRCell(RNNCell):
def getVars(self):
Vars = []
if self._num_weight_matrices[0] == 1:
if self._num_weight_matrices[0] == 4:
Vars.extend([self.W1, self.W2, self.W3, self.W4])
else:
Vars.extend([self.W, self.W1, self.W2, self.W3, self.W4])
if self._num_weight_matrices[1] == 1:
if self._num_weight_matrices[1] == 4:
Vars.extend([self.U1, self.U2, self.U3, self.U4])
else:
Vars.extend([self.U, self.U1, self.U2, self.U3, self.U4])
@ -672,13 +672,13 @@ class GRULRCell(RNNCell):
self._hidden_size = hidden_size
self._gate_non_linearity = gate_non_linearity
self._update_non_linearity = update_non_linearity
self._num_weight_matrices = [1, 1]
self._num_weight_matrices = [3, 3]
self._wRank = wRank
self._uRank = uRank
if wRank is not None:
self._num_weight_matrices[0] += 3
self._num_weight_matrices[0] += 1
if uRank is not None:
self._num_weight_matrices[1] += 3
self._num_weight_matrices[1] += 1
self._name = name
self._reuse = reuse
@ -849,12 +849,12 @@ class GRULRCell(RNNCell):
def getVars(self):
Vars = []
if self._num_weight_matrices[0] == 1:
if self._num_weight_matrices[0] == 3:
Vars.extend([self.W1, self.W2, self.W3])
else:
Vars.extend([self.W, self.W1, self.W2, self.W3])
if self._num_weight_matrices[1] == 1:
if self._num_weight_matrices[1] == 3:
Vars.extend([self.U1, self.U2, self.U3])
else:
Vars.extend([self.U, self.U1, self.U2, self.U3])
@ -900,13 +900,13 @@ class UGRNNLRCell(RNNCell):
self._hidden_size = hidden_size
self._gate_non_linearity = gate_non_linearity
self._update_non_linearity = update_non_linearity
self._num_weight_matrices = [1, 1]
self._num_weight_matrices = [2, 2]
self._wRank = wRank
self._uRank = uRank
if wRank is not None:
self._num_weight_matrices[0] += 2
self._num_weight_matrices[0] += 1
if uRank is not None:
self._num_weight_matrices[1] += 2
self._num_weight_matrices[1] += 1
self._name = name
self._reuse = reuse
@ -1040,12 +1040,12 @@ class UGRNNLRCell(RNNCell):
def getVars(self):
Vars = []
if self._num_weight_matrices[0] == 1:
if self._num_weight_matrices[0] == 2:
Vars.extend([self.W1, self.W2])
else:
Vars.extend([self.W, self.W1, self.W2])
if self._num_weight_matrices[1] == 1:
if self._num_weight_matrices[1] == 2:
Vars.extend([self.U1, self.U2])
else:
Vars.extend([self.U, self.U1, self.U2])

Просмотреть файл

@ -236,12 +236,12 @@ class FastTrainer:
'''
if self.numMatrices[0] == 1:
np.save(os.path.join(currDir, "W.npy"), self.FastParams[0].eval())
elif self.numMatrices[0] == 2:
np.save(os.path.join(currDir, "W1.npy"),
self.FastParams[0].eval())
np.save(os.path.join(currDir, "W2.npy"),
self.FastParams[1].eval())
elif self.FastObj.wRank is None:
if self.numMatrices[0] == 2:
np.save(os.path.join(currDir, "W1.npy"),
self.FastParams[0].eval())
np.save(os.path.join(currDir, "W2.npy"),
self.FastParams[1].eval())
if self.numMatrices[0] == 3:
np.save(os.path.join(currDir, "W1.npy"),
self.FastParams[0].eval())
@ -259,6 +259,11 @@ class FastTrainer:
np.save(os.path.join(currDir, "W4.npy"),
self.FastParams[3].eval())
elif self.FastObj.wRank is not None:
if self.numMatrices[0] == 2:
np.save(os.path.join(currDir, "W1.npy"),
self.FastParams[0].eval())
np.save(os.path.join(currDir, "W2.npy"),
self.FastParams[1].eval())
if self.numMatrices[0] == 3:
np.save(os.path.join(currDir, "W.npy"),
self.FastParams[0].eval())
@ -290,12 +295,12 @@ class FastTrainer:
idx = self.numMatrices[0]
if self.numMatrices[1] == 1:
np.save(os.path.join(currDir, "U.npy"), self.FastParams[idx + 0].eval())
elif self.numMatrices[1] == 2:
np.save(os.path.join(currDir, "U1.npy"),
self.FastParams[idx + 0].eval())
np.save(os.path.join(currDir, "U2.npy"),
self.FastParams[idx + 1].eval())
elif self.FastObj.uRank is None:
if self.numMatrices[1] == 2:
np.save(os.path.join(currDir, "U1.npy"),
self.FastParams[idx + 0].eval())
np.save(os.path.join(currDir, "U2.npy"),
self.FastParams[idx + 1].eval())
if self.numMatrices[1] == 3:
np.save(os.path.join(currDir, "U1.npy"),
self.FastParams[idx + 0].eval())
@ -313,6 +318,11 @@ class FastTrainer:
np.save(os.path.join(currDir, "U4.npy"),
self.FastParams[idx + 3].eval())
elif self.FastObj.uRank is not None:
if self.numMatrices[1] == 2:
np.save(os.path.join(currDir, "U1.npy"),
self.FastParams[idx + 0].eval())
np.save(os.path.join(currDir, "U2.npy"),
self.FastParams[idx + 1].eval())
if self.numMatrices[1] == 3:
np.save(os.path.join(currDir, "U.npy"),
self.FastParams[idx + 0].eval())