Pacs data loader; ResNet-50 model added

This commit is contained in:
divyat09 2020-08-10 21:04:35 +00:00
Родитель dd5389b016
Коммит a02b62caa0
4 изменённых файлов: 10 добавлений и 6 удалений

Просмотреть файл

@ -47,9 +47,9 @@ class BaseAlgo():
if self.args.model_name == 'alexnet':
from models.alexnet import alexnet
phi= alexnet(self.args.out_classes, self.args.pre_trained, self.args.method_name)
if self.args.model_name == 'resnet18':
if 'resnet' in self.args.model_name:
from models.resnet import get_resnet
phi= get_resnet('resnet18', self.args.out_classes, self.args.method_name,
phi= get_resnet(self.args.model_name, self.args.out_classes, self.args.method_name,
self.args.img_c, self.args.pre_trained)
print('Model Architecture: ', self.args.model_name)

Просмотреть файл

@ -84,9 +84,9 @@ class BaseEval():
if self.args.model_name == 'alexnet':
from models.alexnet import alexnet
phi= alexnet(self.args.out_classes, self.args.pre_trained, self.args.method_name)
if self.args.model_name == 'resnet18':
if 'resnet' in self.args.model_name:
from models.resnet import get_resnet
phi= get_resnet('resnet18', self.args.out_classes, self.args.method_name,
phi= get_resnet(self.args.model_name, self.args.out_classes, self.args.method_name,
self.args.img_c, self.args.pre_trained)
print('Model Architecture: ', self.args.model_name)

Просмотреть файл

@ -18,6 +18,10 @@ def get_resnet(model_name, classes, erm_base, num_ch, pre_trained):
model= torchvision.models.resnet18(pre_trained)
n_inputs = model.fc.in_features
n_outputs= classes
elif model_name == 'resnet50':
model= torchvision.models.resnet50(pre_trained)
n_inputs = model.fc.in_features
n_outputs= classes
if erm_base == 'matchdg_ctr':
model.fc = Identity(n_inputs)

Просмотреть файл

@ -29,9 +29,9 @@ parser.add_argument('--method_name', type=str, default='erm_match',
help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
parser.add_argument('--model_name', type=str, default='resnet18',
help='Architecture of the model to be trained')
parser.add_argument('--train_domains', type=int, default=["15", "30", "45", "60", "75"],
parser.add_argument('--train_domains', nargs='+', type=str, default=["15", "30", "45", "60", "75"],
help='List of train domains')
parser.add_argument('--test_domains', type=int, default=["0", "90"],
parser.add_argument('--test_domains', nargs='+', type=str, default=["0", "90"],
help='List of test domains')
parser.add_argument('--out_classes', type=int, default=10,
help='Total number of classes in the dataset')