This commit is contained in:
Edward Hu 2022-05-30 17:49:00 -04:00
Родитель e968350db8
Коммит 265f2d9f63
2 изменённых файлов: 48 добавлений и 19 удалений

Просмотреть файл

@ -95,7 +95,8 @@ if __name__ == '__main__':
help='Do coord check with this many steps.')
parser.add_argument('--coord_check_nseeds', type=int, default=5,
help='number of seeds for testing correctness of μ parametrization')
parser.add_argument('--deferred_init', action='store_true', help='Skip instantiating the base and delta models for mup. Requires torchdistx.')
args = parser.parse_args()
torch.manual_seed(args.seed)
@ -223,17 +224,25 @@ if __name__ == '__main__':
for width in [64, 128, 256, 512, 1024, 2048, 4096, 8192]:
# print(f'{nonlin.__name__}_{criterion.__name__}_{str(width)}')
mynet = MLP(width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult).to(device)
if args.save_base_shapes:
print(f'saving base shapes at {args.save_base_shapes}')
base_shapes = get_shapes(mynet)
delta_shapes = get_shapes(
# just need to change whatever dimension(s) we are scaling
MLP(width=width+1, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult)
)
if args.deferred_init:
from torchdistx.deferred_init import deferred_init
base_shapes = get_shapes(deferred_init(MLP, width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult))
delta_shapes = get_shapes(
# just need to change whatever dimension(s) we are scaling
deferred_init(MLP, width=width+1, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult)
)
else:
base_shapes = get_shapes(MLP(width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult))
delta_shapes = get_shapes(
# just need to change whatever dimension(s) we are scaling
MLP(width=width+1, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult)
)
make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes)
print('done and exit')
import sys; sys.exit()
mynet = MLP(width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult).to(device)
if args.load_base_shapes:
print(f'loading base shapes from {args.load_base_shapes}')
set_base_shapes(mynet, args.load_base_shapes)

Просмотреть файл

@ -186,7 +186,8 @@ if __name__ == '__main__':
help='Do coord check with this many steps.')
parser.add_argument('--coord_check_nseeds', type=int, default=3,
help='number of seeds for testing correctness of μ parametrization')
parser.add_argument('--deferred_init', action='store_true', help='Skip instantiating the base and delta models for mup. Requires torchdistx.')
args = parser.parse_args()
print(args)
@ -306,22 +307,41 @@ if __name__ == '__main__':
import sys; sys.exit()
model = mdl.TransformerModel(args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
if args.save_base_shapes:
print(f'saving base shapes at {args.save_base_shapes}')
base_shapes = get_shapes(model)
delta_shapes = get_shapes(
# just need to change whatever dimension(s) we are scaling
mdl.TransformerModel(args, ntokens, ninp=args.d_model*2, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio*2,
nlayers=args.nlayers, dropout=args.dropout,
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
)
if args.deferred_init:
from torchdistx.deferred_init import deferred_init
base_shapes = get_shapes(
deferred_init(mdl.TransformerModel, args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
)
delta_shapes = get_shapes(
# just need to change whatever dimension(s) we are scaling
deferred_init(mdl.TransformerModel, args, ntokens, ninp=args.d_model*2, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio*2,
nlayers=args.nlayers, dropout=args.dropout,
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
)
else:
base_shapes = get_shapes(
mdl.TransformerModel(args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
)
delta_shapes = get_shapes(
# just need to change whatever dimension(s) we are scaling
mdl.TransformerModel(args, ntokens, ninp=args.d_model*2, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio*2,
nlayers=args.nlayers, dropout=args.dropout,
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
)
make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes)
print('done and exit')
import sys; sys.exit()
model = mdl.TransformerModel(args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
decoder_var=args.init_var, standparam=args.load_base_shapes=='')
if args.load_base_shapes:
print(f'loading base shapes from {args.load_base_shapes}')
set_base_shapes(model, args.load_base_shapes)