зеркало из https://github.com/microsoft/SpareNet.git
fix bugs of cuda
This commit is contained in:
Родитель
2a3336c708
Коммит
fb419dae70
43
test.py
43
test.py
|
@ -2,53 +2,46 @@
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
|
||||||
import sys
|
import sys
|
||||||
import yaml
|
|
||||||
import torch
|
|
||||||
import open3d
|
|
||||||
import argparse
|
import argparse
|
||||||
import pprint
|
|
||||||
from utils.misc import set_logger
|
|
||||||
from configs.base_config import cfg, cfg_from_file, cfg_update
|
|
||||||
|
|
||||||
|
|
||||||
def get_args_from_command_line():
|
def get_args_from_command_line():
|
||||||
"""
|
"""
|
||||||
config the parameter
|
config the parameter
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="The argument parser of R2Net runner")
|
||||||
description="The argument parser of R2Net runner")
|
|
||||||
|
|
||||||
# choose model
|
# choose model
|
||||||
parser.add_argument("--model", type=str, default="sparenet",
|
parser.add_argument("--model", type=str, default="sparenet", help="sparenet, atlasnet, msn, grnet")
|
||||||
help="sparenet, atlasnet, msn, grnet")
|
|
||||||
|
|
||||||
# choose test mode
|
# choose test mode
|
||||||
parser.add_argument("--test_mode", default="default",
|
parser.add_argument("--test_mode", default="default", help="default vis render kitti", type=str)
|
||||||
help="default vis render kitti", type=str)
|
|
||||||
|
|
||||||
# choose load model
|
# choose load model
|
||||||
parser.add_argument("--weights", dest="weights",
|
parser.add_argument("--weights", dest="weights", help="Initialize network from the weights file", default=None)
|
||||||
help="Initialize network from the weights file", default=None)
|
|
||||||
|
|
||||||
# setup gpu
|
# setup gpu
|
||||||
parser.add_argument("--gpu", dest="gpu_id",
|
parser.add_argument("--gpu", dest="gpu_id", help="GPU device to use", default="0", type=str)
|
||||||
help="GPU device to use", default="0", type=str)
|
|
||||||
|
|
||||||
# setup workdir
|
# setup workdir
|
||||||
parser.add_argument("--workdir", dest="workdir",
|
parser.add_argument("--workdir", dest="workdir", help="where to save files", default="./output", type=str)
|
||||||
help="where to save files", default="./output", type=str)
|
|
||||||
|
|
||||||
# choose train mode
|
# choose train mode
|
||||||
parser.add_argument("--gan", dest="gan", help="use gan",
|
parser.add_argument("--gan", dest="gan", help="use gan", action="store_true", default=False)
|
||||||
action="store_true", default=False)
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# update config
|
|
||||||
args = get_args_from_command_line()
|
args = get_args_from_command_line()
|
||||||
|
|
||||||
|
# Set GPU to use
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
|
||||||
|
|
||||||
|
# update config
|
||||||
|
from configs.base_config import cfg, cfg_from_file, cfg_update
|
||||||
|
|
||||||
if args.gan:
|
if args.gan:
|
||||||
cfg_from_file("configs/" + args.model + "_gan.yaml")
|
cfg_from_file("configs/" + args.model + "_gan.yaml")
|
||||||
else:
|
else:
|
||||||
|
@ -60,13 +53,11 @@ def main():
|
||||||
# Set up folders for logs and checkpoints
|
# Set up folders for logs and checkpoints
|
||||||
if not os.path.exists(cfg.DIR.logs):
|
if not os.path.exists(cfg.DIR.logs):
|
||||||
os.makedirs(cfg.DIR.logs)
|
os.makedirs(cfg.DIR.logs)
|
||||||
|
from utils.misc import set_logger
|
||||||
|
|
||||||
logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt"))
|
logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt"))
|
||||||
logger.info("save into dir: %s" % cfg.DIR.logs)
|
logger.info("save into dir: %s" % cfg.DIR.logs)
|
||||||
|
|
||||||
# Set GPU to use
|
|
||||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.device
|
|
||||||
|
|
||||||
if "weights" not in cfg.CONST or not os.path.exists(cfg.CONST.weights):
|
if "weights" not in cfg.CONST or not os.path.exists(cfg.CONST.weights):
|
||||||
logger.error("Please specify the file path of checkpoint.")
|
logger.error("Please specify the file path of checkpoint.")
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
|
|
41
train.py
41
train.py
|
@ -2,49 +2,42 @@
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
|
||||||
import sys
|
|
||||||
import yaml
|
|
||||||
import torch
|
|
||||||
import open3d
|
|
||||||
import argparse
|
import argparse
|
||||||
import pprint
|
|
||||||
from utils.misc import set_logger
|
|
||||||
from configs.base_config import cfg, cfg_from_file, cfg_update
|
|
||||||
|
|
||||||
|
|
||||||
def get_args_from_command_line():
|
def get_args_from_command_line():
|
||||||
"""
|
"""
|
||||||
config the parameter
|
config the parameter
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="The argument parser of R2Net runner")
|
||||||
description="The argument parser of R2Net runner")
|
|
||||||
|
|
||||||
# choose model
|
# choose model
|
||||||
parser.add_argument("--model", type=str, default="sparenet",
|
parser.add_argument("--model", type=str, default="sparenet", help="sparenet, atlasnet, msn, grnet")
|
||||||
help="sparenet, atlasnet, msn, grnet")
|
|
||||||
|
|
||||||
# choose train mode
|
# choose train mode
|
||||||
parser.add_argument("--gan", dest="gan", help="use gan",
|
parser.add_argument("--gan", dest="gan", help="use gan", action="store_true", default=False)
|
||||||
action="store_true", default=False)
|
|
||||||
|
|
||||||
# choose load model
|
# choose load model
|
||||||
parser.add_argument("--weights", dest="weights",
|
parser.add_argument("--weights", dest="weights", help="Initialize network from the weights file", default=None)
|
||||||
help="Initialize network from the weights file", default=None)
|
|
||||||
|
|
||||||
# setup gpu
|
# setup gpu
|
||||||
parser.add_argument("--gpu", dest="gpu_id",
|
parser.add_argument("--gpu", dest="gpu_id", help="GPU device to use", default="0", type=str)
|
||||||
help="GPU device to use", default="0", type=str)
|
|
||||||
|
|
||||||
# setup workdir
|
# setup workdir
|
||||||
parser.add_argument("--workdir", dest="workdir",
|
parser.add_argument("--workdir", dest="workdir", help="where to save files", default=None)
|
||||||
help="where to save files", default=None)
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# update config
|
|
||||||
args = get_args_from_command_line()
|
args = get_args_from_command_line()
|
||||||
|
|
||||||
|
# Set GPU to use
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
|
||||||
|
|
||||||
|
# update config
|
||||||
|
from configs.base_config import cfg, cfg_from_file, cfg_update
|
||||||
|
|
||||||
if args.gan:
|
if args.gan:
|
||||||
cfg_from_file("configs/" + args.model + "_gan.yaml")
|
cfg_from_file("configs/" + args.model + "_gan.yaml")
|
||||||
else:
|
else:
|
||||||
|
@ -54,13 +47,11 @@ def main():
|
||||||
# Set up folders for logs and checkpoints
|
# Set up folders for logs and checkpoints
|
||||||
if not os.path.exists(cfg.DIR.logs):
|
if not os.path.exists(cfg.DIR.logs):
|
||||||
os.makedirs(cfg.DIR.logs)
|
os.makedirs(cfg.DIR.logs)
|
||||||
|
from utils.misc import set_logger
|
||||||
|
|
||||||
logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt"))
|
logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt"))
|
||||||
logger.info("save into dir: %s" % cfg.DIR.logs)
|
logger.info("save into dir: %s" % cfg.DIR.logs)
|
||||||
|
|
||||||
# Set GPU to use
|
|
||||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.device
|
|
||||||
|
|
||||||
# Start train/inference process
|
# Start train/inference process
|
||||||
if args.gan:
|
if args.gan:
|
||||||
runners = __import__("runners." + args.model + "_gan_runner")
|
runners = __import__("runners." + args.model + "_gan_runner")
|
||||||
|
|
|
@ -32,7 +32,8 @@ def gpu_init(cfg):
|
||||||
os.makedirs(cfg.DIR.checkpoints)
|
os.makedirs(cfg.DIR.checkpoints)
|
||||||
# GPU setup
|
# GPU setup
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
return [int(x) for x in cfg.CONST.device.split(",")]
|
gup_ids = [int(x) for x in cfg.CONST.device.split(",")]
|
||||||
|
return list(range(len(gup_ids)))
|
||||||
|
|
||||||
|
|
||||||
def writer_init(cfg):
|
def writer_init(cfg):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче