зеркало из https://github.com/microsoft/mup.git
Update main.py
This commit is contained in:
Родитель
1c7771ab25
Коммит
2448e700e3
|
@ -312,7 +312,6 @@ if __name__ == '__main__':
|
||||||
if args.deferred_init:
|
if args.deferred_init:
|
||||||
from torchdistx.deferred_init import deferred_init
|
from torchdistx.deferred_init import deferred_init
|
||||||
# We don't need to instantiate the base and delta models
|
# We don't need to instantiate the base and delta models
|
||||||
# Note: this only works with torch nightly since unsqueeze isn't supported for fake tensors in stable
|
|
||||||
base_shapes = get_shapes(
|
base_shapes = get_shapes(
|
||||||
deferred_init(mdl.TransformerModel, args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
|
deferred_init(mdl.TransformerModel, args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
|
||||||
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
|
tied=args.tied, bias=args.bias, encoder_var=args.init_var,
|
||||||
|
|
Загрузка…
Ссылка в новой задаче