supporting pytorch 2.x (#346)
This commit is contained in:
Родитель
2cb103f2de
Коммит
968e6b5e42
|
@ -6,9 +6,13 @@
|
|||
# --------------------------------------------------------'
|
||||
|
||||
import os
|
||||
import torch
|
||||
import yaml
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
# pytorch major version (1.x or 2.x)
|
||||
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
|
||||
|
||||
_C = CN()
|
||||
|
||||
# Base config files
|
||||
|
@ -334,7 +338,10 @@ def update_config(config, args):
|
|||
config.TRAIN.OPTIMIZER.NAME = args.optim
|
||||
|
||||
# set local rank for distributed training
|
||||
config.LOCAL_RANK = args.local_rank
|
||||
if PYTORCH_MAJOR_VERSION == 1:
|
||||
config.LOCAL_RANK = args.local_rank
|
||||
else:
|
||||
config.LOCAL_RANK = int(os.environ['LOCAL_RANK'])
|
||||
|
||||
# output folder
|
||||
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
|
||||
|
|
8
main.py
8
main.py
|
@ -29,6 +29,9 @@ from logger import create_logger
|
|||
from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \
|
||||
reduce_tensor
|
||||
|
||||
# pytorch major version (1.x or 2.x)
|
||||
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
|
||||
|
||||
|
||||
def parse_option():
|
||||
parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
|
||||
|
@ -64,7 +67,10 @@ def parse_option():
|
|||
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
|
||||
|
||||
# distributed training
|
||||
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
||||
# for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
|
||||
# (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
|
||||
if PYTORCH_MAJOR_VERSION == 1:
|
||||
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
||||
|
||||
# for acceleration
|
||||
parser.add_argument('--fused_window_process', action='store_true',
|
||||
|
|
|
@ -33,6 +33,9 @@ from utils_moe import load_checkpoint, load_pretrained, save_checkpoint, auto_re
|
|||
|
||||
assert torch.__version__ >= '1.8.0', "DDP-based MoE requires Pytorch >= 1.8.0"
|
||||
|
||||
# pytorch major version (1.x or 2.x)
|
||||
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
|
||||
|
||||
|
||||
def parse_option():
|
||||
parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
|
||||
|
@ -68,7 +71,10 @@ def parse_option():
|
|||
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
|
||||
|
||||
# distributed training
|
||||
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
||||
# for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
|
||||
# (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
|
||||
if PYTORCH_MAJOR_VERSION == 1:
|
||||
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
||||
|
||||
args, unparsed = parser.parse_known_args()
|
||||
|
||||
|
|
|
@ -26,7 +26,11 @@ from data import build_loader
|
|||
from lr_scheduler import build_scheduler
|
||||
from optimizer import build_optimizer
|
||||
from logger import create_logger
|
||||
from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
|
||||
from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, \
|
||||
reduce_tensor
|
||||
|
||||
# pytorch major version (1.x or 2.x)
|
||||
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
|
||||
|
||||
|
||||
def parse_option():
|
||||
|
@ -57,7 +61,10 @@ def parse_option():
|
|||
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
|
||||
|
||||
# distributed training
|
||||
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
||||
# for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
|
||||
# (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
|
||||
if PYTORCH_MAJOR_VERSION == 1:
|
||||
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -67,7 +74,8 @@ def parse_option():
|
|||
|
||||
|
||||
def main(config):
|
||||
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True, is_pretrain=False)
|
||||
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True,
|
||||
is_pretrain=False)
|
||||
|
||||
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
|
||||
model = build_model(config, is_pretrain=False)
|
||||
|
@ -110,7 +118,7 @@ def main(config):
|
|||
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
|
||||
|
||||
if config.MODEL.RESUME:
|
||||
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)
|
||||
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)
|
||||
acc1, acc5, loss = validate(config, data_loader_val, model)
|
||||
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
||||
if config.EVAL_MODE:
|
||||
|
@ -147,7 +155,7 @@ def main(config):
|
|||
def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}')
|
||||
|
||||
num_steps = len(data_loader)
|
||||
|
@ -331,4 +339,4 @@ if __name__ == '__main__':
|
|||
# print config
|
||||
logger.info(config.dump())
|
||||
|
||||
main(config)
|
||||
main(config)
|
||||
|
|
|
@ -26,6 +26,9 @@ from optimizer import build_optimizer
|
|||
from logger import create_logger
|
||||
from utils_simmim import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper
|
||||
|
||||
# pytorch major version (1.x or 2.x)
|
||||
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
|
||||
|
||||
|
||||
def parse_option():
|
||||
parser = argparse.ArgumentParser('SimMIM pre-training script', add_help=False)
|
||||
|
@ -52,7 +55,10 @@ def parse_option():
|
|||
parser.add_argument('--tag', help='tag of experiment')
|
||||
|
||||
# distributed training
|
||||
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
||||
# for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
|
||||
# (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
|
||||
if PYTORCH_MAJOR_VERSION == 1:
|
||||
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -225,4 +231,4 @@ if __name__ == '__main__':
|
|||
# print config
|
||||
logger.info(config.dump())
|
||||
|
||||
main(config)
|
||||
main(config)
|
||||
|
|
6
utils.py
6
utils.py
|
@ -8,7 +8,11 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._six import inf
|
||||
|
||||
try:
|
||||
from torch._six import inf
|
||||
except:
|
||||
from torch import inf
|
||||
|
||||
|
||||
def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):
|
||||
|
|
Загрузка…
Ссылка в новой задаче