This commit is contained in:
Wei-ge Chen 2023-04-19 13:50:28 -07:00
Родитель e5d7271e5f
Коммит e2aee26da7
4 изменённых файлов: 85 добавлений и 112 удалений

Просмотреть файл

@ -7,13 +7,29 @@ import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
import presets
class FaceLandmarkDataset(Dataset):
"""Dataset class for Microsoft Face Synthetics dataset.
Args:
directory (str): Path to the directory containing the PNG images and landmarks files.
limit (int, optional): Maximum number of samples to load from the dataset. Defaults to None.
crop_size (int, optional): Size of the square crop to apply to the images. Defaults to 128.
Attributes:
png_files (list): List of paths to the PNG image files in the dataset.
transform (FaceLandmarkTransform): Transform to apply to the samples.
_num_landmarks (int): Number of landmarks in each sample.
Methods:
__len__(): Returns the number of samples in the dataset.
__getitem__(index): Returns the image and landmarks of the sample at the given index.
num_landmarks(): Returns the number of landmarks in each sample.
"""
def __init__(self, directory, limit=None, crop_size = 128):
""" initialize """
pattern = os.path.join(directory, "[0-9][0-9][0-9][0-9][0-9][0-9].png") #don't load *_seg.png files
self.png_files = glob.glob(pattern)
assert len(self.png_files) > 0, f"Can't find any PNG image in folder: {directory}"
@ -27,7 +43,15 @@ class FaceLandmarkDataset(Dataset):
return len(self.png_files)
def __getitem__(self, index):
"""get a sample"""
"""
Returns the image and landmarks of the sample at the given index.
Args:
index (int): Index of the sample to retrieve.
Returns:
tuple: A tuple containing the transformed image and landmarks of the sample.
"""
png_file = self.png_files[index]
image = Image.open(png_file)
label_file = png_file.replace(".png", "_ldmks.txt")
@ -35,18 +59,22 @@ class FaceLandmarkDataset(Dataset):
assert label.size > 0, "Can't find data in landmarks file: f{label_file}"
#label[:, 1] = image.height - label[:, 1] #flip due to the landmarks Y definition
sample = presets.Sample(bgr_img=image, ldmks_2d=label)
sample = presets.Sample(image=image, landmarks=label)
assert sample is not None
sample_transformed = self.transform(sample)
assert sample_transformed is not None
return sample_transformed.bgr_img, sample_transformed.ldmks_2d
return sample_transformed.image, sample_transformed.landmarks
@property
def num_landmarks(self):
""" number of landmarks in each sample"""
"""
Returns the number of landmarks in each sample.
Returns:
int: The number of landmarks in each sample.
"""
if self._num_landmarks is None:
_, label = self.__getitem__(0)
self._num_landmarks = torch.numel(label)
return self._num_landmarks

Просмотреть файл

