This commit is contained in:
Harsha Vardhan Simhadri 2019-07-30 01:31:53 +05:30
Родитель 39f8a1743a
Коммит ce32db7e1e
5 изменённых файлов: 48 добавлений и 8 удалений

Просмотреть файл

@ -292,9 +292,27 @@ class FastGRNNCell(nn.Module):
Vars.extend([self.bias_gate, self.bias_update])
Vars.extend([self.zeta, self.nu])
return Vars
def getModelSize(self):
'''
Function to get aimed model size
'''
totalnnz = 2 # For Zeta and Nu
totalnnz += utils.countNNZ(self.bias_gate, False)
totalnnz += utils.countNNZ(self.bias_update, False)
if self._wRank is None:
totalnnz += utils.countNNZ(self.W, self._wSparsity)
else:
totalnnz += utils.countNNZ(self.W1, self._wSparsity)
totalnnz += utils.countNNZ(self.W2, self._wSparsity)
if self._uRank is None:
totalnnz += utils.countNNZ(self.U, self._uSparsity)
else:
totalnnz += utils.countNNZ(self.U1, self._uSparsity)
totalnnz += utils.countNNZ(self.U2, self._uSparsity)
return totalnnz
class FastRNNCell(nn.Module):
'''

Просмотреть файл

@ -151,29 +151,29 @@ class FastTrainer:
totalSize = 0
hasSparse = False
for i in range(0, self.numMatrices[0]):
nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], self.sW)
nnz, size, sparseFlag = utils.estimateNNZ(self.FastParams[i], self.sW)
totalnnZ += nnz
totalSize += size
hasSparse = hasSparse or sparseFlag
for i in range(self.numMatrices[0], self.totalMatrices):
nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], self.sU)
nnz, size, sparseFlag = utils.estimateNNZ(self.FastParams[i], self.sU)
totalnnZ += nnz
totalSize += size
hasSparse = hasSparse or sparseFlag
for i in range(self.totalMatrices, len(self.FastParams)):
nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], 1.0)
nnz, size, sparseFlag = utils.estimateNNZ(self.FastParams[i], 1.0)
totalnnZ += nnz
totalSize += size
hasSparse = hasSparse or sparseFlag
# Replace this with classifier class call
nnz, size, sparseFlag = utils.countnnZ(self.FC, 1.0)
nnz, size, sparseFlag = utils.estimateNNZ(self.FC, 1.0)
totalnnZ += nnz
totalSize += size
hasSparse = hasSparse or sparseFlag
nnz, size, sparseFlag = utils.countnnZ(self.FCbias, 1.0)
nnz, size, sparseFlag = utils.estimateNNZ(self.FCbias, 1.0)
totalnnZ += nnz
totalSize += size
hasSparse = hasSparse or sparseFlag

Просмотреть файл

@ -85,6 +85,14 @@ def fastgrnnmodel(inheritance_class=nn.Module):
if self.num_layers > 2:
self.fastgrnn3.cell.sparsify()
def getModelSize(self):
total_size = self.fastgrnn1.cell.getModelSize()
if self.num_layers > 1:
total_size += self.fastgrnn2.cell.getModelSize()
if self.num_layers > 2:
total_size += self.fastgrnn3.cell.getModelSize()
total_size += self.hidden_units_list[2] * self.num_classes
return total_size
def normalize(self, mean, std):
self.mean = mean

Просмотреть файл

@ -70,7 +70,7 @@ def copySupport(src, dest):
return dest
def countnnZ(A, s, bytesPerVar=4):
def estimateNNZ(A, s, bytesPerVar=4):
'''
Returns # of non-zeros and representative size of the tensor
Uses dense for s >= 0.5 - 4 byte
@ -89,6 +89,19 @@ def countnnZ(A, s, bytesPerVar=4):
return nnZ, nnZ * bytesPerVar, hasSparse
def countNNZ(A: torch.nn.Parameter, isSparse):
'''
Returns # of non-zeros
'''
A_ = A.detach().numpy()
if isSparse:
return np.count_nonzero(A_)
else:
nnzs = 1
for i in range(0, len(A.shape)):
nnzs *= int(A.shape[i])
return nnzs
def restructreMatrixBonsaiSeeDot(A, nClasses, nNodes):
'''
Restructures a matrix from [nNodes*nClasses, Proj] to

Просмотреть файл

@ -21,8 +21,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable, Function
import torch.onnx
from torch.autograd import Variable, Function
from torch.utils.data import Dataset, DataLoader
from training_config import TrainingConfig