Continued clean up: simply arg names

This commit is contained in:
Wei-ge Chen 2023-04-18 17:58:45 -07:00
Родитель f07b3787c0
Коммит 92cfab6f23
3 изменённых файлов: 70 добавлений и 85 удалений

Просмотреть файл

@ -13,7 +13,6 @@ from datetime import datetime
from pathlib import Path
import pandas as pd
import torch
import yaml
from overrides.overrides import overrides
@ -22,24 +21,10 @@ from archai.discrete_search.algos.evolution_pareto import EvolutionParetoSearch
from archai.discrete_search.api.model_evaluator import ModelEvaluator
from archai.discrete_search.api.search_objectives import SearchObjectives
from search_space import (ConfigSearchSpaceExt, _create_model_from_csv, _load_pretrain_weight)
from latency import AvgOnnxLatency
import train as model_trainer
#hardcoded need to get as a parameter
NUM_LANDMARK_CLASSES = 140
def convert_args_dict_to_list(d):
if d is None:
return []
new_list = []
for key, val in d.items():
new_list.append(f"--{key}")
new_list.append(f"{val}")
return new_list
from dataset import FaceLandmarkDataset
from latency import AvgOnnxLatency
from search_space import ConfigSearchSpaceExt
class AccuracyEvaluator(ModelEvaluator):
def __init__(self, lit_args) -> None:
@ -59,38 +44,38 @@ class OnnxLatencyEvaluator(ModelEvaluator):
def __init__(self, args) -> None:
self.args = args
self.latency_evaluator = AvgOnnxLatency(input_shape=(1, 3, 128, 128), num_trials=self.args.nas_num_latency_measurements, num_input=self.args.nas_num_input_per_latency_measurement)
self.latency_evaluator = AvgOnnxLatency(input_shape=(1, 3, 128, 128), num_trials=self.args.num_latency_measurements, num_input=self.args.num_input_per_latency_measurement)
@overrides
def evaluate(self, model, dataset_provider, budget = None) -> float:
return self.latency_evaluator.evaluate(model)
class NASLandmarks():
class SearchFaceLandmarkModels():
def __init__(self) -> None:
super().__init__()
config_parser = ArgumentParser(conflict_handler="resolve", description='NAS on Face Tracking.')
config_parser.add_argument("--nas_config", required=True, type=Path, help='YAML config file specifying default arguments')
config_parser = ArgumentParser(conflict_handler="resolve", description='NAS for Facial Landmark Detection.')
config_parser.add_argument("--config", required=True, type=Path, help='YAML config file specifying default arguments')
parser = ArgumentParser(conflict_handler="resolve", description='NAS on Face Tracking.')
parser.add_argument("--nas_output_dir", required=True, type=Path)
parser.add_argument("--nas_search_backbone", type=str, help='backbone to used use for config search')
parser.add_argument('--nas_num_jobs_per_gpu', required=False, type=int, default=1)
parser = ArgumentParser(conflict_handler="resolve", description='NAS for Facial Landmark Detection.')
parser.add_argument("--output_dir", required=True, type=Path)
parser.add_argument('--num_jobs_per_gpu', required=False, type=int, default=1)
finalize_group = parser.add_argument_group('Finalize parameters')
finalize_group.add_argument('--nas_finalize_archid', required=False, type=str, help='archid for the model to be finalized')
finalize_group.add_argument("--nas_finalize_models_csv", required='--nas_finalize_archid' in sys.argv, type=str, help='csv file output from the search stage')
finalize_group.add_argument("--nas_finalize_pretrained_weight_file", required=False, type=str, help='weight file from pretraining')
def convert_args_dict_to_list(d):
if d is None:
return []
qat_group = parser.add_argument_group('QAT parameters')
qat_group.add_argument("--nas_use_tvmodel", action='store_true', help='Use Torchvision model')
qat_group.add_argument("--nas_qat", action='store_true', help='Use model ready for quantization aware training')
qat_group.add_argument("--nas_load_nonqat_weights", action='store_true', help='Use weights from previous training without QAT')
new_list = []
for key, val in d.items():
new_list.append(f"--{key}")
new_list.append(f"{val}")
return new_list
def _parse_args_from_config():
args_config, remaining = config_parser.parse_known_args()
if args_config.nas_config:
with open(args_config.nas_config, 'r') as f:
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
# The usual defaults are overridden if a config file is specified.
parser.set_defaults(**cfg)
@ -98,45 +83,47 @@ class NASLandmarks():
# The main arg parser parses the rest of the known command line args.
args, remaining_args = parser.parse_known_args(remaining)
# Args in the config file will be returned as a list of strings to
# be further used by LitLandmarksTrainer
# be further used by the trainer
remaining_args = remaining_args + convert_args_dict_to_list(cfg) if cfg else None
return args, remaining_args
self.nas_args, remaining_args = _parse_args_from_config()
self.lit_args, _ = model_trainer.get_args_parser().parse_known_args(remaining_args)
self.search_args, remaining_args = _parse_args_from_config()
self.trainer_args, _ = model_trainer.get_args_parser().parse_known_args(remaining_args)
def search(self):
ss = ConfigSearchSpaceExt (self.nas_args, num_classes = NUM_LANDMARK_CLASSES)
dataset = FaceLandmarkDataset (self.trainer_args.data_path)
ss = ConfigSearchSpaceExt (self.search_args, num_classes = dataset.num_landmarks)
search_objectives = SearchObjectives()
search_objectives.add_objective(
'Partial training Validation Accuracy',
AccuracyEvaluator(self.lit_args),
AccuracyEvaluator(self.trainer_args),
higher_is_better=False,
compute_intensive=True)
search_objectives.add_objective(
"onnx_latency (ms)",
OnnxLatencyEvaluator(self.nas_args),
OnnxLatencyEvaluator(self.search_args),
higher_is_better=False,
compute_intensive=False)
algo = EvolutionParetoSearch(
search_space=ss,
search_objectives=search_objectives,
output_dir=self.nas_args.nas_output_dir,
num_iters=self.nas_args.nas_num_iters,
init_num_models=self.nas_args.nas_init_num_models,
num_random_mix=self.nas_args.nas_num_random_mix,
max_unseen_population=self.nas_args.nas_max_unseen_population,
mutations_per_parent=self.nas_args.nas_mutations_per_parent,
num_crossovers=self.nas_args.nas_num_crossovers,
seed=self.nas_args.seed,
output_dir=self.search_args.output_dir,
num_iters=self.search_args.num_iters,
init_num_models=self.search_args.init_num_models,
num_random_mix=self.search_args.num_random_mix,
max_unseen_population=self.search_args.max_unseen_population,
mutations_per_parent=self.search_args.mutations_per_parent,
num_crossovers=self.search_args.num_crossovers,
seed=self.search_args.seed,
save_pareto_model_weights = False)
search_results = algo.search()
results_df = search_results.get_search_state_df()
results_df = search_results.get_search_state_df()
ids = results_df.archid.values.tolist()
if (len(set(ids)) > len(ids)):
print("Duplidated models detected in nas results. This is not supposed to happen.")
@ -152,40 +139,40 @@ class NASLandmarks():
output_csv_name = '-'.join(['search',
datetime.now().strftime("%Y%m%d-%H%M%S"),
'output.csv'])
output_csv_path = os.path.join(self.nas_args.nas_output_dir, output_csv_name)
output_csv_path = os.path.join(self.search_args.output_dir, output_csv_name)
config_df.to_csv(output_csv_path)
return
def _main() -> None:
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" # To overcome 'imgcodecs: OpenEXR codec is disabled' error
nas = NASLandmarks()
if (None == nas.nas_args.nas_finalize_archid):
#args = Namespace(**vars(nas.nas_args), **vars(nas.lit_args))
nas.search()
else:
search = SearchFaceLandmarkModels()
search.search()
###To be moved to trainder
"""
model = _create_model_from_csv (
nas.nas_args.nas_finalize_archid,
nas.nas_args.nas_finalize_models_csv,
nas.search_args.nas_finalize_archid,
nas.search_args.nas_finalize_models_csv,
num_classes=NUM_LANDMARK_CLASSES)
print(f'Loading weights from {str(nas.nas_args.nas_finalize_pretrained_weight_file)}')
if (not nas.nas_args.nas_load_nonqat_weights):
if (nas.nas_args.nas_finalize_pretrained_weight_file is not None) :
model = _load_pretrain_weight(nas.nas_args.nas_finalize_pretrained_weight_file, model)
print(f'Loading weights from {str(nas.search_args.nas_finalize_pretrained_weight_file)}')
if (not nas.search_args.nas_load_nonqat_weights):
if (nas.search_args.nas_finalize_pretrained_weight_file is not None) :
model = _load_pretrain_weight(nas.search_args.nas_finalize_pretrained_weight_file, model)
if (nas.nas_args.nas_use_tvmodel):
if (nas.search_args.nas_use_tvmodel):
model.classifier = torch.nn.Sequential(torch.nn.Dropout(0.2), torch.nn.Linear(model.last_channel, NUM_LANDMARK_CLASSES))
# Load pretrained weights after fixing classifier as the weights match the exact network architecture
if (nas.nas_args.nas_load_nonqat_weights):
assert os.path.exists(nas.nas_args.nas_finalize_pretrained_weight_file)
print(f'Loading weights from previous non-QAT training {nas.nas_args.nas_finalize_pretrained_weight_file}')
model.load_state_dict(torch.load(nas.nas_args.nas_finalize_pretrained_weight_file))
if (nas.search_args.nas_load_nonqat_weights):
assert os.path.exists(nas.search_args.nas_finalize_pretrained_weight_file)
print(f'Loading weights from previous non-QAT training {nas.search_args.nas_finalize_pretrained_weight_file}')
model.load_state_dict(torch.load(nas.search_args.nas_finalize_pretrained_weight_file))
val_error = model_trainer.train(nas.lit_args, model)
print(f"Final validation error for model {nas.nas_args.nas_finalize_archid}: {val_error}")
val_error = model_trainer.train(nas.trainer_args, model)
print(f"Final validation error for model {nas.search_args.nas_finalize_archid}: {val_error}")
"""
if __name__ == "__main__":
_main()