@ -3,30 +3,29 @@ from typing import Dict, List, Optional, Tuple, Union
import os
os.environ["OMP_NUM_THREADS"] = "1"
import statistics
from time import perf_counter
import onnxruntime as rt
import torch
from archai.discrete_search.api.archai_model import ArchaiModel
import os
import statistics
from math import sqrt
from time import perf_counter
from typing import Dict, Optional, Tuple
class AvgOnnxLatency:
higher_is_better: bool = False
def __init__(self, input_shape: Union[Tuple, List[Tuple]], num_trials: int = 1, num_input: int = 10,
def __init__(self, input_shape: Union[Tuple, List[Tuple]], num_trials: int = 15, num_input: int = 15,
input_dtype: str = 'torch.FloatTensor', rand_range: Tuple[float, float] = (0.0, 1.0),
export_kwargs: Optional[Dict] = None, inf_session_kwargs: Optional[Dict] = None):
"""Uses the average ONNX Latency (in seconds) of an architecture as an objective function for
minimization.
""" Measure the average ONNX Latency (in millseconds) of a model
Args:
input_shape (Union[Tuple, List[Tuple]]): Model Input shape or list of model input shapes.
num_trials (int, optional): Number of trials. Defaults to 1.
num_trials (int, optional): Number of trials. Defaults to 15.
num_input (int, optional): Number of input per trial. Defaults to 15.
input_dtype (str, optional): Data type of input samples.
rand_range (Tuple[float, float], optional): The min and max range of input samples.
export_kwargs (Optional[Dict], optional): Optional dictionary of key-value args passed to
`torch.onnx.export`. Defaults to None.
inf_session_kwargs (Optional[Dict], optional): Optional dictionary of key-value args
@ -40,17 +39,17 @@ class AvgOnnxLatency:
for input_shape in input_shapes
])
self.input_dtype = input_dtype
self.rand_range = rand_range
self.num_trials = num_trials
self.num_input_per_trial = num_input
self.export_kwargs = export_kwargs or dict()
self.inf_session_kwargs = inf_session_kwargs or dict()
def evaluate(self, model: ArchaiModel) -> float:
model.arch.to('cpu')
"""Evaluate the model and return the average latency (in milliseconds)"""
"""Args: model (ArchaiModel): Model to evaluate"""
"""Returns: float: Average latency (in milliseconds)"""
# Exports model to ONNX
model.arch.to('cpu')
exported_model_buffer = io.BytesIO()
torch.onnx.export(
model.arch, self.sample_input, exported_model_buffer,
@ -58,69 +57,37 @@ class AvgOnnxLatency:
opset_version=11,
**self.export_kwargs
)
print("torch.onnx.export done")
exported_model_buffer.seek(0)
opts = rt.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
# Benchmarks ONNX model
onnx_session = rt.InferenceSession(exported_model_buffer.read(), sess_options=opts, **self.inf_session_kwargs)
sample_input = {f'input_{i}': inp.numpy() for i, inp in enumerate(self.sample_input)}
# inf_times = []
inf_time_avg = self.get_time_elapsed (onnx_session, sample_input, num_input = self.num_input_per_trial, num_measures = self.num_trials)
# for _ in range(self.num_trials):
# with MeasureBlockTime('onnx_inference') as t:
# onnx_session.run(None, input_feed=sample_input)
# inf_times.append(t.elapsed)
# return sum(inf_times) / self.num_trials
num_input_per_trial = self.num_input_per_trial
#inf_time_avg, inf_time_std = self.get_model_latency_1cpu(onnx_session, model.arch, sample_input, cpu = 1, onnx = True,
inf_time_avg, inf_time_std = self.get_time_elapsed (onnx_session, model.arch, sample_input, onnx = True, num_input = num_input_per_trial, num_measures = self.num_trials)
#per trial time is idealy longer so that timing can be more accurate
if (inf_time_avg * num_input_per_trial < 100) :
num_input_per_trial = int(1.5 * 100 / inf_time_avg + 0.5)
#inf_time_avg, inf_time_std = self.get_model_latency_1cpu(onnx_session, model.arch, sample_input, cpu = 1, onnx = True, num_input = num_input_per_trial, num_measures = self.num_trials)
inf_time_avg, inf_time_std = self.get_time_elapsed (onnx_session, model.arch, sample_input, onnx = True, num_input = num_input_per_trial, num_measures = self.num_trials)
if (inf_time_std > 0.1 * inf_time_avg):
ratio = (0.1 * inf_time_avg) / inf_time_std
ratio *= ratio * 1.1
num_trails_scaled = int(self.num_trials * ratio + 0.5)
#inf_time_avg, inf_time_std = self.get_model_latency_1cpu(onnx_session, model.arch, sample_input, cpu = 1, onnx = True, num_input = num_input_per_trial, num_measures = num_trails_scaled)
inf_time_avg, inf_time_std = self.get_time_elapsed (onnx_session, model.arch, sample_input, onnx = True, num_input = num_input_per_trial, num_measures = self.num_trials)
assert (inf_time_std < 0.1 * inf_time_avg, f"inf_time_std = {inf_time_std}, inf_time_avg = {inf_time_avg:}")
return inf_time_avg
def get_time_elapsed (self, onnx_session, model, sample_input, onnx: bool = False, num_input:int = 15, num_measures:int = 15) -> Tuple[float, float] :
#print("get_time_elapsed: entering")
def get_time_elapsed (self, onnx_session, sample_input, num_input:int = 15, num_measures:int = 15) -> float:
"""Measure the average time elapsed (in milliseconds) for a given model and input for anumber of times
Args:
onnx_session (onnxruntime.InferenceSession): ONNX Inference Session
sample_input (Dict[str, np.ndarray]): Sample input to the model
num_input (int, optional): Number of input per trial. Defaults to 15.
num_measures (int, optional): Number of measures. Defaults to 15.
Returns:
float: Average time elapsed (in milliseconds)"""
def meausre_func() :
#print(f"measure_func entered")
"""Measure the time elapsed (in milliseconds) for a given model and input, once
Returns: float: Time elapsed (in milliseconds)"""
t0 = perf_counter()
for _ in range(num_input): #this is to incease the accuracy as 1 run maybe too short to measure
#print(f"measure_func iter: {_}")
if (onnx):
#print(f"measure_func: start onnx_session.run")
onnx_session.run(None, input_feed=sample_input)[0]
#print(f"measure_func: left onnx_session.run")
else:
pred = model.forward(sample_input)
for _ in range(num_input):
onnx_session.run(None, input_feed=sample_input)[0]
t1 = perf_counter()
time_measured = 1e3 * (t1 - t0) / num_input
#print(f"measure_func return: {time_measured}")
return time_measured
time_measured_all = []
#print("get_time_elapsed: starting measure_func")
time_measured_all = [meausre_func() for _ in range (num_measures)]
time_measured_avg = statistics.mean(time_measured_all)
time_measured_std = sqrt(num_input) * statistics.stdev(time_measured_all) #sigma^2(x+y) = sigma^2(x) + sigma^2(y); then there is average
time_measured_std /= sqrt(num_measures) #stdev of the sample mean, not the population
#print(f"get_time_elapsed return: {time_measured_avg}, {time_measured_std}")
return time_measured_avg, time_measured_std #, time_measured_all
return statistics.mean([meausre_func() for _ in range (num_measures)])

