Adding segmentation models based model sessions

This commit is contained in:
Caleb Robinson 2021-06-08 22:11:46 +00:00
Родитель 4fff2bfd73
Коммит 5f68bc087b
1 изменённых файлов: 3 добавлений и 0 удалений

Просмотреть файл

@ -21,6 +21,7 @@ from web_tool.ModelSessionKerasExample import KerasDenseFineTune
from web_tool.ModelSessionPytorchSolar import SolarFineTuning
from web_tool.ModelSessionPyTorchExample import TorchFineTuning
from web_tool.ModelSessionRandomForest import ModelSessionRandomForest
from web_tool.ModelSessionPyTorchSegmentationModel import TorchSegmentationModel
from web_tool.Utils import setup_logging, serialize, deserialize
from web_tool.Models import load_models
@ -93,6 +94,8 @@ def main():
model = SolarFineTuning(args.gpu_id, **model_configs[args.model_key])
elif model_type == "random_forest":
model = ModelSessionRandomForest(**model_configs[args.model_key])
elif model_type == "pytorch_segmodel":
model = TorchSegmentationModel(args.gpu_id, **model_configs[args.model_key])
else:
raise NotImplementedError("The given model type is not implemented yet.")