зеркало из https://github.com/microsoft/SpareNet.git
fix bugs of cuda
This commit is contained in:
Родитель
2a3336c708
Коммит
fb419dae70
43
test.py
43
test.py
|
@ -2,53 +2,46 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import sys
|
||||
import yaml
|
||||
import torch
|
||||
import open3d
|
||||
import argparse
|
||||
import pprint
|
||||
from utils.misc import set_logger
|
||||
from configs.base_config import cfg, cfg_from_file, cfg_update
|
||||
|
||||
|
||||
def get_args_from_command_line():
|
||||
"""
|
||||
config the parameter
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="The argument parser of R2Net runner")
|
||||
parser = argparse.ArgumentParser(description="The argument parser of R2Net runner")
|
||||
|
||||
# choose model
|
||||
parser.add_argument("--model", type=str, default="sparenet",
|
||||
help="sparenet, atlasnet, msn, grnet")
|
||||
parser.add_argument("--model", type=str, default="sparenet", help="sparenet, atlasnet, msn, grnet")
|
||||
|
||||
# choose test mode
|
||||
parser.add_argument("--test_mode", default="default",
|
||||
help="default vis render kitti", type=str)
|
||||
parser.add_argument("--test_mode", default="default", help="default vis render kitti", type=str)
|
||||
|
||||
# choose load model
|
||||
parser.add_argument("--weights", dest="weights",
|
||||
help="Initialize network from the weights file", default=None)
|
||||
parser.add_argument("--weights", dest="weights", help="Initialize network from the weights file", default=None)
|
||||
|
||||
# setup gpu
|
||||
parser.add_argument("--gpu", dest="gpu_id",
|
||||
help="GPU device to use", default="0", type=str)
|
||||
parser.add_argument("--gpu", dest="gpu_id", help="GPU device to use", default="0", type=str)
|
||||
|
||||
# setup workdir
|
||||
parser.add_argument("--workdir", dest="workdir",
|
||||
help="where to save files", default="./output", type=str)
|
||||
parser.add_argument("--workdir", dest="workdir", help="where to save files", default="./output", type=str)
|
||||
|
||||
# choose train mode
|
||||
parser.add_argument("--gan", dest="gan", help="use gan",
|
||||
action="store_true", default=False)
|
||||
parser.add_argument("--gan", dest="gan", help="use gan", action="store_true", default=False)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
# update config
|
||||
args = get_args_from_command_line()
|
||||
|
||||
# Set GPU to use
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
|
||||
|
||||
# update config
|
||||
from configs.base_config import cfg, cfg_from_file, cfg_update
|
||||
|
||||
if args.gan:
|
||||
cfg_from_file("configs/" + args.model + "_gan.yaml")
|
||||
else:
|
||||
|
@ -60,13 +53,11 @@ def main():
|
|||
# Set up folders for logs and checkpoints
|
||||
if not os.path.exists(cfg.DIR.logs):
|
||||
os.makedirs(cfg.DIR.logs)
|
||||
from utils.misc import set_logger
|
||||
|
||||
logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt"))
|
||||
logger.info("save into dir: %s" % cfg.DIR.logs)
|
||||
|
||||
# Set GPU to use
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.device
|
||||
|
||||
if "weights" not in cfg.CONST or not os.path.exists(cfg.CONST.weights):
|
||||
logger.error("Please specify the file path of checkpoint.")
|
||||
sys.exit(2)
|
||||
|
|
41
train.py
41
train.py
|
@ -2,49 +2,42 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import sys
|
||||
import yaml
|
||||
import torch
|
||||
import open3d
|
||||
import argparse
|
||||
import pprint
|
||||
from utils.misc import set_logger
|
||||
from configs.base_config import cfg, cfg_from_file, cfg_update
|
||||
|
||||
|
||||
def get_args_from_command_line():
|
||||
"""
|
||||
config the parameter
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="The argument parser of R2Net runner")
|
||||
parser = argparse.ArgumentParser(description="The argument parser of R2Net runner")
|
||||
|
||||
# choose model
|
||||
parser.add_argument("--model", type=str, default="sparenet",
|
||||
help="sparenet, atlasnet, msn, grnet")
|
||||
parser.add_argument("--model", type=str, default="sparenet", help="sparenet, atlasnet, msn, grnet")
|
||||
|
||||
# choose train mode
|
||||
parser.add_argument("--gan", dest="gan", help="use gan",
|
||||
action="store_true", default=False)
|
||||
parser.add_argument("--gan", dest="gan", help="use gan", action="store_true", default=False)
|
||||
|
||||
# choose load model
|
||||
parser.add_argument("--weights", dest="weights",
|
||||
help="Initialize network from the weights file", default=None)
|
||||
parser.add_argument("--weights", dest="weights", help="Initialize network from the weights file", default=None)
|
||||
|
||||
# setup gpu
|
||||
parser.add_argument("--gpu", dest="gpu_id",
|
||||
help="GPU device to use", default="0", type=str)
|
||||
parser.add_argument("--gpu", dest="gpu_id", help="GPU device to use", default="0", type=str)
|
||||
|
||||
# setup workdir
|
||||
parser.add_argument("--workdir", dest="workdir",
|
||||
help="where to save files", default=None)
|
||||
parser.add_argument("--workdir", dest="workdir", help="where to save files", default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
# update config
|
||||
args = get_args_from_command_line()
|
||||
|
||||
# Set GPU to use
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
|
||||
|
||||
# update config
|
||||
from configs.base_config import cfg, cfg_from_file, cfg_update
|
||||
|
||||
if args.gan:
|
||||
cfg_from_file("configs/" + args.model + "_gan.yaml")
|
||||
else:
|
||||
|
@ -54,13 +47,11 @@ def main():
|
|||
# Set up folders for logs and checkpoints
|
||||
if not os.path.exists(cfg.DIR.logs):
|
||||
os.makedirs(cfg.DIR.logs)
|
||||
from utils.misc import set_logger
|
||||
|
||||
logger = set_logger(os.path.join(cfg.DIR.logs, "log.txt"))
|
||||
logger.info("save into dir: %s" % cfg.DIR.logs)
|
||||
|
||||
# Set GPU to use
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.device
|
||||
|
||||
# Start train/inference process
|
||||
if args.gan:
|
||||
runners = __import__("runners." + args.model + "_gan_runner")
|
||||
|
|
|
@ -32,7 +32,8 @@ def gpu_init(cfg):
|
|||
os.makedirs(cfg.DIR.checkpoints)
|
||||
# GPU setup
|
||||
torch.backends.cudnn.benchmark = True
|
||||
return [int(x) for x in cfg.CONST.device.split(",")]
|
||||
gup_ids = [int(x) for x in cfg.CONST.device.split(",")]
|
||||
return list(range(len(gup_ids)))
|
||||
|
||||
|
||||
def writer_init(cfg):
|
||||
|
|
Загрузка…
Ссылка в новой задаче