Просмотреть файл

@ -1,22 +0,0 @@
def landmarks_loss(
predicted_coords: torch.Tensor,
label_coords: torch.Tensor) -> torch.Tensor:
"""Calculate a scalar loss value for a batch of landmark predictions, given the GT landmark labels.
In the below:
* B: Batch size
* K: Number of landmarks, (aka. keypoints)
Args:
predicted_coords (torch.Tensor): A batch of predicted 2D landmark coordinates (B, K, 2).
label_coords (torch.Tensor): A batch of true (GT) 2D landmark coordinates (B, K, 2).
Returns:
A scalar loss value, averaged over every keypoint in the batch (torch.Tensor)
"""
assert predicted_coords.shape == label_coords.shape
loss_fn = torch.nn.MSELoss(reduction="none")
loss_per_landmark = loss_fn(label_coords, predicted_coords)
return loss_per_landmark.mean()

Просмотреть файл

@ -69,10 +69,10 @@ class ClassificationPresetEval:
class Sample():
# pylint: disable=too-many-arguments
def __init__(self, bgr_img=None, ldmks_2d=None):
def __init__(self, image=None, landmarks=None):
self.bgr_img = np.array(bgr_img)
self.ldmks_2d = ldmks_2d
self.image = np.array(image)
self.landmarks = landmarks
self.warp_region = None
def get_bounds(points):
@ -219,11 +219,11 @@ class GetWarpRegion():
def __call__(self, sample: Sample):
assert sample.ldmks_2d is not None
assert sample.landmarks is not None
ldmks_2d = sample.ldmks_2d
ldmks_2d = sample.landmarks
if self.landmarks_definition:
ldmks_2d = self.landmarks_definition.apply(sample.ldmks_2d)
ldmks_2d = self.landmarks_definition.apply(sample.landmarks)
sample.warp_region = WarpRegion(*get_square_bounds(ldmks_2d), self.roi_size)
sample.warp_region.scale(self.scale)
@ -239,19 +239,19 @@ class ExtractWarpRegion():
def __call__(self, sample : tuple):
assert sample.bgr_img is not None
assert sample.image is not None
assert sample.warp_region is not None
warp_region = sample.warp_region
if self.keep_unwarped:
# Useful for visualizations and debugging
sample.bgr_img_unwarped = np.copy(sample.bgr_img)
sample.image_unwarped = np.copy(sample.image)
sample.bgr_img = warp_region.extract_from_image(sample.bgr_img, **self.kwargs_bgr)
sample.image = warp_region.extract_from_image(sample.image, **self.kwargs_bgr)
if sample.ldmks_2d is not None:
sample.ldmks_2d = warp_region.transform_points(sample.ldmks_2d)
if sample.landmarks is not None:
sample.landmarks = warp_region.transform_points(sample.landmarks)
return sample
@ -265,9 +265,9 @@ class NormalizeCoordinates():
"""Normalize coordinates from pixel units to [-1, 1]."""
def __call__(self, sample: Sample):
assert (sample.ldmks_2d is not None)
width, height = sample.bgr_img.shape[-2::]
sample.ldmks_2d = normalize_coordinates(sample.ldmks_2d, width, height)
assert (sample.landmarks is not None)
width, height = sample.image.shape[-2::]
sample.landmarks = normalize_coordinates(sample.landmarks, width, height)
return sample
@ -276,12 +276,12 @@ class SampleToTensor():
""" Turns a NumPy data in a Sample into PyTorch data """
def __call__(self, sample: Sample):
sample.bgr_img = torch.from_numpy(np.transpose(sample.bgr_img, (2, 0, 1)))
sample.bgr_img = sample.bgr_img / 255.0
sample.bgr_img = sample.bgr_img.float()
sample.image = torch.from_numpy(np.transpose(sample.image, (2, 0, 1)))
sample.image = sample.image / 255.0
sample.image = sample.image.float()
if sample.ldmks_2d is not None:
sample.ldmks_2d = torch.from_numpy(sample.ldmks_2d).float()
if sample.landmarks is not None:
sample.landmarks = torch.from_numpy(sample.landmarks).float()
return sample