This commit is contained in:
Edward Hu 2022-05-08 08:17:03 -04:00
Родитель c9d67001c4
Коммит 3e3daabdcb
1 изменённых файлов: 4 добавлений и 4 удалений

Просмотреть файл

@ -76,14 +76,14 @@ class MyModel(nn.Module):
### Instantiate a base model
base_model = MyModel(width=1)
### Optionally, use `device='meta'` to avoid instantiating the parameters
### This requires you to pass the device flag down to all sub-modules
# base_model = MyModel(width=1, device='meta')
### Optionally, use `torchdistx.deferred_init.deferred_init` to avoid instantiating the parameters
### Simply install `torchdistx` and use
# base_model = torchdistx.deferred_init.deferred_init(MyModel, width=1)
### Instantiate a "delta" model that differs from the base model
### in all dimensions ("widths") that one wishes to scale.
### Here it's simple, but e.g., in a Transformer, you may want to scale
### both nhead and dhead, so the delta model should differ in both.
delta_model = MyModel(width=2) # Optionally add the `device='meta'` to avoid instantiating
delta_model = MyModel(width=2) # Optionally use `torchdistx` to avoid instantiating
### Instantiate the target model (the model you actually want to train).
### This should be the same as the base model except