зеркало из https://github.com/microsoft/archai.git
Further clean up. Add copyright statements.
This commit is contained in:
Родитель
424d37e791
Коммит
8352dd39f0
|
@ -1,4 +1,5 @@
|
||||||
"""Datasets for Microsoft Face Synthetics dataset."""
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
|
@ -1,18 +1,13 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import warnings
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
|
|
||||||
#from torchvision.ops import Conv2dNormActivation
|
|
||||||
from torchvision.models._utils import _make_divisible
|
from torchvision.models._utils import _make_divisible
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Callable, List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
#from ..utils import _log_api_usage_once
|
|
||||||
def _log_api_usage_once(obj: Any) -> None:
|
def _log_api_usage_once(obj: Any) -> None:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
import torch
|
|
||||||
from torchvision.transforms import autoaugment, transforms
|
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
|
||||||
|
|
||||||
class ClassificationPresetTrain:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
crop_size,
|
|
||||||
mean=(0.485, 0.456, 0.406),
|
|
||||||
std=(0.229, 0.224, 0.225),
|
|
||||||
interpolation=InterpolationMode.BILINEAR,
|
|
||||||
hflip_prob=0.5,
|
|
||||||
auto_augment_policy=None,
|
|
||||||
random_erase_prob=0.0,
|
|
||||||
):
|
|
||||||
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
|
||||||
if hflip_prob > 0:
|
|
||||||
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
|
|
||||||
if auto_augment_policy is not None:
|
|
||||||
if auto_augment_policy == "ra":
|
|
||||||
trans.append(autoaugment.RandAugment(interpolation=interpolation))
|
|
||||||
elif auto_augment_policy == "ta_wide":
|
|
||||||
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
|
|
||||||
else:
|
|
||||||
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
|
|
||||||
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
|
|
||||||
trans.extend(
|
|
||||||
[
|
|
||||||
transforms.PILToTensor(),
|
|
||||||
transforms.ConvertImageDtype(torch.float),
|
|
||||||
transforms.Normalize(mean=mean, std=std),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if random_erase_prob > 0:
|
|
||||||
trans.append(transforms.RandomErasing(p=random_erase_prob))
|
|
||||||
|
|
||||||
self.transforms = transforms.Compose(trans)
|
|
||||||
|
|
||||||
def __call__(self, img):
|
|
||||||
return self.transforms(img)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationPresetEval:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
crop_size,
|
|
||||||
resize_size=256,
|
|
||||||
mean=(0.485, 0.456, 0.406),
|
|
||||||
std=(0.229, 0.224, 0.225),
|
|
||||||
interpolation=InterpolationMode.BILINEAR,
|
|
||||||
):
|
|
||||||
|
|
||||||
self.transforms = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.Resize(resize_size, interpolation=interpolation),
|
|
||||||
transforms.CenterCrop(crop_size),
|
|
||||||
transforms.PILToTensor(),
|
|
||||||
transforms.ConvertImageDtype(torch.float),
|
|
||||||
transforms.Normalize(mean=mean, std=std),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, img):
|
|
||||||
return self.transforms(img)
|
|
||||||
|
|
|
@ -1,62 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
|
|
||||||
class RASampler(torch.utils.data.Sampler):
|
|
||||||
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
|
||||||
with repeated augmentation.
|
|
||||||
It ensures that different each augmented version of a sample will be visible to a
|
|
||||||
different process (GPU).
|
|
||||||
Heavily based on 'torch.utils.data.DistributedSampler'.
|
|
||||||
|
|
||||||
This is borrowed from the DeiT Repo:
|
|
||||||
https://github.com/facebookresearch/deit/blob/main/samplers.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
|
|
||||||
if num_replicas is None:
|
|
||||||
if not dist.is_available():
|
|
||||||
raise RuntimeError("Requires distributed package to be available!")
|
|
||||||
num_replicas = dist.get_world_size()
|
|
||||||
if rank is None:
|
|
||||||
if not dist.is_available():
|
|
||||||
raise RuntimeError("Requires distributed package to be available!")
|
|
||||||
rank = dist.get_rank()
|
|
||||||
self.dataset = dataset
|
|
||||||
self.num_replicas = num_replicas
|
|
||||||
self.rank = rank
|
|
||||||
self.epoch = 0
|
|
||||||
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
|
||||||
self.total_size = self.num_samples * self.num_replicas
|
|
||||||
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
|
||||||
self.shuffle = shuffle
|
|
||||||
self.seed = seed
|
|
||||||
self.repetitions = repetitions
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
if self.shuffle:
|
|
||||||
# Deterministically shuffle based on epoch
|
|
||||||
g = torch.Generator()
|
|
||||||
g.manual_seed(self.seed + self.epoch)
|
|
||||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
||||||
else:
|
|
||||||
indices = list(range(len(self.dataset)))
|
|
||||||
|
|
||||||
# Add extra samples to make it evenly divisible
|
|
||||||
indices = [ele for ele in indices for i in range(self.repetitions)]
|
|
||||||
indices += indices[: (self.total_size - len(indices))]
|
|
||||||
assert len(indices) == self.total_size
|
|
||||||
|
|
||||||
# Subsample
|
|
||||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
||||||
assert len(indices) == self.num_samples
|
|
||||||
|
|
||||||
return iter(indices[: self.num_selected_samples])
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_selected_samples
|
|
||||||
|
|
||||||
def set_epoch(self, epoch):
|
|
||||||
self.epoch = epoch
|
|
|
@ -1,5 +1,8 @@
|
||||||
import sys
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
# to be removed before merging
|
||||||
if ('--debug' in sys.argv):
|
if ('--debug' in sys.argv):
|
||||||
import debugpy
|
import debugpy
|
||||||
debugpy.listen(5678)
|
debugpy.listen(5678)
|
||||||
|
@ -149,30 +152,5 @@ def _main() -> None:
|
||||||
search = SearchFaceLandmarkModels()
|
search = SearchFaceLandmarkModels()
|
||||||
search.search()
|
search.search()
|
||||||
|
|
||||||
###To be moved to trainder
|
|
||||||
"""
|
|
||||||
model = _create_model_from_csv (
|
|
||||||
nas.search_args.nas_finalize_archid,
|
|
||||||
nas.search_args.nas_finalize_models_csv,
|
|
||||||
num_classes=NUM_LANDMARK_CLASSES)
|
|
||||||
|
|
||||||
print(f'Loading weights from {str(nas.search_args.nas_finalize_pretrained_weight_file)}')
|
|
||||||
if (not nas.search_args.nas_load_nonqat_weights):
|
|
||||||
if (nas.search_args.nas_finalize_pretrained_weight_file is not None) :
|
|
||||||
model = _load_pretrain_weight(nas.search_args.nas_finalize_pretrained_weight_file, model)
|
|
||||||
|
|
||||||
if (nas.search_args.nas_use_tvmodel):
|
|
||||||
model.classifier = torch.nn.Sequential(torch.nn.Dropout(0.2), torch.nn.Linear(model.last_channel, NUM_LANDMARK_CLASSES))
|
|
||||||
|
|
||||||
# Load pretrained weights after fixing classifier as the weights match the exact network architecture
|
|
||||||
if (nas.search_args.nas_load_nonqat_weights):
|
|
||||||
assert os.path.exists(nas.search_args.nas_finalize_pretrained_weight_file)
|
|
||||||
print(f'Loading weights from previous non-QAT training {nas.search_args.nas_finalize_pretrained_weight_file}')
|
|
||||||
model.load_state_dict(torch.load(nas.search_args.nas_finalize_pretrained_weight_file))
|
|
||||||
|
|
||||||
val_error = model_trainer.train(nas.trainer_args, model)
|
|
||||||
print(f"Final validation error for model {nas.search_args.nas_finalize_archid}: {val_error}")
|
|
||||||
"""
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_main()
|
_main()
|
|
@ -26,8 +26,8 @@ depth_mult_range: [0.25, 0.5, 0.75, 1.0, 1.25]
|
||||||
# model trainer args
|
# model trainer args
|
||||||
data_path: /data/public_face_synthetics/dataset_100000
|
data_path: /data/public_face_synthetics/dataset_100000
|
||||||
output_dir: /home/wchen/tmp
|
output_dir: /home/wchen/tmp
|
||||||
max_num_images: 20000
|
#max_num_images: 20000
|
||||||
#max_num_images: 1000
|
max_num_images: 1000
|
||||||
train_crop_size: 128
|
train_crop_size: 128
|
||||||
epochs: 30
|
epochs: 30
|
||||||
batch_size: 128
|
batch_size: 128
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
"""Search space for facial landmark detection task"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
@ -5,38 +7,27 @@ import random
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from os import path
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from overrides.overrides import overrides
|
from overrides.overrides import overrides
|
||||||
|
|
||||||
from archai.common.common import logger
|
from archai.common.common import logger
|
||||||
from archai.discrete_search.api.archai_model import ArchaiModel
|
from archai.discrete_search.api.archai_model import ArchaiModel
|
||||||
from archai.discrete_search.api.search_space import DiscreteSearchSpace, BayesOptSearchSpace, EvolutionarySearchSpace
|
from archai.discrete_search.api.search_space import (
|
||||||
|
BayesOptSearchSpace,
|
||||||
|
DiscreteSearchSpace,
|
||||||
|
EvolutionarySearchSpace,
|
||||||
|
)
|
||||||
|
|
||||||
from model import CustomMobileNetV2
|
from model import CustomMobileNetV2
|
||||||
|
|
||||||
|
|
||||||
def _gen_tv_mobilenet ( arch_def,
|
def _gen_tv_mobilenet ( arch_def,
|
||||||
channel_multiplier=1.0,
|
channel_multiplier=1.0,
|
||||||
depth_multiplier=1.0,
|
depth_multiplier=1.0,
|
||||||
num_classes=1000):
|
num_classes=1000):
|
||||||
# default mbv2 setting
|
"""generate mobilenet v2 from torchvision. Adapted from timm source code"""
|
||||||
# t - exp factor, c - channels, n - number of block repeats, s - stride
|
|
||||||
# # t, c, n, s
|
|
||||||
# [1, 16, 1, 1],
|
|
||||||
# [6, 24, 2, 2],
|
|
||||||
# [6, 32, 3, 2],
|
|
||||||
# [6, 64, 4, 2],
|
|
||||||
# [6, 96, 3, 1],
|
|
||||||
# [6, 160, 3, 2],
|
|
||||||
# [6, 320, 1, 1],
|
|
||||||
# archid 0a7b6 - {"arch_def": [["ds_r1_k3_s1_c16"], ["ir_r2_k3_s2_e6_c24"], ["ir_r3_k3_s2_e6_c32"], ["ir_r4_k3_s2_e6_c64"], ["ir_r2_k3_s1_e4_c96"], ["ir_r2_k3_s2_e6_c160"], ["ir_r1_k3_s1_e5_c320"]], "channel_multiplier": 1.0, "depth_multiplier": 0.75}
|
|
||||||
ir_setting = []
|
ir_setting = []
|
||||||
for block_def in arch_def:
|
for block_def in arch_def:
|
||||||
parts = block_def[0].split("_")
|
parts = block_def[0].split("_")
|
||||||
|
@ -83,44 +74,9 @@ def _gen_tv_mobilenet ( arch_def,
|
||||||
model = CustomMobileNetV2(inverted_residual_setting=ir_setting, dropout=0, num_classes=num_classes)
|
model = CustomMobileNetV2(inverted_residual_setting=ir_setting, dropout=0, num_classes=num_classes)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
# return mobilenet_v2(quantize=False,
|
|
||||||
# inverted_residual_setting = [
|
|
||||||
# [1, 16, 1, 1],
|
|
||||||
# [6, 24, 2, 2],
|
|
||||||
# [6, 32, 3, 2],
|
|
||||||
# [6, 64, 3, 2],
|
|
||||||
# [4, 96, 2, 1],
|
|
||||||
# [6, 160, 2, 2],
|
|
||||||
# [5, 320, 1, 1],
|
|
||||||
# ])
|
|
||||||
|
|
||||||
def _create_model_from_csv(archid, csv_file : str, num_classes, use_tvmodel:bool=False, qat:bool=False) :
|
class ConfigSearchModel(torch.nn.Module):
|
||||||
csv_path = Path(csv_file)
|
"""This is a wrapper class to allow the model to be used in the search space"""
|
||||||
assert csv_path.exists()
|
|
||||||
df0 = pd.read_csv(csv_path)#.query('metric == @metric')
|
|
||||||
row = df0[df0['archid'] == archid]
|
|
||||||
cfg = json.loads(row['config'].to_list()[0])
|
|
||||||
|
|
||||||
# Ignore number of classes for now. The classifier layer will be rebuilt after loading pretrained weights
|
|
||||||
# kwargs.pop('num_classes', None)
|
|
||||||
# wchen: This doesn't seem to work. _load_pretrain_weight already pops the state_dict so it should be safe to fix the num_classes for now.
|
|
||||||
model = _gen_tv_mobilenet(cfg['arch_def'],
|
|
||||||
channel_multiplier=cfg['channel_multiplier'],
|
|
||||||
depth_multiplier=cfg['depth_multiplier'],
|
|
||||||
num_classes=num_classes)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _load_pretrain_weight(weight_file: str, model) :
|
|
||||||
print("=> loading pretrained weight '{}'".format(weight_file))
|
|
||||||
assert path.isfile(weight_file)
|
|
||||||
source_state = torch.load(weight_file)
|
|
||||||
state_dict = source_state['state_dict']
|
|
||||||
state_dict.pop('classifier' + '.weight', None)
|
|
||||||
state_dict.pop('classifier' + '.bias', None)
|
|
||||||
model.load_state_dict(state_dict, strict = False)
|
|
||||||
return model
|
|
||||||
|
|
||||||
class ConfigSearchModel(nn.Module):
|
|
||||||
def __init__(self, model : ArchaiModel, archid: str, metadata : dict):
|
def __init__(self, model : ArchaiModel, archid: str, metadata : dict):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -133,52 +89,72 @@ class ConfigSearchModel(nn.Module):
|
||||||
class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
||||||
def __init__(self, args, num_classes=140):
|
def __init__(self, args, num_classes=140):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
##mvn2's config
|
""" Default mobilenetv2 setting in more readable format
|
||||||
self.cfgs_orig = {'arch_def': [['ds_r1_k3_s1_c16'],
|
t - exp factor, c - channels, n - number of block repeats, s - stride
|
||||||
|
t, c, n, s
|
||||||
|
[1, 16, 1, 1],
|
||||||
|
[6, 24, 2, 2],
|
||||||
|
[6, 32, 3, 2],
|
||||||
|
[6, 64, 4, 2],
|
||||||
|
[6, 96, 3, 1],
|
||||||
|
[6, 160, 3, 2],
|
||||||
|
[6, 320, 1, 1]"""
|
||||||
|
#set up a few models configs with variable depth and width
|
||||||
|
self.cfgs_orig = [
|
||||||
|
{
|
||||||
|
'arch_def': [
|
||||||
|
['ds_r1_k3_s1_c16'],
|
||||||
['ir_r2_k3_s2_e6_c24'],
|
['ir_r2_k3_s2_e6_c24'],
|
||||||
['ir_r3_k3_s2_e6_c32'],
|
['ir_r3_k3_s2_e6_c32'],
|
||||||
['ir_r4_k3_s2_e6_c64'],
|
['ir_r4_k3_s2_e6_c64'],
|
||||||
['ir_r3_k3_s1_e6_c96'],
|
['ir_r3_k3_s1_e6_c96'],
|
||||||
['ir_r3_k3_s2_e6_c160'],
|
['ir_r3_k3_s2_e6_c160'],
|
||||||
['ir_r1_k3_s1_e6_c320']],
|
['ir_r1_k3_s1_e6_c320']
|
||||||
|
],
|
||||||
'channel_multiplier': 1.00,
|
'channel_multiplier': 1.00,
|
||||||
'depth_multiplier': 1.00}
|
'depth_multiplier': 1.00
|
||||||
self.cfgs_orig1 = {'arch_def': [['ds_r1_k3_s1_c16'],
|
},
|
||||||
|
{
|
||||||
|
'arch_def': [
|
||||||
|
['ds_r1_k3_s1_c16'],
|
||||||
['ir_r2_k3_s2_e6_c24'],
|
['ir_r2_k3_s2_e6_c24'],
|
||||||
['ir_r3_k3_s2_e6_c32'],
|
['ir_r3_k3_s2_e6_c32'],
|
||||||
['ir_r4_k3_s2_e6_c64'],
|
['ir_r4_k3_s2_e6_c64'],
|
||||||
['ir_r3_k3_s1_e6_c96'],
|
['ir_r3_k3_s1_e6_c96'],
|
||||||
['ir_r3_k3_s2_e6_c160'],
|
['ir_r3_k3_s2_e6_c160'],
|
||||||
['ir_r1_k3_s1_e6_c320']],
|
['ir_r1_k3_s1_e6_c320']
|
||||||
|
],
|
||||||
'channel_multiplier': 0.75,
|
'channel_multiplier': 0.75,
|
||||||
'depth_multiplier': 0.75}
|
'depth_multiplier': 0.75
|
||||||
self.cfgs_orig2 = {'arch_def': [['ds_r1_k3_s1_c16'],
|
},
|
||||||
|
{
|
||||||
|
'arch_def': [
|
||||||
|
['ds_r1_k3_s1_c16'],
|
||||||
['ir_r2_k3_s2_e6_c24'],
|
['ir_r2_k3_s2_e6_c24'],
|
||||||
['ir_r3_k3_s2_e6_c32'],
|
['ir_r3_k3_s2_e6_c32'],
|
||||||
['ir_r4_k3_s2_e6_c64'],
|
['ir_r4_k3_s2_e6_c64'],
|
||||||
['ir_r3_k3_s1_e6_c96'],
|
['ir_r3_k3_s1_e6_c96'],
|
||||||
['ir_r3_k3_s2_e6_c160'],
|
['ir_r3_k3_s2_e6_c160'],
|
||||||
['ir_r1_k3_s1_e6_c320']],
|
['ir_r1_k3_s1_e6_c320']
|
||||||
|
],
|
||||||
'channel_multiplier': 0.5,
|
'channel_multiplier': 0.5,
|
||||||
'depth_multiplier': 0.5}
|
'depth_multiplier': 0.5
|
||||||
self.cfgs_orig3 = {'arch_def': [['ds_r1_k3_s1_c16'],
|
},
|
||||||
|
{
|
||||||
|
'arch_def': [
|
||||||
|
['ds_r1_k3_s1_c16'],
|
||||||
['ir_r2_k3_s2_e6_c24'],
|
['ir_r2_k3_s2_e6_c24'],
|
||||||
['ir_r3_k3_s2_e6_c32'],
|
['ir_r3_k3_s2_e6_c32'],
|
||||||
['ir_r4_k3_s2_e6_c64'],
|
['ir_r4_k3_s2_e6_c64'],
|
||||||
['ir_r3_k3_s1_e6_c96'],
|
['ir_r3_k3_s1_e6_c96'],
|
||||||
['ir_r3_k3_s2_e6_c160'],
|
['ir_r3_k3_s2_e6_c160'],
|
||||||
['ir_r1_k3_s1_e6_c320']],
|
['ir_r1_k3_s1_e6_c320']
|
||||||
|
],
|
||||||
'channel_multiplier': 1.25,
|
'channel_multiplier': 1.25,
|
||||||
'depth_multiplier': 1.25}
|
'depth_multiplier': 1.25
|
||||||
self.cfgs_orig_el0 = {'arch_def': [['ds_r1_k3_s1_e1_c16'],
|
}
|
||||||
['ir_r2_k3_s2_e6_c24'],
|
]
|
||||||
['ir_r2_k5_s2_e6_c40'],
|
|
||||||
['ir_r3_k3_s2_e6_c80'],
|
|
||||||
['ir_r3_k5_s1_e6_c112'],
|
|
||||||
['ir_r4_k5_s2_e6_c192'],
|
|
||||||
['ir_r1_k3_s1_e6_c320']],
|
|
||||||
'channel_multiplier': 1.00,
|
|
||||||
'depth_multiplier': 1.00}
|
|
||||||
self.config_all = {}
|
self.config_all = {}
|
||||||
self.arch_counter= 0
|
self.arch_counter= 0
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
@ -187,31 +163,22 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
||||||
self.k_range = tuple(args.k_range)
|
self.k_range = tuple(args.k_range)
|
||||||
self.channel_mult_range = tuple(args.channel_mult_range)
|
self.channel_mult_range = tuple(args.channel_mult_range)
|
||||||
self.depth_mult_range = tuple(args.depth_mult_range)
|
self.depth_mult_range = tuple(args.depth_mult_range)
|
||||||
|
#make each run deterministic
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def random_sample(self)->ArchaiModel:
|
def random_sample(self)->ArchaiModel:
|
||||||
''' Uniform random sample an architecture, always start with the original model'''
|
''' Uniform random sample an architecture, always start with a few seed models'''
|
||||||
|
|
||||||
if (self.arch_counter== 0):
|
if (self.arch_counter >= 0 and self.arch_counter <= 3):
|
||||||
cfg = copy.deepcopy(self.cfgs_orig)
|
cfg = copy.deepcopy(self.cfgs_orig[self.arch_counter])
|
||||||
arch = self._create_uniq_arch(cfg)
|
|
||||||
elif (self.arch_counter== 1):
|
|
||||||
cfg = copy.deepcopy(self.cfgs_orig1)
|
|
||||||
arch = self._create_uniq_arch(cfg)
|
|
||||||
elif (self.arch_counter== 2):
|
|
||||||
cfg = copy.deepcopy(self.cfgs_orig2)
|
|
||||||
arch = self._create_uniq_arch(cfg)
|
|
||||||
elif (self.arch_counter== 3):
|
|
||||||
cfg = copy.deepcopy(self.cfgs_orig3)
|
|
||||||
arch = self._create_uniq_arch(cfg)
|
arch = self._create_uniq_arch(cfg)
|
||||||
else:
|
else:
|
||||||
#del cfg [block]
|
|
||||||
arch = None
|
arch = None
|
||||||
while (arch == None) :
|
while (arch == None) :
|
||||||
cfg = self._rand_modify_config(self.cfgs_orig, len(self.e_range), len(self.r_range), len(self.k_range),
|
cfg = self._rand_modify_config(self.cfgs_orig[0], len(self.e_range), len(self.r_range), len(self.k_range),
|
||||||
len(self.channel_mult_range), len(self.depth_mult_range))
|
len(self.channel_mult_range), len(self.depth_mult_range))
|
||||||
arch = self._create_uniq_arch(cfg)
|
arch = self._create_uniq_arch(cfg)
|
||||||
assert (arch != None)
|
assert (arch != None)
|
||||||
|
@ -291,12 +258,11 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
def _create_uniq_arch(self, cfg):
|
def _create_uniq_arch(self, cfg):
|
||||||
|
"""create a unique arch from the config"""
|
||||||
cfg_str = json.dumps(cfg)
|
cfg_str = json.dumps(cfg)
|
||||||
archid = sha1(cfg_str.encode('ascii')).hexdigest()[0:8]
|
archid = sha1(cfg_str.encode('ascii')).hexdigest()[0:8]
|
||||||
|
|
||||||
if cfg_str in list(self.config_all.values()):
|
if cfg_str in list(self.config_all.values()):
|
||||||
#return None
|
|
||||||
print(f"Creating duplicated model: {cfg_str} ")
|
print(f"Creating duplicated model: {cfg_str} ")
|
||||||
else :
|
else :
|
||||||
self.config_all[archid] = cfg_str
|
self.config_all[archid] = cfg_str
|
||||||
|
@ -310,8 +276,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
||||||
return arch
|
return arch
|
||||||
|
|
||||||
class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpace, BayesOptSearchSpace):
|
class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpace, BayesOptSearchSpace):
|
||||||
''' We are subclassing CNNSearchSpace just to save up space'''
|
"""Search space for config search"""
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def mutate(self, model_1: ArchaiModel) -> ArchaiModel:
|
def mutate(self, model_1: ArchaiModel) -> ArchaiModel:
|
||||||
|
|
||||||
|
@ -352,30 +317,24 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def encode(self, model: ArchaiModel) -> np.ndarray:
|
def encode(self, model: ArchaiModel) -> np.ndarray:
|
||||||
#TBD
|
#TBD, this is not needed for this implementation
|
||||||
assert (False)
|
assert (False)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
"""unit test for this module"""
|
||||||
from torchinfo import summary
|
from torchinfo import summary
|
||||||
img_size = 192
|
img_size = 128
|
||||||
def create_random_model(ss):
|
def create_random_model(ss):
|
||||||
|
|
||||||
arch = ss.random_sample()
|
arch = ss.random_sample()
|
||||||
model = arch.arch
|
model = arch.arch
|
||||||
#print(type(model))
|
|
||||||
#print(isinstance(model, torch.nn.Module))
|
|
||||||
|
|
||||||
#make sure it works
|
|
||||||
model.to('cpu').eval()
|
model.to('cpu').eval()
|
||||||
pred = model(torch.randn(1, 3, img_size, img_size))
|
pred = model(torch.randn(1, 3, img_size, img_size))
|
||||||
|
|
||||||
model_summary =summary(model, input_size=(1, 3, img_size, img_size), col_names=['input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'], device='cpu')
|
model_summary =summary(model, input_size=(1, 3, img_size, img_size), col_names=['input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'], device='cpu')
|
||||||
#print(model_summary)
|
|
||||||
|
|
||||||
#print(model.summary) #model_desc_builder's
|
|
||||||
return arch
|
return arch
|
||||||
|
|
||||||
#conf = common.common_init(config_filepath= '../nas_landmarks_darts.yaml')
|
|
||||||
ss = DiscreteSearchSpaceMobileNetV2()
|
ss = DiscreteSearchSpaceMobileNetV2()
|
||||||
for i in range(0, 2):
|
for i in range(0, 2):
|
||||||
archai_model = create_random_model(ss)
|
archai_model = create_random_model(ss)
|
||||||
|
|
|
@ -15,13 +15,13 @@ import os
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import presets
|
#import presets
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torchvision
|
import torchvision
|
||||||
import transforms
|
import transforms
|
||||||
import utils
|
import utils
|
||||||
from sampler import RASampler
|
#from sampler import RASampler
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataloader import default_collate
|
from torch.utils.data.dataloader import default_collate
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
@ -221,9 +221,9 @@ def load_data(traindir, valdir, args):
|
||||||
|
|
||||||
print("Creating data loaders")
|
print("Creating data loaders")
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
if hasattr(args, "ra_sampler") and args.ra_sampler:
|
# if hasattr(args, "ra_sampler") and args.ra_sampler:
|
||||||
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
|
# train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
|
||||||
else:
|
# else:
|
||||||
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||||
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
|
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
|
||||||
else:
|
else:
|
||||||
|
@ -562,3 +562,56 @@ def get_args_parser(add_help=True):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args, _ = get_args_parser().parse_known_args()
|
args, _ = get_args_parser().parse_known_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
|
||||||
|
###To be moved to trainder
|
||||||
|
"""
|
||||||
|
model = _create_model_from_csv (
|
||||||
|
nas.search_args.nas_finalize_archid,
|
||||||
|
nas.search_args.nas_finalize_models_csv,
|
||||||
|
num_classes=NUM_LANDMARK_CLASSES)
|
||||||
|
|
||||||
|
print(f'Loading weights from {str(nas.search_args.nas_finalize_pretrained_weight_file)}')
|
||||||
|
if (not nas.search_args.nas_load_nonqat_weights):
|
||||||
|
if (nas.search_args.nas_finalize_pretrained_weight_file is not None) :
|
||||||
|
model = _load_pretrain_weight(nas.search_args.nas_finalize_pretrained_weight_file, model)
|
||||||
|
|
||||||
|
if (nas.search_args.nas_use_tvmodel):
|
||||||
|
model.classifier = torch.nn.Sequential(torch.nn.Dropout(0.2), torch.nn.Linear(model.last_channel, NUM_LANDMARK_CLASSES))
|
||||||
|
|
||||||
|
# Load pretrained weights after fixing classifier as the weights match the exact network architecture
|
||||||
|
if (nas.search_args.nas_load_nonqat_weights):
|
||||||
|
assert os.path.exists(nas.search_args.nas_finalize_pretrained_weight_file)
|
||||||
|
print(f'Loading weights from previous non-QAT training {nas.search_args.nas_finalize_pretrained_weight_file}')
|
||||||
|
model.load_state_dict(torch.load(nas.search_args.nas_finalize_pretrained_weight_file))
|
||||||
|
|
||||||
|
val_error = model_trainer.train(nas.trainer_args, model)
|
||||||
|
print(f"Final validation error for model {nas.search_args.nas_finalize_archid}: {val_error}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
""" TBD: need to move to trainer? Or delete.
|
||||||
|
def _create_model_from_csv(archid, csv_file : str, num_classes, use_tvmodel:bool=False, qat:bool=False) :
|
||||||
|
csv_path = Path(csv_file)
|
||||||
|
assert csv_path.exists()
|
||||||
|
df0 = pd.read_csv(csv_path)#.query('metric == @metric')
|
||||||
|
row = df0[df0['archid'] == archid]
|
||||||
|
cfg = json.loads(row['config'].to_list()[0])
|
||||||
|
|
||||||
|
# Ignore number of classes for now. The classifier layer will be rebuilt after loading pretrained weights
|
||||||
|
# kwargs.pop('num_classes', None)
|
||||||
|
# wchen: This doesn't seem to work. _load_pretrain_weight already pops the state_dict so it should be safe to fix the num_classes for now.
|
||||||
|
model = _gen_tv_mobilenet(cfg['arch_def'],
|
||||||
|
channel_multiplier=cfg['channel_multiplier'],
|
||||||
|
depth_multiplier=cfg['depth_multiplier'],
|
||||||
|
num_classes=num_classes)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load_pretrain_weight(weight_file: str, model) :
|
||||||
|
print("=> loading pretrained weight '{}'".format(weight_file))
|
||||||
|
assert path.isfile(weight_file)
|
||||||
|
source_state = torch.load(weight_file)
|
||||||
|
state_dict = source_state['state_dict']
|
||||||
|
state_dict.pop('classifier' + '.weight', None)
|
||||||
|
state_dict.pop('classifier' + '.bias', None)
|
||||||
|
model.load_state_dict(state_dict, strict = False)
|
||||||
|
return model
|
||||||
|
"""
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
from torchvision.transforms import functional as F
|
|
||||||
|
|
||||||
from torchvision.transforms import ToTensor, Compose
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torchvision.transforms import Compose, ToTensor
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
class Sample():
|
class Sample():
|
||||||
"""A sample of an image and its landmarks."""
|
"""A sample of an image and its landmarks."""
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
"""This is from torchvision source code"""
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
import errno
|
import errno
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
#to be removed before merging
|
||||||
|
|
||||||
"""This module contains methods and callbacks for visualizing training or validation data."""
|
"""This module contains methods and callbacks for visualizing training or validation data."""
|
||||||
|
|
||||||
from heapq import heapify, heappush, heappushpop
|
from heapq import heapify, heappush, heappushpop
|
||||||
|
|
Загрузка…
Ссылка в новой задаче