This commit is contained in:
XinYuan-believe 2021-06-18 20:41:33 +08:00
Родитель 2a3336c708
Коммит fb419dae70
3 изменённых файлов: 35 добавлений и 52 удалений

43
test.py
Просмотреть файл

@ -2,53 +2,46 @@
# Licensed under the MIT License.
import os
import cv2
import sys
import yaml
import torch
import open3d
import argparse
import pprint
from utils.misc import set_logger
from configs.base_config import cfg, cfg_from_file, cfg_update
def get_args_from_command_line():
"""
config the parameter
"""
parser = argparse.ArgumentParser(
description="The argument parser of R2Net runner")
parser = argparse.ArgumentParser(description="The argument parser of R2Net runner")
# choose model
parser.add_argument("--model", type=str, default="sparenet",
help="sparenet, atlasnet, msn, grnet")
parser.add_argument("--model", type=str, default="sparenet", help="sparenet, atlasnet, msn, grnet")
# choose test mode
parser.add_argument("--test_mode", default="default",
help="default vis render kitti", type=str)
parser.add_argument("--test_mode", default="default", help="default vis render kitti", type=str)
# choose load model
parser.add_argument("--weights", dest="weights",
help="Initialize network from the weights file", default=None)
parser.add_argument("--weights", dest="weights", help="Initialize network from the weights file", default=None)
# setup gpu
parser.add_argument("--gpu", dest="gpu_id",
help="GPU device to use", default="0", type=str)
parser.add_argument("--gpu", dest="gpu_id", help="GPU device to use", default="0", type=str)
# setup workdir
parser.add_argument("--workdir", dest="workdir",
help="where to save files", default="./output", type=str)
parser.add_argument("--workdir", dest="workdir", help="where to save files", default="./output", type=str)
# choose train mode
parser.add_argument("--gan", dest="gan", help="use gan",
action="store_true", default=False)
parser.add_argument("--gan", dest="gan", help="use gan", action="store_true", default=False)
return parser.parse_args()
def main():
# update config
args = get_args_from_command_line()
# Set GPU to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
# update config
from configs.base_config import cfg, cfg_from_file, cfg_update
if args.gan:
cfg_from_file("configs/" + args.model + "_gan.yaml")
else:
@ -60,13 +53,11 @@ def main():
# Set up folders for logs and checkpoints
if not os.path.exists(cfg.DIR.logs):
os.makedirs(cfg.DIR.logs)
from utils.misc import set_logger
logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt"))
logger.info("save into dir: %s" % cfg.DIR.logs)
# Set GPU to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.device
if "weights" not in cfg.CONST or not os.path.exists(cfg.CONST.weights):
logger.error("Please specify the file path of checkpoint.")
sys.exit(2)

Просмотреть файл

@ -2,49 +2,42 @@
# Licensed under the MIT License.
import os
import cv2
import sys
import yaml
import torch
import open3d
import argparse
import pprint
from utils.misc import set_logger
from configs.base_config import cfg, cfg_from_file, cfg_update
def get_args_from_command_line():
"""
config the parameter
"""
parser = argparse.ArgumentParser(
description="The argument parser of R2Net runner")
parser = argparse.ArgumentParser(description="The argument parser of R2Net runner")
# choose model
parser.add_argument("--model", type=str, default="sparenet",
help="sparenet, atlasnet, msn, grnet")
parser.add_argument("--model", type=str, default="sparenet", help="sparenet, atlasnet, msn, grnet")
# choose train mode
parser.add_argument("--gan", dest="gan", help="use gan",
action="store_true", default=False)
parser.add_argument("--gan", dest="gan", help="use gan", action="store_true", default=False)
# choose load model
parser.add_argument("--weights", dest="weights",
help="Initialize network from the weights file", default=None)
parser.add_argument("--weights", dest="weights", help="Initialize network from the weights file", default=None)
# setup gpu
parser.add_argument("--gpu", dest="gpu_id",
help="GPU device to use", default="0", type=str)
parser.add_argument("--gpu", dest="gpu_id", help="GPU device to use", default="0", type=str)
# setup workdir
parser.add_argument("--workdir", dest="workdir",
help="where to save files", default=None)
parser.add_argument("--workdir", dest="workdir", help="where to save files", default=None)
return parser.parse_args()
def main():
# update config
args = get_args_from_command_line()
# Set GPU to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
# update config
from configs.base_config import cfg, cfg_from_file, cfg_update
if args.gan:
cfg_from_file("configs/" + args.model + "_gan.yaml")
else:
@ -54,13 +47,11 @@ def main():
# Set up folders for logs and checkpoints
if not os.path.exists(cfg.DIR.logs):
os.makedirs(cfg.DIR.logs)
from utils.misc import set_logger
logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt"))
logger.info("save into dir: %s" % cfg.DIR.logs)
# Set GPU to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.device
# Start train/inference process
if args.gan:
runners = __import__("runners." + args.model + "_gan_runner")

Просмотреть файл

@ -32,7 +32,8 @@ def gpu_init(cfg):
os.makedirs(cfg.DIR.checkpoints)
# GPU setup
torch.backends.cudnn.benchmark = True
return [int(x) for x in cfg.CONST.device.split(",")]
gup_ids = [int(x) for x in cfg.CONST.device.split(",")]
return list(range(len(gup_ids)))
def writer_init(cfg):