This commit is contained in:
Wei-ge Chen 2023-04-28 11:22:15 -07:00
Родитель b5d00d3098
Коммит 0c2ebec20b
2 изменённых файлов: 12 добавлений и 13 удалений

Просмотреть файл

@ -188,7 +188,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
arch = self._create_uniq_arch(cfg)
else:
arch = None
while arch == None:
while arch is None:
cfg = self._rand_modify_config(
self.cfgs_orig[0],
len(self.e_range),
@ -198,7 +198,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
len(self.depth_mult_range),
)
arch = self._create_uniq_arch(cfg)
assert arch != None
assert arch is not None
logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}")
@ -232,7 +232,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
block_cfg = cfg["arch_def"][block][0]
res = re.search(rf"_{type}(\d)_", block_cfg)
if res != None:
if res is not None:
curr = res.group(1)
curr_idx = curr_range.index(int(curr))
mod_range = curr_range[max(0, curr_idx - delta) : min(len(curr_range), curr_idx + delta + 1)]
@ -304,7 +304,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
cfg_1 = json.loads(model_1.metadata["config"])
arch = None
while arch == None:
while arch is None:
cfg = self._rand_modify_config(
cfg_1,
len(self.e_range),
@ -314,7 +314,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
len(self.depth_mult_range),
)
arch = self._create_uniq_arch(cfg)
assert arch != None
assert arch is not None
logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}")
return arch
@ -329,7 +329,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
cfg = copy.deepcopy(cfg_1)
arch = None
while arch == None:
while arch is None:
for block in range(2, len(cfg["arch_def"])):
cfg["arch_def"][block] = random.choice((cfg_1["arch_def"][block], cfg_2["arch_def"][block]))
@ -337,7 +337,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
cfg["depth_multiplier"] = random.choice((cfg_1["depth_multiplier"], cfg_2["depth_multiplier"]))
arch = self._create_uniq_arch(cfg)
assert arch != None
assert arch is not None
logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}")
return arch
@ -360,12 +360,14 @@ if __name__ == "__main__":
model.to("cpu").eval()
pred = model(torch.randn(1, 3, img_size, img_size))
assert pred is not None
model_summary = summary(
model,
input_size=(1, 3, img_size, img_size),
col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"],
device="cpu",
)
print(model_summary)
return arch

Просмотреть файл

@ -120,10 +120,10 @@ def evaluate(model, criterion, data_loader, epoch, device, print_freq=100, log_s
return float(metric_logger.error.global_avg)
def load_data(traindir, valdir, args):
def load_data(traindir, args):
# Data loading code
print("Loading data")
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
_, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
print("Loading training data")
st = time.time()
@ -163,11 +163,8 @@ def train(args, model: nn.Module = None):
else:
torch.backends.cudnn.benchmark = True
train_dir = os.path.join(args.data_path, "train")
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(args.data_path, val_dir, args)
dataset, dataset_test, train_sampler, test_sampler = load_data(args.data_path, args)
collate_fn = None
num_classes = dataset.dataset.num_landmarks
data_loader = torch.utils.data.DataLoader(
dataset,