зеркало из https://github.com/microsoft/mup.git
add tests for meta tensors
This commit is contained in:
Родитель
a2fec5fdb3
Коммит
7758dae40b
|
@ -28,6 +28,13 @@ class SetBaseShapeCase(unittest.TestCase):
|
|||
set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file)
|
||||
return get_infshapes(target_model)
|
||||
|
||||
def get_mlp_infshapes1meta(self):
|
||||
base_model = _generate_MLP(64, True, True, True, device='meta')
|
||||
delta_model = _generate_MLP(65, True, True, True, device='meta')
|
||||
target_model = _generate_MLP(128, True, True, True)
|
||||
set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file)
|
||||
return get_infshapes(target_model)
|
||||
|
||||
def get_mlp_infshapes2(self):
|
||||
target_model = _generate_MLP(128, True, True, True)
|
||||
set_base_shapes(target_model, self.mlp_base_shapes_file)
|
||||
|
@ -41,6 +48,14 @@ class SetBaseShapeCase(unittest.TestCase):
|
|||
set_base_shapes(target_model, base_infshapes)
|
||||
return get_infshapes(target_model)
|
||||
|
||||
def get_mlp_infshapes3meta(self):
|
||||
base_model = _generate_MLP(64, True, True, True, device='meta')
|
||||
delta_model = _generate_MLP(65, True, True, True, device='meta')
|
||||
base_infshapes = make_base_shapes(base_model, delta_model)
|
||||
target_model = _generate_MLP(128, True, True, True)
|
||||
set_base_shapes(target_model, base_infshapes)
|
||||
return get_infshapes(target_model)
|
||||
|
||||
def get_mlp_infshapes4(self):
|
||||
base_model = _generate_MLP(64, True, True, True)
|
||||
delta_model = _generate_MLP(65, True, True, True)
|
||||
|
@ -48,6 +63,13 @@ class SetBaseShapeCase(unittest.TestCase):
|
|||
set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model))
|
||||
return get_infshapes(target_model)
|
||||
|
||||
def get_mlp_infshapes4meta(self):
|
||||
base_model = _generate_MLP(64, True, True, True)
|
||||
delta_model = _generate_MLP(65, True, True, True, device='meta')
|
||||
target_model = _generate_MLP(128, True, True, True, device='meta')
|
||||
set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model))
|
||||
return get_infshapes(target_model)
|
||||
|
||||
def get_mlp_infshapes5(self):
|
||||
delta_model = _generate_MLP(65, True, True, True)
|
||||
target_model = _generate_MLP(128, True, True, True)
|
||||
|
@ -55,6 +77,13 @@ class SetBaseShapeCase(unittest.TestCase):
|
|||
set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model))
|
||||
return get_infshapes(target_model)
|
||||
|
||||
def get_mlp_infshapes5meta(self):
|
||||
delta_model = _generate_MLP(65, True, True, True, device='meta')
|
||||
target_model = _generate_MLP(128, True, True, True)
|
||||
# `delta` here doesn't do anything because of base shape file
|
||||
set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model))
|
||||
return get_infshapes(target_model)
|
||||
|
||||
def get_mlp_infshapes_bad(self):
|
||||
base_model = _generate_MLP(64, True, True, True)
|
||||
target_model = _generate_MLP(128, True, True, True)
|
||||
|
@ -62,10 +91,14 @@ class SetBaseShapeCase(unittest.TestCase):
|
|||
return get_infshapes(target_model)
|
||||
|
||||
def test_set_base_shape(self):
|
||||
self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes1meta())
|
||||
self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes2())
|
||||
self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes2())
|
||||
self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes4())
|
||||
self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes3meta())
|
||||
self.assertEqual(self.get_mlp_infshapes4(), self.get_mlp_infshapes4meta())
|
||||
self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes4())
|
||||
self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes5meta())
|
||||
self.assertNotEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes_bad())
|
||||
|
||||
|
||||
|
|
|
@ -35,19 +35,19 @@ init_methods = {
|
|||
k: partial(init_model, sampler=s) for k, s in samplers.items()
|
||||
}
|
||||
|
||||
def _generate_MLP(width, bias=True, mup=True, batchnorm=False):
|
||||
mods = [Linear(3072, width, bias=bias),
|
||||
def _generate_MLP(width, bias=True, mup=True, batchnorm=False, device='cpu'):
|
||||
mods = [Linear(3072, width, bias=bias, device=device),
|
||||
nn.ReLU(),
|
||||
Linear(width, width, bias=bias),
|
||||
Linear(width, width, bias=bias, device=device),
|
||||
nn.ReLU()
|
||||
]
|
||||
if mup:
|
||||
mods.append(MuReadout(width, 10, bias=bias, readout_zero_init=False))
|
||||
mods.append(MuReadout(width, 10, bias=bias, readout_zero_init=False, device=device))
|
||||
else:
|
||||
mods.append(Linear(width, 10, bias=bias))
|
||||
mods.append(Linear(width, 10, bias=bias, device=device))
|
||||
if batchnorm:
|
||||
mods.insert(1, nn.BatchNorm1d(width))
|
||||
mods.insert(4, nn.BatchNorm1d(width))
|
||||
mods.insert(1, nn.BatchNorm1d(width, device=device))
|
||||
mods.insert(4, nn.BatchNorm1d(width, device=device))
|
||||
model = nn.Sequential(*mods)
|
||||
return model
|
||||
|
||||
|
@ -58,7 +58,7 @@ def generate_MLP(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F
|
|||
return set_base_shapes(model, None)
|
||||
# it's important we make `model` first, because of random seed
|
||||
model = _generate_MLP(width, bias, mup, batchnorm)
|
||||
base_model = _generate_MLP(base_width, bias, mup, batchnorm)
|
||||
base_model = _generate_MLP(base_width, bias, mup, batchnorm, device='meta')
|
||||
set_base_shapes(model, base_model)
|
||||
init_methods[init](model)
|
||||
if readout_zero_init:
|
||||
|
@ -73,29 +73,29 @@ def generate_MLP(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F
|
|||
return model
|
||||
|
||||
|
||||
def _generate_CNN(width, bias=True, mup=True, batchnorm=False):
|
||||
def _generate_CNN(width, bias=True, mup=True, batchnorm=False, device='cpu'):
|
||||
mods = [
|
||||
nn.Conv2d(3, width, kernel_size=5, bias=bias),
|
||||
nn.Conv2d(3, width, kernel_size=5, bias=bias, device=device),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(width, 2*width, kernel_size=5, bias=bias),
|
||||
nn.Conv2d(width, 2*width, kernel_size=5, bias=bias, device=device),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Flatten(),
|
||||
nn.Linear(2*width*25, width*16, bias=bias),
|
||||
nn.Linear(2*width*25, width*16, bias=bias, device=device),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(width*16, width*10, bias=bias),
|
||||
nn.Linear(width*16, width*10, bias=bias, device=device),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
if mup:
|
||||
mods.append(MuReadout(width*10, 10, bias=bias, readout_zero_init=False))
|
||||
mods.append(MuReadout(width*10, 10, bias=bias, readout_zero_init=False, device=device))
|
||||
else:
|
||||
mods.append(nn.Linear(width*10, 10, bias=bias))
|
||||
mods.append(nn.Linear(width*10, 10, bias=bias, device=device))
|
||||
if batchnorm:
|
||||
mods.insert(1, nn.BatchNorm2d(width))
|
||||
mods.insert(5, nn.BatchNorm2d(2*width))
|
||||
mods.insert(10, nn.BatchNorm1d(16*width))
|
||||
mods.insert(13, nn.BatchNorm1d(10*width))
|
||||
mods.insert(1, nn.BatchNorm2d(width, device=device))
|
||||
mods.insert(5, nn.BatchNorm2d(2*width, device=device))
|
||||
mods.insert(10, nn.BatchNorm1d(16*width, device=device))
|
||||
mods.insert(13, nn.BatchNorm1d(10*width, device=device))
|
||||
return nn.Sequential(*mods)
|
||||
|
||||
def generate_CNN(width, bias=True, mup=True, readout_zero_init=True, batchnorm=False, init='default', bias_zero_init=False, base_width=8):
|
||||
|
@ -105,7 +105,7 @@ def generate_CNN(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F
|
|||
return set_base_shapes(model, None)
|
||||
# it's important we make `model` first, because of random seed
|
||||
model = _generate_CNN(width, bias, mup, batchnorm)
|
||||
base_model = _generate_CNN(base_width, bias, mup, batchnorm)
|
||||
base_model = _generate_CNN(base_width, bias, mup, batchnorm, device='meta')
|
||||
set_base_shapes(model, base_model)
|
||||
init_methods[init](model)
|
||||
if readout_zero_init:
|
||||
|
|
Загрузка…
Ссылка в новой задаче