Further clean up. Add copyright statements.

This commit is contained in:
Wei-ge Chen 2023-04-21 12:24:43 -07:00
Родитель 424d37e791
Коммит 8352dd39f0
12 изменённых файлов: 174 добавлений и 306 удалений

Просмотреть файл

@ -1,4 +1,5 @@
"""Datasets for Microsoft Face Synthetics dataset.""" # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import glob import glob
import os import os

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import io import io
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union

Просмотреть файл

@ -1,18 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import warnings
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
import torch
from torch import nn, Tensor from torch import nn, Tensor
#from torchvision.ops import Conv2dNormActivation
from torchvision.models._utils import _make_divisible from torchvision.models._utils import _make_divisible
import warnings
from typing import Callable, List, Optional
import torch
from torch import Tensor
#from ..utils import _log_api_usage_once
def _log_api_usage_once(obj: Any) -> None: def _log_api_usage_once(obj: Any) -> None:
""" """

Просмотреть файл

@ -1,65 +0,0 @@
import torch
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode
class ClassificationPresetTrain:
def __init__(
self,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
random_erase_prob=0.0,
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
trans.extend(
[
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob))
self.transforms = transforms.Compose(trans)
def __call__(self, img):
return self.transforms(img)
class ClassificationPresetEval:
def __init__(
self,
crop_size,
resize_size=256,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
):
self.transforms = transforms.Compose(
[
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
)
def __call__(self, img):
return self.transforms(img)

Просмотреть файл

@ -1,62 +0,0 @@
import math
import torch
import torch.distributed as dist
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU).
Heavily based on 'torch.utils.data.DistributedSampler'.
This is borrowed from the DeiT Repo:
https://github.com/facebookresearch/deit/blob/main/samplers.py
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available!")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available!")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
self.seed = seed
self.repetitions = repetitions
def __iter__(self):
if self.shuffle:
# Deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible
indices = [ele for ele in indices for i in range(self.repetitions)]
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# Subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[: self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch

Просмотреть файл

@ -1,5 +1,8 @@
import sys # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
# to be removed before merging
if ('--debug' in sys.argv): if ('--debug' in sys.argv):
import debugpy import debugpy
debugpy.listen(5678) debugpy.listen(5678)
@ -149,30 +152,5 @@ def _main() -> None:
search = SearchFaceLandmarkModels() search = SearchFaceLandmarkModels()
search.search() search.search()
###To be moved to trainder
"""
model = _create_model_from_csv (
nas.search_args.nas_finalize_archid,
nas.search_args.nas_finalize_models_csv,
num_classes=NUM_LANDMARK_CLASSES)
print(f'Loading weights from {str(nas.search_args.nas_finalize_pretrained_weight_file)}')
if (not nas.search_args.nas_load_nonqat_weights):
if (nas.search_args.nas_finalize_pretrained_weight_file is not None) :
model = _load_pretrain_weight(nas.search_args.nas_finalize_pretrained_weight_file, model)
if (nas.search_args.nas_use_tvmodel):
model.classifier = torch.nn.Sequential(torch.nn.Dropout(0.2), torch.nn.Linear(model.last_channel, NUM_LANDMARK_CLASSES))
# Load pretrained weights after fixing classifier as the weights match the exact network architecture
if (nas.search_args.nas_load_nonqat_weights):
assert os.path.exists(nas.search_args.nas_finalize_pretrained_weight_file)
print(f'Loading weights from previous non-QAT training {nas.search_args.nas_finalize_pretrained_weight_file}')
model.load_state_dict(torch.load(nas.search_args.nas_finalize_pretrained_weight_file))
val_error = model_trainer.train(nas.trainer_args, model)
print(f"Final validation error for model {nas.search_args.nas_finalize_archid}: {val_error}")
"""
if __name__ == "__main__": if __name__ == "__main__":
_main() _main()

Просмотреть файл

