From a02b62caa01cff215a462fa591c37516b1ec4c10 Mon Sep 17 00:00:00 2001 From: divyat09 Date: Mon, 10 Aug 2020 21:04:35 +0000 Subject: [PATCH] Pacs data loader; ResNet-50 model added --- algorithms/algo.py | 4 ++-- evaluation/base_eval.py | 4 ++-- models/resnet.py | 4 ++++ train.py | 4 ++-- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/algorithms/algo.py b/algorithms/algo.py index 5f6157c..41d59d8 100644 --- a/algorithms/algo.py +++ b/algorithms/algo.py @@ -47,9 +47,9 @@ class BaseAlgo(): if self.args.model_name == 'alexnet': from models.alexnet import alexnet phi= alexnet(self.args.out_classes, self.args.pre_trained, self.args.method_name) - if self.args.model_name == 'resnet18': + if 'resnet' in self.args.model_name: from models.resnet import get_resnet - phi= get_resnet('resnet18', self.args.out_classes, self.args.method_name, + phi= get_resnet(self.args.model_name, self.args.out_classes, self.args.method_name, self.args.img_c, self.args.pre_trained) print('Model Architecture: ', self.args.model_name) diff --git a/evaluation/base_eval.py b/evaluation/base_eval.py index 09742df..a292a03 100644 --- a/evaluation/base_eval.py +++ b/evaluation/base_eval.py @@ -84,9 +84,9 @@ class BaseEval(): if self.args.model_name == 'alexnet': from models.alexnet import alexnet phi= alexnet(self.args.out_classes, self.args.pre_trained, self.args.method_name) - if self.args.model_name == 'resnet18': + if 'resnet' in self.args.model_name: from models.resnet import get_resnet - phi= get_resnet('resnet18', self.args.out_classes, self.args.method_name, + phi= get_resnet(self.args.model_name, self.args.out_classes, self.args.method_name, self.args.img_c, self.args.pre_trained) print('Model Architecture: ', self.args.model_name) diff --git a/models/resnet.py b/models/resnet.py index dd4acc3..a6c3214 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -18,6 +18,10 @@ def get_resnet(model_name, classes, erm_base, num_ch, pre_trained): model= torchvision.models.resnet18(pre_trained) n_inputs = model.fc.in_features n_outputs= classes + elif model_name == 'resnet50': + model= torchvision.models.resnet50(pre_trained) + n_inputs = model.fc.in_features + n_outputs= classes if erm_base == 'matchdg_ctr': model.fc = Identity(n_inputs) diff --git a/train.py b/train.py index ec711c6..d9c0979 100644 --- a/train.py +++ b/train.py @@ -29,9 +29,9 @@ parser.add_argument('--method_name', type=str, default='erm_match', help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm') parser.add_argument('--model_name', type=str, default='resnet18', help='Architecture of the model to be trained') -parser.add_argument('--train_domains', type=int, default=["15", "30", "45", "60", "75"], +parser.add_argument('--train_domains', nargs='+', type=str, default=["15", "30", "45", "60", "75"], help='List of train domains') -parser.add_argument('--test_domains', type=int, default=["0", "90"], +parser.add_argument('--test_domains', nargs='+', type=str, default=["0", "90"], help='List of test domains') parser.add_argument('--out_classes', type=int, default=10, help='Total number of classes in the dataset')