robustdg/docs/notebooks/Spur_Rotated_MNIST.ipynb

180 строки
6.1 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import numpy as np\n",
"import argparse\n",
"import copy\n",
"import random\n",
"import json\n",
"from glob import glob\n",
"import matplotlib.pyplot as plt\n",
"from numpy import asarray\n",
"\n",
"#Rand Number using Numpy\n",
"from numpy.random import default_rng\n",
"\n",
"\n",
"#Sklearn\n",
"from scipy.stats import bernoulli\n",
"\n",
"#Pytorch\n",
"import torch\n",
"from torch.autograd import grad\n",
"from torch import nn, optim\n",
"from torch.nn import functional as F\n",
"from torchvision import datasets, transforms\n",
"from torchvision.utils import save_image\n",
"from torch.autograd import Variable\n",
"import torch.utils.data as data_utils\n",
"\n",
"#Pillow\n",
"from PIL import Image, ImageColor, ImageOps "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"to_pil= transforms.Compose([\n",
" transforms.ToPILImage(),\n",
" transforms.Resize((224, 224))\n",
" ])\n",
"\n",
"to_augment= transforms.Compose([\n",
" transforms.RandomResizedCrop(224, scale=(0.7,1.0)),\n",
" transforms.RandomHorizontalFlip()\n",
" ])\n",
"\n",
"to_tensor= transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
" \n",
"color_list=['red', 'blue', 'green', 'orange', 'yellow', 'brown', 'pink', 'magenta', 'olive', 'cyan']\n",
"\n",
"def load_inds(mnist_subset, data_case):\n",
" data_dir= '../../data/datasets/rot_mnist/rot_mnist_resnet18_indices/'\n",
" if data_case != 'val':\n",
" return np.load(data_dir + '/supervised_inds_' + str(mnist_subset) + '.npy')\n",
" else:\n",
" return np.load(data_dir + '/val' + '/supervised_inds_' + str(mnist_subset) + '.npy')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_data(data_dir, data_case, subset ):\n",
"\n",
" data_obj_train= datasets.MNIST(data_dir,\n",
" train=True,\n",
" download=True,\n",
" transform=transforms.ToTensor()\n",
" )\n",
"\n",
" data_obj_test= datasets.MNIST(data_dir,\n",
" train=False,\n",
" download=True,\n",
" transform=transforms.ToTensor()\n",
" )\n",
" mnist_imgs= torch.cat((data_obj_train.data, data_obj_test.data))\n",
" mnist_labels= torch.cat((data_obj_train.targets, data_obj_test.targets))\n",
"\n",
"\n",
" # Select subset of data correponding to data_case (train/val/test) with amount domain_size \n",
" sub_inds= load_inds(subset, data_case) \n",
" imgs = mnist_imgs[sub_inds]\n",
" labels = mnist_labels[sub_inds]\n",
" mnist_size= labels.shape[0]\n",
" \n",
" rand_var= bernoulli.rvs(0.7, size=mnist_size)\n",
" spur= torch.tensor(rand_var)\n",
" \n",
" for rotation in [0, 15, 30, 45, 60 ,75, 90]:\n",
" \n",
" imgs_rot= torch.zeros((mnist_size, 3, 224, 224))\n",
" imgs_rot_org= torch.zeros((mnist_size, 3, 224, 224))\n",
" for idx in range(mnist_size):\n",
" curr_img= imgs[idx]\n",
" curr_img= to_pil(curr_img)\n",
"\n",
" #Color as additional feature additional\n",
" if rand_var[idx]:\n",
" curr_img = ImageOps.colorize(curr_img, black =\"black\", white =color_list[labels[idx].item()]) \n",
" else:\n",
" curr_img = ImageOps.colorize(curr_img, black =\"black\", white =\"white\") \n",
"\n",
" #Rotation\n",
" curr_img= transforms.functional.rotate(curr_img, rotation)\n",
"\n",
" #Augmentation\n",
" imgs_rot[idx]= to_tensor( to_augment(curr_img) )\n",
"\n",
" #No Augmentation\n",
" imgs_rot_org[idx]= to_tensor(curr_img)\n",
"\n",
" print('Data Case: ', data_case, ' Subset: ', subset, ' Rotation: ', rotation ) \n",
" print('Image: ', imgs_rot.shape, ' Labels: ', labels.shape, ' Spur: ', spur.shape)\n",
" print('Image: ', imgs_rot.dtype, ' Labels: ', labels.dtype, ' Spur: ', spur.dtype) \n",
" torch.save(imgs_rot, data_dir+ 'Imgs' + '_case_' + data_case + '_subset_' + str(subset) + '_rot_' + str(rotation)+ '.pt' )\n",
" torch.save(imgs_rot_org, data_dir+ 'Imgs_org' + '_case_' + data_case + '_subset_' + str(subset) + '_rot_' + str(rotation)+ '.pt' )\n",
" \n",
" torch.save(labels, data_dir+ 'Labels' + '_case_' + data_case + '_subset_' + str(subset) + '.pt' )\n",
" torch.save(spur, data_dir+ 'Spur' + '_case_' + data_case + '_subset_' + str(subset) + '.pt' )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_dir= '../../data/datasets/rot_mnist_spur/'\n",
"\n",
"for data_case in ['val', 'test']:\n",
" for subset in range(10):\n",
" get_data(data_dir, data_case, subset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "main-envs",
"language": "python",
"name": "main-envs"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}