669 строки
25 KiB
Plaintext
669 строки
25 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import sys\n",
|
|
"import numpy as np\n",
|
|
"import argparse\n",
|
|
"import copy\n",
|
|
"import random\n",
|
|
"import json\n",
|
|
"from glob import glob\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"#Rand Number using Numpy\n",
|
|
"from numpy.random import default_rng\n",
|
|
"\n",
|
|
"#Pytorch\n",
|
|
"import torch\n",
|
|
"from torch.autograd import grad\n",
|
|
"from torch import nn, optim\n",
|
|
"from torch.nn import functional as F\n",
|
|
"from torchvision import datasets, transforms\n",
|
|
"from torchvision.utils import save_image\n",
|
|
"from torch.autograd import Variable\n",
|
|
"import torch.utils.data as data_utils\n",
|
|
"\n",
|
|
"#Pillow\n",
|
|
"from PIL import Image\n",
|
|
"\n",
|
|
"#XRayVision\n",
|
|
"import torchxrayvision as xrv"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Random Seed\n",
|
|
"seed= 100\n",
|
|
"random.seed(seed)\n",
|
|
"np.random.seed(seed) \n",
|
|
"torch.manual_seed(seed) \n",
|
|
"if torch.cuda.is_available():\n",
|
|
" torch.cuda.manual_seed_all(seed)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"to_pil= xrv.datasets.ToPILImage()\n",
|
|
"\n",
|
|
"to_resize= transforms.Resize([224, 224])\n",
|
|
"\n",
|
|
"to_augment= transforms.Compose([\n",
|
|
"# transforms.RandomAffine(45, translate=(0.15, 0.15), scale=(0.85, 1.15)),\n",
|
|
" transforms.RandomResizedCrop(224, scale=(0.7,1.0)),\n",
|
|
" transforms.RandomHorizontalFlip(),\n",
|
|
"])\n",
|
|
"\n",
|
|
"to_tensor= transforms.Compose([\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Lambda(lambda x: x.repeat(3, 1, 1) ),\n",
|
|
" # imagenet mean and std-dev.\n",
|
|
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
|
|
"])\n",
|
|
"\n",
|
|
"root_dir= '/home/t-dimaha/RobustDG/robustdg'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def translate(img, label, domain, idx, test_case):\n",
|
|
" \n",
|
|
" img= to_pil(img)\n",
|
|
" \n",
|
|
" # No translation in images with label 1 for test domains\n",
|
|
" if label == 0 and test_case == 1:\n",
|
|
" return to_resize(img)\n",
|
|
" \n",
|
|
" # No translation in images with label 0 for train domains\n",
|
|
" if label == 1 and test_case == 0:\n",
|
|
" return to_resize(img)\n",
|
|
" \n",
|
|
"# # No translation for Images with label 1 in case of NIH, Chex\n",
|
|
"# if label ==1:\n",
|
|
"# if domain in ['nih', 'chex'] or not test_case:\n",
|
|
"# return to_resize(img)\n",
|
|
" \n",
|
|
"# # Translation for Images with label 0 in case of NIH, Chex\n",
|
|
"# if label ==0:\n",
|
|
"# if domain in ['kaggle'] and test_case:\n",
|
|
"# return to_resize(img)\n",
|
|
" \n",
|
|
" hor_shift= 0\n",
|
|
" \n",
|
|
" if domain == 'nih':\n",
|
|
" ver_shift= 45\n",
|
|
" elif domain == 'chex':\n",
|
|
" ver_shift= 35\n",
|
|
" elif domain == 'kaggle':\n",
|
|
" ver_shift= 15\n",
|
|
"\n",
|
|
"# if domain == 'nih':\n",
|
|
"# ver_shift= 45\n",
|
|
"# elif domain == 'chex':\n",
|
|
"# ver_shift= 45\n",
|
|
"# elif domain == 'kaggle':\n",
|
|
"# ver_shift= 45\n",
|
|
"\n",
|
|
" a = 1\n",
|
|
" b = 0\n",
|
|
" c = hor_shift #left/right (i.e. 5/-5)\n",
|
|
" d = 0\n",
|
|
" e = 1\n",
|
|
" f = -ver_shift #up/down (i.e. 5/-5)\n",
|
|
"\n",
|
|
" img_new = img.transform(img.size, Image.AFFINE, (a, b, c, d, e, f))\n",
|
|
" size = (img_new.size[0] - hor_shift, img_new.size[1] - ver_shift)\n",
|
|
" \n",
|
|
" # Crop to the desired size: rectangular crop (start_horizonal, start_top, end_horizontal, end_top)\n",
|
|
" if f>0:\n",
|
|
" #Vertial Shift in AFFINE requires black region at bottom to be cut out with EXTENT\n",
|
|
" img_new = img_new.transform(size, Image.EXTENT, (0, 0, size[0], size[1]))\n",
|
|
" else:\n",
|
|
" #Downward Shift in AFFINE requires black region at top to be cut out with EXTENT\n",
|
|
" img_new = img_new.transform(size, Image.EXTENT, (0, ver_shift, img_new.size[0], img_new.size[1])) \n",
|
|
" \n",
|
|
" #Resize, Save and Return Image\n",
|
|
" img= to_resize(img)\n",
|
|
" img_new= to_resize(img_new) \n",
|
|
" #save_img(img, img_new, idx)\n",
|
|
" \n",
|
|
" return img_new"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def save_img(img, img_new, idx):\n",
|
|
" # Save fig a with one cmap\n",
|
|
" plt.imsave('images/figa_' + str(idx) + '.png', img, cmap='Greys_r')\n",
|
|
"\n",
|
|
" # Save fig b with a different cmap\n",
|
|
" plt.imsave('images/figb_' + str(idx) + '.png', img_new, cmap='Greys_r')\n",
|
|
"\n",
|
|
" # Reopen fig a and fig b\n",
|
|
" figa=plt.imread('images/figa_' + str(idx) + '.png')\n",
|
|
" figb=plt.imread('images/figb_' + str(idx) + '.png')\n",
|
|
"\n",
|
|
" # Stitch the two figures together\n",
|
|
" figc=np.concatenate((figa,figb),axis=1)\n",
|
|
"\n",
|
|
" # Save without a cmap, to preserve the ones you saved earlier\n",
|
|
" plt.imsave('images/figc_' + str(idx) + '.png', figc, cmap='Greys_r')\n",
|
|
" \n",
|
|
" os.remove('images/figa_' + str(idx) + '.png')\n",
|
|
" os.remove('images/figb_' + str(idx) + '.png')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Lung Lesion doesn't exist. Adding nans instead.\n",
|
|
"Fracture doesn't exist. Adding nans instead.\n",
|
|
"Lung Opacity doesn't exist. Adding nans instead.\n",
|
|
"Enlarged Cardiomediastinum doesn't exist. Adding nans instead.\n",
|
|
"{'Support Devices', 'Pleural Other'} will be dropped\n",
|
|
"Infiltration doesn't exist. Adding nans instead.\n",
|
|
"Emphysema doesn't exist. Adding nans instead.\n",
|
|
"Fibrosis doesn't exist. Adding nans instead.\n",
|
|
"Pleural_Thickening doesn't exist. Adding nans instead.\n",
|
|
"Nodule doesn't exist. Adding nans instead.\n",
|
|
"Mass doesn't exist. Adding nans instead.\n",
|
|
"Hernia doesn't exist. Adding nans instead.\n",
|
|
"Atelectasis doesn't exist. Adding nans instead.\n",
|
|
"Consolidation doesn't exist. Adding nans instead.\n",
|
|
"Infiltration doesn't exist. Adding nans instead.\n",
|
|
"Pneumothorax doesn't exist. Adding nans instead.\n",
|
|
"Edema doesn't exist. Adding nans instead.\n",
|
|
"Emphysema doesn't exist. Adding nans instead.\n",
|
|
"Fibrosis doesn't exist. Adding nans instead.\n",
|
|
"Effusion doesn't exist. Adding nans instead.\n",
|
|
"Pleural_Thickening doesn't exist. Adding nans instead.\n",
|
|
"Cardiomegaly doesn't exist. Adding nans instead.\n",
|
|
"Nodule doesn't exist. Adding nans instead.\n",
|
|
"Mass doesn't exist. Adding nans instead.\n",
|
|
"Hernia doesn't exist. Adding nans instead.\n",
|
|
"Lung Lesion doesn't exist. Adding nans instead.\n",
|
|
"Fracture doesn't exist. Adding nans instead.\n",
|
|
"Enlarged Cardiomediastinum doesn't exist. Adding nans instead.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"d_nih = xrv.datasets.NIH_Dataset(imgpath=root_dir + \"/data/datasets/NIH/images_224/\",\n",
|
|
" views=[\"PA\",\"AP\"], unique_patients=False)\n",
|
|
"\n",
|
|
"d_chex = xrv.datasets.CheX_Dataset(imgpath=root_dir + \"/data/datasets/CheXpert-v1.0-small\",\n",
|
|
" csvpath= root_dir + \"/data/datasets/CheXpert-v1.0-small/train.csv\",\n",
|
|
" views=[\"PA\",\"AP\"], unique_patients=False)\n",
|
|
"\n",
|
|
"d_rsna = xrv.datasets.RSNA_Pneumonia_Dataset(imgpath=root_dir + \"/data/datasets/Kaggle/stage_2_train_images_jpg\", \n",
|
|
" views=[\"PA\",\"AP\"],\n",
|
|
" unique_patients=False)\n",
|
|
"\n",
|
|
"xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, d_nih)\n",
|
|
"xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, d_chex)\n",
|
|
"xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, d_rsna)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# data={}\n",
|
|
" \n",
|
|
"# for data_name in ['nih', 'chex', 'kaggle']:\n",
|
|
" \n",
|
|
"# #TestSize: 30 percent of label 1\n",
|
|
"# #TrainSize: Label 1 - TestSize\n",
|
|
"# #ValSize: 25 percent of TrainSize\n",
|
|
" \n",
|
|
"# data[data_name]={}\n",
|
|
"# if data_name == 'nih':\n",
|
|
"# data[data_name]['obj']= d_nih\n",
|
|
"# data[data_name]['test_size']= 430 \n",
|
|
"# data[data_name]['train_size']= 800\n",
|
|
"# data[data_name]['val_size']= 200\n",
|
|
"# elif data_name == 'chex':\n",
|
|
"# data[data_name]['obj']= d_chex\n",
|
|
"# data[data_name]['test_size']= 1402\n",
|
|
"# data[data_name]['train_size']= 2618\n",
|
|
"# data[data_name]['val_size']= 654 \n",
|
|
"# elif data_name == 'kaggle':\n",
|
|
"# data[data_name]['obj']= d_rsna\n",
|
|
"# data[data_name]['test_size']= 1803\n",
|
|
"# data[data_name]['train_size']= 3368\n",
|
|
"# data[data_name]['val_size']= 841 \n",
|
|
" \n",
|
|
"# data[data_name]['size']= len(data[data_name]['obj'])\n",
|
|
"\n",
|
|
"\n",
|
|
"data={}\n",
|
|
" \n",
|
|
"for data_name in ['nih', 'chex', 'kaggle']:\n",
|
|
" \n",
|
|
" data[data_name]={}\n",
|
|
" if data_name == 'nih':\n",
|
|
" data[data_name]['obj']= d_nih\n",
|
|
" data[data_name]['test_size']= 400 \n",
|
|
" data[data_name]['train_size']= 800\n",
|
|
" data[data_name]['val_size']= 200\n",
|
|
" elif data_name == 'chex':\n",
|
|
" data[data_name]['obj']= d_chex\n",
|
|
" data[data_name]['test_size']= 400\n",
|
|
" data[data_name]['train_size']= 800\n",
|
|
" data[data_name]['val_size']= 200 \n",
|
|
" elif data_name == 'kaggle':\n",
|
|
" data[data_name]['obj']= d_rsna\n",
|
|
" data[data_name]['test_size']= 400\n",
|
|
" data[data_name]['train_size']= 800\n",
|
|
" data[data_name]['val_size']= 200 \n",
|
|
" \n",
|
|
" data[data_name]['size']= len(data[data_name]['obj'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Translating Images:\n",
|
|
"\n",
|
|
"Make directoy '/data/datasets/chestxray/' and run the following cells\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for data_name in ['nih', 'chex', 'kaggle']:\n",
|
|
"\n",
|
|
" rng = default_rng()\n",
|
|
" indices = rng.choice(data[data_name]['size'], size=data[data_name]['size'], replace=False) \n",
|
|
" print(indices.shape)\n",
|
|
" \n",
|
|
" count=0 \n",
|
|
" for case in ['train', 'val', 'test']: \n",
|
|
" size= data[data_name][case+'_size']\n",
|
|
" ids=[]\n",
|
|
" \n",
|
|
" count_l0=0\n",
|
|
" count_l1=0\n",
|
|
" count_lim=int(size/2)\n",
|
|
" \n",
|
|
" while count_l0 < count_lim or count_l1 < count_lim:\n",
|
|
"\n",
|
|
" index= indices[count].item()\n",
|
|
" task = xrv.datasets.default_pathologies.index('Pneumonia') \n",
|
|
" label= data[data_name]['obj'][index]['lab'][task]\n",
|
|
" count+=1\n",
|
|
" \n",
|
|
" if np.isnan(label):\n",
|
|
" continue\n",
|
|
" else:\n",
|
|
" \n",
|
|
" if label == 0:\n",
|
|
" if count_l0 < count_lim:\n",
|
|
"# print('Label 0')\n",
|
|
" count_l0+= 1\n",
|
|
" else:\n",
|
|
" continue\n",
|
|
"\n",
|
|
" if label ==1:\n",
|
|
" if count_l1 < count_lim:\n",
|
|
"# print('Label 1')\n",
|
|
" count_l1+= 1\n",
|
|
" else:\n",
|
|
" continue\n",
|
|
" \n",
|
|
" ids.append(index) \n",
|
|
" \n",
|
|
" ids= np.array(ids)\n",
|
|
" \n",
|
|
" print(count_l0, count_l1)\n",
|
|
" base_dir= root_dir + '/data/datasets/chestxray/'\n",
|
|
" np.save(base_dir + data_name + '_' + case + '_' + 'indices.npy', ids) "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Without Translation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(112120,)\n",
|
|
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
|
|
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
|
|
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n",
|
|
"(191010,)\n",
|
|
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
|
|
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
|
|
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n",
|
|
"(26684,)\n",
|
|
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
|
|
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
|
|
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"base_dir=root_dir + '/data/datasets/chestxray/'\n",
|
|
" \n",
|
|
"for data_name in ['nih', 'chex', 'kaggle']:\n",
|
|
" \n",
|
|
" indices= np.random.randint(0, data[data_name]['size'], data[data_name]['size'] )\n",
|
|
" print(indices.shape)\n",
|
|
" \n",
|
|
" count=0 \n",
|
|
" for case in ['train', 'val', 'test']:\n",
|
|
" \n",
|
|
" size= data[data_name][case+'_size'] \n",
|
|
" imgs=[]\n",
|
|
" labels=[] \n",
|
|
" imgs_org=[]\n",
|
|
" \n",
|
|
" indices= np.load(base_dir + data_name + '_' + case + '_' + 'indices.npy') \n",
|
|
" \n",
|
|
" count_l0=0\n",
|
|
" count_l1=0\n",
|
|
" for idx in range(indices.shape[0]):\n",
|
|
" \n",
|
|
" index= indices[idx].item()\n",
|
|
" task = xrv.datasets.default_pathologies.index('Pneumonia')\n",
|
|
" \n",
|
|
" img= data[data_name]['obj'][index]['img']\n",
|
|
" img_org= data[data_name]['obj'][index]['img']\n",
|
|
" label= data[data_name]['obj'][index]['lab'][task]\n",
|
|
" count+=1\n",
|
|
" \n",
|
|
" if np.isnan(data[data_name]['obj'][index]['lab'][task]):\n",
|
|
" print('Error: Nan in the labels')\n",
|
|
" \n",
|
|
" if label == 0:\n",
|
|
" count_l0+=1\n",
|
|
" \n",
|
|
" if label == 1:\n",
|
|
" count_l1+=1\n",
|
|
" \n",
|
|
" label=torch.tensor(label).long()\n",
|
|
" label= label.view(1)\n",
|
|
" \n",
|
|
" \n",
|
|
" img= to_pil(img)\n",
|
|
" img= to_resize(img) \n",
|
|
" img= to_tensor(to_augment(img))\n",
|
|
" \n",
|
|
" img_org= to_pil(img_org)\n",
|
|
" img_org= to_resize(img_org) \n",
|
|
" img_org= to_tensor(img_org) \n",
|
|
" \n",
|
|
" img= img.view(1, img.shape[0], img.shape[1], img.shape[2])\n",
|
|
" img_org= img_org.view(1, img_org.shape[0], img_org.shape[1], img_org.shape[2])\n",
|
|
" \n",
|
|
"# print('Data: ', data_name, count, img.shape, label)\n",
|
|
" imgs.append(img)\n",
|
|
" imgs_org.append(img_org)\n",
|
|
" labels.append(label)\n",
|
|
" \n",
|
|
" if torch.all(torch.eq(img, img_org)):\n",
|
|
" print('Error:')\n",
|
|
" \n",
|
|
" \n",
|
|
" imgs=torch.cat(imgs)\n",
|
|
" imgs_org=torch.cat(imgs_org)\n",
|
|
" labels=torch.cat(labels)\n",
|
|
" print(imgs.shape, imgs_org.shape, labels.shape, count_l0, count_l1)\n",
|
|
" \n",
|
|
" torch.save(imgs, base_dir + data_name + '_' + case + '_' + 'image.pt')\n",
|
|
" torch.save(imgs_org, base_dir + data_name + '_' + case + '_' + 'image_org.pt')\n",
|
|
" torch.save(labels, base_dir + data_name + '_' + case + '_' + 'label.pt') "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## With Translation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(112120,)\n",
|
|
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
|
|
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
|
|
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n",
|
|
"(191010,)\n",
|
|
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
|
|
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
|
|
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n",
|
|
"(26684,)\n",
|
|
"torch.Size([800, 3, 224, 224]) torch.Size([800, 3, 224, 224]) torch.Size([800]) 400 400\n",
|
|
"torch.Size([200, 3, 224, 224]) torch.Size([200, 3, 224, 224]) torch.Size([200]) 100 100\n",
|
|
"torch.Size([400, 3, 224, 224]) torch.Size([400, 3, 224, 224]) torch.Size([400]) 200 200\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"base_dir=root_dir + '/data/datasets/chestxray/'\n",
|
|
" \n",
|
|
"for data_name in ['nih', 'chex', 'kaggle']:\n",
|
|
" \n",
|
|
" indices= np.random.randint(0, data[data_name]['size'], data[data_name]['size'] )\n",
|
|
" print(indices.shape)\n",
|
|
" \n",
|
|
" count=0 \n",
|
|
" for case in ['train', 'val', 'test']:\n",
|
|
" \n",
|
|
" size= data[data_name][case+'_size'] \n",
|
|
" imgs=[]\n",
|
|
" labels=[] \n",
|
|
" imgs_org=[]\n",
|
|
" \n",
|
|
" indices= np.load(base_dir + data_name + '_' + case + '_' + 'indices.npy') \n",
|
|
" \n",
|
|
" count_l0=0\n",
|
|
" count_l1=0\n",
|
|
" for idx in range(indices.shape[0]):\n",
|
|
" \n",
|
|
" index= indices[idx].item()\n",
|
|
" task = xrv.datasets.default_pathologies.index('Pneumonia')\n",
|
|
" \n",
|
|
" img= data[data_name]['obj'][index]['img']\n",
|
|
" img_org= data[data_name]['obj'][index]['img']\n",
|
|
" label= data[data_name]['obj'][index]['lab'][task]\n",
|
|
" count+=1\n",
|
|
" \n",
|
|
" if np.isnan(data[data_name]['obj'][index]['lab'][task]):\n",
|
|
" print('Error: Nan in the labels')\n",
|
|
" \n",
|
|
" if label == 0:\n",
|
|
" count_l0+=1\n",
|
|
" \n",
|
|
" if label == 1:\n",
|
|
" count_l1+=1\n",
|
|
" \n",
|
|
" label=torch.tensor(label).long()\n",
|
|
" label= label.view(1)\n",
|
|
" \n",
|
|
" img= to_tensor( to_augment( translate(img, label, data_name, index, 0) ) ) \n",
|
|
" img_org= to_tensor( translate(img_org, label, data_name, index, 0) )\n",
|
|
" \n",
|
|
" \n",
|
|
" img= img.view(1, img.shape[0], img.shape[1], img.shape[2])\n",
|
|
" img_org= img_org.view(1, img_org.shape[0], img_org.shape[1], img_org.shape[2])\n",
|
|
" \n",
|
|
"# print('Data: ', data_name, count, img.shape, label)\n",
|
|
" imgs.append(img)\n",
|
|
" imgs_org.append(img_org)\n",
|
|
" labels.append(label)\n",
|
|
" \n",
|
|
" if torch.all(torch.eq(img, img_org)):\n",
|
|
" print('Error:')\n",
|
|
" \n",
|
|
" \n",
|
|
" imgs=torch.cat(imgs)\n",
|
|
" imgs_org=torch.cat(imgs_org)\n",
|
|
" labels=torch.cat(labels)\n",
|
|
" print(imgs.shape, imgs_org.shape, labels.shape, count_l0, count_l1)\n",
|
|
" \n",
|
|
" torch.save(imgs, base_dir + data_name + '_trans_' + case + '_' + 'image.pt')\n",
|
|
" torch.save(imgs_org, base_dir + data_name + '_trans_' + case + '_' + 'image_org.pt')\n",
|
|
" torch.save(labels, base_dir + data_name + '_trans_' + case + '_' + 'label.pt') "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Opp Translation for Test Domains"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"base_dir=root_dir + '/data/datasets/chestxray/'\n",
|
|
" \n",
|
|
"for data_name in ['nih', 'chex', 'kaggle']:\n",
|
|
" \n",
|
|
" indices= np.random.randint(0, data[data_name]['size'], data[data_name]['size'] )\n",
|
|
" print(indices.shape)\n",
|
|
" \n",
|
|
" count=0 \n",
|
|
" for case in ['train', 'val', 'test']:\n",
|
|
" \n",
|
|
" size= data[data_name][case+'_size'] \n",
|
|
" imgs=[]\n",
|
|
" labels=[] \n",
|
|
" imgs_org=[]\n",
|
|
" \n",
|
|
" indices= np.load(base_dir + data_name + '_' + case + '_' + 'indices.npy') \n",
|
|
" \n",
|
|
" count_l0=0\n",
|
|
" count_l1=0\n",
|
|
" for idx in range(indices.shape[0]):\n",
|
|
" \n",
|
|
" index= indices[idx].item()\n",
|
|
" task = xrv.datasets.default_pathologies.index('Pneumonia')\n",
|
|
" \n",
|
|
" img= data[data_name]['obj'][index]['img']\n",
|
|
" img_org= data[data_name]['obj'][index]['img']\n",
|
|
" label= data[data_name]['obj'][index]['lab'][task]\n",
|
|
" count+=1\n",
|
|
" \n",
|
|
" if np.isnan(data[data_name]['obj'][index]['lab'][task]):\n",
|
|
" print('Error: Nan in the labels')\n",
|
|
" \n",
|
|
" if label == 0:\n",
|
|
" count_l0+=1\n",
|
|
" \n",
|
|
" if label == 1:\n",
|
|
" count_l1+=1\n",
|
|
" \n",
|
|
" label=torch.tensor(label).long()\n",
|
|
" label= label.view(1)\n",
|
|
" \n",
|
|
" img= to_tensor( to_augment( translate(img, label, data_name, index, 1) ) ) \n",
|
|
" img_org= to_tensor( translate(img_org, label, data_name, index, 1) )\n",
|
|
" \n",
|
|
" \n",
|
|
" img= img.view(1, img.shape[0], img.shape[1], img.shape[2])\n",
|
|
" img_org= img_org.view(1, img_org.shape[0], img_org.shape[1], img_org.shape[2])\n",
|
|
" \n",
|
|
"# print('Data: ', data_name, count, img.shape, label)\n",
|
|
" imgs.append(img)\n",
|
|
" imgs_org.append(img_org)\n",
|
|
" labels.append(label)\n",
|
|
" \n",
|
|
" if torch.all(torch.eq(img, img_org)):\n",
|
|
" print('Error:')\n",
|
|
" \n",
|
|
" \n",
|
|
" imgs=torch.cat(imgs)\n",
|
|
" imgs_org=torch.cat(imgs_org)\n",
|
|
" labels=torch.cat(labels)\n",
|
|
" print(imgs.shape, imgs_org.shape, labels.shape, count_l0, count_l1)\n",
|
|
" \n",
|
|
" torch.save(imgs, base_dir + data_name + '_opp_trans_' + case + '_' + 'image.pt')\n",
|
|
" torch.save(imgs_org, base_dir + data_name + '_opp_trans_' + case + '_' + 'image_org.pt')\n",
|
|
" torch.save(labels, base_dir + data_name + '_opp_trans_' + case + '_' + 'label.pt') "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "matchdg-env",
|
|
"language": "python",
|
|
"name": "matchdg-env"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|