diff --git a/examples/Transformer/main.py b/examples/Transformer/main.py index 9af4383..3c6dc80 100644 --- a/examples/Transformer/main.py +++ b/examples/Transformer/main.py @@ -312,7 +312,6 @@ if __name__ == '__main__': if args.deferred_init: from torchdistx.deferred_init import deferred_init # We don't need to instantiate the base and delta models - # Note: this only works with torch nightly since unsqueeze isn't supported for fake tensors in stable base_shapes = get_shapes( deferred_init(mdl.TransformerModel, args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout, tied=args.tied, bias=args.bias, encoder_var=args.init_var,