Pacs data loader; ResNet-50 model added
This commit is contained in:
Родитель
dd5389b016
Коммит
a02b62caa0
|
@ -47,9 +47,9 @@ class BaseAlgo():
|
|||
if self.args.model_name == 'alexnet':
|
||||
from models.alexnet import alexnet
|
||||
phi= alexnet(self.args.out_classes, self.args.pre_trained, self.args.method_name)
|
||||
if self.args.model_name == 'resnet18':
|
||||
if 'resnet' in self.args.model_name:
|
||||
from models.resnet import get_resnet
|
||||
phi= get_resnet('resnet18', self.args.out_classes, self.args.method_name,
|
||||
phi= get_resnet(self.args.model_name, self.args.out_classes, self.args.method_name,
|
||||
self.args.img_c, self.args.pre_trained)
|
||||
|
||||
print('Model Architecture: ', self.args.model_name)
|
||||
|
|
|
@ -84,9 +84,9 @@ class BaseEval():
|
|||
if self.args.model_name == 'alexnet':
|
||||
from models.alexnet import alexnet
|
||||
phi= alexnet(self.args.out_classes, self.args.pre_trained, self.args.method_name)
|
||||
if self.args.model_name == 'resnet18':
|
||||
if 'resnet' in self.args.model_name:
|
||||
from models.resnet import get_resnet
|
||||
phi= get_resnet('resnet18', self.args.out_classes, self.args.method_name,
|
||||
phi= get_resnet(self.args.model_name, self.args.out_classes, self.args.method_name,
|
||||
self.args.img_c, self.args.pre_trained)
|
||||
|
||||
print('Model Architecture: ', self.args.model_name)
|
||||
|
|
|
@ -18,6 +18,10 @@ def get_resnet(model_name, classes, erm_base, num_ch, pre_trained):
|
|||
model= torchvision.models.resnet18(pre_trained)
|
||||
n_inputs = model.fc.in_features
|
||||
n_outputs= classes
|
||||
elif model_name == 'resnet50':
|
||||
model= torchvision.models.resnet50(pre_trained)
|
||||
n_inputs = model.fc.in_features
|
||||
n_outputs= classes
|
||||
|
||||
if erm_base == 'matchdg_ctr':
|
||||
model.fc = Identity(n_inputs)
|
||||
|
|
4
train.py
4
train.py
|
@ -29,9 +29,9 @@ parser.add_argument('--method_name', type=str, default='erm_match',
|
|||
help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
|
||||
parser.add_argument('--model_name', type=str, default='resnet18',
|
||||
help='Architecture of the model to be trained')
|
||||
parser.add_argument('--train_domains', type=int, default=["15", "30", "45", "60", "75"],
|
||||
parser.add_argument('--train_domains', nargs='+', type=str, default=["15", "30", "45", "60", "75"],
|
||||
help='List of train domains')
|
||||
parser.add_argument('--test_domains', type=int, default=["0", "90"],
|
||||
parser.add_argument('--test_domains', nargs='+', type=str, default=["0", "90"],
|
||||
help='List of test domains')
|
||||
parser.add_argument('--out_classes', type=int, default=10,
|
||||
help='Total number of classes in the dataset')
|
||||
|
|
Загрузка…
Ссылка в новой задаче