In [None]:
import os
import sys
import numpy as np
import argparse
import copy
import random
import json
from glob import glob
import matplotlib.pyplot as plt
from numpy import asarray

#Rand Number using Numpy
from numpy.random import default_rng


#Sklearn
from scipy.stats import bernoulli

#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

#Pillow
from PIL import Image, ImageColor, ImageOps 

In [None]:
to_pil= transforms.Compose([
 transforms.ToPILImage(),
 transforms.Resize((224, 224))
 ])

to_augment= transforms.Compose([
 transforms.RandomResizedCrop(224, scale=(0.7,1.0)),
 transforms.RandomHorizontalFlip()
 ])

to_tensor= transforms.Compose([
 transforms.ToTensor(),
 transforms.Normalize((0.1307,), (0.3081,))
 ])
 
color_list=['red', 'blue', 'green', 'orange', 'yellow', 'brown', 'pink', 'magenta', 'olive', 'cyan']

def load_inds(mnist_subset, data_case):
 data_dir= '../../data/datasets/rot_mnist/rot_mnist_resnet18_indices/'
 if data_case != 'val':
 return np.load(data_dir + '/supervised_inds_' + str(mnist_subset) + '.npy')
 else:
 return np.load(data_dir + '/val' + '/supervised_inds_' + str(mnist_subset) + '.npy')

In [None]:
def get_data(data_dir, data_case, subset ):

 data_obj_train= datasets.MNIST(data_dir,
 train=True,
 download=True,
 transform=transforms.ToTensor()
 )

 data_obj_test= datasets.MNIST(data_dir,
 train=False,
 download=True,
 transform=transforms.ToTensor()
 )
 mnist_imgs= torch.cat((data_obj_train.data, data_obj_test.data))
 mnist_labels= torch.cat((data_obj_train.targets, data_obj_test.targets))


 # Select subset of data correponding to data_case (train/val/test) with amount domain_size 
 sub_inds= load_inds(subset, data_case) 
 imgs = mnist_imgs[sub_inds]
 labels = mnist_labels[sub_inds]
 mnist_size= labels.shape[0]
 
 rand_var= bernoulli.rvs(0.7, size=mnist_size)
 spur= torch.tensor(rand_var)
 
 for rotation in [0, 15, 30, 45, 60 ,75, 90]:
 
 imgs_rot= torch.zeros((mnist_size, 3, 224, 224))
 imgs_rot_org= torch.zeros((mnist_size, 3, 224, 224))
 for idx in range(mnist_size):
 curr_img= imgs[idx]
 curr_img= to_pil(curr_img)

 #Color as additional feature additional
 if rand_var[idx]:
 curr_img = ImageOps.colorize(curr_img, black ="black", white =color_list[labels[idx].item()]) 
 else:
 curr_img = ImageOps.colorize(curr_img, black ="black", white ="white") 

 #Rotation
 curr_img= transforms.functional.rotate(curr_img, rotation)

 #Augmentation
 imgs_rot[idx]= to_tensor( to_augment(curr_img) )

 #No Augmentation
 imgs_rot_org[idx]= to_tensor(curr_img)

 print('Data Case: ', data_case, ' Subset: ', subset, ' Rotation: ', rotation ) 
 print('Image: ', imgs_rot.shape, ' Labels: ', labels.shape, ' Spur: ', spur.shape)
 print('Image: ', imgs_rot.dtype, ' Labels: ', labels.dtype, ' Spur: ', spur.dtype) 
 torch.save(imgs_rot, data_dir+ 'Imgs' + '_case_' + data_case + '_subset_' + str(subset) + '_rot_' + str(rotation)+ '.pt' )
 torch.save(imgs_rot_org, data_dir+ 'Imgs_org' + '_case_' + data_case + '_subset_' + str(subset) + '_rot_' + str(rotation)+ '.pt' )
 
 torch.save(labels, data_dir+ 'Labels' + '_case_' + data_case + '_subset_' + str(subset) + '.pt' )
 torch.save(spur, data_dir+ 'Spur' + '_case_' + data_case + '_subset_' + str(subset) + '.pt' )


In [None]:
data_dir= '../../data/datasets/rot_mnist_spur/'

for data_case in ['val', 'test']:
 for subset in range(10):
 get_data(data_dir, data_case, subset)