зеркало из https://github.com/microsoft/EdgeML.git
added model size export
This commit is contained in:
Родитель
39f8a1743a
Коммит
ce32db7e1e
|
@ -292,9 +292,27 @@ class FastGRNNCell(nn.Module):
|
|||
|
||||
Vars.extend([self.bias_gate, self.bias_update])
|
||||
Vars.extend([self.zeta, self.nu])
|
||||
|
||||
return Vars
|
||||
|
||||
def getModelSize(self):
|
||||
'''
|
||||
Function to get aimed model size
|
||||
'''
|
||||
totalnnz = 2 # For Zeta and Nu
|
||||
totalnnz += utils.countNNZ(self.bias_gate, False)
|
||||
totalnnz += utils.countNNZ(self.bias_update, False)
|
||||
if self._wRank is None:
|
||||
totalnnz += utils.countNNZ(self.W, self._wSparsity)
|
||||
else:
|
||||
totalnnz += utils.countNNZ(self.W1, self._wSparsity)
|
||||
totalnnz += utils.countNNZ(self.W2, self._wSparsity)
|
||||
|
||||
if self._uRank is None:
|
||||
totalnnz += utils.countNNZ(self.U, self._uSparsity)
|
||||
else:
|
||||
totalnnz += utils.countNNZ(self.U1, self._uSparsity)
|
||||
totalnnz += utils.countNNZ(self.U2, self._uSparsity)
|
||||
return totalnnz
|
||||
|
||||
class FastRNNCell(nn.Module):
|
||||
'''
|
||||
|
|
|
@ -151,29 +151,29 @@ class FastTrainer:
|
|||
totalSize = 0
|
||||
hasSparse = False
|
||||
for i in range(0, self.numMatrices[0]):
|
||||
nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], self.sW)
|
||||
nnz, size, sparseFlag = utils.estimateNNZ(self.FastParams[i], self.sW)
|
||||
totalnnZ += nnz
|
||||
totalSize += size
|
||||
hasSparse = hasSparse or sparseFlag
|
||||
|
||||
for i in range(self.numMatrices[0], self.totalMatrices):
|
||||
nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], self.sU)
|
||||
nnz, size, sparseFlag = utils.estimateNNZ(self.FastParams[i], self.sU)
|
||||
totalnnZ += nnz
|
||||
totalSize += size
|
||||
hasSparse = hasSparse or sparseFlag
|
||||
for i in range(self.totalMatrices, len(self.FastParams)):
|
||||
nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], 1.0)
|
||||
nnz, size, sparseFlag = utils.estimateNNZ(self.FastParams[i], 1.0)
|
||||
totalnnZ += nnz
|
||||
totalSize += size
|
||||
hasSparse = hasSparse or sparseFlag
|
||||
|
||||
# Replace this with classifier class call
|
||||
nnz, size, sparseFlag = utils.countnnZ(self.FC, 1.0)
|
||||
nnz, size, sparseFlag = utils.estimateNNZ(self.FC, 1.0)
|
||||
totalnnZ += nnz
|
||||
totalSize += size
|
||||
hasSparse = hasSparse or sparseFlag
|
||||
|
||||
nnz, size, sparseFlag = utils.countnnZ(self.FCbias, 1.0)
|
||||
nnz, size, sparseFlag = utils.estimateNNZ(self.FCbias, 1.0)
|
||||
totalnnZ += nnz
|
||||
totalSize += size
|
||||
hasSparse = hasSparse or sparseFlag
|
||||
|
|
|
@ -85,6 +85,14 @@ def fastgrnnmodel(inheritance_class=nn.Module):
|
|||
if self.num_layers > 2:
|
||||
self.fastgrnn3.cell.sparsify()
|
||||
|
||||
def getModelSize(self):
|
||||
total_size = self.fastgrnn1.cell.getModelSize()
|
||||
if self.num_layers > 1:
|
||||
total_size += self.fastgrnn2.cell.getModelSize()
|
||||
if self.num_layers > 2:
|
||||
total_size += self.fastgrnn3.cell.getModelSize()
|
||||
total_size += self.hidden_units_list[2] * self.num_classes
|
||||
return total_size
|
||||
|
||||
def normalize(self, mean, std):
|
||||
self.mean = mean
|
||||
|
|
|
@ -70,7 +70,7 @@ def copySupport(src, dest):
|
|||
return dest
|
||||
|
||||
|
||||
def countnnZ(A, s, bytesPerVar=4):
|
||||
def estimateNNZ(A, s, bytesPerVar=4):
|
||||
'''
|
||||
Returns # of non-zeros and representative size of the tensor
|
||||
Uses dense for s >= 0.5 - 4 byte
|
||||
|
@ -89,6 +89,19 @@ def countnnZ(A, s, bytesPerVar=4):
|
|||
return nnZ, nnZ * bytesPerVar, hasSparse
|
||||
|
||||
|
||||
def countNNZ(A: torch.nn.Parameter, isSparse):
|
||||
'''
|
||||
Returns # of non-zeros
|
||||
'''
|
||||
A_ = A.detach().numpy()
|
||||
if isSparse:
|
||||
return np.count_nonzero(A_)
|
||||
else:
|
||||
nnzs = 1
|
||||
for i in range(0, len(A.shape)):
|
||||
nnzs *= int(A.shape[i])
|
||||
return nnzs
|
||||
|
||||
def restructreMatrixBonsaiSeeDot(A, nClasses, nNodes):
|
||||
'''
|
||||
Restructures a matrix from [nNodes*nClasses, Proj] to
|
||||
|
|
|
@ -21,8 +21,9 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.autograd import Variable, Function
|
||||
import torch.onnx
|
||||
|
||||
from torch.autograd import Variable, Function
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
from training_config import TrainingConfig
|
||||
|
|
Загрузка…
Ссылка в новой задаче