diff --git a/train.py b/train.py index d89acb8..b84d786 100644 --- a/train.py +++ b/train.py @@ -72,10 +72,9 @@ def train(): # training options args = parse_args() - opt = Option() + opt = Option(model_name=args.model_name) opt.data_path = [args.data_path] opt.val_data_path = [args.val_data_path] - opt.model_name = args.model_name # load training data into queue train_iterator = load_dataset(opt)