{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import numpy as np\n", "import argparse\n", "import copy\n", "import random\n", "import json\n", "from glob import glob\n", "import matplotlib.pyplot as plt\n", "from numpy import asarray\n", "\n", "#Rand Number using Numpy\n", "from numpy.random import default_rng\n", "\n", "\n", "#Sklearn\n", "from scipy.stats import bernoulli\n", "\n", "#Pytorch\n", "import torch\n", "from torch.autograd import grad\n", "from torch import nn, optim\n", "from torch.nn import functional as F\n", "from torchvision import datasets, transforms\n", "from torchvision.utils import save_image\n", "from torch.autograd import Variable\n", "import torch.utils.data as data_utils\n", "\n", "#Pillow\n", "from PIL import Image, ImageColor, ImageOps " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "to_pil= transforms.Compose([\n", " transforms.ToPILImage(),\n", " transforms.Resize((224, 224))\n", " ])\n", "\n", "to_augment= transforms.Compose([\n", " transforms.RandomResizedCrop(224, scale=(0.7,1.0)),\n", " transforms.RandomHorizontalFlip()\n", " ])\n", "\n", "to_tensor= transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])\n", " \n", "color_list=['red', 'blue', 'green', 'orange', 'yellow', 'brown', 'pink', 'magenta', 'olive', 'cyan']\n", "\n", "def load_inds(mnist_subset, data_case):\n", " data_dir= '../../data/datasets/rot_mnist/rot_mnist_resnet18_indices/'\n", " if data_case != 'val':\n", " return np.load(data_dir + '/supervised_inds_' + str(mnist_subset) + '.npy')\n", " else:\n", " return np.load(data_dir + '/val' + '/supervised_inds_' + str(mnist_subset) + '.npy')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(data_dir, data_case, subset ):\n", "\n", " data_obj_train= datasets.MNIST(data_dir,\n", " train=True,\n", " download=True,\n", " transform=transforms.ToTensor()\n", " )\n", "\n", " data_obj_test= datasets.MNIST(data_dir,\n", " train=False,\n", " download=True,\n", " transform=transforms.ToTensor()\n", " )\n", " mnist_imgs= torch.cat((data_obj_train.data, data_obj_test.data))\n", " mnist_labels= torch.cat((data_obj_train.targets, data_obj_test.targets))\n", "\n", "\n", " # Select subset of data correponding to data_case (train/val/test) with amount domain_size \n", " sub_inds= load_inds(subset, data_case) \n", " imgs = mnist_imgs[sub_inds]\n", " labels = mnist_labels[sub_inds]\n", " mnist_size= labels.shape[0]\n", " \n", " rand_var= bernoulli.rvs(0.7, size=mnist_size)\n", " spur= torch.tensor(rand_var)\n", " \n", " for rotation in [0, 15, 30, 45, 60 ,75, 90]:\n", " \n", " imgs_rot= torch.zeros((mnist_size, 3, 224, 224))\n", " imgs_rot_org= torch.zeros((mnist_size, 3, 224, 224))\n", " for idx in range(mnist_size):\n", " curr_img= imgs[idx]\n", " curr_img= to_pil(curr_img)\n", "\n", " #Color as additional feature additional\n", " if rand_var[idx]:\n", " curr_img = ImageOps.colorize(curr_img, black =\"black\", white =color_list[labels[idx].item()]) \n", " else:\n", " curr_img = ImageOps.colorize(curr_img, black =\"black\", white =\"white\") \n", "\n", " #Rotation\n", " curr_img= transforms.functional.rotate(curr_img, rotation)\n", "\n", " #Augmentation\n", " imgs_rot[idx]= to_tensor( to_augment(curr_img) )\n", "\n", " #No Augmentation\n", " imgs_rot_org[idx]= to_tensor(curr_img)\n", "\n", " print('Data Case: ', data_case, ' Subset: ', subset, ' Rotation: ', rotation ) \n", " print('Image: ', imgs_rot.shape, ' Labels: ', labels.shape, ' Spur: ', spur.shape)\n", " print('Image: ', imgs_rot.dtype, ' Labels: ', labels.dtype, ' Spur: ', spur.dtype) \n", " torch.save(imgs_rot, data_dir+ 'Imgs' + '_case_' + data_case + '_subset_' + str(subset) + '_rot_' + str(rotation)+ '.pt' )\n", " torch.save(imgs_rot_org, data_dir+ 'Imgs_org' + '_case_' + data_case + '_subset_' + str(subset) + '_rot_' + str(rotation)+ '.pt' )\n", " \n", " torch.save(labels, data_dir+ 'Labels' + '_case_' + data_case + '_subset_' + str(subset) + '.pt' )\n", " torch.save(spur, data_dir+ 'Spur' + '_case_' + data_case + '_subset_' + str(subset) + '.pt' )\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_dir= '../../data/datasets/rot_mnist_spur/'\n", "\n", "for data_case in ['val', 'test']:\n", " for subset in range(10):\n", " get_data(data_dir, data_case, subset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "main-envs", "language": "python", "name": "main-envs" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 4 }