From da82f47a9c477843a266ebf0c1c41245c7fee0cc Mon Sep 17 00:00:00 2001 From: Aditya Kusupati Date: Thu, 30 May 2019 23:42:23 +0530 Subject: [PATCH] Fixed for the right npy model saving --- tf/edgeml/graph/rnn.py | 30 +++++++++++++++--------------- tf/edgeml/trainer/fastTrainer.py | 30 ++++++++++++++++++++---------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/tf/edgeml/graph/rnn.py b/tf/edgeml/graph/rnn.py index a9a12308..f600280f 100644 --- a/tf/edgeml/graph/rnn.py +++ b/tf/edgeml/graph/rnn.py @@ -416,13 +416,13 @@ class LSTMLRCell(RNNCell): self._hidden_size = hidden_size self._gate_non_linearity = gate_non_linearity self._update_non_linearity = update_non_linearity - self._num_weight_matrices = [1, 1] + self._num_weight_matrices = [4, 4] self._wRank = wRank self._uRank = uRank if wRank is not None: - self._num_weight_matrices[0] += 4 + self._num_weight_matrices[0] += 1 if uRank is not None: - self._num_weight_matrices[1] += 4 + self._num_weight_matrices[1] += 1 self._name = name self._reuse = reuse @@ -620,12 +620,12 @@ class LSTMLRCell(RNNCell): def getVars(self): Vars = [] - if self._num_weight_matrices[0] == 1: + if self._num_weight_matrices[0] == 4: Vars.extend([self.W1, self.W2, self.W3, self.W4]) else: Vars.extend([self.W, self.W1, self.W2, self.W3, self.W4]) - if self._num_weight_matrices[1] == 1: + if self._num_weight_matrices[1] == 4: Vars.extend([self.U1, self.U2, self.U3, self.U4]) else: Vars.extend([self.U, self.U1, self.U2, self.U3, self.U4]) @@ -672,13 +672,13 @@ class GRULRCell(RNNCell): self._hidden_size = hidden_size self._gate_non_linearity = gate_non_linearity self._update_non_linearity = update_non_linearity - self._num_weight_matrices = [1, 1] + self._num_weight_matrices = [3, 3] self._wRank = wRank self._uRank = uRank if wRank is not None: - self._num_weight_matrices[0] += 3 + self._num_weight_matrices[0] += 1 if uRank is not None: - self._num_weight_matrices[1] += 3 + self._num_weight_matrices[1] += 1 self._name = name self._reuse = reuse @@ -849,12 +849,12 @@ class GRULRCell(RNNCell): def getVars(self): Vars = [] - if self._num_weight_matrices[0] == 1: + if self._num_weight_matrices[0] == 3: Vars.extend([self.W1, self.W2, self.W3]) else: Vars.extend([self.W, self.W1, self.W2, self.W3]) - if self._num_weight_matrices[1] == 1: + if self._num_weight_matrices[1] == 3: Vars.extend([self.U1, self.U2, self.U3]) else: Vars.extend([self.U, self.U1, self.U2, self.U3]) @@ -900,13 +900,13 @@ class UGRNNLRCell(RNNCell): self._hidden_size = hidden_size self._gate_non_linearity = gate_non_linearity self._update_non_linearity = update_non_linearity - self._num_weight_matrices = [1, 1] + self._num_weight_matrices = [2, 2] self._wRank = wRank self._uRank = uRank if wRank is not None: - self._num_weight_matrices[0] += 2 + self._num_weight_matrices[0] += 1 if uRank is not None: - self._num_weight_matrices[1] += 2 + self._num_weight_matrices[1] += 1 self._name = name self._reuse = reuse @@ -1040,12 +1040,12 @@ class UGRNNLRCell(RNNCell): def getVars(self): Vars = [] - if self._num_weight_matrices[0] == 1: + if self._num_weight_matrices[0] == 2: Vars.extend([self.W1, self.W2]) else: Vars.extend([self.W, self.W1, self.W2]) - if self._num_weight_matrices[1] == 1: + if self._num_weight_matrices[1] == 2: Vars.extend([self.U1, self.U2]) else: Vars.extend([self.U, self.U1, self.U2]) diff --git a/tf/edgeml/trainer/fastTrainer.py b/tf/edgeml/trainer/fastTrainer.py index 7c83a1b5..37561b9e 100644 --- a/tf/edgeml/trainer/fastTrainer.py +++ b/tf/edgeml/trainer/fastTrainer.py @@ -236,12 +236,12 @@ class FastTrainer: ''' if self.numMatrices[0] == 1: np.save(os.path.join(currDir, "W.npy"), self.FastParams[0].eval()) - elif self.numMatrices[0] == 2: - np.save(os.path.join(currDir, "W1.npy"), - self.FastParams[0].eval()) - np.save(os.path.join(currDir, "W2.npy"), - self.FastParams[1].eval()) elif self.FastObj.wRank is None: + if self.numMatrices[0] == 2: + np.save(os.path.join(currDir, "W1.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "W2.npy"), + self.FastParams[1].eval()) if self.numMatrices[0] == 3: np.save(os.path.join(currDir, "W1.npy"), self.FastParams[0].eval()) @@ -259,6 +259,11 @@ class FastTrainer: np.save(os.path.join(currDir, "W4.npy"), self.FastParams[3].eval()) elif self.FastObj.wRank is not None: + if self.numMatrices[0] == 2: + np.save(os.path.join(currDir, "W1.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "W2.npy"), + self.FastParams[1].eval()) if self.numMatrices[0] == 3: np.save(os.path.join(currDir, "W.npy"), self.FastParams[0].eval()) @@ -290,12 +295,12 @@ class FastTrainer: idx = self.numMatrices[0] if self.numMatrices[1] == 1: np.save(os.path.join(currDir, "U.npy"), self.FastParams[idx + 0].eval()) - elif self.numMatrices[1] == 2: - np.save(os.path.join(currDir, "U1.npy"), - self.FastParams[idx + 0].eval()) - np.save(os.path.join(currDir, "U2.npy"), - self.FastParams[idx + 1].eval()) elif self.FastObj.uRank is None: + if self.numMatrices[1] == 2: + np.save(os.path.join(currDir, "U1.npy"), + self.FastParams[idx + 0].eval()) + np.save(os.path.join(currDir, "U2.npy"), + self.FastParams[idx + 1].eval()) if self.numMatrices[1] == 3: np.save(os.path.join(currDir, "U1.npy"), self.FastParams[idx + 0].eval()) @@ -313,6 +318,11 @@ class FastTrainer: np.save(os.path.join(currDir, "U4.npy"), self.FastParams[idx + 3].eval()) elif self.FastObj.uRank is not None: + if self.numMatrices[1] == 2: + np.save(os.path.join(currDir, "U1.npy"), + self.FastParams[idx + 0].eval()) + np.save(os.path.join(currDir, "U2.npy"), + self.FastParams[idx + 1].eval()) if self.numMatrices[1] == 3: np.save(os.path.join(currDir, "U.npy"), self.FastParams[idx + 0].eval())