This commit is contained in:
Ze Liu 2024-01-31 15:57:42 +08:00 коммит произвёл GitHub
Родитель 2cb103f2de
Коммит 968e6b5e42
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
6 изменённых файлов: 49 добавлений и 12 удалений

Просмотреть файл

@ -6,9 +6,13 @@
# --------------------------------------------------------'
import os
import torch
import yaml
from yacs.config import CfgNode as CN
# pytorch major version (1.x or 2.x)
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
_C = CN()
# Base config files
@ -334,7 +338,10 @@ def update_config(config, args):
config.TRAIN.OPTIMIZER.NAME = args.optim
# set local rank for distributed training
config.LOCAL_RANK = args.local_rank
if PYTORCH_MAJOR_VERSION == 1:
config.LOCAL_RANK = args.local_rank
else:
config.LOCAL_RANK = int(os.environ['LOCAL_RANK'])
# output folder
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)

Просмотреть файл

@ -29,6 +29,9 @@ from logger import create_logger
from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \
reduce_tensor
# pytorch major version (1.x or 2.x)
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
def parse_option():
parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
@ -64,7 +67,10 @@ def parse_option():
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
# distributed training
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
# for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
# (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
if PYTORCH_MAJOR_VERSION == 1:
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
# for acceleration
parser.add_argument('--fused_window_process', action='store_true',

Просмотреть файл

@ -33,6 +33,9 @@ from utils_moe import load_checkpoint, load_pretrained, save_checkpoint, auto_re
assert torch.__version__ >= '1.8.0', "DDP-based MoE requires Pytorch >= 1.8.0"
# pytorch major version (1.x or 2.x)
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
def parse_option():
parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
@ -68,7 +71,10 @@ def parse_option():
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
# distributed training
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
# for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
# (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
if PYTORCH_MAJOR_VERSION == 1:
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
args, unparsed = parser.parse_known_args()

Просмотреть файл

@ -26,7 +26,11 @@ from data import build_loader
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from logger import create_logger
from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, \
reduce_tensor
# pytorch major version (1.x or 2.x)
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
def parse_option():
@ -57,7 +61,10 @@ def parse_option():
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
# distributed training
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
# for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
# (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
if PYTORCH_MAJOR_VERSION == 1:
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
args = parser.parse_args()
@ -67,7 +74,8 @@ def parse_option():
def main(config):
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True, is_pretrain=False)
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True,
is_pretrain=False)
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
model = build_model(config, is_pretrain=False)
@ -110,7 +118,7 @@ def main(config):
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
if config.MODEL.RESUME:
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)
acc1, acc5, loss = validate(config, data_loader_val, model)
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
if config.EVAL_MODE:
@ -147,7 +155,7 @@ def main(config):
def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler):
model.train()
optimizer.zero_grad()
logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}')
num_steps = len(data_loader)
@ -331,4 +339,4 @@ if __name__ == '__main__':
# print config
logger.info(config.dump())
main(config)
main(config)

Просмотреть файл

@ -26,6 +26,9 @@ from optimizer import build_optimizer
from logger import create_logger
from utils_simmim import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper
# pytorch major version (1.x or 2.x)
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
def parse_option():
parser = argparse.ArgumentParser('SimMIM pre-training script', add_help=False)
@ -52,7 +55,10 @@ def parse_option():
parser.add_argument('--tag', help='tag of experiment')
# distributed training
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
# for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
# (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
if PYTORCH_MAJOR_VERSION == 1:
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
args = parser.parse_args()
@ -225,4 +231,4 @@ if __name__ == '__main__':
# print config
logger.info(config.dump())
main(config)
main(config)

Просмотреть файл

@ -8,7 +8,11 @@
import os
import torch
import torch.distributed as dist
from torch._six import inf
try:
from torch._six import inf
except:
from torch import inf
def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):