1
0
Форкнуть 0
leprosy-skin-lesion-ai-anal.../prepare_data.py

69 строки
3.0 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import torchvision.transforms as transforms
import os
import pandas as pd
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import pretrainedmodels as ptm
class LeprosyDataset(Dataset):
"""Leprosy dataset."""
def __init__(self, img_names, root_dir, img_df,target_name, transform=None,
id_col='label_id',
extension=".jpg"):
"""
Args:
img_names (string): List of the names of images in this dataset not including extension
root_dir (string): Image directory.
img_df (DataFrame): Dataframe including patient_id, image_name, label_id, lesion_leprosy. Its label_id is identitical to the img_names
target_name (string): the colname of the response variable in img_df
transform (callable, optional): Transform to be applied on images.
Returns:
image: lesion image
target: 0(leprosy)/1(nonleprosy)
weight: prop to 1/(num of images of the patient)
freq_ratio: 1/(num of images of the patient)
"""
self.img_names = img_names
self.root_dir = root_dir
self.img_df=img_df
self.transform = transform
self.target_name=target_name
classes = list(self.img_df[self.target_name].unique())
classes.sort()
self.class_to_idx = {classes[i]: i for i in range(len(classes))}
self.classes = classes
self.extension=extension
self.id_col=id_col
lesion_count=[self.img_df.patient_id.values.tolist().count(p) for p in self.img_df.patient_id] # count num of images for each patient
count_max=max(lesion_count)# max num of images of one patient
self.count_max=count_max
self.img_df['sample_weight']=[count_max/x for x in lesion_count] # give an option to re-weight the loss if there are some patients with siginificantly more images than the other.
self.img_df['sample_freq_ratio']=[1/x for x in lesion_count]
def __len__(self):
return len(self.img_names)
def __getitem__(self, i):
image = Image.open(os.path.join(self.root_dir,
self.img_names[i]+self.extension))
weight=self.img_df.loc[self.img_df[self.id_col]==self.img_names[i],'sample_weight'].values[0]
freq_ratio=self.img_df.loc[self.img_df[self.id_col]==self.img_names[i],'sample_freq_ratio'].values[0]
target = self.class_to_idx[self.img_df.loc[self.img_df[self.id_col]==self.img_names[i],self.target_name].values[0]]
if self.transform:
image = self.transform(image)
return image, target,weight,freq_ratio
def numpatient(self):
return len(set(self.img_df.loc[self.img_df[self.id_col].isin(self.img_names),'patient_id'].values))