This commit is contained in:
XinYuan-believe 2021-06-18 20:41:33 +08:00
Родитель 2a3336c708
Коммит fb419dae70
3 изменённых файлов: 35 добавлений и 52 удалений

43
test.py
Просмотреть файл

@ -2,53 +2,46 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import os import os
import cv2
import sys import sys
import yaml
import torch
import open3d
import argparse import argparse
import pprint
from utils.misc import set_logger
from configs.base_config import cfg, cfg_from_file, cfg_update
def get_args_from_command_line(): def get_args_from_command_line():
""" """
config the parameter config the parameter
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="The argument parser of R2Net runner")
description="The argument parser of R2Net runner")
# choose model # choose model
parser.add_argument("--model", type=str, default="sparenet", parser.add_argument("--model", type=str, default="sparenet", help="sparenet, atlasnet, msn, grnet")
help="sparenet, atlasnet, msn, grnet")
# choose test mode # choose test mode
parser.add_argument("--test_mode", default="default", parser.add_argument("--test_mode", default="default", help="default vis render kitti", type=str)
help="default vis render kitti", type=str)
# choose load model # choose load model
parser.add_argument("--weights", dest="weights", parser.add_argument("--weights", dest="weights", help="Initialize network from the weights file", default=None)
help="Initialize network from the weights file", default=None)
# setup gpu # setup gpu
parser.add_argument("--gpu", dest="gpu_id", parser.add_argument("--gpu", dest="gpu_id", help="GPU device to use", default="0", type=str)
help="GPU device to use", default="0", type=str)
# setup workdir # setup workdir
parser.add_argument("--workdir", dest="workdir", parser.add_argument("--workdir", dest="workdir", help="where to save files", default="./output", type=str)
help="where to save files", default="./output", type=str)
# choose train mode # choose train mode
parser.add_argument("--gan", dest="gan", help="use gan", parser.add_argument("--gan", dest="gan", help="use gan", action="store_true", default=False)
action="store_true", default=False)
return parser.parse_args() return parser.parse_args()
def main(): def main():
# update config
args = get_args_from_command_line() args = get_args_from_command_line()
# Set GPU to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
# update config
from configs.base_config import cfg, cfg_from_file, cfg_update
if args.gan: if args.gan:
cfg_from_file("configs/" + args.model + "_gan.yaml") cfg_from_file("configs/" + args.model + "_gan.yaml")
else: else:
@ -60,13 +53,11 @@ def main():
# Set up folders for logs and checkpoints # Set up folders for logs and checkpoints
if not os.path.exists(cfg.DIR.logs): if not os.path.exists(cfg.DIR.logs):
os.makedirs(cfg.DIR.logs) os.makedirs(cfg.DIR.logs)
from utils.misc import set_logger
logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt")) logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt"))
logger.info("save into dir: %s" % cfg.DIR.logs) logger.info("save into dir: %s" % cfg.DIR.logs)
# Set GPU to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.device
if "weights" not in cfg.CONST or not os.path.exists(cfg.CONST.weights): if "weights" not in cfg.CONST or not os.path.exists(cfg.CONST.weights):
logger.error("Please specify the file path of checkpoint.") logger.error("Please specify the file path of checkpoint.")
sys.exit(2) sys.exit(2)

Просмотреть файл

@ -2,49 +2,42 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import os import os
import cv2
import sys
import yaml
import torch
import open3d
import argparse import argparse
import pprint
from utils.misc import set_logger
from configs.base_config import cfg, cfg_from_file, cfg_update
def get_args_from_command_line(): def get_args_from_command_line():
""" """
config the parameter config the parameter
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="The argument parser of R2Net runner")
description="The argument parser of R2Net runner")
# choose model # choose model
parser.add_argument("--model", type=str, default="sparenet", parser.add_argument("--model", type=str, default="sparenet", help="sparenet, atlasnet, msn, grnet")
help="sparenet, atlasnet, msn, grnet")
# choose train mode # choose train mode
parser.add_argument("--gan", dest="gan", help="use gan", parser.add_argument("--gan", dest="gan", help="use gan", action="store_true", default=False)
action="store_true", default=False)
# choose load model # choose load model
parser.add_argument("--weights", dest="weights", parser.add_argument("--weights", dest="weights", help="Initialize network from the weights file", default=None)
help="Initialize network from the weights file", default=None)
# setup gpu # setup gpu
parser.add_argument("--gpu", dest="gpu_id", parser.add_argument("--gpu", dest="gpu_id", help="GPU device to use", default="0", type=str)
help="GPU device to use", default="0", type=str)
# setup workdir # setup workdir
parser.add_argument("--workdir", dest="workdir", parser.add_argument("--workdir", dest="workdir", help="where to save files", default=None)
help="where to save files", default=None)
return parser.parse_args() return parser.parse_args()
def main(): def main():
# update config
args = get_args_from_command_line() args = get_args_from_command_line()
# Set GPU to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
# update config
from configs.base_config import cfg, cfg_from_file, cfg_update
if args.gan: if args.gan:
cfg_from_file("configs/" + args.model + "_gan.yaml") cfg_from_file("configs/" + args.model + "_gan.yaml")
else: else:
@ -54,13 +47,11 @@ def main():
# Set up folders for logs and checkpoints # Set up folders for logs and checkpoints
if not os.path.exists(cfg.DIR.logs): if not os.path.exists(cfg.DIR.logs):
os.makedirs(cfg.DIR.logs) os.makedirs(cfg.DIR.logs)
from utils.misc import set_logger
logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt")) logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt"))
logger.info("save into dir: %s" % cfg.DIR.logs) logger.info("save into dir: %s" % cfg.DIR.logs)
# Set GPU to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.device
# Start train/inference process # Start train/inference process
if args.gan: if args.gan:
runners = __import__("runners." + args.model + "_gan_runner") runners = __import__("runners." + args.model + "_gan_runner")

Просмотреть файл

@ -32,7 +32,8 @@ def gpu_init(cfg):
os.makedirs(cfg.DIR.checkpoints) os.makedirs(cfg.DIR.checkpoints)
# GPU setup # GPU setup
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
return [int(x) for x in cfg.CONST.device.split(",")] gup_ids = [int(x) for x in cfg.CONST.device.split(",")]
return list(range(len(gup_ids)))
def writer_init(cfg): def writer_init(cfg):