зеркало из https://github.com/microsoft/archai.git
Further clean up
This commit is contained in:
Родитель
7b334a2a4f
Коммит
2673902b15
|
@ -1,60 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import csv\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"y_data = []\n",
|
||||
"y_name = 'onnx_latency (ms)'\n",
|
||||
"x_data = []\n",
|
||||
"x_name = 'Full training Validation Accuracy'\n",
|
||||
"with open('search_results_with_accuracy.csv', 'r') as csvfile:\n",
|
||||
" reader = csv.DictReader(csvfile)\n",
|
||||
" for row in reader:\n",
|
||||
" if (len(row[x_name]) != 0):\n",
|
||||
" x_data.append(float(row[x_name]))\n",
|
||||
" y_data.append(float(row[y_name]))\n",
|
||||
"\n",
|
||||
"# Create plot\n",
|
||||
"fig, ax = plt.subplots()\n",
|
||||
"ax.plot(x_data, y_data, 'o', color='blue', markersize=5)\n",
|
||||
"\n",
|
||||
"# Set axis labels and title\n",
|
||||
"ax.set_xlabel(x_name)\n",
|
||||
"ax.set_ylabel(y_name)\n",
|
||||
"ax.set_title(f'{y_name} vs {x_name}')\n",
|
||||
"\n",
|
||||
"# Save plot to file\n",
|
||||
"plt.savefig(f'{y_name} vs {x_name}.png'.replace(' ', '_'))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "archai_face_landmark",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.16"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -1,585 +0,0 @@
|
|||
#%%
|
||||
from ast import walk
|
||||
from typing import Any
|
||||
import torch
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
from face_synthetics_training.training.datasets.synthetics import CogFaceSynthetics
|
||||
from face_synthetics_training.training.landmarks2 import data_module
|
||||
from face_synthetics_training.training.landmarks2.lit_landmarks import LitLandmarksTrainer, unnormalize_coordinates, landmarks_error
|
||||
from face_synthetics_training.training.landmarks2.data_module import SyntheticsDataModule
|
||||
from face_synthetics_training.training.landmarks2.nas.mobilenetv2 import InvertedResidual, MobileNetV2
|
||||
from torchvision.models.quantization.mobilenetv2 import mobilenet_v2
|
||||
import torchvision.models as models
|
||||
import deepspeed
|
||||
from deepspeed.compression.compress import init_compression, redundancy_clean
|
||||
from queue import Queue
|
||||
import torchvision.models.quantization as tvqntmodels
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from face_synthetics_training.training.landmarks2.nas.utils import to_onnx, get_model_flops, profile_onnx, get_model_latency_1cpu, get_time_elapsed
|
||||
from onnxruntime import InferenceSession
|
||||
from time import time
|
||||
import statistics
|
||||
|
||||
# This script handles static quantization of PyTorch MobilenetV2 model.
|
||||
# However, it was found that ONNX export of quantized models is not supported.
|
||||
# So, the reduction in latency achieved by static quantizing PyTorch model was subpar
|
||||
# compared to just converting the base model to ONNX. Still, static quantization of
|
||||
# converted ONNX model using ONNX APIs yielded better latency. So, going with that approach.
|
||||
# Leaving this code intact for any future reference.
|
||||
|
||||
def save_model(model:torch.nn.Module, model_filepath:Path):
|
||||
torch.save(model.state_dict(), model_filepath)
|
||||
|
||||
def load_model(model, model_filepath, device):
|
||||
model.load_state_dict(torch.load(model_filepath, map_location=device))
|
||||
return model
|
||||
|
||||
# def get_time_elapsed (model, img_size: int, onnx: bool = False) -> float :
|
||||
# #sanity check to make sure we are consisent
|
||||
# num_input = 10
|
||||
# input_img = torch.randn(1, 3, img_size, img_size)
|
||||
|
||||
# if (onnx) :
|
||||
# with tempfile.NamedTemporaryFile() as tmp:
|
||||
# output_path = Path(tmp.name)
|
||||
# to_onnx(model, output_path, img_size=(img_size, img_size))
|
||||
# onnx_session = InferenceSession(str(output_path))
|
||||
|
||||
# def meausre_func() :
|
||||
# t0 = time()
|
||||
# for _ in range(num_input):
|
||||
# if (onnx):
|
||||
# input_name = onnx_session.get_inputs()[0].name
|
||||
# onnx_session.run(None, input_feed={input_name: input_img.numpy()})[0]
|
||||
# else:
|
||||
# pred = model.forward(input_img)
|
||||
# time_measured = 1e3 * (time() - t0) / num_input
|
||||
# return time_measured
|
||||
|
||||
# while True:
|
||||
# time_measured_all = [meausre_func() for _ in range (10)]
|
||||
# time_measured_avg = statistics.mean(time_measured_all)
|
||||
# time_measured_std = statistics.stdev(time_measured_all)
|
||||
# if (time_measured_std < time_measured_avg * 0.1):
|
||||
# break
|
||||
|
||||
# return time_measured_avg, time_measured_std
|
||||
|
||||
def get_1cpu_latency(model_path:str, onnx=False):
|
||||
# model = torch.load(model_path)
|
||||
# model.to('cpu').eval()
|
||||
with torch.no_grad():
|
||||
print(get_model_latency_1cpu(
|
||||
model_path,
|
||||
img_size=192,
|
||||
cpu=1,
|
||||
onnx=onnx,
|
||||
num_input=128
|
||||
))
|
||||
|
||||
def calculate_latency(model:torch.nn.Module, img_size:int):
|
||||
img_size = [img_size, img_size]
|
||||
|
||||
dummy_inputs = [
|
||||
torch.randn(bsz, 3, *img_size[::-1]) % 255
|
||||
for bsz in [1]
|
||||
for _ in range(30)
|
||||
]
|
||||
|
||||
t0 = time()
|
||||
_ = ([
|
||||
model.forward(dummy_input)
|
||||
for dummy_input in dummy_inputs
|
||||
])
|
||||
|
||||
print(f'latency (ms): {1e3 * (time() - t0) / 30}')
|
||||
|
||||
def basic_to_onnx(model:torch.nn.Module, onnx_path:str):
|
||||
torch.onnx.export(
|
||||
model, (torch.ones(1, 3, 192, 192)),
|
||||
onnx_path,
|
||||
opset_version=17,
|
||||
verbose=True,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
dynamic_axes={'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}})
|
||||
|
||||
class QuantizableInvertedResidual(InvertedResidual):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.skip_add = torch.nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.identity:
|
||||
# Regular addition is not supported for quantized operands. Hence, this way
|
||||
# ref. https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/mobilenetv2.py
|
||||
return self.skip_add.add(x, self.conv(x))
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
class QuantizedModel(torch.nn.Module):
|
||||
def __init__(self, model_fp32):
|
||||
|
||||
super(QuantizedModel, self).__init__()
|
||||
# QuantStub converts tensors from floating point to quantized.
|
||||
# This will only be used for inputs.
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
# FP32 model
|
||||
self.model_fp32 = model_fp32
|
||||
# DeQuantStub converts tensors from quantized to floating point.
|
||||
# This will only be used for outputs.
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, x):
|
||||
# manually specify where tensors will be converted from floating
|
||||
# point to quantized in the quantized model
|
||||
x = self.quant(x)
|
||||
x = self.model_fp32(x)
|
||||
# manually specify where tensors will be converted from quantized
|
||||
# to floating point in the quantized model
|
||||
x = self.dequant(x)
|
||||
return x
|
||||
|
||||
def static_quantize_mobilenetv2():
|
||||
untr_model = MobileNetV2(num_classes=960, block=QuantizableInvertedResidual)
|
||||
|
||||
print("base model")
|
||||
print(untr_model)
|
||||
untr_model.to('cpu').eval()
|
||||
|
||||
untr_model_copy = copy.deepcopy(untr_model)
|
||||
quantized_untr_model = QuantizedModel(untr_model_copy)
|
||||
quantized_untr_model.to('cpu').eval()
|
||||
quantized_untr_model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
model_fp32_fused = quantized_untr_model
|
||||
|
||||
# #fuse layers
|
||||
# TODO: Requires in-depth walkthrough of sub-modules to fuse layers at each level
|
||||
# ref. https://leimao.github.io/blog/PyTorch-Static-Quantization/
|
||||
# model_fp32_fused = torch.quantization.fuse_modules(
|
||||
# quantized_untr_model, ['features.0.0', 'features.0.1'] # TODO: check layer names
|
||||
# )
|
||||
|
||||
prepared_model = torch.quantization.prepare(model_fp32_fused)
|
||||
img_size = 192
|
||||
|
||||
prepared_model(torch.randn(1, 3, img_size, img_size))
|
||||
|
||||
quantized_model = torch.quantization.convert(prepared_model)
|
||||
quantized_model.to('cpu').eval()
|
||||
print("quantized model")
|
||||
print(quantized_model)
|
||||
out = quantized_model(torch.randn(1, 3, img_size, img_size))
|
||||
print(out.shape)
|
||||
calculate_latency(quantized_model, 192)
|
||||
print('calculating time taken')
|
||||
avg, std = get_time_elapsed(quantized_model, 192, onnx=False)
|
||||
print(f'avg {avg} std {std}')
|
||||
|
||||
def walkthrough_model(model:torch.nn.Module, module_list, parent = ''):
|
||||
for module_name, module in model.named_children():
|
||||
full_name = f'{parent}/{module_name}'
|
||||
module_list.append((full_name, module))
|
||||
# print(f'module name: {full_name}')
|
||||
# print(f'type {type(module)}')
|
||||
# if('torch.nn.quantized.modules.conv.Conv2d' in str(type(module))):
|
||||
# if (hasattr(module, 'weight')):
|
||||
# breakpoint()
|
||||
# for w in module.weight:
|
||||
# print(f'w {w}')
|
||||
# print(f'weights {module.weight}')
|
||||
walkthrough_model(module, module_list, full_name)
|
||||
|
||||
def copy_parameters(model:torch.nn.Module, quantized_model:torch.nn.Module):
|
||||
modules = []
|
||||
quantized_modules = []
|
||||
walkthrough_model(model, modules)
|
||||
walkthrough_model(quantized_model, quantized_modules)
|
||||
print(f'modules {len(modules)}')
|
||||
print(f'qmodules {len(quantized_modules)}')
|
||||
|
||||
# Skip through initial quant modules to find first Sequential module
|
||||
# to match with non-quantized modules list
|
||||
qm_idx = 0
|
||||
while(True):
|
||||
_, qm = quantized_modules[qm_idx]
|
||||
if (type(qm) == torch.nn.modules.container.Sequential):
|
||||
break
|
||||
else:
|
||||
qm_idx += 1
|
||||
|
||||
print(f'qmidx {qm_idx}')
|
||||
|
||||
m_idx = 0
|
||||
while(m_idx < len(modules)):
|
||||
#working_copy.features._modules['0']._modules['0'].weight.data =
|
||||
# quantized_model.model_fp32.features._modules['0']._modules['0'].weight().dequantize()
|
||||
if (type(modules[m_idx][1]) == torch.nn.modules.conv.Conv2d or
|
||||
type(modules[m_idx][1]) == torch.nn.modules.linear.Linear):
|
||||
# print(f'copying weights from {quantized_modules[qm_idx][0]} to {modules[m_idx][0]}')
|
||||
modules[m_idx][1].weight.data = quantized_modules[qm_idx][1].weight().dequantize()
|
||||
|
||||
if (type(modules[m_idx][1]) == torch.nn.modules.batchnorm.BatchNorm2d):
|
||||
# print(f'copying bias from {quantized_modules[qm_idx][0]} to {modules[m_idx][0]}')
|
||||
modules[m_idx][1].bias.data = quantized_modules[qm_idx][1].bias.dequantize()
|
||||
|
||||
m_idx += 1
|
||||
qm_idx += 1
|
||||
|
||||
def quantize_model(saved_state:Path) -> torch.nn.Module:
|
||||
model = tvqntmodels.mobilenet_v2(pretrained=False, num_classes = 960)
|
||||
model.train()
|
||||
model.qconfig = torch.quantization.get_default_qat_qconfig(backend='qnnpack')
|
||||
# model.fuse_model()
|
||||
torch.quantization.prepare_qat(model, inplace=True)
|
||||
model.load_state_dict(torch.load(saved_state, map_location='cpu'))
|
||||
converted_model = torch.quantization.convert(model, inplace=False)
|
||||
converted_model.to('cpu').eval()
|
||||
return converted_model
|
||||
|
||||
def perform_qat(model:torch.nn.Module, quantized_model_path:str, quantized_onnx_path:str, dummy_training:bool = False) -> torch.nn.Module:
|
||||
if model is None:
|
||||
model = tvqntmodels.mobilenet_v2(pretrained=False, num_classes = 960)
|
||||
model.train()
|
||||
model.qconfig = torch.quantization.get_default_qat_qconfig(backend='qnnpack')
|
||||
# model.fuse_model()
|
||||
torch.quantization.prepare_qat(model, inplace=True)
|
||||
|
||||
#Perform training
|
||||
if dummy_training:
|
||||
img_size = 192
|
||||
model(torch.randn(1, 3, img_size, img_size))
|
||||
else:
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" # To overcome 'imgcodecs: OpenEXR codec is disabled' error
|
||||
args, *_ = LitLandmarksTrainer.parse_args()
|
||||
LitLandmarksTrainer.train(args, model=model)
|
||||
|
||||
torch.save(model.state_dict(), quantized_model_path)
|
||||
converted_model = torch.quantization.convert(model, inplace=False)
|
||||
converted_model.to('cpu').eval()
|
||||
# torch.save(m, quantized_model_path, _use_new_zipfile_serialization = True)
|
||||
calculate_latency(converted_model, 192)
|
||||
print(get_time_elapsed(converted_model, 192, onnx=False))
|
||||
|
||||
# print(converted_model(torch.ones(1, 3, 192, 192)))
|
||||
# test_torch_model('', converted_model)
|
||||
# print('testing complete')
|
||||
# print('testing with path')
|
||||
# test_torch_model(quantized_model_path)
|
||||
|
||||
# to_onnx(converted_model, quantized_onnx_path, (192, 192))
|
||||
# basic_to_onnx(converted_model, quantized_onnx_path)
|
||||
return converted_model
|
||||
|
||||
|
||||
#%%
|
||||
def mobilenetv2_qat():
|
||||
# return
|
||||
untr_model = MobileNetV2(num_classes=960, block=QuantizableInvertedResidual)
|
||||
|
||||
print("base model")
|
||||
#print(untr_model)
|
||||
untr_model.to('cpu').eval()
|
||||
calculate_latency(untr_model, 192)
|
||||
torch.save(untr_model, '/home/yrajas/tmp/qat/untr_model.pt')
|
||||
|
||||
untr_model_copy = copy.deepcopy(untr_model)
|
||||
|
||||
print("walkthrough untr_model")
|
||||
# walkthrough_model(untr_model)
|
||||
|
||||
quantized_untr_model = QuantizedModel(untr_model_copy)
|
||||
quantized_untr_model.eval()
|
||||
|
||||
# elif backend == 'qnnpack':
|
||||
# model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
|
||||
# activation=torch.quantization.default_observer,
|
||||
# weight=torch.quantization.default_weight_observer)
|
||||
quantized_untr_model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
model_fp32_fused = quantized_untr_model
|
||||
|
||||
# #fuse layers
|
||||
# TODO: Requires in-depth walkthrough of sub-modules to fuse layers at each level
|
||||
# ref. https://leimao.github.io/blog/PyTorch-Static-Quantization/
|
||||
# model_fp32_fused = torch.quantization.fuse_modules(
|
||||
# quantized_untr_model, ['features.0.0', 'features.0.1'] # TODO: check layer names
|
||||
# )
|
||||
|
||||
model_fp32_fused.train()
|
||||
prepared_model = torch.quantization.prepare_qat(model_fp32_fused)
|
||||
prepared_model.train()
|
||||
img_size = 192
|
||||
prepared_model(torch.randn(1, 3, img_size, img_size))
|
||||
|
||||
#Perform QAT
|
||||
# os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" # To overcome 'imgcodecs: OpenEXR codec is disabled' error
|
||||
# args, *_ = LitLandmarksTrainer.parse_args()
|
||||
# LitLandmarksTrainer.train(args, prepared_model)
|
||||
|
||||
print("before conversion")
|
||||
#print(prepared_model)
|
||||
|
||||
quantized_model = torch.quantization.convert(prepared_model)
|
||||
quantized_model.to('cpu').eval()
|
||||
print("quantized model")
|
||||
print("walkthrough quantized_model")
|
||||
# walkthrough_model(quantized_model)
|
||||
|
||||
#working_copy.features._modules['0']._modules['0'].weight.data = quantized_model.model_fp32.features._modules['0']._modules['0'].weight().dequantize()
|
||||
#print(quantized_model)
|
||||
#out = quantized_model(torch.randn(1, 3, img_size, img_size))
|
||||
|
||||
# print(out.shape)
|
||||
# calculate_latency(quantized_model, 192)
|
||||
print('calculating time taken')
|
||||
# copy_parameters(untr_model, quantized_model)
|
||||
torch.save(quantized_model, '/home/yrajas/tmp/qat/quantized_model.pt')
|
||||
# untr_model(torch.randn(1, 3, img_size, img_size))
|
||||
# to_onnx(untr_model, '/home/yrajas/tmp/qatcopyparams/mobilenetv2_qat_copyparams.onnx', (img_size, img_size))
|
||||
# avg, std = get_time_elapsed(quantized_model, 192, onnx=True)
|
||||
# print(f'avg {avg} std {std}')
|
||||
|
||||
#%%
|
||||
def get_test_dataset(identities:list, frames:int = 0):
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" # To overcome 'imgcodecs: OpenEXR codec is disabled' error
|
||||
data_module = SyntheticsDataModule(
|
||||
batch_size=5,
|
||||
num_workers=1,
|
||||
validation_proportion=0.1,
|
||||
landmarks_definition='dense_320',
|
||||
roi_landmarks_definition='dense_320',
|
||||
landmarks_weights='dense_320_weights',
|
||||
roi_size=(192, 192),
|
||||
roi_size_multiplier=1.1,
|
||||
use_sigma=True,
|
||||
warp_affine=True,
|
||||
warp_scale=0.05,
|
||||
warp_rotate=10,
|
||||
warp_shift=0.05,
|
||||
warp_jiggle=0.05,
|
||||
warp_squash_chance=0.0,
|
||||
motion_blur_chance=0.05,
|
||||
data_dir='/home/yrajas/data/groundtruth_render_20220419_155805/',
|
||||
frames_per_identity=frames,
|
||||
identities=identities,
|
||||
preload=True,
|
||||
load_depth=False,
|
||||
load_seg=False)
|
||||
data_module.setup()
|
||||
return data_module.full_dataset
|
||||
|
||||
def test_model(onnx_model_path:str):
|
||||
# MR model real data evaluation
|
||||
width = 192
|
||||
height = 192
|
||||
synthetics = get_test_dataset(list(range(0,20000,156)), 3)
|
||||
label_coords_unnormalized = torch.stack([s.ldmks_2d for s in synthetics])
|
||||
input_img = np.array([s.bgr_img.numpy() for s in synthetics])
|
||||
|
||||
onnx_session = InferenceSession(str(onnx_model_path))
|
||||
input_name = onnx_session.get_inputs()[0].name
|
||||
predicted_coords_normalized = onnx_session.run([], {input_name: input_img})[0]
|
||||
|
||||
predicted_coords_normalized = torch.tensor(predicted_coords_normalized)[:,:640] # ignore sigma
|
||||
predicted_coords_normalized = predicted_coords_normalized.reshape(-1, 320, 2) # reshape as co-ordinates
|
||||
predicted_coords_unnormalized = unnormalize_coordinates(predicted_coords_normalized, width, height)
|
||||
error = landmarks_error(predicted_coords_unnormalized, label_coords_unnormalized)
|
||||
|
||||
print(f"Error [Val] {error.mean()}")
|
||||
|
||||
def test_torch_model(model_path:str, model:torch.nn.Module = None):
|
||||
# MR model real data evaluation
|
||||
width = 192
|
||||
height = 192
|
||||
synthetics = get_test_dataset(list(range(0,20000,156)), 3)
|
||||
label_coords_unnormalized = torch.stack([s.ldmks_2d for s in synthetics])
|
||||
input_img = torch.stack([s.bgr_img for s in synthetics])
|
||||
|
||||
if model is None:
|
||||
model = torch.load(model_path)
|
||||
|
||||
model.to('cpu').eval()
|
||||
with torch.no_grad():
|
||||
predicted_coords_normalized = model(input_img)
|
||||
predicted_coords_normalized = torch.tensor(predicted_coords_normalized)[:,:640] # ignore sigma
|
||||
predicted_coords_normalized = predicted_coords_normalized.reshape(-1, 320, 2) # reshape as co-ordinates
|
||||
predicted_coords_unnormalized = unnormalize_coordinates(predicted_coords_normalized, width, height)
|
||||
error = landmarks_error(predicted_coords_unnormalized, label_coords_unnormalized)
|
||||
|
||||
print(f"Error [Val] {error.mean()}")
|
||||
|
||||
def static_quantize_onnx_model(onnx_model_path:str, quantized_onnx_model_path:str):
|
||||
from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader
|
||||
class DummyReader(CalibrationDataReader):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dataset = get_test_dataset(list(range(int(20000))))
|
||||
self.length = len(self.dataset)
|
||||
self.idx = 0
|
||||
|
||||
def get_next(self) -> dict:
|
||||
if self.idx < self.length:
|
||||
img = np.array([self.dataset[self.idx].bgr_img.numpy()])
|
||||
data = {'input': img}
|
||||
self.idx += 1
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
# class RandomReader(CalibrationDataReader):
|
||||
# def __init__(self) -> None:
|
||||
# super().__init__()
|
||||
# self.idx = 0
|
||||
|
||||
# def get_next(self) -> dict:
|
||||
# img_size = 192
|
||||
# if self.idx < 10:
|
||||
# self.idx += 1
|
||||
# return {'input': torch.randn(1, 3, img_size, img_size).numpy()}
|
||||
# else:
|
||||
# return None
|
||||
quantize_static(onnx_model_path, quantized_onnx_model_path, calibration_data_reader=DummyReader())
|
||||
|
||||
def deepspeedcompress_qat():
|
||||
config_path = '/home/yrajas/vision.hu.face.synthetics.training/face_synthetics_training/training/landmarks2/nas/dscompress.json'
|
||||
model = MobileNetV2(num_classes=960)
|
||||
# model = torch.load('/home/yrajas/tmp/qat/tvbaseline/captured_output_dense_320_mobilenetv2_100_192.pt')
|
||||
|
||||
model_copy = copy.deepcopy(model)
|
||||
model_copy.to('cpu').eval()
|
||||
print("***before***")
|
||||
calculate_latency(model_copy, 192)
|
||||
# for name, module in model.named_modules():
|
||||
# print(name)
|
||||
# print(model)
|
||||
# for p in model.parameters():
|
||||
# print(p.norm())
|
||||
|
||||
model = init_compression(model=model, deepspeed_config=config_path)
|
||||
model, _, _, _ = deepspeed.initialize(model=model, config=config_path)
|
||||
model.train()
|
||||
|
||||
# for _ in range(100):
|
||||
# _ = model(torch.randn(1, 3, 192, 192).cuda())
|
||||
#Perform QAT
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" # To overcome 'imgcodecs: OpenEXR codec is disabled' error
|
||||
args, *_ = LitLandmarksTrainer.parse_args()
|
||||
LitLandmarksTrainer.train(args, model)
|
||||
|
||||
#model = redundancy_clean(model=model, deepspeed_config=config_path)
|
||||
|
||||
print("****after**")
|
||||
# for p in model.parameters():
|
||||
# print(p.norm())
|
||||
|
||||
model.module.to('cpu').eval()
|
||||
calculate_latency(model.module, 192)
|
||||
|
||||
# torch.onnx.export(
|
||||
# model, (torch.randn(1, 3, 192, 192)),
|
||||
# '~/tmp/test.onnx',
|
||||
# opset_version=11,
|
||||
# verbose=False,
|
||||
# input_names=['input_0'],
|
||||
# output_names=['output_0'])
|
||||
print("done")
|
||||
|
||||
def train_mobilenetv2():
|
||||
model = MobileNetV2(num_classes=960)
|
||||
model.train()
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" # To overcome 'imgcodecs: OpenEXR codec is disabled' error
|
||||
args, *_ = LitLandmarksTrainer.parse_args()
|
||||
LitLandmarksTrainer.train(args, model)
|
||||
|
||||
def train_torchvisionmodel():
|
||||
model = tvqntmodels.mobilenet_v2(pretrained=False, num_classes = 960)
|
||||
torch.save(model, '/home/yrajas/tmp/qat/tvbaseline/before_training.pt')
|
||||
model.train()
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" # To overcome 'imgcodecs: OpenEXR codec is disabled' error
|
||||
args, *_ = LitLandmarksTrainer.parse_args()
|
||||
LitLandmarksTrainer.train(args, model)
|
||||
torch.save(model, '/home/yrajas/tmp/qat/tvbaseline/after_training.pt')
|
||||
|
||||
def convert_to_onnx(model_path:str, onnx_path:str, img_size:int):
|
||||
model = torch.load(model_path)
|
||||
|
||||
to_onnx(model, onnx_path, (img_size, img_size))
|
||||
|
||||
#%%
|
||||
def main():
|
||||
# calculate_latency(torch.load('/home/yrajas/tmp/qat_nightly/tv_noqat_baseline/outputs/dense_320_mobilenetv2_100_192.pt'), 192)
|
||||
# calculate_latency(torch.load('/home/yrajas/tmp/qat_nightly/tv_qat_mrpretrained/outputs/dense_320_mobilenetv2_100_192.pt'), 192)
|
||||
|
||||
# get_1cpu_latency('/home/yrajas/tmp/qat_nightly/tv_noqat_baseline/outputs/dense_320_mobilenetv2_100_192.onnx', onnx=True)
|
||||
# get_1cpu_latency('/home/yrajas/tmp/qat_nightly/tv_qat_mrpretrained/outputs/dense_320_mobilenetv2_100_192.onnx', onnx=True)
|
||||
# print(profile_onnx('/home/yrajas/tmp/qat_nightly/tv_noqat_baseline/outputs/dense_320_mobilenetv2_100_192.onnx', img_size=(192, 192)))
|
||||
# print(get_time_elapsed(torch.load('/home/yrajas/tmp/qat_nightly/tv_noqat_baseline/outputs/dense_320_mobilenetv2_100_192.pt'), img_size=192, onnx=True))
|
||||
|
||||
# test_torch_model('/home/yrajas/tmp/qat_nightly/tv_noqat_baseline/outputs/dense_320_mobilenetv2_100_192.pt')
|
||||
# test_torch_model('/home/yrajas/tmp/qat_nightly/tv_qat_mrpretrained/outputs/dense_320_mobilenetv2_100_192.pt')
|
||||
|
||||
test_model('/home/yrajas/tmp/qat_nightly/tv_noqat_baseline/outputs/dense_320_mobilenetv2_100_192.onnx')
|
||||
test_model('/home/yrajas/tmp/qat_nightly/tv_qat_mrpretrained/outputs/dense_320_mobilenetv2_100_192.onnx')
|
||||
|
||||
# qat_model = quantize_model('/home/yrajas/tmp/qat/tv_qat_lr1e-4/tv_qat_lr1e-4.pt')
|
||||
# test_torch_model(model_path='', model = qat_model)
|
||||
|
||||
# torchvision_qat()
|
||||
# train_torchvisionmodel()
|
||||
# model = perform_qat(
|
||||
# torch.load('/home/yrajas/tmp/qat/tvbaseline/captured_output_dense_320_mobilenetv2_100_192.pt'),
|
||||
# '/home/yrajas/tmp/qat/tv_qat_swa/tv_qat_swa.pt',
|
||||
# '/home/yrajas/tmp/qat/test_qat.onnx',
|
||||
# dummy_training = False)
|
||||
# test_torch_model(model_path='', model=model)
|
||||
# print('exporting to onnx')
|
||||
# basic_to_onnx(model, '/home/yrajas/tmp/qat/tvbaseline/captured_output_dense_320_mobilenetv2_100_192_test_qat.onnx')
|
||||
# train_mobilenetv2()
|
||||
#static_quantize_mobilenetv2()
|
||||
|
||||
# test_torch_model('/home/yrajas/tmp/qat/tvqat_mrtorch1_10/dense_320_mobilenetv2_100_192_mrtorch110_qat.pt')
|
||||
# test_torch_model('/home/yrajas/tmp/qat/tvbaseline/captured_output_dense_320_mobilenetv2_100_192.pt')
|
||||
# test_torch_model('/home/yrajas/tmp/qat/tvbaseline/captured_output_dense_320_mobilenetv2_100_192_qat.pt')
|
||||
#mobilenetv2_qat()
|
||||
|
||||
# convert_to_onnx('/home/yrajas/tmp/qat/quantized_model.pt', '/home/yrajas/tmp/qat/quantized_model.onnx', 192)
|
||||
|
||||
# deepspeedcompress_qat()
|
||||
|
||||
# calculate_latency(torch.load('/home/yrajas/tmp/qat/baseline/outputs/dense_320_mobilenetv2_100_192.pt'), 192)
|
||||
|
||||
# test_model('/home/yrajas/tmp/qat/quantized_static_model.onnx')
|
||||
# static_quantize_onnx_model(
|
||||
# onnx_model_path = '/home/yrajas/tmp/qat/quantized_model.onnx',
|
||||
# quantized_onnx_model_path='/home/yrajas/tmp/qat/quantized_static_model.onnx'
|
||||
# )
|
||||
|
||||
# test_model('/home/yrajas/tmp/qat/quantized_static_model.onnx')
|
||||
# test_torch_model('/home/yrajas/tmp/qat/baseline/outputs/dense_320_mobilenetv2_100_192.pt')
|
||||
# static_quantize_onnx_model(
|
||||
# onnx_model_path = '/home/yrajas/tmp/mr20k/outputs/dense_320_mobilenetv2_100_192.onnx',
|
||||
# quantized_onnx_model_path='/home/yrajas/tmp/mr20k/outputs/dense_320_mobilenetv2_100_192_randcalib_stqntz.onnx'
|
||||
# )
|
||||
|
||||
# test_torch_model('/home/yrajas/tmp/qat/quantized_model.pt')
|
||||
# print(get_model_flops(torch.load('/home/yrajas/tmp/qat/quantized_model.pt'), 192))
|
||||
|
||||
# print(profile_onnx('/home/yrajas/tmp/qat/quantized_static_model.onnx', (192, 192)))
|
||||
|
||||
# test model error for timm and quantized from timm
|
||||
# test_model('/home/yrajas/tmp/mr20k/outputs/dense_320_mobilenetv2_100_192.onnx')
|
||||
# test_model('/home/yrajas/tmp/mr20k/outputs/dense_320_mobilenetv2_100_192_static_quantized.onnx')
|
||||
|
||||
# test model error for timm and random calibrated quantized from timm
|
||||
# test_model('/home/yrajas/tmp/mr20k/outputs/dense_320_mobilenetv2_100_192.onnx')
|
||||
# test_model('/home/yrajas/tmp/mr20k/outputs/dense_320_mobilenetv2_100_192_randcalib_stqntz.onnx')
|
||||
|
||||
# test model error for orig paper based mbv2 and quantized form of it
|
||||
# test_model('/home/yrajas/tmp/mbv2paper/mbnetv2_paper.onnx')
|
||||
# test_model('/home/yrajas/tmp/mbv2paper/mbnetv2_paper_stquantized.onnx')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -30,7 +30,6 @@ depth_mult_range: [0.25, 0.5, 0.75, 1.0, 1.25]
|
|||
data_path: face_synthetics/dataset_100000
|
||||
output_dir: ./output
|
||||
max_num_images: 20000
|
||||
#max_num_images: 1000
|
||||
train_crop_size: 128
|
||||
epochs: 30
|
||||
batch_size: 128
|
||||
|
|
|
@ -1,282 +0,0 @@
|
|||
import copy
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.ao.quantization
|
||||
import torch.utils.data
|
||||
import torchvision
|
||||
import utils
|
||||
from torch import nn
|
||||
from train import train_one_epoch, evaluate, load_data
|
||||
|
||||
|
||||
try:
|
||||
from torchvision import prototype
|
||||
except ImportError:
|
||||
prototype = None
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.prototype and prototype is None:
|
||||
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
|
||||
if not args.prototype and args.weights:
|
||||
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
|
||||
if args.output_dir:
|
||||
utils.mkdir(args.output_dir)
|
||||
|
||||
utils.init_distributed_mode(args)
|
||||
print(args)
|
||||
|
||||
if args.post_training_quantize and args.distributed:
|
||||
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
|
||||
|
||||
# Set backend engine to ensure that quantized model runs on the correct kernels
|
||||
if args.backend not in torch.backends.quantized.supported_engines:
|
||||
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
|
||||
torch.backends.quantized.engine = args.backend
|
||||
|
||||
device = torch.device(args.device)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# Data loading code
|
||||
print("Loading data")
|
||||
train_dir = os.path.join(args.data_path, "train")
|
||||
val_dir = os.path.join(args.data_path, "val")
|
||||
|
||||
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
|
||||
)
|
||||
|
||||
data_loader_test = torch.utils.data.DataLoader(
|
||||
dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
|
||||
)
|
||||
|
||||
print("Creating model", args.model)
|
||||
# when training quantized models, we always start from a pre-trained fp32 reference model
|
||||
if not args.prototype:
|
||||
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
|
||||
else:
|
||||
model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
|
||||
model.to(device)
|
||||
|
||||
if not (args.test_only or args.post_training_quantize):
|
||||
model.fuse_model(is_qat=True)
|
||||
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
|
||||
torch.ao.quantization.prepare_qat(model, inplace=True)
|
||||
|
||||
if args.distributed and args.sync_bn:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
model_without_ddp = model
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
model_without_ddp = model.module
|
||||
|
||||
if args.resume:
|
||||
checkpoint = torch.load(args.resume, map_location="cpu")
|
||||
model_without_ddp.load_state_dict(checkpoint["model"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
||||
args.start_epoch = checkpoint["epoch"] + 1
|
||||
|
||||
if args.post_training_quantize:
|
||||
# perform calibration on a subset of the training dataset
|
||||
# for that, create a subset of the training dataset
|
||||
ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
|
||||
data_loader_calibration = torch.utils.data.DataLoader(
|
||||
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
|
||||
)
|
||||
model.eval()
|
||||
model.fuse_model(is_qat=False)
|
||||
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
|
||||
torch.ao.quantization.prepare(model, inplace=True)
|
||||
# Calibrate first
|
||||
print("Calibrating")
|
||||
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
|
||||
torch.ao.quantization.convert(model, inplace=True)
|
||||
if args.output_dir:
|
||||
print("Saving quantized model")
|
||||
if utils.is_main_process():
|
||||
torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
|
||||
print("Evaluating post-training quantized model")
|
||||
evaluate(model, criterion, data_loader_test, device=device)
|
||||
return
|
||||
|
||||
if args.test_only:
|
||||
evaluate(model, criterion, data_loader_test, device=device)
|
||||
return
|
||||
|
||||
model.apply(torch.ao.quantization.enable_observer)
|
||||
model.apply(torch.ao.quantization.enable_fake_quant)
|
||||
start_time = time.time()
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
print("Starting training for epoch", epoch)
|
||||
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
|
||||
lr_scheduler.step()
|
||||
with torch.inference_mode():
|
||||
if epoch >= args.num_observer_update_epochs:
|
||||
print("Disabling observer for subseq epochs, epoch = ", epoch)
|
||||
model.apply(torch.ao.quantization.disable_observer)
|
||||
if epoch >= args.num_batch_norm_update_epochs:
|
||||
print("Freezing BN for subseq epochs, epoch = ", epoch)
|
||||
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
|
||||
print("Evaluate QAT model")
|
||||
|
||||
evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
|
||||
quantized_eval_model = copy.deepcopy(model_without_ddp)
|
||||
quantized_eval_model.eval()
|
||||
quantized_eval_model.to(torch.device("cpu"))
|
||||
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
|
||||
|
||||
print("Evaluate Quantized model")
|
||||
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
|
||||
|
||||
model.train()
|
||||
|
||||
if args.output_dir:
|
||||
checkpoint = {
|
||||
"model": model_without_ddp.state_dict(),
|
||||
"eval_model": quantized_eval_model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"lr_scheduler": lr_scheduler.state_dict(),
|
||||
"epoch": epoch,
|
||||
"args": args,
|
||||
}
|
||||
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
||||
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
||||
print("Saving models after epoch ", epoch)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print(f"Training time {total_time_str}")
|
||||
|
||||
|
||||
def get_args_parser(add_help=True):
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
|
||||
|
||||
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
|
||||
parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
|
||||
parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
|
||||
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
||||
|
||||
parser.add_argument(
|
||||
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
||||
)
|
||||
parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
|
||||
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
||||
parser.add_argument(
|
||||
"--num-observer-update-epochs",
|
||||
default=4,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to update observers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-batch-norm-update-epochs",
|
||||
default=3,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to update batch norm stats",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-calibration-batches",
|
||||
default=32,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of batches of training set for \
|
||||
observer calibration ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
||||
)
|
||||
parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
|
||||
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=1e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
||||
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
||||
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
||||
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
||||
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
||||
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
||||
parser.add_argument(
|
||||
"--cache-dataset",
|
||||
dest="cache_dataset",
|
||||
help="Cache the datasets for quicker initialization. \
|
||||
It also serializes the transforms",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sync-bn",
|
||||
dest="sync_bn",
|
||||
help="Use sync batch norm",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-only",
|
||||
dest="test_only",
|
||||
help="Only test the model",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--post-training-quantize",
|
||||
dest="post_training_quantize",
|
||||
help="Post training quantize the model",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
# distributed training parameters
|
||||
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
||||
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
||||
|
||||
parser.add_argument(
|
||||
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|
||||
)
|
||||
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
|
||||
|
||||
# Prototype models only
|
||||
parser.add_argument(
|
||||
"--prototype",
|
||||
dest="prototype",
|
||||
help="Use prototype model builders instead those from main area",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args_parser().parse_args()
|
||||
main(args)
|
|
@ -132,172 +132,3 @@ class FaceLandmarkTransform:
|
|||
|
||||
def __call__(self, sample: Sample):
|
||||
return self.transform(sample)
|
||||
|
||||
|
||||
class RandomMixup(torch.nn.Module):
|
||||
"""Randomly apply Mixup to the provided batch and targets.
|
||||
The class implements the data augmentations as described in the paper
|
||||
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
|
||||
|
||||
Args:
|
||||
num_classes (int): number of classes used for one-hot encoding.
|
||||
p (float): probability of the batch being transformed. Default value is 0.5.
|
||||
alpha (float): hyperparameter of the Beta distribution used for mixup.
|
||||
Default value is 1.0.
|
||||
inplace (bool): boolean to make this transform inplace. Default set to False.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
||||
super().__init__()
|
||||
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
|
||||
assert alpha > 0, "Alpha param can't be zero."
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.p = p
|
||||
self.alpha = alpha
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
batch (Tensor): Float tensor of size (B, C, H, W)
|
||||
target (Tensor): Integer tensor of size (B, )
|
||||
|
||||
Returns:
|
||||
Tensor: Randomly transformed batch.
|
||||
"""
|
||||
if batch.ndim != 4:
|
||||
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
||||
if target.ndim != 1:
|
||||
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
||||
if not batch.is_floating_point():
|
||||
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
||||
if target.dtype != torch.int64:
|
||||
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
||||
|
||||
if not self.inplace:
|
||||
batch = batch.clone()
|
||||
target = target.clone()
|
||||
|
||||
if target.ndim == 1:
|
||||
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
||||
|
||||
if torch.rand(1).item() >= self.p:
|
||||
return batch, target
|
||||
|
||||
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
||||
batch_rolled = batch.roll(1, 0)
|
||||
target_rolled = target.roll(1, 0)
|
||||
|
||||
# Implemented as on mixup paper, page 3.
|
||||
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
||||
batch_rolled.mul_(1.0 - lambda_param)
|
||||
batch.mul_(lambda_param).add_(batch_rolled)
|
||||
|
||||
target_rolled.mul_(1.0 - lambda_param)
|
||||
target.mul_(lambda_param).add_(target_rolled)
|
||||
|
||||
return batch, target
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = (
|
||||
f"{self.__class__.__name__}("
|
||||
f"num_classes={self.num_classes}"
|
||||
f", p={self.p}"
|
||||
f", alpha={self.alpha}"
|
||||
f", inplace={self.inplace}"
|
||||
f")"
|
||||
)
|
||||
return s
|
||||
|
||||
|
||||
class RandomCutmix(torch.nn.Module):
|
||||
"""Randomly apply Cutmix to the provided batch and targets.
|
||||
The class implements the data augmentations as described in the paper
|
||||
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
|
||||
<https://arxiv.org/abs/1905.04899>`_.
|
||||
|
||||
Args:
|
||||
num_classes (int): number of classes used for one-hot encoding.
|
||||
p (float): probability of the batch being transformed. Default value is 0.5.
|
||||
alpha (float): hyperparameter of the Beta distribution used for cutmix.
|
||||
Default value is 1.0.
|
||||
inplace (bool): boolean to make this transform inplace. Default set to False.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
||||
super().__init__()
|
||||
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
|
||||
assert alpha > 0, "Alpha param can't be zero."
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.p = p
|
||||
self.alpha = alpha
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
batch (Tensor): Float tensor of size (B, C, H, W)
|
||||
target (Tensor): Integer tensor of size (B, )
|
||||
|
||||
Returns:
|
||||
Tensor: Randomly transformed batch.
|
||||
"""
|
||||
if batch.ndim != 4:
|
||||
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
||||
if target.ndim != 1:
|
||||
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
||||
if not batch.is_floating_point():
|
||||
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
||||
if target.dtype != torch.int64:
|
||||
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
||||
|
||||
if not self.inplace:
|
||||
batch = batch.clone()
|
||||
target = target.clone()
|
||||
|
||||
if target.ndim == 1:
|
||||
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
||||
|
||||
if torch.rand(1).item() >= self.p:
|
||||
return batch, target
|
||||
|
||||
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
||||
batch_rolled = batch.roll(1, 0)
|
||||
target_rolled = target.roll(1, 0)
|
||||
|
||||
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
|
||||
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
||||
W, H = F.get_image_size(batch)
|
||||
|
||||
r_x = torch.randint(W, (1,))
|
||||
r_y = torch.randint(H, (1,))
|
||||
|
||||
r = 0.5 * math.sqrt(1.0 - lambda_param)
|
||||
r_w_half = int(r * W)
|
||||
r_h_half = int(r * H)
|
||||
|
||||
x1 = int(torch.clamp(r_x - r_w_half, min=0))
|
||||
y1 = int(torch.clamp(r_y - r_h_half, min=0))
|
||||
x2 = int(torch.clamp(r_x + r_w_half, max=W))
|
||||
y2 = int(torch.clamp(r_y + r_h_half, max=H))
|
||||
|
||||
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
|
||||
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
|
||||
|
||||
target_rolled.mul_(1.0 - lambda_param)
|
||||
target.mul_(lambda_param).add_(target_rolled)
|
||||
|
||||
return batch, target
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = (
|
||||
f"{self.__class__.__name__}("
|
||||
f"num_classes={self.num_classes}"
|
||||
f", p={self.p}"
|
||||
f", alpha={self.alpha}"
|
||||
f", inplace={self.inplace}"
|
||||
f")"
|
||||
)
|
||||
return s
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
"""This is from torchvision source code"""
|
||||
"""This is from torchvision source code: https://github.com/pytorch/vision/blob/main/references/classification/utils.py"""
|
||||
import copy
|
||||
import datetime
|
||||
import errno
|
||||
|
|
|
@ -1,139 +0,0 @@
|
|||
#to be removed before merging
|
||||
|
||||
"""This module contains methods and callbacks for visualizing training or validation data."""
|
||||
|
||||
from heapq import heapify, heappush, heappushpop
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
#from face_synthetics_training.core.utils import add_text, draw_landmark_connectivity, draw_landmarks, make_image_grid
|
||||
#from face_synthetics_training.training.utils import bgr_img_to_np, to_np
|
||||
|
||||
UINT8_MAX = 255
|
||||
def bgr_img_to_np(bgr_img_pt):
|
||||
"""Converts a PyTorch (C, H, W) BGR float image [0, 1] into a NumPy (H, W, C) UINT8 image [0, 255]."""
|
||||
|
||||
assert isinstance(bgr_img_pt, torch.Tensor)
|
||||
bgr_img_np = np.clip(np.transpose(bgr_img_pt.cpu().detach().numpy(), (1, 2, 0)), 0, 1)
|
||||
return np.ascontiguousarray((bgr_img_np * UINT8_MAX).astype(np.uint8))
|
||||
|
||||
|
||||
def to_np(tensor: torch.Tensor):
|
||||
"""Convenience function for NumPy-ifying a PyTorch tensor."""
|
||||
isinstance(tensor, torch.Tensor)
|
||||
return tensor.cpu().detach().numpy()
|
||||
|
||||
def draw_landmarks(img, ldmks_2d, thickness=1, thicknesses=None, color=(255, 255, 255),
|
||||
colors=None, ldmks_visibility=None):
|
||||
"""Drawing dots on an image."""
|
||||
# pylint: disable=too-many-arguments
|
||||
assert img.dtype == np.uint8
|
||||
|
||||
img_size = (img.shape[1], img.shape[0])
|
||||
|
||||
for i, ldmk in enumerate(ldmks_2d.astype(int)):
|
||||
if ldmks_visibility is not None and ldmks_visibility[i] == 0:
|
||||
continue
|
||||
if np.all(ldmk > 0) and np.all(ldmk < img_size):
|
||||
l_c = tuple(colors[i] if colors is not None else color)
|
||||
t = thicknesses[i] if thicknesses is not None else thickness
|
||||
cv2.circle(img, tuple(ldmk+1), t, 0, -1, cv2.LINE_AA)
|
||||
cv2.circle(img, tuple(ldmk), t, l_c, -1, cv2.LINE_AA)
|
||||
|
||||
|
||||
def visualize_landmarks(
|
||||
color_image: np.ndarray,
|
||||
label_landmarks: Optional[np.ndarray] = None,
|
||||
predicted_landmarks: Optional[np.ndarray] = None,
|
||||
name: Optional[str] = None,
|
||||
connectivity: Optional[np.ndarray] = None,
|
||||
error=None,
|
||||
include_original_image: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""Creates a visualization of landmarks on a training image.
|
||||
|
||||
Args:
|
||||
color_image (np.ndarray): The color image, e.g. an image of a face.
|
||||
label_landmarks (Optional[np.ndarray], optional): The label or GT landmarks. Defaults to None.
|
||||
predicted_landmarks (Optional[np.ndarray], optional): The landmarks predicted by a network. Defaults to None.
|
||||
name (Optional[str], optional): The name of the image. Defaults to None.
|
||||
connectivity (Optional[np.ndarray], optional): The connectivity between landmark pairs. Defaults to None.
|
||||
error ([type], optional): The average Euclidean landmark error. Defaults to None.
|
||||
include_original_image (bool, optional): If true, also include the original image without annotation.
|
||||
|
||||
Returns:
|
||||
np.ndarray: [description]
|
||||
"""
|
||||
# pylint: disable=too-many-arguments
|
||||
vis_img = color_image.copy()
|
||||
if connectivity is not None:
|
||||
if label_landmarks is not None:
|
||||
draw_landmark_connectivity(vis_img, label_landmarks, connectivity, color=(0, 255, 0))
|
||||
if predicted_landmarks is not None:
|
||||
draw_landmark_connectivity(vis_img, predicted_landmarks, connectivity)
|
||||
else:
|
||||
if label_landmarks is not None:
|
||||
draw_landmarks(vis_img, label_landmarks, color=(0, 255, 0))
|
||||
if predicted_landmarks is not None:
|
||||
draw_landmarks(vis_img, predicted_landmarks, color=(0, 165, 255))
|
||||
|
||||
return np.vstack([color_image, vis_img]) if include_original_image else vis_img
|
||||
|
||||
def unnormalize_coordinates(coords: np.array, img_size: Tuple[int, int]):
|
||||
"""Unnormalize coordinates from [-1, 1] to pixel units."""
|
||||
img_size = np.divide(img_size, 2)
|
||||
coords_pixels = np.add(img_size, np.multiply(coords, img_size))
|
||||
return coords_pixels
|
||||
|
||||
|
||||
def visualize_batch_data(
|
||||
img_file_prefix: str,
|
||||
epoch: int, #epoch number
|
||||
outputs,
|
||||
batch: Any, #image, label tuple
|
||||
batch_idx: int = 0
|
||||
):
|
||||
"""At the end of each training batch, dump an image visualizing label and predicted landmarks."""
|
||||
# We are overriding, and do not use all arguments, so:
|
||||
# pylint: disable=too-many-arguments,signature-differs,unused-argument
|
||||
|
||||
if batch_idx == 0: # Visualize first batch only
|
||||
vis_imgs = []
|
||||
|
||||
batch_size = batch[0].shape[0]
|
||||
num_images = 1
|
||||
for img_idx in range(num_images):
|
||||
color_image = bgr_img_to_np(batch[0][img_idx]).copy()
|
||||
label_coordinates = to_np(batch[1][img_idx])
|
||||
predicted_coords = to_np(outputs[img_idx])
|
||||
|
||||
img_size = color_image.shape[0:2]
|
||||
label_coordinates_unnormalized = unnormalize_coordinates(label_coordinates, img_size)
|
||||
predicted_coords_unnormalized = unnormalize_coordinates(predicted_coords, img_size)
|
||||
|
||||
vis_img = visualize_landmarks(
|
||||
color_image=color_image,
|
||||
label_landmarks=label_coordinates_unnormalized,
|
||||
predicted_landmarks=predicted_coords_unnormalized)
|
||||
vis_imgs.append(vis_img)
|
||||
|
||||
cv2.imwrite(f"{img_file_prefix}_{epoch:04d}.jpg", vis_imgs[0])
|
||||
#batch_visualization = make_image_grid(vis_imgs, min(num_images, 8))
|
||||
#cv2.imwrite(str(self.log_dir / f"vis_img_train_{epoch:04d}.jpg"), batch_visualization)
|
||||
"""
|
||||
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
epoch = trainer.current_epoch
|
||||
vis_imgs = []
|
||||
|
||||
# Sort all samples by score, and visualize them
|
||||
for _, error, _, sample in sorted(self.samples):
|
||||
vis_imgs.append(visualize_landmarks(**sample, connectivity=self.connectivity, error=error))
|
||||
|
||||
batch_visualization = make_image_grid(vis_imgs, min(len(self.samples), 8))
|
||||
cv2.imwrite(str(self.log_dir / f"{self.name}_{epoch:04d}.jpg"), batch_visualization)
|
||||
"""
|
||||
|
Загрузка…
Ссылка в новой задаче