180 строки
6.1 KiB
Plaintext
180 строки
6.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import sys\n",
|
|
"import numpy as np\n",
|
|
"import argparse\n",
|
|
"import copy\n",
|
|
"import random\n",
|
|
"import json\n",
|
|
"from glob import glob\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from numpy import asarray\n",
|
|
"\n",
|
|
"#Rand Number using Numpy\n",
|
|
"from numpy.random import default_rng\n",
|
|
"\n",
|
|
"\n",
|
|
"#Sklearn\n",
|
|
"from scipy.stats import bernoulli\n",
|
|
"\n",
|
|
"#Pytorch\n",
|
|
"import torch\n",
|
|
"from torch.autograd import grad\n",
|
|
"from torch import nn, optim\n",
|
|
"from torch.nn import functional as F\n",
|
|
"from torchvision import datasets, transforms\n",
|
|
"from torchvision.utils import save_image\n",
|
|
"from torch.autograd import Variable\n",
|
|
"import torch.utils.data as data_utils\n",
|
|
"\n",
|
|
"#Pillow\n",
|
|
"from PIL import Image, ImageColor, ImageOps "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"to_pil= transforms.Compose([\n",
|
|
" transforms.ToPILImage(),\n",
|
|
" transforms.Resize((224, 224))\n",
|
|
" ])\n",
|
|
"\n",
|
|
"to_augment= transforms.Compose([\n",
|
|
" transforms.RandomResizedCrop(224, scale=(0.7,1.0)),\n",
|
|
" transforms.RandomHorizontalFlip()\n",
|
|
" ])\n",
|
|
"\n",
|
|
"to_tensor= transforms.Compose([\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
|
" ])\n",
|
|
" \n",
|
|
"color_list=['red', 'blue', 'green', 'orange', 'yellow', 'brown', 'pink', 'magenta', 'olive', 'cyan']\n",
|
|
"\n",
|
|
"def load_inds(mnist_subset, data_case):\n",
|
|
" data_dir= '../../data/datasets/rot_mnist/rot_mnist_resnet18_indices/'\n",
|
|
" if data_case != 'val':\n",
|
|
" return np.load(data_dir + '/supervised_inds_' + str(mnist_subset) + '.npy')\n",
|
|
" else:\n",
|
|
" return np.load(data_dir + '/val' + '/supervised_inds_' + str(mnist_subset) + '.npy')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_data(data_dir, data_case, subset ):\n",
|
|
"\n",
|
|
" data_obj_train= datasets.MNIST(data_dir,\n",
|
|
" train=True,\n",
|
|
" download=True,\n",
|
|
" transform=transforms.ToTensor()\n",
|
|
" )\n",
|
|
"\n",
|
|
" data_obj_test= datasets.MNIST(data_dir,\n",
|
|
" train=False,\n",
|
|
" download=True,\n",
|
|
" transform=transforms.ToTensor()\n",
|
|
" )\n",
|
|
" mnist_imgs= torch.cat((data_obj_train.data, data_obj_test.data))\n",
|
|
" mnist_labels= torch.cat((data_obj_train.targets, data_obj_test.targets))\n",
|
|
"\n",
|
|
"\n",
|
|
" # Select subset of data correponding to data_case (train/val/test) with amount domain_size \n",
|
|
" sub_inds= load_inds(subset, data_case) \n",
|
|
" imgs = mnist_imgs[sub_inds]\n",
|
|
" labels = mnist_labels[sub_inds]\n",
|
|
" mnist_size= labels.shape[0]\n",
|
|
" \n",
|
|
" rand_var= bernoulli.rvs(0.7, size=mnist_size)\n",
|
|
" spur= torch.tensor(rand_var)\n",
|
|
" \n",
|
|
" for rotation in [0, 15, 30, 45, 60 ,75, 90]:\n",
|
|
" \n",
|
|
" imgs_rot= torch.zeros((mnist_size, 3, 224, 224))\n",
|
|
" imgs_rot_org= torch.zeros((mnist_size, 3, 224, 224))\n",
|
|
" for idx in range(mnist_size):\n",
|
|
" curr_img= imgs[idx]\n",
|
|
" curr_img= to_pil(curr_img)\n",
|
|
"\n",
|
|
" #Color as additional feature additional\n",
|
|
" if rand_var[idx]:\n",
|
|
" curr_img = ImageOps.colorize(curr_img, black =\"black\", white =color_list[labels[idx].item()]) \n",
|
|
" else:\n",
|
|
" curr_img = ImageOps.colorize(curr_img, black =\"black\", white =\"white\") \n",
|
|
"\n",
|
|
" #Rotation\n",
|
|
" curr_img= transforms.functional.rotate(curr_img, rotation)\n",
|
|
"\n",
|
|
" #Augmentation\n",
|
|
" imgs_rot[idx]= to_tensor( to_augment(curr_img) )\n",
|
|
"\n",
|
|
" #No Augmentation\n",
|
|
" imgs_rot_org[idx]= to_tensor(curr_img)\n",
|
|
"\n",
|
|
" print('Data Case: ', data_case, ' Subset: ', subset, ' Rotation: ', rotation ) \n",
|
|
" print('Image: ', imgs_rot.shape, ' Labels: ', labels.shape, ' Spur: ', spur.shape)\n",
|
|
" print('Image: ', imgs_rot.dtype, ' Labels: ', labels.dtype, ' Spur: ', spur.dtype) \n",
|
|
" torch.save(imgs_rot, data_dir+ 'Imgs' + '_case_' + data_case + '_subset_' + str(subset) + '_rot_' + str(rotation)+ '.pt' )\n",
|
|
" torch.save(imgs_rot_org, data_dir+ 'Imgs_org' + '_case_' + data_case + '_subset_' + str(subset) + '_rot_' + str(rotation)+ '.pt' )\n",
|
|
" \n",
|
|
" torch.save(labels, data_dir+ 'Labels' + '_case_' + data_case + '_subset_' + str(subset) + '.pt' )\n",
|
|
" torch.save(spur, data_dir+ 'Spur' + '_case_' + data_case + '_subset_' + str(subset) + '.pt' )\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_dir= '../../data/datasets/rot_mnist_spur/'\n",
|
|
"\n",
|
|
"for data_case in ['val', 'test']:\n",
|
|
" for subset in range(10):\n",
|
|
" get_data(data_dir, data_case, subset)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "main-envs",
|
|
"language": "python",
|
|
"name": "main-envs"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|