зеркало из https://github.com/microsoft/archai.git
Fixed more flake8 warnings
This commit is contained in:
Родитель
b5d00d3098
Коммит
0c2ebec20b
|
@ -188,7 +188,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
|||
arch = self._create_uniq_arch(cfg)
|
||||
else:
|
||||
arch = None
|
||||
while arch == None:
|
||||
while arch is None:
|
||||
cfg = self._rand_modify_config(
|
||||
self.cfgs_orig[0],
|
||||
len(self.e_range),
|
||||
|
@ -198,7 +198,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
|||
len(self.depth_mult_range),
|
||||
)
|
||||
arch = self._create_uniq_arch(cfg)
|
||||
assert arch != None
|
||||
assert arch is not None
|
||||
|
||||
logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}")
|
||||
|
||||
|
@ -232,7 +232,7 @@ class DiscreteSearchSpaceMobileNetV2(DiscreteSearchSpace):
|
|||
block_cfg = cfg["arch_def"][block][0]
|
||||
|
||||
res = re.search(rf"_{type}(\d)_", block_cfg)
|
||||
if res != None:
|
||||
if res is not None:
|
||||
curr = res.group(1)
|
||||
curr_idx = curr_range.index(int(curr))
|
||||
mod_range = curr_range[max(0, curr_idx - delta) : min(len(curr_range), curr_idx + delta + 1)]
|
||||
|
@ -304,7 +304,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
|
|||
cfg_1 = json.loads(model_1.metadata["config"])
|
||||
|
||||
arch = None
|
||||
while arch == None:
|
||||
while arch is None:
|
||||
cfg = self._rand_modify_config(
|
||||
cfg_1,
|
||||
len(self.e_range),
|
||||
|
@ -314,7 +314,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
|
|||
len(self.depth_mult_range),
|
||||
)
|
||||
arch = self._create_uniq_arch(cfg)
|
||||
assert arch != None
|
||||
assert arch is not None
|
||||
logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}")
|
||||
|
||||
return arch
|
||||
|
@ -329,7 +329,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
|
|||
cfg = copy.deepcopy(cfg_1)
|
||||
|
||||
arch = None
|
||||
while arch == None:
|
||||
while arch is None:
|
||||
for block in range(2, len(cfg["arch_def"])):
|
||||
cfg["arch_def"][block] = random.choice((cfg_1["arch_def"][block], cfg_2["arch_def"][block]))
|
||||
|
||||
|
@ -337,7 +337,7 @@ class ConfigSearchSpaceExt(DiscreteSearchSpaceMobileNetV2, EvolutionarySearchSpa
|
|||
cfg["depth_multiplier"] = random.choice((cfg_1["depth_multiplier"], cfg_2["depth_multiplier"]))
|
||||
|
||||
arch = self._create_uniq_arch(cfg)
|
||||
assert arch != None
|
||||
assert arch is not None
|
||||
logger.info(f"{sys._getframe(0).f_code.co_name} return archid = {arch.archid} with config = {arch.metadata}")
|
||||
|
||||
return arch
|
||||
|
@ -360,12 +360,14 @@ if __name__ == "__main__":
|
|||
|
||||
model.to("cpu").eval()
|
||||
pred = model(torch.randn(1, 3, img_size, img_size))
|
||||
assert pred is not None
|
||||
model_summary = summary(
|
||||
model,
|
||||
input_size=(1, 3, img_size, img_size),
|
||||
col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"],
|
||||
device="cpu",
|
||||
)
|
||||
print(model_summary)
|
||||
|
||||
return arch
|
||||
|
||||
|
|
|
@ -120,10 +120,10 @@ def evaluate(model, criterion, data_loader, epoch, device, print_freq=100, log_s
|
|||
return float(metric_logger.error.global_avg)
|
||||
|
||||
|
||||
def load_data(traindir, valdir, args):
|
||||
def load_data(traindir, args):
|
||||
# Data loading code
|
||||
print("Loading data")
|
||||
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
|
||||
_, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
|
||||
|
||||
print("Loading training data")
|
||||
st = time.time()
|
||||
|
@ -163,11 +163,8 @@ def train(args, model: nn.Module = None):
|
|||
else:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
train_dir = os.path.join(args.data_path, "train")
|
||||
val_dir = os.path.join(args.data_path, "val")
|
||||
dataset, dataset_test, train_sampler, test_sampler = load_data(args.data_path, val_dir, args)
|
||||
dataset, dataset_test, train_sampler, test_sampler = load_data(args.data_path, args)
|
||||
|
||||
collate_fn = None
|
||||
num_classes = dataset.dataset.num_landmarks
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
|
|
Загрузка…
Ссылка в новой задаче