зеркало из https://github.com/microsoft/archai.git
Continued clean up: simply arg names
This commit is contained in:
Родитель
f07b3787c0
Коммит
92cfab6f23
|
@ -13,7 +13,6 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
from overrides.overrides import overrides
|
||||
|
||||
|
@ -22,24 +21,10 @@ from archai.discrete_search.algos.evolution_pareto import EvolutionParetoSearch
|
|||
from archai.discrete_search.api.model_evaluator import ModelEvaluator
|
||||
from archai.discrete_search.api.search_objectives import SearchObjectives
|
||||
|
||||
from search_space import (ConfigSearchSpaceExt, _create_model_from_csv, _load_pretrain_weight)
|
||||
from latency import AvgOnnxLatency
|
||||
import train as model_trainer
|
||||
|
||||
#hardcoded need to get as a parameter
|
||||
NUM_LANDMARK_CLASSES = 140
|
||||
|
||||
def convert_args_dict_to_list(d):
|
||||
if d is None:
|
||||
return []
|
||||
|
||||
new_list = []
|
||||
for key, val in d.items():
|
||||
new_list.append(f"--{key}")
|
||||
new_list.append(f"{val}")
|
||||
|
||||
return new_list
|
||||
|
||||
from dataset import FaceLandmarkDataset
|
||||
from latency import AvgOnnxLatency
|
||||
from search_space import ConfigSearchSpaceExt
|
||||
class AccuracyEvaluator(ModelEvaluator):
|
||||
|
||||
def __init__(self, lit_args) -> None:
|
||||
|
@ -59,38 +44,38 @@ class OnnxLatencyEvaluator(ModelEvaluator):
|
|||
|
||||
def __init__(self, args) -> None:
|
||||
self.args = args
|
||||
self.latency_evaluator = AvgOnnxLatency(input_shape=(1, 3, 128, 128), num_trials=self.args.nas_num_latency_measurements, num_input=self.args.nas_num_input_per_latency_measurement)
|
||||
self.latency_evaluator = AvgOnnxLatency(input_shape=(1, 3, 128, 128), num_trials=self.args.num_latency_measurements, num_input=self.args.num_input_per_latency_measurement)
|
||||
|
||||
@overrides
|
||||
def evaluate(self, model, dataset_provider, budget = None) -> float:
|
||||
return self.latency_evaluator.evaluate(model)
|
||||
|
||||
class NASLandmarks():
|
||||
class SearchFaceLandmarkModels():
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
config_parser = ArgumentParser(conflict_handler="resolve", description='NAS on Face Tracking.')
|
||||
config_parser.add_argument("--nas_config", required=True, type=Path, help='YAML config file specifying default arguments')
|
||||
config_parser = ArgumentParser(conflict_handler="resolve", description='NAS for Facial Landmark Detection.')
|
||||
config_parser.add_argument("--config", required=True, type=Path, help='YAML config file specifying default arguments')
|
||||
|
||||
parser = ArgumentParser(conflict_handler="resolve", description='NAS on Face Tracking.')
|
||||
parser.add_argument("--nas_output_dir", required=True, type=Path)
|
||||
parser.add_argument("--nas_search_backbone", type=str, help='backbone to used use for config search')
|
||||
parser.add_argument('--nas_num_jobs_per_gpu', required=False, type=int, default=1)
|
||||
parser = ArgumentParser(conflict_handler="resolve", description='NAS for Facial Landmark Detection.')
|
||||
parser.add_argument("--output_dir", required=True, type=Path)
|
||||
parser.add_argument('--num_jobs_per_gpu', required=False, type=int, default=1)
|
||||
|
||||
finalize_group = parser.add_argument_group('Finalize parameters')
|
||||
finalize_group.add_argument('--nas_finalize_archid', required=False, type=str, help='archid for the model to be finalized')
|
||||
finalize_group.add_argument("--nas_finalize_models_csv", required='--nas_finalize_archid' in sys.argv, type=str, help='csv file output from the search stage')
|
||||
finalize_group.add_argument("--nas_finalize_pretrained_weight_file", required=False, type=str, help='weight file from pretraining')
|
||||
def convert_args_dict_to_list(d):
|
||||
if d is None:
|
||||
return []
|
||||
|
||||
qat_group = parser.add_argument_group('QAT parameters')
|
||||
qat_group.add_argument("--nas_use_tvmodel", action='store_true', help='Use Torchvision model')
|
||||
qat_group.add_argument("--nas_qat", action='store_true', help='Use model ready for quantization aware training')
|
||||
qat_group.add_argument("--nas_load_nonqat_weights", action='store_true', help='Use weights from previous training without QAT')
|
||||
new_list = []
|
||||
for key, val in d.items():
|
||||
new_list.append(f"--{key}")
|
||||
new_list.append(f"{val}")
|
||||
|
||||
return new_list
|
||||
|
||||
def _parse_args_from_config():
|
||||
args_config, remaining = config_parser.parse_known_args()
|
||||
if args_config.nas_config:
|
||||
with open(args_config.nas_config, 'r') as f:
|
||||
if args_config.config:
|
||||
with open(args_config.config, 'r') as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
# The usual defaults are overridden if a config file is specified.
|
||||
parser.set_defaults(**cfg)
|
||||
|
@ -98,45 +83,47 @@ class NASLandmarks():
|
|||
# The main arg parser parses the rest of the known command line args.
|
||||
args, remaining_args = parser.parse_known_args(remaining)
|
||||
# Args in the config file will be returned as a list of strings to
|
||||
# be further used by LitLandmarksTrainer
|
||||
# be further used by the trainer
|
||||
remaining_args = remaining_args + convert_args_dict_to_list(cfg) if cfg else None
|
||||
|
||||
return args, remaining_args
|
||||
|
||||
self.nas_args, remaining_args = _parse_args_from_config()
|
||||
self.lit_args, _ = model_trainer.get_args_parser().parse_known_args(remaining_args)
|
||||
self.search_args, remaining_args = _parse_args_from_config()
|
||||
self.trainer_args, _ = model_trainer.get_args_parser().parse_known_args(remaining_args)
|
||||
|
||||
def search(self):
|
||||
|
||||
ss = ConfigSearchSpaceExt (self.nas_args, num_classes = NUM_LANDMARK_CLASSES)
|
||||
dataset = FaceLandmarkDataset (self.trainer_args.data_path)
|
||||
ss = ConfigSearchSpaceExt (self.search_args, num_classes = dataset.num_landmarks)
|
||||
|
||||
search_objectives = SearchObjectives()
|
||||
search_objectives.add_objective(
|
||||
'Partial training Validation Accuracy',
|
||||
AccuracyEvaluator(self.lit_args),
|
||||
AccuracyEvaluator(self.trainer_args),
|
||||
higher_is_better=False,
|
||||
compute_intensive=True)
|
||||
search_objectives.add_objective(
|
||||
"onnx_latency (ms)",
|
||||
OnnxLatencyEvaluator(self.nas_args),
|
||||
OnnxLatencyEvaluator(self.search_args),
|
||||
higher_is_better=False,
|
||||
compute_intensive=False)
|
||||
|
||||
algo = EvolutionParetoSearch(
|
||||
search_space=ss,
|
||||
search_objectives=search_objectives,
|
||||
output_dir=self.nas_args.nas_output_dir,
|
||||
num_iters=self.nas_args.nas_num_iters,
|
||||
init_num_models=self.nas_args.nas_init_num_models,
|
||||
num_random_mix=self.nas_args.nas_num_random_mix,
|
||||
max_unseen_population=self.nas_args.nas_max_unseen_population,
|
||||
mutations_per_parent=self.nas_args.nas_mutations_per_parent,
|
||||
num_crossovers=self.nas_args.nas_num_crossovers,
|
||||
seed=self.nas_args.seed,
|
||||
output_dir=self.search_args.output_dir,
|
||||
num_iters=self.search_args.num_iters,
|
||||
init_num_models=self.search_args.init_num_models,
|
||||
num_random_mix=self.search_args.num_random_mix,
|
||||
max_unseen_population=self.search_args.max_unseen_population,
|
||||
mutations_per_parent=self.search_args.mutations_per_parent,
|
||||
num_crossovers=self.search_args.num_crossovers,
|
||||
seed=self.search_args.seed,
|
||||
save_pareto_model_weights = False)
|
||||
|
||||
search_results = algo.search()
|
||||
results_df = search_results.get_search_state_df()
|
||||
|
||||
results_df = search_results.get_search_state_df()
|
||||
ids = results_df.archid.values.tolist()
|
||||
if (len(set(ids)) > len(ids)):
|
||||
print("Duplidated models detected in nas results. This is not supposed to happen.")
|
||||
|
@ -152,40 +139,40 @@ class NASLandmarks():
|
|||
output_csv_name = '-'.join(['search',
|
||||
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
||||
'output.csv'])
|
||||
output_csv_path = os.path.join(self.nas_args.nas_output_dir, output_csv_name)
|
||||
output_csv_path = os.path.join(self.search_args.output_dir, output_csv_name)
|
||||
config_df.to_csv(output_csv_path)
|
||||
return
|
||||
|
||||
|
||||
def _main() -> None:
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" # To overcome 'imgcodecs: OpenEXR codec is disabled' error
|
||||
|
||||
nas = NASLandmarks()
|
||||
if (None == nas.nas_args.nas_finalize_archid):
|
||||
#args = Namespace(**vars(nas.nas_args), **vars(nas.lit_args))
|
||||
nas.search()
|
||||
else:
|
||||
search = SearchFaceLandmarkModels()
|
||||
search.search()
|
||||
|
||||
###To be moved to trainder
|
||||
"""
|
||||
model = _create_model_from_csv (
|
||||
nas.nas_args.nas_finalize_archid,
|
||||
nas.nas_args.nas_finalize_models_csv,
|
||||
nas.search_args.nas_finalize_archid,
|
||||
nas.search_args.nas_finalize_models_csv,
|
||||
num_classes=NUM_LANDMARK_CLASSES)
|
||||
|
||||
print(f'Loading weights from {str(nas.nas_args.nas_finalize_pretrained_weight_file)}')
|
||||
if (not nas.nas_args.nas_load_nonqat_weights):
|
||||
if (nas.nas_args.nas_finalize_pretrained_weight_file is not None) :
|
||||
model = _load_pretrain_weight(nas.nas_args.nas_finalize_pretrained_weight_file, model)
|
||||
print(f'Loading weights from {str(nas.search_args.nas_finalize_pretrained_weight_file)}')
|
||||
if (not nas.search_args.nas_load_nonqat_weights):
|
||||
if (nas.search_args.nas_finalize_pretrained_weight_file is not None) :
|
||||
model = _load_pretrain_weight(nas.search_args.nas_finalize_pretrained_weight_file, model)
|
||||
|
||||
if (nas.nas_args.nas_use_tvmodel):
|
||||
if (nas.search_args.nas_use_tvmodel):
|
||||
model.classifier = torch.nn.Sequential(torch.nn.Dropout(0.2), torch.nn.Linear(model.last_channel, NUM_LANDMARK_CLASSES))
|
||||
|
||||
# Load pretrained weights after fixing classifier as the weights match the exact network architecture
|
||||
if (nas.nas_args.nas_load_nonqat_weights):
|
||||
assert os.path.exists(nas.nas_args.nas_finalize_pretrained_weight_file)
|
||||
print(f'Loading weights from previous non-QAT training {nas.nas_args.nas_finalize_pretrained_weight_file}')
|
||||
model.load_state_dict(torch.load(nas.nas_args.nas_finalize_pretrained_weight_file))
|
||||
if (nas.search_args.nas_load_nonqat_weights):
|
||||
assert os.path.exists(nas.search_args.nas_finalize_pretrained_weight_file)
|
||||
print(f'Loading weights from previous non-QAT training {nas.search_args.nas_finalize_pretrained_weight_file}')
|
||||
model.load_state_dict(torch.load(nas.search_args.nas_finalize_pretrained_weight_file))
|
||||
|
||||
val_error = model_trainer.train(nas.lit_args, model)
|
||||
print(f"Final validation error for model {nas.nas_args.nas_finalize_archid}: {val_error}")
|
||||
val_error = model_trainer.train(nas.trainer_args, model)
|
||||
print(f"Final validation error for model {nas.search_args.nas_finalize_archid}: {val_error}")
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
|
@ -2,15 +2,15 @@
|
|||
# Config file for NAS search - Debug run
|
||||
#
|
||||
|
||||
# NAS args
|
||||
nas_num_iters: 6
|
||||
nas_init_num_models: 4 #16 #32
|
||||
nas_num_random_mix: 32
|
||||
nas_num_crossovers: 4
|
||||
nas_mutations_per_parent: 4
|
||||
nas_max_unseen_population: 4 #16 #32
|
||||
nas_num_latency_measurements: 15
|
||||
nas_num_input_per_latency_measurement: 15
|
||||
# search args
|
||||
num_iters: 6
|
||||
init_num_models: 4 #16 #32
|
||||
num_random_mix: 32
|
||||
num_crossovers: 4
|
||||
mutations_per_parent: 4
|
||||
max_unseen_population: 4 #16 #32
|
||||
num_latency_measurements: 15
|
||||
num_input_per_latency_measurement: 15
|
||||
seed: 0
|
||||
|
||||
# Search space args
|
||||
|
@ -22,7 +22,6 @@ k_range:
|
|||
- 7
|
||||
channel_mult_range: [0.25, 0.5, 0.75, 1.0, 1.25]
|
||||
depth_mult_range: [0.25, 0.5, 0.75, 1.0, 1.25]
|
||||
num_neighbors: 100
|
||||
|
||||
# model trainer args
|
||||
data_path: /data/public_face_synthetics/dataset_100000
|
||||
|
|
|
@ -130,7 +130,7 @@ class ConfigSearchModel(nn.Module):
|
|||
def forward(self, x):
|
||||
return self.model.forward(x)
|
||||
|
||||
class DiscreteSearchSpaceMNv2Config(DiscreteSearchSpace):
|
||||
class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
||||
def __init__(self, args, num_classes=140):
|
||||
super().__init__()
|
||||
##mvn2's config
|
||||
|
@ -187,7 +187,6 @@ class DiscreteSearchSpaceMNv2Config(DiscreteSearchSpace):
|
|||
self.k_range = tuple(args.k_range)
|
||||
self.channel_mult_range = tuple(args.channel_mult_range)
|
||||
self.depth_mult_range = tuple(args.depth_mult_range)
|
||||
self.NUM_NEIGHBORS = args.num_neighbors
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
@ -310,7 +309,7 @@ class DiscreteSearchSpaceMNv2Config(DiscreteSearchSpace):
|
|||
|
||||
return arch
|
||||
|
||||
class ConfigSearchSpaceExt(DiscreteSearchSpaceMNv2Config, EvolutionarySearchSpace, BayesOptSearchSpace):
|
||||
class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpace, BayesOptSearchSpace):
|
||||
''' We are subclassing CNNSearchSpace just to save up space'''
|
||||
|
||||
@overrides
|
||||
|
@ -377,7 +376,7 @@ if __name__ == "__main__":
|
|||
return arch
|
||||
|
||||
#conf = common.common_init(config_filepath= '../nas_landmarks_darts.yaml')
|
||||
ss = DiscreteSearchSpaceMNv2Config()
|
||||
ss = DiscreteSearchSpaceMobileNetV2()
|
||||
for i in range(0, 2):
|
||||
archai_model = create_random_model(ss)
|
||||
print(archai_model.metadata['config'])
|
Загрузка…
Ссылка в новой задаче