зеркало из https://github.com/microsoft/archai.git
Further clean up. Add copyright statements.
This commit is contained in:
Родитель
424d37e791
Коммит
8352dd39f0
|
@ -1,4 +1,5 @@
|
|||
"""Datasets for Microsoft Face Synthetics dataset."""
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import io
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
|
|
@ -1,18 +1,13 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
#from torchvision.ops import Conv2dNormActivation
|
||||
from torchvision.models._utils import _make_divisible
|
||||
|
||||
import warnings
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
#from ..utils import _log_api_usage_once
|
||||
def _log_api_usage_once(obj: Any) -> None:
|
||||
|
||||
"""
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
import torch
|
||||
from torchvision.transforms import autoaugment, transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
class ClassificationPresetTrain:
|
||||
def __init__(
|
||||
self,
|
||||
crop_size,
|
||||
mean=(0.485, 0.456, 0.406),
|
||||
std=(0.229, 0.224, 0.225),
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
hflip_prob=0.5,
|
||||
auto_augment_policy=None,
|
||||
random_erase_prob=0.0,
|
||||
):
|
||||
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
||||
if hflip_prob > 0:
|
||||
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
|
||||
if auto_augment_policy is not None:
|
||||
if auto_augment_policy == "ra":
|
||||
trans.append(autoaugment.RandAugment(interpolation=interpolation))
|
||||
elif auto_augment_policy == "ta_wide":
|
||||
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
|
||||
else:
|
||||
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
|
||||
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
|
||||
trans.extend(
|
||||
[
|
||||
transforms.PILToTensor(),
|
||||
transforms.ConvertImageDtype(torch.float),
|
||||
transforms.Normalize(mean=mean, std=std),
|
||||
]
|
||||
)
|
||||
if random_erase_prob > 0:
|
||||
trans.append(transforms.RandomErasing(p=random_erase_prob))
|
||||
|
||||
self.transforms = transforms.Compose(trans)
|
||||
|
||||
def __call__(self, img):
|
||||
return self.transforms(img)
|
||||
|
||||
|
||||
class ClassificationPresetEval:
|
||||
def __init__(
|
||||
self,
|
||||
crop_size,
|
||||
resize_size=256,
|
||||
mean=(0.485, 0.456, 0.406),
|
||||
std=(0.229, 0.224, 0.225),
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
):
|
||||
|
||||
self.transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(resize_size, interpolation=interpolation),
|
||||
transforms.CenterCrop(crop_size),
|
||||
transforms.PILToTensor(),
|
||||
transforms.ConvertImageDtype(torch.float),
|
||||
transforms.Normalize(mean=mean, std=std),
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, img):
|
||||
return self.transforms(img)
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class RASampler(torch.utils.data.Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
||||
with repeated augmentation.
|
||||
It ensures that different each augmented version of a sample will be visible to a
|
||||
different process (GPU).
|
||||
Heavily based on 'torch.utils.data.DistributedSampler'.
|
||||
|
||||
This is borrowed from the DeiT Repo:
|
||||
https://github.com/facebookresearch/deit/blob/main/samplers.py
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available!")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available!")
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.repetitions = repetitions
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
# Deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# Add extra samples to make it evenly divisible
|
||||
indices = [ele for ele in indices for i in range(self.repetitions)]
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# Subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices[: self.num_selected_samples])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_selected_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
|
@ -1,5 +1,8 @@
|
|||
import sys
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sys
|
||||
# to be removed before merging
|
||||
if ('--debug' in sys.argv):
|
||||
import debugpy
|
||||
debugpy.listen(5678)
|
||||
|
@ -149,30 +152,5 @@ def _main() -> None:
|
|||
search = SearchFaceLandmarkModels()
|
||||
search.search()
|
||||
|
||||
###To be moved to trainder
|
||||
"""
|
||||
model = _create_model_from_csv (
|
||||
nas.search_args.nas_finalize_archid,
|
||||
nas.search_args.nas_finalize_models_csv,
|
||||
num_classes=NUM_LANDMARK_CLASSES)
|
||||
|
||||
print(f'Loading weights from {str(nas.search_args.nas_finalize_pretrained_weight_file)}')
|
||||
if (not nas.search_args.nas_load_nonqat_weights):
|
||||
if (nas.search_args.nas_finalize_pretrained_weight_file is not None) :
|
||||
model = _load_pretrain_weight(nas.search_args.nas_finalize_pretrained_weight_file, model)
|
||||
|
||||
if (nas.search_args.nas_use_tvmodel):
|
||||
model.classifier = torch.nn.Sequential(torch.nn.Dropout(0.2), torch.nn.Linear(model.last_channel, NUM_LANDMARK_CLASSES))
|
||||
|
||||
# Load pretrained weights after fixing classifier as the weights match the exact network architecture
|
||||
if (nas.search_args.nas_load_nonqat_weights):
|
||||
assert os.path.exists(nas.search_args.nas_finalize_pretrained_weight_file)
|
||||
print(f'Loading weights from previous non-QAT training {nas.search_args.nas_finalize_pretrained_weight_file}')
|
||||
model.load_state_dict(torch.load(nas.search_args.nas_finalize_pretrained_weight_file))
|
||||
|
||||
val_error = model_trainer.train(nas.trainer_args, model)
|
||||
print(f"Final validation error for model {nas.search_args.nas_finalize_archid}: {val_error}")
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
|
@ -26,8 +26,8 @@ depth_mult_range: [0.25, 0.5, 0.75, 1.0, 1.25]
|
|||
# model trainer args
|
||||
data_path: /data/public_face_synthetics/dataset_100000
|
||||
output_dir: /home/wchen/tmp
|
||||
max_num_images: 20000
|
||||
#max_num_images: 1000
|
||||
#max_num_images: 20000
|
||||
max_num_images: 1000
|
||||
train_crop_size: 128
|
||||
epochs: 30
|
||||
batch_size: 128
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
"""Search space for facial landmark detection task"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
|
@ -5,38 +7,27 @@ import random
|
|||
import re
|
||||
import sys
|
||||
from hashlib import sha1
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from overrides.overrides import overrides
|
||||
|
||||
from archai.common.common import logger
|
||||
from archai.discrete_search.api.archai_model import ArchaiModel
|
||||
from archai.discrete_search.api.search_space import DiscreteSearchSpace, BayesOptSearchSpace, EvolutionarySearchSpace
|
||||
from archai.discrete_search.api.search_space import (
|
||||
BayesOptSearchSpace,
|
||||
DiscreteSearchSpace,
|
||||
EvolutionarySearchSpace,
|
||||
)
|
||||
|
||||
from model import CustomMobileNetV2
|
||||
|
||||
|
||||
def _gen_tv_mobilenet ( arch_def,
|
||||
channel_multiplier=1.0,
|
||||
depth_multiplier=1.0,
|
||||
num_classes=1000):
|
||||
# default mbv2 setting
|
||||
# t - exp factor, c - channels, n - number of block repeats, s - stride
|
||||
# # t, c, n, s
|
||||
# [1, 16, 1, 1],
|
||||
# [6, 24, 2, 2],
|
||||
# [6, 32, 3, 2],
|
||||
# [6, 64, 4, 2],
|
||||
# [6, 96, 3, 1],
|
||||
# [6, 160, 3, 2],
|
||||
# [6, 320, 1, 1],
|
||||
# archid 0a7b6 - {"arch_def": [["ds_r1_k3_s1_c16"], ["ir_r2_k3_s2_e6_c24"], ["ir_r3_k3_s2_e6_c32"], ["ir_r4_k3_s2_e6_c64"], ["ir_r2_k3_s1_e4_c96"], ["ir_r2_k3_s2_e6_c160"], ["ir_r1_k3_s1_e5_c320"]], "channel_multiplier": 1.0, "depth_multiplier": 0.75}
|
||||
"""generate mobilenet v2 from torchvision. Adapted from timm source code"""
|
||||
ir_setting = []
|
||||
for block_def in arch_def:
|
||||
parts = block_def[0].split("_")
|
||||
|
@ -83,44 +74,9 @@ def _gen_tv_mobilenet ( arch_def,
|
|||
model = CustomMobileNetV2(inverted_residual_setting=ir_setting, dropout=0, num_classes=num_classes)
|
||||
|
||||
return model
|
||||
# return mobilenet_v2(quantize=False,
|
||||
# inverted_residual_setting = [
|
||||
# [1, 16, 1, 1],
|
||||
# [6, 24, 2, 2],
|
||||
# [6, 32, 3, 2],
|
||||
# [6, 64, 3, 2],
|
||||
# [4, 96, 2, 1],
|
||||
# [6, 160, 2, 2],
|
||||
# [5, 320, 1, 1],
|
||||
# ])
|
||||
|
||||
def _create_model_from_csv(archid, csv_file : str, num_classes, use_tvmodel:bool=False, qat:bool=False) :
|
||||
csv_path = Path(csv_file)
|
||||
assert csv_path.exists()
|
||||
df0 = pd.read_csv(csv_path)#.query('metric == @metric')
|
||||
row = df0[df0['archid'] == archid]
|
||||
cfg = json.loads(row['config'].to_list()[0])
|
||||
|
||||
# Ignore number of classes for now. The classifier layer will be rebuilt after loading pretrained weights
|
||||
# kwargs.pop('num_classes', None)
|
||||
# wchen: This doesn't seem to work. _load_pretrain_weight already pops the state_dict so it should be safe to fix the num_classes for now.
|
||||
model = _gen_tv_mobilenet(cfg['arch_def'],
|
||||
channel_multiplier=cfg['channel_multiplier'],
|
||||
depth_multiplier=cfg['depth_multiplier'],
|
||||
num_classes=num_classes)
|
||||
return model
|
||||
|
||||
def _load_pretrain_weight(weight_file: str, model) :
|
||||
print("=> loading pretrained weight '{}'".format(weight_file))
|
||||
assert path.isfile(weight_file)
|
||||
source_state = torch.load(weight_file)
|
||||
state_dict = source_state['state_dict']
|
||||
state_dict.pop('classifier' + '.weight', None)
|
||||
state_dict.pop('classifier' + '.bias', None)
|
||||
model.load_state_dict(state_dict, strict = False)
|
||||
return model
|
||||
|
||||
class ConfigSearchModel(nn.Module):
|
||||
class ConfigSearchModel(torch.nn.Module):
|
||||
"""This is a wrapper class to allow the model to be used in the search space"""
|
||||
def __init__(self, model : ArchaiModel, archid: str, metadata : dict):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
@ -133,52 +89,72 @@ class ConfigSearchModel(nn.Module):
|
|||
class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
||||
def __init__(self, args, num_classes=140):
|
||||
super().__init__()
|
||||
##mvn2's config
|
||||
self.cfgs_orig = {'arch_def': [['ds_r1_k3_s1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r3_k3_s2_e6_c32'],
|
||||
['ir_r4_k3_s2_e6_c64'],
|
||||
['ir_r3_k3_s1_e6_c96'],
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
['ir_r1_k3_s1_e6_c320']],
|
||||
'channel_multiplier': 1.00,
|
||||
'depth_multiplier': 1.00}
|
||||
self.cfgs_orig1 = {'arch_def': [['ds_r1_k3_s1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r3_k3_s2_e6_c32'],
|
||||
['ir_r4_k3_s2_e6_c64'],
|
||||
['ir_r3_k3_s1_e6_c96'],
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
['ir_r1_k3_s1_e6_c320']],
|
||||
'channel_multiplier': 0.75,
|
||||
'depth_multiplier': 0.75}
|
||||
self.cfgs_orig2 = {'arch_def': [['ds_r1_k3_s1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r3_k3_s2_e6_c32'],
|
||||
['ir_r4_k3_s2_e6_c64'],
|
||||
['ir_r3_k3_s1_e6_c96'],
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
['ir_r1_k3_s1_e6_c320']],
|
||||
'channel_multiplier': 0.5,
|
||||
'depth_multiplier': 0.5}
|
||||
self.cfgs_orig3 = {'arch_def': [['ds_r1_k3_s1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r3_k3_s2_e6_c32'],
|
||||
['ir_r4_k3_s2_e6_c64'],
|
||||
['ir_r3_k3_s1_e6_c96'],
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
['ir_r1_k3_s1_e6_c320']],
|
||||
'channel_multiplier': 1.25,
|
||||
'depth_multiplier': 1.25}
|
||||
self.cfgs_orig_el0 = {'arch_def': [['ds_r1_k3_s1_e1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r2_k5_s2_e6_c40'],
|
||||
['ir_r3_k3_s2_e6_c80'],
|
||||
['ir_r3_k5_s1_e6_c112'],
|
||||
['ir_r4_k5_s2_e6_c192'],
|
||||
['ir_r1_k3_s1_e6_c320']],
|
||||
'channel_multiplier': 1.00,
|
||||
'depth_multiplier': 1.00}
|
||||
""" Default mobilenetv2 setting in more readable format
|
||||
t - exp factor, c - channels, n - number of block repeats, s - stride
|
||||
t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1]"""
|
||||
#set up a few models configs with variable depth and width
|
||||
self.cfgs_orig = [
|
||||
{
|
||||
'arch_def': [
|
||||
['ds_r1_k3_s1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r3_k3_s2_e6_c32'],
|
||||
['ir_r4_k3_s2_e6_c64'],
|
||||
['ir_r3_k3_s1_e6_c96'],
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
['ir_r1_k3_s1_e6_c320']
|
||||
],
|
||||
'channel_multiplier': 1.00,
|
||||
'depth_multiplier': 1.00
|
||||
},
|
||||
{
|
||||
'arch_def': [
|
||||
['ds_r1_k3_s1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r3_k3_s2_e6_c32'],
|
||||
['ir_r4_k3_s2_e6_c64'],
|
||||
['ir_r3_k3_s1_e6_c96'],
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
['ir_r1_k3_s1_e6_c320']
|
||||
],
|
||||
'channel_multiplier': 0.75,
|
||||
'depth_multiplier': 0.75
|
||||
},
|
||||
{
|
||||
'arch_def': [
|
||||
['ds_r1_k3_s1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r3_k3_s2_e6_c32'],
|
||||
['ir_r4_k3_s2_e6_c64'],
|
||||
['ir_r3_k3_s1_e6_c96'],
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
['ir_r1_k3_s1_e6_c320']
|
||||
],
|
||||
'channel_multiplier': 0.5,
|
||||
'depth_multiplier': 0.5
|
||||
},
|
||||
{
|
||||
'arch_def': [
|
||||
['ds_r1_k3_s1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r3_k3_s2_e6_c32'],
|
||||
['ir_r4_k3_s2_e6_c64'],
|
||||
['ir_r3_k3_s1_e6_c96'],
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
['ir_r1_k3_s1_e6_c320']
|
||||
],
|
||||
'channel_multiplier': 1.25,
|
||||
'depth_multiplier': 1.25
|
||||
}
|
||||
]
|
||||
|
||||
self.config_all = {}
|
||||
self.arch_counter= 0
|
||||
self.num_classes = num_classes
|
||||
|
@ -187,31 +163,22 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
|||
self.k_range = tuple(args.k_range)
|
||||
self.channel_mult_range = tuple(args.channel_mult_range)
|
||||
self.depth_mult_range = tuple(args.depth_mult_range)
|
||||
#make each run deterministic
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
@overrides
|
||||
def random_sample(self)->ArchaiModel:
|
||||
''' Uniform random sample an architecture, always start with the original model'''
|
||||
''' Uniform random sample an architecture, always start with a few seed models'''
|
||||
|
||||
if (self.arch_counter== 0):
|
||||
cfg = copy.deepcopy(self.cfgs_orig)
|
||||
arch = self._create_uniq_arch(cfg)
|
||||
elif (self.arch_counter== 1):
|
||||
cfg = copy.deepcopy(self.cfgs_orig1)
|
||||
arch = self._create_uniq_arch(cfg)
|
||||
elif (self.arch_counter== 2):
|
||||
cfg = copy.deepcopy(self.cfgs_orig2)
|
||||
arch = self._create_uniq_arch(cfg)
|
||||
elif (self.arch_counter== 3):
|
||||
cfg = copy.deepcopy(self.cfgs_orig3)
|
||||
if (self.arch_counter >= 0 and self.arch_counter <= 3):
|
||||
cfg = copy.deepcopy(self.cfgs_orig[self.arch_counter])
|
||||
arch = self._create_uniq_arch(cfg)
|
||||
else:
|
||||
#del cfg [block]
|
||||
arch = None
|
||||
while (arch == None) :
|
||||
cfg = self._rand_modify_config(self.cfgs_orig, len(self.e_range), len(self.r_range), len(self.k_range),
|
||||
cfg = self._rand_modify_config(self.cfgs_orig[0], len(self.e_range), len(self.r_range), len(self.k_range),
|
||||
len(self.channel_mult_range), len(self.depth_mult_range))
|
||||
arch = self._create_uniq_arch(cfg)
|
||||
assert (arch != None)
|
||||
|
@ -291,12 +258,11 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
|||
return cfg
|
||||
|
||||
def _create_uniq_arch(self, cfg):
|
||||
|
||||
"""create a unique arch from the config"""
|
||||
cfg_str = json.dumps(cfg)
|
||||
archid = sha1(cfg_str.encode('ascii')).hexdigest()[0:8]
|
||||
|
||||
if cfg_str in list(self.config_all.values()):
|
||||
#return None
|
||||
print(f"Creating duplicated model: {cfg_str} ")
|
||||
else :
|
||||
self.config_all[archid] = cfg_str
|
||||
|
@ -310,8 +276,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
|||
return arch
|
||||
|
||||
class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpace, BayesOptSearchSpace):
|
||||
''' We are subclassing CNNSearchSpace just to save up space'''
|
||||
|
||||
"""Search space for config search"""
|
||||
@overrides
|
||||
def mutate(self, model_1: ArchaiModel) -> ArchaiModel:
|
||||
|
||||
|
@ -352,30 +317,24 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
|
|||
|
||||
@overrides
|
||||
def encode(self, model: ArchaiModel) -> np.ndarray:
|
||||
#TBD
|
||||
#TBD, this is not needed for this implementation
|
||||
assert (False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""unit test for this module"""
|
||||
from torchinfo import summary
|
||||
img_size = 192
|
||||
img_size = 128
|
||||
def create_random_model(ss):
|
||||
|
||||
arch = ss.random_sample()
|
||||
model = arch.arch
|
||||
#print(type(model))
|
||||
#print(isinstance(model, torch.nn.Module))
|
||||
|
||||
#make sure it works
|
||||
model.to('cpu').eval()
|
||||
pred = model(torch.randn(1, 3, img_size, img_size))
|
||||
|
||||
model_summary =summary(model, input_size=(1, 3, img_size, img_size), col_names=['input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'], device='cpu')
|
||||
#print(model_summary)
|
||||
|
||||
#print(model.summary) #model_desc_builder's
|
||||
return arch
|
||||
|
||||
#conf = common.common_init(config_filepath= '../nas_landmarks_darts.yaml')
|
||||
ss = DiscreteSearchSpaceMobileNetV2()
|
||||
for i in range(0, 2):
|
||||
archai_model = create_random_model(ss)
|
||||
|
|
|
@ -15,13 +15,13 @@ import os
|
|||
import time
|
||||
import warnings
|
||||
|
||||
import presets
|
||||
#import presets
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision
|
||||
import transforms
|
||||
import utils
|
||||
from sampler import RASampler
|
||||
#from sampler import RASampler
|
||||
from torch import nn
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
@ -221,10 +221,10 @@ def load_data(traindir, valdir, args):
|
|||
|
||||
print("Creating data loaders")
|
||||
if args.distributed:
|
||||
if hasattr(args, "ra_sampler") and args.ra_sampler:
|
||||
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
|
||||
else:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
# if hasattr(args, "ra_sampler") and args.ra_sampler:
|
||||
# train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
|
||||
# else:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
|
||||
else:
|
||||
train_sampler = torch.utils.data.RandomSampler(dataset)
|
||||
|
@ -562,3 +562,56 @@ def get_args_parser(add_help=True):
|
|||
if __name__ == "__main__":
|
||||
args, _ = get_args_parser().parse_known_args()
|
||||
train(args)
|
||||
|
||||
###To be moved to trainder
|
||||
"""
|
||||
model = _create_model_from_csv (
|
||||
nas.search_args.nas_finalize_archid,
|
||||
nas.search_args.nas_finalize_models_csv,
|
||||
num_classes=NUM_LANDMARK_CLASSES)
|
||||
|
||||
print(f'Loading weights from {str(nas.search_args.nas_finalize_pretrained_weight_file)}')
|
||||
if (not nas.search_args.nas_load_nonqat_weights):
|
||||
if (nas.search_args.nas_finalize_pretrained_weight_file is not None) :
|
||||
model = _load_pretrain_weight(nas.search_args.nas_finalize_pretrained_weight_file, model)
|
||||
|
||||
if (nas.search_args.nas_use_tvmodel):
|
||||
model.classifier = torch.nn.Sequential(torch.nn.Dropout(0.2), torch.nn.Linear(model.last_channel, NUM_LANDMARK_CLASSES))
|
||||
|
||||
# Load pretrained weights after fixing classifier as the weights match the exact network architecture
|
||||
if (nas.search_args.nas_load_nonqat_weights):
|
||||
assert os.path.exists(nas.search_args.nas_finalize_pretrained_weight_file)
|
||||
print(f'Loading weights from previous non-QAT training {nas.search_args.nas_finalize_pretrained_weight_file}')
|
||||
model.load_state_dict(torch.load(nas.search_args.nas_finalize_pretrained_weight_file))
|
||||
|
||||
val_error = model_trainer.train(nas.trainer_args, model)
|
||||
print(f"Final validation error for model {nas.search_args.nas_finalize_archid}: {val_error}")
|
||||
"""
|
||||
|
||||
""" TBD: need to move to trainer? Or delete.
|
||||
def _create_model_from_csv(archid, csv_file : str, num_classes, use_tvmodel:bool=False, qat:bool=False) :
|
||||
csv_path = Path(csv_file)
|
||||
assert csv_path.exists()
|
||||
df0 = pd.read_csv(csv_path)#.query('metric == @metric')
|
||||
row = df0[df0['archid'] == archid]
|
||||
cfg = json.loads(row['config'].to_list()[0])
|
||||
|
||||
# Ignore number of classes for now. The classifier layer will be rebuilt after loading pretrained weights
|
||||
# kwargs.pop('num_classes', None)
|
||||
# wchen: This doesn't seem to work. _load_pretrain_weight already pops the state_dict so it should be safe to fix the num_classes for now.
|
||||
model = _gen_tv_mobilenet(cfg['arch_def'],
|
||||
channel_multiplier=cfg['channel_multiplier'],
|
||||
depth_multiplier=cfg['depth_multiplier'],
|
||||
num_classes=num_classes)
|
||||
return model
|
||||
|
||||
def _load_pretrain_weight(weight_file: str, model) :
|
||||
print("=> loading pretrained weight '{}'".format(weight_file))
|
||||
assert path.isfile(weight_file)
|
||||
source_state = torch.load(weight_file)
|
||||
state_dict = source_state['state_dict']
|
||||
state_dict.pop('classifier' + '.weight', None)
|
||||
state_dict.pop('classifier' + '.bias', None)
|
||||
model.load_state_dict(state_dict, strict = False)
|
||||
return model
|
||||
"""
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
from torchvision.transforms import ToTensor, Compose
|
||||
import numpy as np
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch import Tensor
|
||||
from torchvision.transforms import Compose, ToTensor
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
class Sample():
|
||||
"""A sample of an image and its landmarks."""
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
"""This is from torchvision source code"""
|
||||
import copy
|
||||
import datetime
|
||||
import errno
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
#to be removed before merging
|
||||
|
||||
"""This module contains methods and callbacks for visualizing training or validation data."""
|
||||
|
||||
from heapq import heapify, heappush, heappushpop
|
||||
|
|
Загрузка…
Ссылка в новой задаче