@ -26,8 +26,8 @@ depth_mult_range: [0.25, 0.5, 0.75, 1.0, 1.25]
# model trainer args # model trainer args
data_path: /data/public_face_synthetics/dataset_100000 data_path: /data/public_face_synthetics/dataset_100000
output_dir: /home/wchen/tmp output_dir: /home/wchen/tmp
max_num_images: 20000 #max_num_images: 20000
#max_num_images: 1000 max_num_images: 1000
train_crop_size: 128 train_crop_size: 128
epochs: 30 epochs: 30
batch_size: 128 batch_size: 128

Просмотреть файл

@ -1,3 +1,5 @@
"""Search space for facial landmark detection task"""
import copy import copy
import json import json
import math import math
@ -5,38 +7,27 @@ import random
import re import re
import sys import sys
from hashlib import sha1 from hashlib import sha1
from os import path
from pathlib import Path
from typing import List from typing import List
import numpy as np import numpy as np
import pandas as pd
import torch import torch
import torch.nn as nn
from overrides.overrides import overrides from overrides.overrides import overrides
from archai.common.common import logger from archai.common.common import logger
from archai.discrete_search.api.archai_model import ArchaiModel from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.search_space import DiscreteSearchSpace, BayesOptSearchSpace, EvolutionarySearchSpace from archai.discrete_search.api.search_space import (
BayesOptSearchSpace,
DiscreteSearchSpace,
EvolutionarySearchSpace,
)
from model import CustomMobileNetV2 from model import CustomMobileNetV2
def _gen_tv_mobilenet ( arch_def, def _gen_tv_mobilenet ( arch_def,
channel_multiplier=1.0, channel_multiplier=1.0,
depth_multiplier=1.0, depth_multiplier=1.0,
num_classes=1000): num_classes=1000):
# default mbv2 setting """generate mobilenet v2 from torchvision. Adapted from timm source code"""
# t - exp factor, c - channels, n - number of block repeats, s - stride
# # t, c, n, s
# [1, 16, 1, 1],
# [6, 24, 2, 2],
# [6, 32, 3, 2],
# [6, 64, 4, 2],
# [6, 96, 3, 1],
# [6, 160, 3, 2],
# [6, 320, 1, 1],
# archid 0a7b6 - {"arch_def": [["ds_r1_k3_s1_c16"], ["ir_r2_k3_s2_e6_c24"], ["ir_r3_k3_s2_e6_c32"], ["ir_r4_k3_s2_e6_c64"], ["ir_r2_k3_s1_e4_c96"], ["ir_r2_k3_s2_e6_c160"], ["ir_r1_k3_s1_e5_c320"]], "channel_multiplier": 1.0, "depth_multiplier": 0.75}
ir_setting = [] ir_setting = []
for block_def in arch_def: for block_def in arch_def:
parts = block_def[0].split("_") parts = block_def[0].split("_")
@ -83,44 +74,9 @@ def _gen_tv_mobilenet ( arch_def,
model = CustomMobileNetV2(inverted_residual_setting=ir_setting, dropout=0, num_classes=num_classes) model = CustomMobileNetV2(inverted_residual_setting=ir_setting, dropout=0, num_classes=num_classes)
return model return model
# return mobilenet_v2(quantize=False,
# inverted_residual_setting = [
# [1, 16, 1, 1],
# [6, 24, 2, 2],
# [6, 32, 3, 2],
# [6, 64, 3, 2],
# [4, 96, 2, 1],
# [6, 160, 2, 2],
# [5, 320, 1, 1],
# ])
def _create_model_from_csv(archid, csv_file : str, num_classes, use_tvmodel:bool=False, qat:bool=False) : class ConfigSearchModel(torch.nn.Module):
csv_path = Path(csv_file) """This is a wrapper class to allow the model to be used in the search space"""
assert csv_path.exists()
df0 = pd.read_csv(csv_path)#.query('metric == @metric')
row = df0[df0['archid'] == archid]
cfg = json.loads(row['config'].to_list()[0])
# Ignore number of classes for now. The classifier layer will be rebuilt after loading pretrained weights
# kwargs.pop('num_classes', None)
# wchen: This doesn't seem to work. _load_pretrain_weight already pops the state_dict so it should be safe to fix the num_classes for now.
model = _gen_tv_mobilenet(cfg['arch_def'],
channel_multiplier=cfg['channel_multiplier'],
depth_multiplier=cfg['depth_multiplier'],
num_classes=num_classes)
return model
def _load_pretrain_weight(weight_file: str, model) :
print("=> loading pretrained weight '{}'".format(weight_file))
assert path.isfile(weight_file)
source_state = torch.load(weight_file)
state_dict = source_state['state_dict']
state_dict.pop('classifier' + '.weight', None)
state_dict.pop('classifier' + '.bias', None)
model.load_state_dict(state_dict, strict = False)
return model
class ConfigSearchModel(nn.Module):
def __init__(self, model : ArchaiModel, archid: str, metadata : dict): def __init__(self, model : ArchaiModel, archid: str, metadata : dict):
super().__init__() super().__init__()
self.model = model self.model = model
@ -133,52 +89,72 @@ class ConfigSearchModel(nn.Module):
class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace): class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
def __init__(self, args, num_classes=140): def __init__(self, args, num_classes=140):
super().__init__() super().__init__()
##mvn2's config """ Default mobilenetv2 setting in more readable format
self.cfgs_orig = {'arch_def': [['ds_r1_k3_s1_c16'], t - exp factor, c - channels, n - number of block repeats, s - stride
['ir_r2_k3_s2_e6_c24'], t, c, n, s
['ir_r3_k3_s2_e6_c32'], [1, 16, 1, 1],
['ir_r4_k3_s2_e6_c64'], [6, 24, 2, 2],
['ir_r3_k3_s1_e6_c96'], [6, 32, 3, 2],
['ir_r3_k3_s2_e6_c160'], [6, 64, 4, 2],
['ir_r1_k3_s1_e6_c320']], [6, 96, 3, 1],
'channel_multiplier': 1.00, [6, 160, 3, 2],
'depth_multiplier': 1.00} [6, 320, 1, 1]"""
self.cfgs_orig1 = {'arch_def': [['ds_r1_k3_s1_c16'], #set up a few models configs with variable depth and width
['ir_r2_k3_s2_e6_c24'], self.cfgs_orig = [
['ir_r3_k3_s2_e6_c32'], {
['ir_r4_k3_s2_e6_c64'], 'arch_def': [
['ir_r3_k3_s1_e6_c96'], ['ds_r1_k3_s1_c16'],
['ir_r3_k3_s2_e6_c160'], ['ir_r2_k3_s2_e6_c24'],
['ir_r1_k3_s1_e6_c320']], ['ir_r3_k3_s2_e6_c32'],
'channel_multiplier': 0.75, ['ir_r4_k3_s2_e6_c64'],
'depth_multiplier': 0.75} ['ir_r3_k3_s1_e6_c96'],
self.cfgs_orig2 = {'arch_def': [['ds_r1_k3_s1_c16'], ['ir_r3_k3_s2_e6_c160'],
['ir_r2_k3_s2_e6_c24'], ['ir_r1_k3_s1_e6_c320']
['ir_r3_k3_s2_e6_c32'], ],
['ir_r4_k3_s2_e6_c64'], 'channel_multiplier': 1.00,
['ir_r3_k3_s1_e6_c96'], 'depth_multiplier': 1.00
['ir_r3_k3_s2_e6_c160'], },
['ir_r1_k3_s1_e6_c320']], {
'channel_multiplier': 0.5, 'arch_def': [
'depth_multiplier': 0.5} ['ds_r1_k3_s1_c16'],
self.cfgs_orig3 = {'arch_def': [['ds_r1_k3_s1_c16'], ['ir_r2_k3_s2_e6_c24'],
['ir_r2_k3_s2_e6_c24'], ['ir_r3_k3_s2_e6_c32'],
['ir_r3_k3_s2_e6_c32'], ['ir_r4_k3_s2_e6_c64'],
['ir_r4_k3_s2_e6_c64'], ['ir_r3_k3_s1_e6_c96'],
['ir_r3_k3_s1_e6_c96'], ['ir_r3_k3_s2_e6_c160'],
['ir_r3_k3_s2_e6_c160'], ['ir_r1_k3_s1_e6_c320']
['ir_r1_k3_s1_e6_c320']], ],
'channel_multiplier': 1.25, 'channel_multiplier': 0.75,
'depth_multiplier': 1.25} 'depth_multiplier': 0.75
self.cfgs_orig_el0 = {'arch_def': [['ds_r1_k3_s1_e1_c16'], },
['ir_r2_k3_s2_e6_c24'], {
['ir_r2_k5_s2_e6_c40'], 'arch_def': [
['ir_r3_k3_s2_e6_c80'], ['ds_r1_k3_s1_c16'],
['ir_r3_k5_s1_e6_c112'], ['ir_r2_k3_s2_e6_c24'],
['ir_r4_k5_s2_e6_c192'], ['ir_r3_k3_s2_e6_c32'],
['ir_r1_k3_s1_e6_c320']], ['ir_r4_k3_s2_e6_c64'],
'channel_multiplier': 1.00, ['ir_r3_k3_s1_e6_c96'],
'depth_multiplier': 1.00} ['ir_r3_k3_s2_e6_c160'],
['ir_r1_k3_s1_e6_c320']
],
'channel_multiplier': 0.5,
'depth_multiplier': 0.5
},
{
'arch_def': [
['ds_r1_k3_s1_c16'],
['ir_r2_k3_s2_e6_c24'],
['ir_r3_k3_s2_e6_c32'],
['ir_r4_k3_s2_e6_c64'],
['ir_r3_k3_s1_e6_c96'],
['ir_r3_k3_s2_e6_c160'],
['ir_r1_k3_s1_e6_c320']
],
'channel_multiplier': 1.25,
'depth_multiplier': 1.25
}
]
self.config_all = {} self.config_all = {}
self.arch_counter= 0 self.arch_counter= 0
self.num_classes = num_classes self.num_classes = num_classes
@ -187,31 +163,22 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
self.k_range = tuple(args.k_range) self.k_range = tuple(args.k_range)
self.channel_mult_range = tuple(args.channel_mult_range) self.channel_mult_range = tuple(args.channel_mult_range)
self.depth_mult_range = tuple(args.depth_mult_range) self.depth_mult_range = tuple(args.depth_mult_range)
#make each run deterministic
torch.manual_seed(0) torch.manual_seed(0)
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
@overrides @overrides
def random_sample(self)->ArchaiModel: def random_sample(self)->ArchaiModel:
''' Uniform random sample an architecture, always start with the original model''' ''' Uniform random sample an architecture, always start with a few seed models'''
if (self.arch_counter== 0): if (self.arch_counter >= 0 and self.arch_counter <= 3):
cfg = copy.deepcopy(self.cfgs_orig) cfg = copy.deepcopy(self.cfgs_orig[self.arch_counter])
arch = self._create_uniq_arch(cfg)
elif (self.arch_counter== 1):
cfg = copy.deepcopy(self.cfgs_orig1)
arch = self._create_uniq_arch(cfg)
elif (self.arch_counter== 2):
cfg = copy.deepcopy(self.cfgs_orig2)
arch = self._create_uniq_arch(cfg)
elif (self.arch_counter== 3):
cfg = copy.deepcopy(self.cfgs_orig3)
arch = self._create_uniq_arch(cfg) arch = self._create_uniq_arch(cfg)
else: else:
#del cfg [block]
arch = None arch = None
while (arch == None) : while (arch == None) :
cfg = self._rand_modify_config(self.cfgs_orig, len(self.e_range), len(self.r_range), len(self.k_range), cfg = self._rand_modify_config(self.cfgs_orig[0], len(self.e_range), len(self.r_range), len(self.k_range),
len(self.channel_mult_range), len(self.depth_mult_range)) len(self.channel_mult_range), len(self.depth_mult_range))
arch = self._create_uniq_arch(cfg) arch = self._create_uniq_arch(cfg)
assert (arch != None) assert (arch != None)
@ -291,12 +258,11 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
return cfg return cfg
def _create_uniq_arch(self, cfg): def _create_uniq_arch(self, cfg):
"""create a unique arch from the config"""
cfg_str = json.dumps(cfg) cfg_str = json.dumps(cfg)
archid = sha1(cfg_str.encode('ascii')).hexdigest()[0:8] archid = sha1(cfg_str.encode('ascii')).hexdigest()[0:8]
if cfg_str in list(self.config_all.values()): if cfg_str in list(self.config_all.values()):
#return None
print(f"Creating duplicated model: {cfg_str} ") print(f"Creating duplicated model: {cfg_str} ")
else : else :
self.config_all[archid] = cfg_str self.config_all[archid] = cfg_str
@ -310,8 +276,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
return arch return arch
class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpace, BayesOptSearchSpace): class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpace, BayesOptSearchSpace):
''' We are subclassing CNNSearchSpace just to save up space''' """Search space for config search"""
@overrides @overrides
def mutate(self, model_1: ArchaiModel) -> ArchaiModel: def mutate(self, model_1: ArchaiModel) -> ArchaiModel:
@ -349,33 +314,27 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}") logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}")
return ArchaiModel(arch=arch, archid=arch.archid, metadata={'config' : arch.metadata}) return ArchaiModel(arch=arch, archid=arch.archid, metadata={'config' : arch.metadata})
@overrides @overrides
def encode(self, model: ArchaiModel) -> np.ndarray: def encode(self, model: ArchaiModel) -> np.ndarray:
#TBD #TBD, this is not needed for this implementation
assert (False) assert (False)
if __name__ == "__main__": if __name__ == "__main__":
"""unit test for this module"""
from torchinfo import summary from torchinfo import summary
img_size = 192 img_size = 128
def create_random_model(ss): def create_random_model(ss):
arch = ss.random_sample() arch = ss.random_sample()
model = arch.arch model = arch.arch
#print(type(model))
#print(isinstance(model, torch.nn.Module))
#make sure it works
model.to('cpu').eval() model.to('cpu').eval()
pred = model(torch.randn(1, 3, img_size, img_size)) pred = model(torch.randn(1, 3, img_size, img_size))
model_summary =summary(model, input_size=(1, 3, img_size, img_size), col_names=['input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'], device='cpu') model_summary =summary(model, input_size=(1, 3, img_size, img_size), col_names=['input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'], device='cpu')
#print(model_summary)
#print(model.summary) #model_desc_builder's
return arch return arch
#conf = common.common_init(config_filepath= '../nas_landmarks_darts.yaml')
ss = DiscreteSearchSpaceMobileNetV2() ss = DiscreteSearchSpaceMobileNetV2()
for i in range(0, 2): for i in range(0, 2):
archai_model = create_random_model(ss) archai_model = create_random_model(ss)

Просмотреть файл

@ -15,13 +15,13 @@ import os
import time import time
import warnings import warnings
import presets #import presets
import torch import torch
import torch.utils.data import torch.utils.data
import torchvision import torchvision
import transforms import transforms
import utils import utils
from sampler import RASampler #from sampler import RASampler
from torch import nn from torch import nn
from torch.utils.data.dataloader import default_collate from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
@ -221,10 +221,10 @@ def load_data(traindir, valdir, args):
print("Creating data loaders") print("Creating data loaders")
if args.distributed: if args.distributed:
if hasattr(args, "ra_sampler") and args.ra_sampler: # if hasattr(args, "ra_sampler") and args.ra_sampler:
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps) # train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
else: # else:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else: else:
train_sampler = torch.utils.data.RandomSampler(dataset) train_sampler = torch.utils.data.RandomSampler(dataset)
@ -562,3 +562,56 @@ def get_args_parser(add_help=True):
if __name__ == "__main__": if __name__ == "__main__":
args, _ = get_args_parser().parse_known_args() args, _ = get_args_parser().parse_known_args()
train(args) train(args)
###To be moved to trainder
"""
model = _create_model_from_csv (
nas.search_args.nas_finalize_archid,
nas.search_args.nas_finalize_models_csv,
num_classes=NUM_LANDMARK_CLASSES)
print(f'Loading weights from {str(nas.search_args.nas_finalize_pretrained_weight_file)}')
if (not nas.search_args.nas_load_nonqat_weights):
if (nas.search_args.nas_finalize_pretrained_weight_file is not None) :
model = _load_pretrain_weight(nas.search_args.nas_finalize_pretrained_weight_file, model)
if (nas.search_args.nas_use_tvmodel):
model.classifier = torch.nn.Sequential(torch.nn.Dropout(0.2), torch.nn.Linear(model.last_channel, NUM_LANDMARK_CLASSES))
# Load pretrained weights after fixing classifier as the weights match the exact network architecture
if (nas.search_args.nas_load_nonqat_weights):
assert os.path.exists(nas.search_args.nas_finalize_pretrained_weight_file)
print(f'Loading weights from previous non-QAT training {nas.search_args.nas_finalize_pretrained_weight_file}')
model.load_state_dict(torch.load(nas.search_args.nas_finalize_pretrained_weight_file))
val_error = model_trainer.train(nas.trainer_args, model)
print(f"Final validation error for model {nas.search_args.nas_finalize_archid}: {val_error}")
"""
""" TBD: need to move to trainer? Or delete.
def _create_model_from_csv(archid, csv_file : str, num_classes, use_tvmodel:bool=False, qat:bool=False) :
csv_path = Path(csv_file)
assert csv_path.exists()
df0 = pd.read_csv(csv_path)#.query('metric == @metric')
row = df0[df0['archid'] == archid]
cfg = json.loads(row['config'].to_list()[0])
# Ignore number of classes for now. The classifier layer will be rebuilt after loading pretrained weights
# kwargs.pop('num_classes', None)
# wchen: This doesn't seem to work. _load_pretrain_weight already pops the state_dict so it should be safe to fix the num_classes for now.
model = _gen_tv_mobilenet(cfg['arch_def'],
channel_multiplier=cfg['channel_multiplier'],
depth_multiplier=cfg['depth_multiplier'],
num_classes=num_classes)
return model
def _load_pretrain_weight(weight_file: str, model) :
print("=> loading pretrained weight '{}'".format(weight_file))
assert path.isfile(weight_file)
source_state = torch.load(weight_file)
state_dict = source_state['state_dict']
state_dict.pop('classifier' + '.weight', None)
state_dict.pop('classifier' + '.bias', None)
model.load_state_dict(state_dict, strict = False)
return model
"""

Просмотреть файл

@ -1,13 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math import math
from typing import Tuple from typing import Tuple
import torch
from torch import Tensor
from torchvision.transforms import functional as F
from torchvision.transforms import ToTensor, Compose
import numpy as np
import cv2 import cv2
import numpy as np
import torch
from torch import Tensor
from torchvision.transforms import Compose, ToTensor
from torchvision.transforms import functional as F
class Sample(): class Sample():
"""A sample of an image and its landmarks.""" """A sample of an image and its landmarks."""

Просмотреть файл

@ -1,3 +1,4 @@
"""This is from torchvision source code"""
import copy import copy
import datetime import datetime
import errno import errno

Просмотреть файл

@ -1,3 +1,5 @@
#to be removed before merging
"""This module contains methods and callbacks for visualizing training or validation data.""" """This module contains methods and callbacks for visualizing training or validation data."""
from heapq import heapify, heappush, heappushpop from heapq import heapify, heappush, heappushpop