From ce32db7e1ec967d1bc90a627037979fcb1510d86 Mon Sep 17 00:00:00 2001 From: Harsha Vardhan Simhadri Date: Tue, 30 Jul 2019 01:31:53 +0530 Subject: [PATCH] added model size export --- edgeml/pytorch/graph/rnn.py | 20 ++++++++++++++++++- edgeml/pytorch/trainer/fastTrainer.py | 10 +++++----- edgeml/pytorch/trainer/fastmodel.py | 8 ++++++++ edgeml/pytorch/utils.py | 15 +++++++++++++- .../pytorch/FastCells/train_classifier.py | 3 ++- 5 files changed, 48 insertions(+), 8 deletions(-) diff --git a/edgeml/pytorch/graph/rnn.py b/edgeml/pytorch/graph/rnn.py index 40ac715d..fd01283e 100644 --- a/edgeml/pytorch/graph/rnn.py +++ b/edgeml/pytorch/graph/rnn.py @@ -292,9 +292,27 @@ class FastGRNNCell(nn.Module): Vars.extend([self.bias_gate, self.bias_update]) Vars.extend([self.zeta, self.nu]) - return Vars + def getModelSize(self): + ''' + Function to get aimed model size + ''' + totalnnz = 2 # For Zeta and Nu + totalnnz += utils.countNNZ(self.bias_gate, False) + totalnnz += utils.countNNZ(self.bias_update, False) + if self._wRank is None: + totalnnz += utils.countNNZ(self.W, self._wSparsity) + else: + totalnnz += utils.countNNZ(self.W1, self._wSparsity) + totalnnz += utils.countNNZ(self.W2, self._wSparsity) + + if self._uRank is None: + totalnnz += utils.countNNZ(self.U, self._uSparsity) + else: + totalnnz += utils.countNNZ(self.U1, self._uSparsity) + totalnnz += utils.countNNZ(self.U2, self._uSparsity) + return totalnnz class FastRNNCell(nn.Module): ''' diff --git a/edgeml/pytorch/trainer/fastTrainer.py b/edgeml/pytorch/trainer/fastTrainer.py index efe6698b..9ae78590 100644 --- a/edgeml/pytorch/trainer/fastTrainer.py +++ b/edgeml/pytorch/trainer/fastTrainer.py @@ -151,29 +151,29 @@ class FastTrainer: totalSize = 0 hasSparse = False for i in range(0, self.numMatrices[0]): - nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], self.sW) + nnz, size, sparseFlag = utils.estimateNNZ(self.FastParams[i], self.sW) totalnnZ += nnz totalSize += size hasSparse = hasSparse or sparseFlag for i in range(self.numMatrices[0], self.totalMatrices): - nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], self.sU) + nnz, size, sparseFlag = utils.estimateNNZ(self.FastParams[i], self.sU) totalnnZ += nnz totalSize += size hasSparse = hasSparse or sparseFlag for i in range(self.totalMatrices, len(self.FastParams)): - nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], 1.0) + nnz, size, sparseFlag = utils.estimateNNZ(self.FastParams[i], 1.0) totalnnZ += nnz totalSize += size hasSparse = hasSparse or sparseFlag # Replace this with classifier class call - nnz, size, sparseFlag = utils.countnnZ(self.FC, 1.0) + nnz, size, sparseFlag = utils.estimateNNZ(self.FC, 1.0) totalnnZ += nnz totalSize += size hasSparse = hasSparse or sparseFlag - nnz, size, sparseFlag = utils.countnnZ(self.FCbias, 1.0) + nnz, size, sparseFlag = utils.estimateNNZ(self.FCbias, 1.0) totalnnZ += nnz totalSize += size hasSparse = hasSparse or sparseFlag diff --git a/edgeml/pytorch/trainer/fastmodel.py b/edgeml/pytorch/trainer/fastmodel.py index dfc84917..ff0c27a1 100644 --- a/edgeml/pytorch/trainer/fastmodel.py +++ b/edgeml/pytorch/trainer/fastmodel.py @@ -85,6 +85,14 @@ def fastgrnnmodel(inheritance_class=nn.Module): if self.num_layers > 2: self.fastgrnn3.cell.sparsify() + def getModelSize(self): + total_size = self.fastgrnn1.cell.getModelSize() + if self.num_layers > 1: + total_size += self.fastgrnn2.cell.getModelSize() + if self.num_layers > 2: + total_size += self.fastgrnn3.cell.getModelSize() + total_size += self.hidden_units_list[2] * self.num_classes + return total_size def normalize(self, mean, std): self.mean = mean diff --git a/edgeml/pytorch/utils.py b/edgeml/pytorch/utils.py index 2a60282c..5d95ab4f 100644 --- a/edgeml/pytorch/utils.py +++ b/edgeml/pytorch/utils.py @@ -70,7 +70,7 @@ def copySupport(src, dest): return dest -def countnnZ(A, s, bytesPerVar=4): +def estimateNNZ(A, s, bytesPerVar=4): ''' Returns # of non-zeros and representative size of the tensor Uses dense for s >= 0.5 - 4 byte @@ -89,6 +89,19 @@ def countnnZ(A, s, bytesPerVar=4): return nnZ, nnZ * bytesPerVar, hasSparse +def countNNZ(A: torch.nn.Parameter, isSparse): + ''' + Returns # of non-zeros + ''' + A_ = A.detach().numpy() + if isSparse: + return np.count_nonzero(A_) + else: + nnzs = 1 + for i in range(0, len(A.shape)): + nnzs *= int(A.shape[i]) + return nnzs + def restructreMatrixBonsaiSeeDot(A, nClasses, nNodes): ''' Restructures a matrix from [nNodes*nClasses, Proj] to diff --git a/examples/pytorch/FastCells/train_classifier.py b/examples/pytorch/FastCells/train_classifier.py index 6c98cbb7..704b49cb 100644 --- a/examples/pytorch/FastCells/train_classifier.py +++ b/examples/pytorch/FastCells/train_classifier.py @@ -21,8 +21,9 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from torch.autograd import Variable, Function import torch.onnx + +from torch.autograd import Variable, Function from torch.utils.data import Dataset, DataLoader from training_config import TrainingConfig