This commit is contained in:
Edward Hu 2022-05-08 08:34:13 -04:00
Родитель 3e3daabdcb
Коммит e968350db8
1 изменённых файлов: 2 добавлений и 3 удалений

Просмотреть файл

@ -126,9 +126,8 @@ optimizer = MuSGD(model.parameters(), lr=0.1)
```
Note the base and delta models *do not need to be trained* --- we are only extracting parameter shape information from them.
Therefore, optionally, we can avoid instantiating these potentially large models by passing `device='meta'` to their constructor.
However, you need to make sure that the `device` flag is appropriately passed down to the constructor of all submodules.
Of course, it'd be even better if PyTorch can do this automatically for any existing `nn.Module`. If you want to see this happen, please upvote [this PyTorch issue](https://github.com/pytorch/pytorch/issues/74143).
Therefore, optionally, we can avoid instantiating these potentially large models by using the `deferred_init` function in `torchdistx`.
After installing [`torchdistx`](https://github.com/pytorch/torchdistx), use `torchdistx.deferred_init.deferred_init(MyModel, **args)` instead of `MyModel(**args)`. See [this page](https://pytorch.org/torchdistx/latest/deferred_init.html) for more detail.
## How `mup` Works Under the Hood