зеркало из https://github.com/microsoft/Chestist.git
325 строки
14 KiB
Python
325 строки
14 KiB
Python
import os
|
|
import numpy as np
|
|
import time
|
|
import sys
|
|
import re
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.backends.cudnn as cudnn
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
import torch.optim as optim
|
|
import torch.nn.functional as tfunc
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
from tempfile import TemporaryFile
|
|
import torch.nn.functional as func
|
|
from CNNModel import CNNModel
|
|
from DatasetGenerator import DatasetGenerator
|
|
from azureml.core import Workspace, Datastore, Dataset
|
|
from azureml.core.authentication import InteractiveLoginAuthentication
|
|
from sklearn.metrics.ranking import roc_auc_score
|
|
|
|
class TrainerTester ():
|
|
|
|
def trainer(pathDirData, pathFileTrain, pathFileVal,pathFileTest, nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, transResize, transCrop, launchTimestamp, checkpoint):
|
|
|
|
#-------------------- SETTINGS: NETWORK ARCHITECTURE
|
|
if nnArchitecture == 'CNNModel': model = CNNModel(nnClassCount, nnIsTrained).cuda()
|
|
model = torch.nn.DataParallel(model).cuda()
|
|
|
|
|
|
#-------------------- SETTINGS: AML WORKSPACE AND DATASTORE
|
|
interactive_auth = InteractiveLoginAuthentication(tenant_id=os.environ['TENANT_ID'])
|
|
ws = Workspace(
|
|
subscription_id=os.environ['SUBSCRIPTION_ID'],
|
|
resource_group=os.environ["RESOURCE_GROUP"],
|
|
workspace_name=os.environ['WORKSPACE_NAME'],
|
|
auth=interactive_auth
|
|
)
|
|
datastore = Datastore.get(ws, datastore_name=os.environ['DATASTORE_NAME'])
|
|
|
|
|
|
|
|
#-------------------- SETTINGS: MOUNTING THE DATASET TO MAKE IT AVAILABLE
|
|
chestist_data = Dataset.get_by_name(ws,os.environ['DATASET_NAME_CSV'])
|
|
mountPoint = chestist_data.mount()
|
|
mountPoint.start()
|
|
mountFolder = mountPoint.mount_point
|
|
files=os.listdir(mountFolder) #Need to generalize for the whole dataset
|
|
# pathDirData=files
|
|
csvFilePath= mountFolder #Path for the csv file with the labels
|
|
|
|
|
|
#-------------------- SETTINGS: DATA TRANSFORMS (IMAGES SETTINGS)
|
|
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #Using the mean and std of Imagenet is a common practice. They are calculated based on millions of images. We can calculate the new mean and std
|
|
transformList = []
|
|
transformList.append(transforms.RandomResizedCrop(transCrop)) #randomize size as well
|
|
transformList.append(transforms.RandomHorizontalFlip()) #we are adding here a random flip so that is not always horizontal
|
|
transformList.append(transforms.ToTensor()) #This converts to tensor
|
|
transformList.append(normalize)
|
|
transformSequence=transforms.Compose(transformList)
|
|
|
|
#List of images paths for train and validation
|
|
listImagesTrain=[]
|
|
listImagesVal=[]
|
|
listImagesTest=[]
|
|
|
|
|
|
for imagePath in pathFileTrain:
|
|
listImagesTrain.append(os.path.basename(imagePath))
|
|
|
|
for imagePath in pathFileVal:
|
|
listImagesVal.append(os.path.basename(imagePath))
|
|
|
|
for imagePath in pathFileTest:
|
|
listImagesTest.append(os.path.basename(imagePath))
|
|
|
|
labelList=[]
|
|
#-------------------- DATASET BUILDERS
|
|
datasetTrain = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTrain, listImages=listImagesTrain,labelList=labelList, transform=transformSequence,csvFilePath=csvFilePath)
|
|
datasetVal = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileVal, listImages=listImagesVal,labelList=labelList, transform=transformSequence,csvFilePath=csvFilePath)
|
|
dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=trBatchSize, shuffle=True, num_workers=24, pin_memory=True)
|
|
dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=trBatchSize, shuffle=False, num_workers=24, pin_memory=True)
|
|
print("dataset")
|
|
print(list(datasetTrain))
|
|
|
|
|
|
#-------------------- SETTINGS: OPTIMIZER & SCHEDULER
|
|
optimizer = optim.Adam (model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
|
|
scheduler = ReduceLROnPlateau(optimizer, factor = 0.1, patience = 5, mode = 'min')
|
|
|
|
#-------------------- SETTINGS: LOSS
|
|
loss = torch.nn.BCELoss(size_average = True)
|
|
|
|
#---- Load checkpoint
|
|
if checkpoint != None:
|
|
modelCheckpoint = torch.load(checkpoint)
|
|
model.load_state_dict(modelCheckpoint['state_dict'],strict=False)
|
|
optimizer.load_state_dict(modelCheckpoint['optimizer'])
|
|
|
|
|
|
#---- TRAIN THE NETWORK
|
|
|
|
lossMIN = 100000 #Fixable
|
|
|
|
for epochID in range (0, trMaxEpoch):
|
|
|
|
timestampTime = time.strftime("%H%M%S")
|
|
timestampDate = time.strftime("%d%m%Y")
|
|
timestampSTART = timestampDate + '-' + timestampTime
|
|
print("1")
|
|
#print(list(dataLoaderTrain))
|
|
TrainerTester.epochTrain (model, dataLoaderTrain, optimizer, scheduler, trMaxEpoch, nnClassCount, loss)
|
|
#del dataLoaderTrain
|
|
#torch.cuda.empty_cache()
|
|
lossVal, losstensor = TrainerTester.epochVal (model, dataLoaderVal, optimizer, scheduler, trMaxEpoch, nnClassCount, loss)
|
|
#del dataLoaderVal
|
|
timestampTime = time.strftime("%H%M%S")
|
|
timestampDate = time.strftime("%d%m%Y")
|
|
timestampEND = timestampDate + '-' + timestampTime
|
|
|
|
scheduler.step(losstensor.item())
|
|
|
|
if lossVal < lossMIN:
|
|
lossMIN = lossVal
|
|
torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN, 'optimizer' : optimizer.state_dict()}, 'm-' + launchTimestamp + '.pth.tar')
|
|
print ('Epoch [' + str(epochID + 1) + '] [save] [' + timestampEND + '] loss= ' + str(lossVal))
|
|
else:
|
|
print ('Epoch [' + str(epochID + 1) + '] [----] [' + timestampEND + '] loss= ' + str(lossVal))
|
|
|
|
|
|
#--------------------------------------------------------------------------------
|
|
|
|
def epochTrain (model, dataLoader, optimizer, scheduler, epochMax, classCount, loss):
|
|
|
|
model.train()
|
|
|
|
print("Before batchID")
|
|
for batchID, (input, target) in enumerate (dataLoader):
|
|
#print(input)
|
|
print("batchID")
|
|
target = target.cuda(non_blocking = True)
|
|
|
|
varInput = torch.autograd.Variable(input)
|
|
varTarget = torch.autograd.Variable(target)
|
|
varOutput = model(varInput)
|
|
|
|
lossvalue = loss(varOutput, varTarget)
|
|
|
|
optimizer.zero_grad()
|
|
lossvalue.backward()
|
|
optimizer.step()
|
|
|
|
|
|
#--------------------------------------------------------------------------------
|
|
|
|
def epochVal (model, dataLoader, optimizer, scheduler, epochMax, classCount, loss):
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
|
|
lossVal = 0
|
|
lossValNorm = 0
|
|
|
|
losstensorMean = 0
|
|
|
|
for i, (input, target) in enumerate (dataLoader):
|
|
print("validation")
|
|
target = target.cuda(non_blocking=True)
|
|
|
|
varInput = torch.autograd.Variable(input, volatile=True)
|
|
varTarget = torch.autograd.Variable(target, volatile=True)
|
|
varOutput = model(varInput)
|
|
|
|
losstensor = loss(varOutput, varTarget)
|
|
losstensorMean += losstensor
|
|
|
|
lossVal += losstensor.item()
|
|
lossValNorm += 1
|
|
|
|
outLoss = lossVal / lossValNorm
|
|
losstensorMean = losstensorMean / lossValNorm
|
|
|
|
return outLoss, losstensorMean
|
|
|
|
|
|
|
|
def computeAUROC (dataGT, dataPRED, classCount):
|
|
|
|
outAUROC = []
|
|
|
|
datanpGT = dataGT.cpu().numpy()
|
|
datanpPRED = dataPRED.cpu().numpy()
|
|
|
|
for i in range(classCount):
|
|
try:
|
|
#roc_auc_score(y_true, y_scores)
|
|
outAUROC.append(roc_auc_score(datanpGT[:, i], datanpPRED[:, i]))
|
|
except ValueError:
|
|
pass
|
|
print(datanpGT)
|
|
print(datanpPRED)
|
|
|
|
# save numpy array as csv file
|
|
#from numpy import asarray
|
|
#from numpy import savetxt
|
|
# define data
|
|
#data = asarray([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
|
|
# save to csv file
|
|
#tuple=datanpPRED.shape
|
|
#savetxt('predictions_'+str(tuple[0])+'images'+'.csv', datanpPRED, delimiter=',')
|
|
#ChexnetTrainer.uploadToBlob('predictions_'+str(tuple[0])+'images'+'.csv',"datashowcaseprod")
|
|
|
|
return outAUROC
|
|
|
|
#--------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
def tester (pathDirData, pathFileTest, pathModel, nnArchitecture, nnClassCount, nnIsTrained, trBatchSize, transResize, transCrop, launchTimeStamp):
|
|
|
|
|
|
#CLASS_NAMES = ['Cardiomegaly', 'Effusion', 'Nodule', 'Pneumonia','Pneumothorax']
|
|
CLASS_NAMES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia','Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']
|
|
#CLASS_NAMES = [ 'Effusion', 'Infiltration', 'Mass', 'Nodule','Pneumothorax']
|
|
cudnn.benchmark = True
|
|
|
|
if nnArchitecture == 'CNNModel': model = CNNModel(nnClassCount, nnIsTrained).cuda()
|
|
|
|
#import re
|
|
model = torch.nn.DataParallel(model).cuda()
|
|
print("PWD")
|
|
# !pwd
|
|
checkpoint = torch.load(pathModel)
|
|
state_dict = checkpoint['state_dict']
|
|
remove_data_parallel = False # Change if you don't want to use nn.DataParallel(model)
|
|
pattern = re.compile(
|
|
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
|
|
for key in list(state_dict.keys()):
|
|
match = pattern.match(key)
|
|
new_key = match.group(1) + match.group(2) if match else key
|
|
new_key = new_key[7:] if remove_data_parallel else new_key
|
|
state_dict[new_key] = state_dict[key]
|
|
# Delete old key only if modified.
|
|
if match or remove_data_parallel:
|
|
del state_dict[key]
|
|
model.load_state_dict(checkpoint['state_dict'], strict=False )
|
|
#optimizer.load_state_dict(checkpoint['optimizer'])
|
|
#modelCheckpoint = torch.load(pathModel)
|
|
#model.load_state_dict(modelCheckpoint['state_dict'])
|
|
|
|
#-------------------- SETTINGS: DATA TRANSFORMS, TEN CROPS
|
|
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
|
#-------------------- SETTINGS: DATASET BUILDERS
|
|
transformList = []
|
|
transformList.append(transforms.Resize(transResize))
|
|
transformList.append(transforms.TenCrop(transCrop))
|
|
transformList.append(transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
|
|
transformList.append(transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])))
|
|
transformSequence=transforms.Compose(transformList)
|
|
|
|
interactive_auth = InteractiveLoginAuthentication(tenant_id=os.environ['TENANT_ID'])
|
|
ws = Workspace(
|
|
subscription_id=os.environ['SUBSCRIPTION_ID'],
|
|
resource_group=os.environ["RESOURCE_GROUP"],
|
|
workspace_name=os.environ['WORKSPACE_NAME'],
|
|
auth=interactive_auth
|
|
)
|
|
datastore = Datastore.get(ws, datastore_name=os.environ['DATASTORE_NAME'])
|
|
|
|
|
|
|
|
#-------------------- SETTINGS: MOUNTING THE DATASET TO MAKE IT AVAILABLE
|
|
chestist_data = Dataset.get_by_name(ws,os.environ['DATASET_NAME'])
|
|
mountPoint = chestist_data.mount()
|
|
mountPoint.start()
|
|
mountFolder = mountPoint.mount_point
|
|
pathDirData=mountFolder
|
|
#pathDirData=files
|
|
listImagesTest=[]
|
|
for imagePath in pathFileTest:
|
|
listImagesTest.append(os.path.basename(imagePath))
|
|
csvFilePath=""
|
|
labelList=[]
|
|
datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest, listImages=listImagesTest,labelList=labelList, transform=transformSequence,csvFilePath=csvFilePath)
|
|
dataLoaderTest = DataLoader(dataset=datasetTest, batch_size=trBatchSize, num_workers=8, shuffle=False, pin_memory=True)
|
|
print("HEY2")
|
|
|
|
with torch.no_grad():
|
|
|
|
print(list(datasetTest))
|
|
outGT = torch.FloatTensor().cuda()
|
|
outPRED = torch.FloatTensor().cuda()
|
|
|
|
model.eval()
|
|
|
|
for i, (input, target) in enumerate(dataLoaderTest):
|
|
|
|
target = target.cuda()
|
|
outGT = torch.cat((outGT, target), 0)
|
|
|
|
bs, n_crops, c, h, w = input.size()
|
|
|
|
varInput = torch.autograd.Variable(input.view(-1, c, h, w).cuda(), volatile=True)
|
|
|
|
out = model(varInput)
|
|
outMean = out.view(bs, n_crops, -1).mean(1)
|
|
|
|
outPRED = torch.cat((outPRED, outMean.data), 0)
|
|
|
|
aurocIndividual = TrainerTester.computeAUROC(outGT, outPRED, nnClassCount)
|
|
aurocMean = np.array(aurocIndividual).mean()
|
|
#del dataLoaderTest
|
|
print ('AUROC mean ', aurocMean)
|
|
|
|
for i in range (0, len(aurocIndividual)):
|
|
print (CLASS_NAMES[i], ' ', aurocIndividual[i])
|
|
|
|
|
|
return
|
|
|
|
|
|
|