This commit is contained in:
Wei-ge Chen 2023-04-28 11:22:15 -07:00
Родитель b5d00d3098
Коммит 0c2ebec20b
2 изменённых файлов: 12 добавлений и 13 удалений

Просмотреть файл

@ -188,7 +188,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
arch = self._create_uniq_arch(cfg) arch = self._create_uniq_arch(cfg)
else: else:
arch = None arch = None
while arch == None: while arch is None:
cfg = self._rand_modify_config( cfg = self._rand_modify_config(
self.cfgs_orig[0], self.cfgs_orig[0],
len(self.e_range), len(self.e_range),
@ -198,7 +198,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
len(self.depth_mult_range), len(self.depth_mult_range),
) )
arch = self._create_uniq_arch(cfg) arch = self._create_uniq_arch(cfg)
assert arch != None assert arch is not None
logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}") logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}")
@ -232,7 +232,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
block_cfg = cfg["arch_def"][block][0] block_cfg = cfg["arch_def"][block][0]
res = re.search(rf"_{type}(\d)_", block_cfg) res = re.search(rf"_{type}(\d)_", block_cfg)
if res != None: if res is not None:
curr = res.group(1) curr = res.group(1)
curr_idx = curr_range.index(int(curr)) curr_idx = curr_range.index(int(curr))
mod_range = curr_range[max(0, curr_idx - delta) : min(len(curr_range), curr_idx + delta + 1)] mod_range = curr_range[max(0, curr_idx - delta) : min(len(curr_range), curr_idx + delta + 1)]
@ -304,7 +304,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
cfg_1 = json.loads(model_1.metadata["config"]) cfg_1 = json.loads(model_1.metadata["config"])
arch = None arch = None
while arch == None: while arch is None:
cfg = self._rand_modify_config( cfg = self._rand_modify_config(
cfg_1, cfg_1,
len(self.e_range), len(self.e_range),
@ -314,7 +314,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
len(self.depth_mult_range), len(self.depth_mult_range),
) )
arch = self._create_uniq_arch(cfg) arch = self._create_uniq_arch(cfg)
assert arch != None assert arch is not None
logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}") logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}")
return arch return arch
@ -329,7 +329,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
cfg = copy.deepcopy(cfg_1) cfg = copy.deepcopy(cfg_1)
arch = None arch = None
while arch == None: while arch is None:
for block in range(2, len(cfg["arch_def"])): for block in range(2, len(cfg["arch_def"])):
cfg["arch_def"][block] = random.choice((cfg_1["arch_def"][block], cfg_2["arch_def"][block])) cfg["arch_def"][block] = random.choice((cfg_1["arch_def"][block], cfg_2["arch_def"][block]))
@ -337,7 +337,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
cfg["depth_multiplier"] = random.choice((cfg_1["depth_multiplier"], cfg_2["depth_multiplier"])) cfg["depth_multiplier"] = random.choice((cfg_1["depth_multiplier"], cfg_2["depth_multiplier"]))
arch = self._create_uniq_arch(cfg) arch = self._create_uniq_arch(cfg)
assert arch != None assert arch is not None
logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}") logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}")
return arch return arch
@ -360,12 +360,14 @@ if __name__ == "__main__":
model.to("cpu").eval() model.to("cpu").eval()
pred = model(torch.randn(1, 3, img_size, img_size)) pred = model(torch.randn(1, 3, img_size, img_size))
assert pred is not None
model_summary = summary( model_summary = summary(
model, model,
input_size=(1, 3, img_size, img_size), input_size=(1, 3, img_size, img_size),
col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"], col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"],
device="cpu", device="cpu",
) )
print(model_summary)
return arch return arch

Просмотреть файл

@ -120,10 +120,10 @@ def evaluate(model, criterion, data_loader, epoch, device, print_freq=100, log_s
return float(metric_logger.error.global_avg) return float(metric_logger.error.global_avg)
def load_data(traindir, valdir, args): def load_data(traindir, args):
# Data loading code # Data loading code
print("Loading data") print("Loading data")
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size _, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
print("Loading training data") print("Loading training data")
st = time.time() st = time.time()
@ -163,11 +163,8 @@ def train(args, model: nn.Module = None):
else: else:
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
train_dir = os.path.join(args.data_path, "train") dataset, dataset_test, train_sampler, test_sampler = load_data(args.data_path, args)
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(args.data_path, val_dir, args)
collate_fn = None
num_classes = dataset.dataset.num_landmarks num_classes = dataset.dataset.num_landmarks
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, dataset,