From 968e6b5e428186dc99d19166996439c23dccc4d1 Mon Sep 17 00:00:00 2001 From: Ze Liu Date: Wed, 31 Jan 2024 15:57:42 +0800 Subject: [PATCH] supporting pytorch 2.x (#346) --- config.py | 9 ++++++++- main.py | 8 +++++++- main_moe.py | 8 +++++++- main_simmim_ft.py | 20 ++++++++++++++------ main_simmim_pt.py | 10 ++++++++-- utils.py | 6 +++++- 6 files changed, 49 insertions(+), 12 deletions(-) diff --git a/config.py b/config.py index 1671ec3..88acdb6 100644 --- a/config.py +++ b/config.py @@ -6,9 +6,13 @@ # --------------------------------------------------------' import os +import torch import yaml from yacs.config import CfgNode as CN +# pytorch major version (1.x or 2.x) +PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) + _C = CN() # Base config files @@ -334,7 +338,10 @@ def update_config(config, args): config.TRAIN.OPTIMIZER.NAME = args.optim # set local rank for distributed training - config.LOCAL_RANK = args.local_rank + if PYTORCH_MAJOR_VERSION == 1: + config.LOCAL_RANK = args.local_rank + else: + config.LOCAL_RANK = int(os.environ['LOCAL_RANK']) # output folder config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) diff --git a/main.py b/main.py index 84230ea..7fa03ca 100644 --- a/main.py +++ b/main.py @@ -29,6 +29,9 @@ from logger import create_logger from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \ reduce_tensor +# pytorch major version (1.x or 2.x) +PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) + def parse_option(): parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) @@ -64,7 +67,10 @@ def parse_option(): parser.add_argument('--throughput', action='store_true', help='Test throughput only') # distributed training - parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') + # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead + # (see https://pytorch.org/docs/stable/distributed.html#launch-utility) + if PYTORCH_MAJOR_VERSION == 1: + parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') # for acceleration parser.add_argument('--fused_window_process', action='store_true', diff --git a/main_moe.py b/main_moe.py index acf5d20..cc7e664 100644 --- a/main_moe.py +++ b/main_moe.py @@ -33,6 +33,9 @@ from utils_moe import load_checkpoint, load_pretrained, save_checkpoint, auto_re assert torch.__version__ >= '1.8.0', "DDP-based MoE requires Pytorch >= 1.8.0" +# pytorch major version (1.x or 2.x) +PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) + def parse_option(): parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) @@ -68,7 +71,10 @@ def parse_option(): parser.add_argument('--throughput', action='store_true', help='Test throughput only') # distributed training - parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') + # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead + # (see https://pytorch.org/docs/stable/distributed.html#launch-utility) + if PYTORCH_MAJOR_VERSION == 1: + parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') args, unparsed = parser.parse_known_args() diff --git a/main_simmim_ft.py b/main_simmim_ft.py index 067dfbb..c757847 100644 --- a/main_simmim_ft.py +++ b/main_simmim_ft.py @@ -26,7 +26,11 @@ from data import build_loader from lr_scheduler import build_scheduler from optimizer import build_optimizer from logger import create_logger -from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor +from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, \ + reduce_tensor + +# pytorch major version (1.x or 2.x) +PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) def parse_option(): @@ -57,7 +61,10 @@ def parse_option(): parser.add_argument('--throughput', action='store_true', help='Test throughput only') # distributed training - parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') + # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead + # (see https://pytorch.org/docs/stable/distributed.html#launch-utility) + if PYTORCH_MAJOR_VERSION == 1: + parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') args = parser.parse_args() @@ -67,7 +74,8 @@ def parse_option(): def main(config): - dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True, is_pretrain=False) + dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True, + is_pretrain=False) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config, is_pretrain=False) @@ -110,7 +118,7 @@ def main(config): logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') if config.MODEL.RESUME: - max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) + max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") if config.EVAL_MODE: @@ -147,7 +155,7 @@ def main(config): def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler): model.train() optimizer.zero_grad() - + logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}') num_steps = len(data_loader) @@ -331,4 +339,4 @@ if __name__ == '__main__': # print config logger.info(config.dump()) - main(config) \ No newline at end of file + main(config) diff --git a/main_simmim_pt.py b/main_simmim_pt.py index 6591d21..f7ca542 100644 --- a/main_simmim_pt.py +++ b/main_simmim_pt.py @@ -26,6 +26,9 @@ from optimizer import build_optimizer from logger import create_logger from utils_simmim import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper +# pytorch major version (1.x or 2.x) +PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) + def parse_option(): parser = argparse.ArgumentParser('SimMIM pre-training script', add_help=False) @@ -52,7 +55,10 @@ def parse_option(): parser.add_argument('--tag', help='tag of experiment') # distributed training - parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') + # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead + # (see https://pytorch.org/docs/stable/distributed.html#launch-utility) + if PYTORCH_MAJOR_VERSION == 1: + parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') args = parser.parse_args() @@ -225,4 +231,4 @@ if __name__ == '__main__': # print config logger.info(config.dump()) - main(config) \ No newline at end of file + main(config) diff --git a/utils.py b/utils.py index eb607cf..328ad09 100644 --- a/utils.py +++ b/utils.py @@ -8,7 +8,11 @@ import os import torch import torch.distributed as dist -from torch._six import inf + +try: + from torch._six import inf +except: + from torch import inf def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):