13 KiB
Bring Your Own PyTorch Lightning Model
The InnerEye toolbox is capable of training any PyTorch Lighting (PL) model inside of AzureML, making use of all the usual InnerEye toolbox features:
- Working with different model in the same codebase, and selecting one by name
- Distributed training in AzureML
- Logging via AzureML's native capabilities
- Training on a local GPU machine or inside of AzureML without code changes
- Supply commandline overrides for model configuration elements, to quickly queue many jobs
This can be used by
- Defining a special container class, that encapsulates the PyTorch Lighting model to train, and the data that should be used for training and testing.
- Adding essential trainer parameters like number of epochs to that container.
- Invoking the InnerEye runner and providing the name of the container class, like this:
python InnerEye/ML/runner.py --model=MyContainer
. To train in AzureML, just add a--azureml
flag.
There is a fully working example HelloContainer, that implements
a simple 1-dimensional regression model from data stored in a CSV file. You can run that
from the command line by python InnerEye/ML/runner.py --model=HelloContainer
.
Setup
In order to use these capabilities, you need to implement a class deriving from LightningContainer
. This class
encapsulates everything that is needed for training with PyTorch Lightning:
- The
create_model
method needs to return a subclass ofLightningModule
, that has all the usual PyTorch Lightning methods required for training, like thetraining_step
andforward
methods. This object needs to adhere to additional constraints, see below. - The
get_data_module
method of the container needs to return aLightningDataModule
that has the data loaders for training and validation data. - The optional
get_inference_data_module
returns aLightningDataModule
that is used to read the data for inference (that is, evaluating the trained model). By default, this returns the same data asget_training_data_module
, but you can override this for special models like segmentation models that are trained on equal sized image patches, but evaluated on full images of varying size.
Your class needs to be defined in a Python file in the InnerEye/ML/configs
folder, otherwise it won't be picked up
correctly. If you'd like to have your model defined in a different folder, please specify the Python namespace via
the --model_configs_namespace
argument. For example, use --model_configs_namespace=My.Own.configs
if your
model configuration classes reside in folder My/Own/configs
from the repository root.
Cross Validation
If you are doing cross validation you need to ensure that the LightningDataModule
returned by your container's
get_data_module
method:
- Needs to take into account the number of cross validation splits, and the cross validation split index when preparing the data.
- Needs to log val/Loss in its
validation_step
method. You can find a working example of handling cross validation in the HelloContainer class.
Example:
from pathlib import Path
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, LightningDataModule
from InnerEye.ML.lightning_container import LightningContainer
class MyLightningModel(LightningModule):
def __init__(self):
self.layer = ...
def training_step(self, *args, **kwargs):
...
def forward(self, *args, **kwargs):
...
def configure_optimizers(self):
...
def test_step(self, *args, **kwargs):
...
class MyDataModule(LightningDataModule):
def __init__(self, root_path: Path):
# All data should be read from the folder given in self.root_path
self.root_path = root_path
def train_dataloader(self, *args, **kwargs) -> DataLoader:
# The data should be read off self.root_path
train_dataset = ...
return DataLoader(train_dataset, batch_size=5, num_workers=5)
def val_dataloader(self, *args, **kwargs) -> DataLoader:
# The data should be read off self.root_path
val_dataset = ...
return DataLoader(val_dataset, batch_size=5, num_workers=5)
def test_dataloader(self, *args, **kwargs) -> DataLoader:
# The data should be read off self.root_path
test_dataset = ...
return DataLoader(test_dataset, batch_size=5, num_workers=5)
class MyContainer(LightningContainer):
def __init__(self):
super().__init__()
self.azure_dataset_id = "folder_name_in_azure_blob_storage"
self.local_dataset = "/some/local/path"
self.num_epochs = 42
def create_model(self) -> LightningModule:
return MyLightningModel()
def get_data_module(self) -> LightningDataModule:
return MyDataModule(root_path=self.local_dataset)
Where does the data for training come from?
- When training a model on a local box or VM, the data is read from the
local_dataset
folder that you define in the container. - When training a model in AzureML, the code searches for a folder called
folder_name_in_azure_blob_storage
in Azure blob storage. That is then downloaded or mounted. The local download path is then copied over thelocal_dataset
field in the container, and hence you can always read data fromself.local_dataset
- Alternatively, you can use the
prepare_data
method of aLightningDataModule
to download data from the web, for example. In this case, you don't need to define any of thelocal_dataset
orazure_dataset_id
fields.
In the above example, training is done for 42 epochs. After the model is trained, it will be evaluated on the test set, via PyTorch Lightning's built-in test functionality. See below for an alternative way of running the evaluation on the test set.
Data loaders
The example above creates DataLoader
objects from a dataset. When creating those, you need to specify a batch size
(how many samples from your dataset will go into one minibatch), and a number of worker processes. Note that, by
default, data loading will happen in the main process, meaning that your GPU will sit idle while the CPU reads data
from disk. When specifying a number of workers, it will spawn processes that pre-fetch data from disk, and put them
into a queue, ready for the GPU to pick it up when it is done processing the current minibatch.
For more details, please see the documentation for DataLoader. There is also a tutorial describing the foundations of datasets and data loaders
Outputting files during training
The Lightning model returned by create_model
needs to write its output files to the current working directory.
When running the InnerEye toolbox outside of AzureML, the toolbox will change the current working directory to a
newly created output folder, with a name that contains the time stamp and and the model name.
When running the InnerEye toolbox in AzureML, the folder structure will be set up such that all files written
to the current working directory are later uploaded to Azure blob storage at the end of the AzureML job. The files
will also be later available via the AzureML UI.
Trainer arguments
All arguments that control the PyTorch Lightning Trainer
object are defined in the class TrainerParams
. A
LightningContainer
object inherits from this class. The most essential one is the num_epochs
field, which controls
the max_epochs
argument of the Trainer
.
Usage example:
from pytorch_lightning import LightningModule, LightningDataModule
from InnerEye.ML.lightning_container import LightningContainer
class MyContainer(LightningContainer):
def __init__(self):
super().__init__()
self.num_epochs = 42
def create_model(self) -> LightningModule:
return MyLightningModel()
def get_data_module(self) -> LightningDataModule:
return MyDataModule(root_path=self.local_dataset)
For further details how the TrainerParams
are used, refer to the create_lightning_trainer
method in
InnerEye/ML/model_training.py
Optimizer and LR scheduler arguments
There are two possible ways of choosing the optimizer and LR scheduler:
- The Lightning model returned by
create_model
can define its ownconfigure_optimizers
method, with the same signature asLightningModule.configure_optimizers
. This is the typical way of configuring it for Lightning models. - Alternatively, the model can inherit from
LightningModuleWithOptimizer
. This class implements aconfigure_optimizers
method that uses settings defined in theOptimizerParams
class. These settings are all available from the command line, and you can, for example, start a new run with a different learning rate by supplying the additional commandline flag--l_rate=1e-2
.
Evaluating the trained model
The InnerEye toolbox provides two possible routes of implementing that:
You can either use PyTorch Lightning's built-in capabilities, via the test_step
method. If the model that is
returned by create_model
implements the test_step
method, the InnerEye toolbox will use the trainer.test
method
(see docs).
In this case, the best checkpoint during training will be used. The test data is read via the data loader created
by the test_dataloader
of the LightningDataModule
that is used for training/validation.
Alternatively, the model can implement the methods defined in InnerEyeInference
. In this case, the methods will be
call in this order:
model.on_inference_start()
for dataset_split in [Train, Val, Test]
model.on_inference_epoch_start(dataset_split, is_ensemble_model=False)
for batch_idx, item in enumerate(dataloader[dataset_split])):
model_outputs = model.forward(item)
model.inference_step(item, batch_idx, model_outputs)
model.on_inference_epoch_end()
model.on_inference_end()
Overriding properties on the commandline
You can define hyperparameters that affect data and/or model, as in the following code snippet:
import param
from pytorch_lightning import LightningModule
from InnerEye.ML.lightning_container import LightningContainer
class DummyContainerWithParameters(LightningContainer):
num_layers = param.Integer(default=4)
def create_model(self) -> LightningModule:
return MyLightningModel(self.num_layers)
...
All parameters added in this form will be automatically accessible from the commandline, there is no need to define
a separate argument parser: When starting training, you can add a flag like --num_layers=7
.
Examples
Setting only the required fields
from pytorch_lightning import LightningModule, LightningDataModule
from InnerEye.ML.lightning_container import LightningContainer
class Container1(LightningContainer):
def __init__(self):
super().__init__()
self.azure_dataset_id = "some_folder_in_azure"
self.num_epochs = 20
def create_model(self) -> LightningModule:
return MyLightningModel()
def get_data_module(self) -> LightningDataModule:
# This should read data from self.local_dataset. Before training, the data folder "some_folder_in_azure"
# (given by self.azure_dataset_id) will be downloaded or mounted, and its local path set in
# self.local_dataset
return MyDataModule(root_folder=self.local_dataset)
Adding additional arguments for the PyTorch Lightning trainer
from typing import Dict, Any
from pytorch_lightning import LightningModule, LightningDataModule
from InnerEye.ML.lightning_container import LightningContainer
class Container2(LightningContainer):
def __init__(self):
super().__init__()
self.azure_dataset_id = "some_folder_in_azure"
self.num_epochs = 20
def create_model(self) -> LightningModule:
return MyLightningModel()
def get_data_module(self) -> LightningDataModule:
# This should read data from self.local_dataset. Before training, the data folder "some_folder_in_azure"
# (given by self.azure_dataset_id) will be downloaded or mounted, and its local path set in
# self.local_dataset
return MyDataModule(root_folder=self.local_dataset)
def get_trainer_arguments(self) -> Dict[str, Any]:
# These arguments will be passed through to the Lightning trainer.
return {"gradient_clip_val": 1, "limit_train_batches": 10}