Просмотреть файл

@ -2,15 +2,15 @@
# Config file for NAS search - Debug run
#
# NAS args
nas_num_iters: 6
nas_init_num_models: 4 #16 #32
nas_num_random_mix: 32
nas_num_crossovers: 4
nas_mutations_per_parent: 4
nas_max_unseen_population: 4 #16 #32
nas_num_latency_measurements: 15
nas_num_input_per_latency_measurement: 15
# search args
num_iters: 6
init_num_models: 4 #16 #32
num_random_mix: 32
num_crossovers: 4
mutations_per_parent: 4
max_unseen_population: 4 #16 #32
num_latency_measurements: 15
num_input_per_latency_measurement: 15
seed: 0
# Search space args
@ -22,7 +22,6 @@ k_range:
- 7
channel_mult_range: [0.25, 0.5, 0.75, 1.0, 1.25]
depth_mult_range: [0.25, 0.5, 0.75, 1.0, 1.25]
num_neighbors: 100
# model trainer args
data_path: /data/public_face_synthetics/dataset_100000

Просмотреть файл

@ -130,7 +130,7 @@ class ConfigSearchModel(nn.Module):
def forward(self, x):
return self.model.forward(x)
class DiscreteSearchSpaceMNv2Config(DiscreteSearchSpace):
class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
def __init__(self, args, num_classes=140):
super().__init__()
##mvn2's config
@ -187,7 +187,6 @@ class DiscreteSearchSpaceMNv2Config(DiscreteSearchSpace):
self.k_range = tuple(args.k_range)
self.channel_mult_range = tuple(args.channel_mult_range)
self.depth_mult_range = tuple(args.depth_mult_range)
self.NUM_NEIGHBORS = args.num_neighbors
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
@ -310,7 +309,7 @@ class DiscreteSearchSpaceMNv2Config(DiscreteSearchSpace):
return arch
class ConfigSearchSpaceExt(DiscreteSearchSpaceMNv2Config, EvolutionarySearchSpace, BayesOptSearchSpace):
class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpace, BayesOptSearchSpace):
''' We are subclassing CNNSearchSpace just to save up space'''
@overrides
@ -377,7 +376,7 @@ if __name__ == "__main__":
return arch
#conf = common.common_init(config_filepath= '../nas_landmarks_darts.yaml')
ss = DiscreteSearchSpaceMNv2Config()
ss = DiscreteSearchSpaceMobileNetV2()
for i in range(0, 2):
archai_model = create_random_model(ss)
print(archai_model.metadata['config'])