archai/tasks/facial_landmark_detection/train.py

497 строки
20 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Adapted from torchvision https://github.com/pytorch/vision/blob/main/references/classification/train.py"""
import copy
import datetime
import os
import time
import warnings
import torch
import torch.utils.data
import torchvision
import transforms
import utils
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from torchvision.models.quantization.mobilenetv2 import _replace_relu
from torchinfo import summary
from dataset import FaceLandmarkDataset
from search_space import create_model_from_search_results
def average_error(target, output):
errors = target - output # shape (B, K, 2)
norm = torch.norm(errors, dim=-1) # shape (B, K)
return norm.mean()
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
header = f"Epoch: [{epoch}]"
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time()
image, target = image.to(device), target.to(device)
target = torch.squeeze(target)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
output = torch.reshape(output, target.shape)
loss = criterion(output, target)
optimizer.zero_grad()
if scaler is not None:
scaler.scale(loss).backward()
if args.clip_grad_norm is not None:
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step()
if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)
error = average_error(target, output)
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters["error"].update(error, n=batch_size)
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
def evaluate(model, criterion, data_loader, epoch, device, print_freq=100, log_suffix=""):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = f"Test: {log_suffix}"
num_processed_samples = 0
with torch.inference_mode():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(image)
output = torch.reshape(output, target.shape)
loss = criterion(output, target)
error = average_error(target, output)
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters["error"].update(error, n=batch_size)
num_processed_samples += batch_size
# gather the stats from all processes
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
if (
hasattr(data_loader.dataset, "__len__")
and len(data_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
# See FIXME above
warnings.warn(
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
"samples were used for the validation, which might bias the results. "
"Try adjusting the batch size and / or the world size. "
"Setting the world size to 1 is always a safe bet."
)
metric_logger.synchronize_between_processes()
# print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
print(f"{header} Error {metric_logger.error.global_avg:.4f}")
return float(metric_logger.error.global_avg)
def load_data(traindir, args):
# Data loading code
print("Loading data")
_, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
print("Loading training data")
st = time.time()
assert val_crop_size == train_crop_size
dataset = FaceLandmarkDataset(traindir, limit=args.max_num_images, crop_size=train_crop_size)
print("Took", time.time() - st)
validation_dataset_size = int(len(dataset) * 0.1)
dataset_train, dataset_test = torch.utils.data.random_split(
dataset, [len(dataset) - validation_dataset_size, validation_dataset_size]
)
dataset = dataset_train
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
return dataset, dataset_test, train_sampler, test_sampler
def setup_qat(model: nn.Module) -> None:
assert model is not None
_replace_relu(model)
model.eval()
# Modify quantization engine as appropriate for the target platform
model.setup_qconfig('fbgemm')
model.fuse_model()
torch.ao.quantization.prepare_qat(model.train(), inplace=True)
def train(args, model: nn.Module = None):
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
dataset, dataset_test, train_sampler, test_sampler = load_data(args.data_path, args)
num_classes = dataset.dataset.num_landmarks
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.workers,
pin_memory=True,
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
)
print("Creating model")
if model is None:
if args.search_result_archid:
model = create_model_from_search_results(
args.search_result_archid,
args.search_result_csv,
num_classes=num_classes,
qat=args.qat,
qat_skip_layers=args.qat_skip_layers)
if (args.qat):
print('Preparing for QAT')
setup_qat(model)
else:
model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
print(summary(model, input_size=(1, 3, 192, 192)))
criterion = nn.MSELoss()
if args.norm_weight_decay is None:
parameters = model.parameters()
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop(
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
scaler = torch.cuda.amp.GradScaler() if args.amp else None
args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "steplr":
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
elif args.lr_scheduler == "cosineannealinglr":
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs - args.lr_warmup_epochs
)
elif args.lr_scheduler == "exponentiallr":
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
else:
raise RuntimeError(
f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported."
)
if args.lr_warmup_epochs > 0:
if args.lr_warmup_method == "linear":
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
)
elif args.lr_warmup_method == "constant":
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
)
else:
raise RuntimeError(
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
)
else:
lr_scheduler = main_lr_scheduler
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
model_ema = None
if args.model_ema:
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
#
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
alpha = 1.0 - args.model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if model_ema:
model_ema.load_state_dict(checkpoint["model_ema"])
if scaler:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only:
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
else:
evaluate(model, criterion, data_loader_test, device=device)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
lr_scheduler.step()
val_error = evaluate(model, criterion, data_loader_test, epoch, device=device)
if model_ema:
val_error = evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
if args.output_dir:
model_to_save = model_without_ddp
if (args.qat):
model_to_save = copy.deepcopy(model_without_ddp)
model_to_save.eval()
model_to_save.to(torch.device("cpu"))
torch.ao.quantization.convert(model_to_save, inplace=True)
checkpoint = {
"model": model_to_save.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
if model_ema:
checkpoint["model_ema"] = model_ema.state_dict()
if scaler:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")
return val_error
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
parser.add_argument("--data_path", "--data-path", default=None, type=str, help="dataset path")
parser.add_argument(
"--max_num_images",
"--max-num-images",
default=None,
type=int,
help="limit to number of images to use in dataset",
)
parser.add_argument(
"--search_result_archid", "--search-result-archid", default=None, type=str, help="nas search arch id to use"
)
parser.add_argument(
"--search_result_csv", "--search_result-csv", default=None, type=str, help="nas search result csv to use"
)
parser.add_argument("--model", default="resnet18", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
"-b",
"--batch_size",
"--batch-size",
default=32,
type=int,
help="images per gpu, the total batch size is $NGPU x batch_size",
)
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument(
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
)
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
"--wd",
"--weight-decay",
default=1e-4,
type=float,
metavar="W",
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
parser.add_argument(
"--lr_scheduler", "--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)"
)
parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
parser.add_argument(
"--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
)
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
parser.add_argument(
"--lr_step_size", "--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs"
)
parser.add_argument("--lr_gamma", "--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
parser.add_argument("--output_dir", "--output-dir", default=".", type=str, help="path to save outputs")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
parser.add_argument(
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
)
parser.add_argument(
"--model-ema-steps",
type=int,
default=32,
help="the number of iterations that controls how often to update the EMA model (default: 32)",
)
parser.add_argument(
"--model-ema-decay",
type=float,
default=0.99998,
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
)
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
parser.add_argument(
"--val-resize-size", default=128, type=int, help="the resize size used for validation (default: 128)"
)
parser.add_argument(
"--val-crop-size", default=128, type=int, help="the central crop size used for validation (default: 128)"
)
parser.add_argument(
"--train_crop_size",
"--train-crop-size",
default=128,
type=int,
help="the random crop size used for training (default: 128)",
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--qat", help="Performs quantization aware training", action="store_true")
parser.add_argument(
"--qat_skip_layers", default=0, type=int, help="Number of layers to be skipped from quantization when performing QAT"
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser
if __name__ == "__main__":
args, _ = get_args_parser().parse_known_args()
train(args)