зеркало из https://github.com/microsoft/mup.git
add --deferred_init option
This commit is contained in:
Родитель
e968350db8
Коммит
265f2d9f63
|
@ -95,7 +95,8 @@ if __name__ == '__main__':
|
|||
help='Do coord check with this many steps.')
|
||||
parser.add_argument('--coord_check_nseeds', type=int, default=5,
|
||||
help='number of seeds for testing correctness of μ parametrization')
|
||||
|
||||
parser.add_argument('--deferred_init', action='store_true', help='Skip instantiating the base and delta models for mup. Requires torchdistx.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
|
@ -223,17 +224,25 @@ if __name__ == '__main__':
|
|||
|
||||
for width in [64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
||||
# print(f'{nonlin.__name__}_{criterion.__name__}_{str(width)}')
|
||||
mynet = MLP(width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult).to(device)
|
||||
if args.save_base_shapes:
|
||||
print(f'saving base shapes at {args.save_base_shapes}')
|
||||
base_shapes = get_shapes(mynet)
|
||||
delta_shapes = get_shapes(
|
||||
# just need to change whatever dimension(s) we are scaling
|
||||
MLP(width=width+1, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult)
|
||||
)
|
||||
if args.deferred_init:
|
||||
from torchdistx.deferred_init import deferred_init
|
||||
base_shapes = get_shapes(deferred_init(MLP, width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult))
|
||||
delta_shapes = get_shapes(
|
||||
# just need to change whatever dimension(s) we are scaling
|
||||
deferred_init(MLP, width=width+1, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult)
|
||||
)
|
||||
else:
|
||||
base_shapes = get_shapes(MLP(width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult))
|
||||
delta_shapes = get_shapes(
|
||||
# just need to change whatever dimension(s) we are scaling
|
||||
MLP(width=width+1, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult)
|
||||
)
|
||||
make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes)
|
||||
print('done and exit')
|
||||
import sys; sys.exit()
|
||||
mynet = MLP(width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult).to(device)
|
||||
if args.load_base_shapes:
|
||||
print(f'loading base shapes from {args.load_base_shapes}')
|
||||
set_base_shapes(mynet, args.load_base_shapes)
|
||||
|
|
|
@ -186,7 +186,8 @@ if __name__ == '__main__':
|
|||
help='Do coord check with this many steps.')
|
||||
parser.add_argument('--coord_check_nseeds', type=int, default=3,
|
||||
help='number of seeds for testing correctness of μ parametrization')
|
||||
|
||||
parser.add_argument('--deferred_init', action='store_true', help='Skip instantiating the base and delta models for mup. Requires torchdistx.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(args)
|
||||
|
@ -306,22 +307,41 @@ if __name__ == '__main__':
|
|||
import sys; sys.exit()
|
||||
|
||||
|
||||
model = mdl.TransformerModel(args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
|
||||
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
|
||||
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
|
||||
if args.save_base_shapes:
|
||||
print(f'saving base shapes at {args.save_base_shapes}')
|
||||
base_shapes = get_shapes(model)
|
||||
delta_shapes = get_shapes(
|
||||
# just need to change whatever dimension(s) we are scaling
|
||||
mdl.TransformerModel(args, ntokens, ninp=args.d_model*2, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio*2,
|
||||
nlayers=args.nlayers, dropout=args.dropout,
|
||||
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
|
||||
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
|
||||
)
|
||||
if args.deferred_init:
|
||||
from torchdistx.deferred_init import deferred_init
|
||||
base_shapes = get_shapes(
|
||||
deferred_init(mdl.TransformerModel, args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
|
||||
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
|
||||
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
|
||||
)
|
||||
delta_shapes = get_shapes(
|
||||
# just need to change whatever dimension(s) we are scaling
|
||||
deferred_init(mdl.TransformerModel, args, ntokens, ninp=args.d_model*2, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio*2,
|
||||
nlayers=args.nlayers, dropout=args.dropout,
|
||||
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
|
||||
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
|
||||
)
|
||||
else:
|
||||
base_shapes = get_shapes(
|
||||
mdl.TransformerModel(args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
|
||||
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
|
||||
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
|
||||
)
|
||||
delta_shapes = get_shapes(
|
||||
# just need to change whatever dimension(s) we are scaling
|
||||
mdl.TransformerModel(args, ntokens, ninp=args.d_model*2, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio*2,
|
||||
nlayers=args.nlayers, dropout=args.dropout,
|
||||
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
|
||||
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
|
||||
)
|
||||
make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes)
|
||||
print('done and exit')
|
||||
import sys; sys.exit()
|
||||
model = mdl.TransformerModel(args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
|
||||
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
|
||||
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
|
||||
if args.load_base_shapes:
|
||||
print(f'loading base shapes from {args.load_base_shapes}')
|
||||
set_base_shapes(model, args.load_base_shapes)
|
||||
|
|
Загрузка…
Ссылка в новой задаче