This commit is contained in:
Edward Hu 2022-03-19 14:47:15 -04:00
Родитель a2fec5fdb3
Коммит 7758dae40b
2 изменённых файлов: 53 добавлений и 20 удалений

Просмотреть файл

@ -28,6 +28,13 @@ class SetBaseShapeCase(unittest.TestCase):
set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file)
return get_infshapes(target_model)
def get_mlp_infshapes1meta(self):
base_model = _generate_MLP(64, True, True, True, device='meta')
delta_model = _generate_MLP(65, True, True, True, device='meta')
target_model = _generate_MLP(128, True, True, True)
set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file)
return get_infshapes(target_model)
def get_mlp_infshapes2(self):
target_model = _generate_MLP(128, True, True, True)
set_base_shapes(target_model, self.mlp_base_shapes_file)
@ -41,6 +48,14 @@ class SetBaseShapeCase(unittest.TestCase):
set_base_shapes(target_model, base_infshapes)
return get_infshapes(target_model)
def get_mlp_infshapes3meta(self):
base_model = _generate_MLP(64, True, True, True, device='meta')
delta_model = _generate_MLP(65, True, True, True, device='meta')
base_infshapes = make_base_shapes(base_model, delta_model)
target_model = _generate_MLP(128, True, True, True)
set_base_shapes(target_model, base_infshapes)
return get_infshapes(target_model)
def get_mlp_infshapes4(self):
base_model = _generate_MLP(64, True, True, True)
delta_model = _generate_MLP(65, True, True, True)
@ -48,6 +63,13 @@ class SetBaseShapeCase(unittest.TestCase):
set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model))
return get_infshapes(target_model)
def get_mlp_infshapes4meta(self):
base_model = _generate_MLP(64, True, True, True)
delta_model = _generate_MLP(65, True, True, True, device='meta')
target_model = _generate_MLP(128, True, True, True, device='meta')
set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model))
return get_infshapes(target_model)
def get_mlp_infshapes5(self):
delta_model = _generate_MLP(65, True, True, True)
target_model = _generate_MLP(128, True, True, True)
@ -55,6 +77,13 @@ class SetBaseShapeCase(unittest.TestCase):
set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model))
return get_infshapes(target_model)
def get_mlp_infshapes5meta(self):
delta_model = _generate_MLP(65, True, True, True, device='meta')
target_model = _generate_MLP(128, True, True, True)
# `delta` here doesn't do anything because of base shape file
set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model))
return get_infshapes(target_model)
def get_mlp_infshapes_bad(self):
base_model = _generate_MLP(64, True, True, True)
target_model = _generate_MLP(128, True, True, True)
@ -62,10 +91,14 @@ class SetBaseShapeCase(unittest.TestCase):
return get_infshapes(target_model)
def test_set_base_shape(self):
self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes1meta())
self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes2())
self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes2())
self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes4())
self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes3meta())
self.assertEqual(self.get_mlp_infshapes4(), self.get_mlp_infshapes4meta())
self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes4())
self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes5meta())
self.assertNotEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes_bad())

Просмотреть файл

@ -35,19 +35,19 @@ init_methods = {
k: partial(init_model, sampler=s) for k, s in samplers.items()
}
def _generate_MLP(width, bias=True, mup=True, batchnorm=False):
mods = [Linear(3072, width, bias=bias),
def _generate_MLP(width, bias=True, mup=True, batchnorm=False, device='cpu'):
mods = [Linear(3072, width, bias=bias, device=device),
nn.ReLU(),
Linear(width, width, bias=bias),
Linear(width, width, bias=bias, device=device),
nn.ReLU()
]
if mup:
mods.append(MuReadout(width, 10, bias=bias, readout_zero_init=False))
mods.append(MuReadout(width, 10, bias=bias, readout_zero_init=False, device=device))
else:
mods.append(Linear(width, 10, bias=bias))
mods.append(Linear(width, 10, bias=bias, device=device))
if batchnorm:
mods.insert(1, nn.BatchNorm1d(width))
mods.insert(4, nn.BatchNorm1d(width))
mods.insert(1, nn.BatchNorm1d(width, device=device))
mods.insert(4, nn.BatchNorm1d(width, device=device))
model = nn.Sequential(*mods)
return model
@ -58,7 +58,7 @@ def generate_MLP(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F
return set_base_shapes(model, None)
# it's important we make `model` first, because of random seed
model = _generate_MLP(width, bias, mup, batchnorm)
base_model = _generate_MLP(base_width, bias, mup, batchnorm)
base_model = _generate_MLP(base_width, bias, mup, batchnorm, device='meta')
set_base_shapes(model, base_model)
init_methods[init](model)
if readout_zero_init:
@ -73,29 +73,29 @@ def generate_MLP(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F
return model
def _generate_CNN(width, bias=True, mup=True, batchnorm=False):
def _generate_CNN(width, bias=True, mup=True, batchnorm=False, device='cpu'):
mods = [
nn.Conv2d(3, width, kernel_size=5, bias=bias),
nn.Conv2d(3, width, kernel_size=5, bias=bias, device=device),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(width, 2*width, kernel_size=5, bias=bias),
nn.Conv2d(width, 2*width, kernel_size=5, bias=bias, device=device),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(2*width*25, width*16, bias=bias),
nn.Linear(2*width*25, width*16, bias=bias, device=device),
nn.ReLU(inplace=True),
nn.Linear(width*16, width*10, bias=bias),
nn.Linear(width*16, width*10, bias=bias, device=device),
nn.ReLU(inplace=True),
]
if mup:
mods.append(MuReadout(width*10, 10, bias=bias, readout_zero_init=False))
mods.append(MuReadout(width*10, 10, bias=bias, readout_zero_init=False, device=device))
else:
mods.append(nn.Linear(width*10, 10, bias=bias))
mods.append(nn.Linear(width*10, 10, bias=bias, device=device))
if batchnorm:
mods.insert(1, nn.BatchNorm2d(width))
mods.insert(5, nn.BatchNorm2d(2*width))
mods.insert(10, nn.BatchNorm1d(16*width))
mods.insert(13, nn.BatchNorm1d(10*width))
mods.insert(1, nn.BatchNorm2d(width, device=device))
mods.insert(5, nn.BatchNorm2d(2*width, device=device))
mods.insert(10, nn.BatchNorm1d(16*width, device=device))
mods.insert(13, nn.BatchNorm1d(10*width, device=device))
return nn.Sequential(*mods)
def generate_CNN(width, bias=True, mup=True, readout_zero_init=True, batchnorm=False, init='default', bias_zero_init=False, base_width=8):
@ -105,7 +105,7 @@ def generate_CNN(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F
return set_base_shapes(model, None)
# it's important we make `model` first, because of random seed
model = _generate_CNN(width, bias, mup, batchnorm)
base_model = _generate_CNN(base_width, bias, mup, batchnorm)
base_model = _generate_CNN(base_width, bias, mup, batchnorm, device='meta')
set_base_shapes(model, base_model)
init_methods[init](model)
if readout_zero_init: