1044 строки
43 KiB
Plaintext
1044 строки
43 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import random\n",
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class DataLoader:\n",
|
|
" \"\"\"A data interface for public data.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, params):\n",
|
|
" \"\"\"Init method\n",
|
|
"\n",
|
|
" :param dataframe: Pandas DataFrame.\n",
|
|
" :param continuous_features: List of names of continuous features. The remaining features are categorical features.\n",
|
|
" :param outcome_name: Outcome feature name.\n",
|
|
" :param permitted_range (optional): Dictionary with feature names as keys and permitted range as values. Defaults to the range inferred from training data.\n",
|
|
" :param test_size (optional): Proportion of test set split. Defaults to 0.2.\n",
|
|
" :param test_split_random_state (optional): Random state for train test split. Defaults to 17.\n",
|
|
"\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" if isinstance(params['dataframe'], pd.DataFrame):\n",
|
|
" self.data_df = params['dataframe']\n",
|
|
" else:\n",
|
|
" raise ValueError(\"should provide a pandas dataframe\")\n",
|
|
"\n",
|
|
" if type(params['continuous_features']) is list:\n",
|
|
" self.continuous_feature_names = params['continuous_features']\n",
|
|
" else:\n",
|
|
" raise ValueError(\n",
|
|
" \"should provide the name(s) of continuous features in the data\")\n",
|
|
"\n",
|
|
" if type(params['outcome_name']) is str:\n",
|
|
" self.outcome_name = params['outcome_name']\n",
|
|
" else:\n",
|
|
" raise ValueError(\"should provide the name of outcome feature\")\n",
|
|
"\n",
|
|
" self.categorical_feature_names = [name for name in self.data_df.columns.tolist(\n",
|
|
" ) if name not in self.continuous_feature_names+[self.outcome_name]]\n",
|
|
"\n",
|
|
" self.feature_names = [\n",
|
|
" name for name in self.data_df.columns.tolist() if name != self.outcome_name]\n",
|
|
"\n",
|
|
" self.continuous_feature_indexes = [self.data_df.columns.get_loc(\n",
|
|
" name) for name in self.continuous_feature_names if name in self.data_df]\n",
|
|
"\n",
|
|
" self.categorical_feature_indexes = [self.data_df.columns.get_loc(\n",
|
|
" name) for name in self.categorical_feature_names if name in self.data_df]\n",
|
|
"\n",
|
|
" if 'test_size' in params:\n",
|
|
" self.test_size = params['test_size']\n",
|
|
" else:\n",
|
|
" self.test_size = 0.2\n",
|
|
"\n",
|
|
" if 'test_split_random_state' in params:\n",
|
|
" self.test_split_random_state = params['test_split_random_state']\n",
|
|
" else:\n",
|
|
" self.test_split_random_state = 17\n",
|
|
"\n",
|
|
" if len(self.categorical_feature_names) > 0:\n",
|
|
" self.data_df[self.categorical_feature_names] = self.data_df[self.categorical_feature_names].astype(\n",
|
|
" 'category')\n",
|
|
" if len(self.continuous_feature_names) > 0:\n",
|
|
" print(self.data_df.head())\n",
|
|
"# for feature in self.continuous_feature_names:\n",
|
|
"# if self.get_data_type(self.data_df[feature]) == ' float':\n",
|
|
"# self.data_df[self.continuous_feature_names] = self.data_df[self.continuous_feature_names].astype(\n",
|
|
"# float)\n",
|
|
"# else:\n",
|
|
"# self.data_df[self.continuous_feature_names] = self.data_df[self.continuous_feature_names].astype(\n",
|
|
"# int)\n",
|
|
" print(self.data_df.head())\n",
|
|
"\n",
|
|
" if len(self.categorical_feature_names) > 0:\n",
|
|
"# print(self.data_df.head())\n",
|
|
" self.one_hot_encoded_data = self.one_hot_encode_data(self.data_df)\n",
|
|
"# print(self.one_hot_encoded_data.head())\n",
|
|
" self.encoded_feature_names = [x for x in self.one_hot_encoded_data.columns.tolist(\n",
|
|
" ) if x not in np.array([self.outcome_name])]\n",
|
|
" else:\n",
|
|
" # one-hot-encoded data is same as orignial data if there is no categorical features.\n",
|
|
" self.one_hot_encoded_data = self.data_df\n",
|
|
" self.encoded_feature_names = self.feature_names\n",
|
|
"\n",
|
|
" self.train_df, self.test_df = self.split_data(self.data_df)\n",
|
|
" if 'permitted_range' in params:\n",
|
|
" self.permitted_range = params['permitted_range']\n",
|
|
" else:\n",
|
|
" self.permitted_range = self.get_features_range()\n",
|
|
"\n",
|
|
" def get_features_range(self):\n",
|
|
" ranges = {}\n",
|
|
" for feature_name in self.continuous_feature_names:\n",
|
|
" ranges[feature_name] = [\n",
|
|
" self.data_df[feature_name].min(), self.data_df[feature_name].max()]\n",
|
|
" return ranges\n",
|
|
"\n",
|
|
" def get_data_type(self, col):\n",
|
|
" \"\"\"Infers data type of a feature from the training data.\"\"\"\n",
|
|
" for instance in col.tolist():\n",
|
|
" if isinstance(instance, int):\n",
|
|
" return 'int'\n",
|
|
" else:\n",
|
|
" if float(str(instance).split('.')[1]) > 0:\n",
|
|
" return 'float'\n",
|
|
" return 'int'\n",
|
|
"\n",
|
|
" def one_hot_encode_data(self, data):\n",
|
|
" \"\"\"One-hot-encodes the data.\"\"\"\n",
|
|
" return pd.get_dummies(data, drop_first=False, columns=self.categorical_feature_names)\n",
|
|
"\n",
|
|
" def normalize_data(self, df):\n",
|
|
" \"\"\"Normalizes continuous features to make them fall in the range [0,1].\"\"\"\n",
|
|
" result = df.copy()\n",
|
|
" for feature_name in self.continuous_feature_names:\n",
|
|
" max_value = self.data_df[feature_name].max()\n",
|
|
" min_value = self.data_df[feature_name].min()\n",
|
|
" result[feature_name] = (\n",
|
|
" df[feature_name] - min_value) / (max_value - min_value)\n",
|
|
" return result\n",
|
|
"\n",
|
|
" def de_normalize_data(self, df):\n",
|
|
" \"\"\"De-normalizes continuous features from [0,1] range to original range.\"\"\"\n",
|
|
" result = df.copy()\n",
|
|
" for feature_name in self.continuous_feature_names:\n",
|
|
" max_value = self.data_df[feature_name].max()\n",
|
|
" min_value = self.data_df[feature_name].min()\n",
|
|
" result[feature_name] = (\n",
|
|
" df[feature_name]*(max_value - min_value)) + min_value\n",
|
|
" return result\n",
|
|
"\n",
|
|
" def get_minx_maxx(self, normalized=True):\n",
|
|
" \"\"\"Gets the min/max value of features in normalized or de-normalized form.\"\"\"\n",
|
|
" minx = np.array([[0.0]*len(self.encoded_feature_names)])\n",
|
|
" maxx = np.array([[1.0]*len(self.encoded_feature_names)])\n",
|
|
"\n",
|
|
" for idx, feature_name in enumerate(self.continuous_feature_names):\n",
|
|
" max_value = self.data_df[feature_name].max()\n",
|
|
" min_value = self.data_df[feature_name].min()\n",
|
|
"\n",
|
|
" if normalized:\n",
|
|
" minx[0][idx] = (self.permitted_range[feature_name]\n",
|
|
" [0] - min_value) / (max_value - min_value)\n",
|
|
" maxx[0][idx] = (self.permitted_range[feature_name]\n",
|
|
" [1] - min_value) / (max_value - min_value)\n",
|
|
" else:\n",
|
|
" minx[0][idx] = self.permitted_range[feature_name][0]\n",
|
|
" maxx[0][idx] = self.permitted_range[feature_name][1]\n",
|
|
" return minx, maxx\n",
|
|
"\n",
|
|
" def split_data(self, data):\n",
|
|
" train_df, test_df = train_test_split(\n",
|
|
" data, test_size=self.test_size, random_state=self.test_split_random_state)\n",
|
|
" return train_df, test_df\n",
|
|
"\n",
|
|
" def get_mads_from_training_data(self, normalized=False):\n",
|
|
" \"\"\"Computes Median Absolute Deviation of features.\"\"\"\n",
|
|
"\n",
|
|
" mads = {}\n",
|
|
" if normalized is False:\n",
|
|
" for feature in self.continuous_feature_names:\n",
|
|
" mads[feature] = np.median(\n",
|
|
" abs(self.data_df[feature].values - np.median(self.data_df[feature].values)))\n",
|
|
" else:\n",
|
|
" normalized_train_df = self.normalize_data(self.train_df)\n",
|
|
" for feature in self.continuous_feature_names:\n",
|
|
" mads[feature] = np.median(\n",
|
|
" abs(normalized_train_df[feature].values - np.median(normalized_train_df[feature].values)))\n",
|
|
" return mads\n",
|
|
"\n",
|
|
" def get_data_params(self):\n",
|
|
" \"\"\"Gets all data related params for DiCE.\"\"\"\n",
|
|
"\n",
|
|
" minx, maxx = self.get_minx_maxx(normalized=True)\n",
|
|
"\n",
|
|
" # get the column indexes of categorical features after one-hot-encoding\n",
|
|
" self.encoded_categorical_feature_indexes = self.get_encoded_categorical_feature_indexes()\n",
|
|
"\n",
|
|
" return minx, maxx, self.encoded_categorical_feature_indexes\n",
|
|
"\n",
|
|
" def get_encoded_categorical_feature_indexes(self):\n",
|
|
" \"\"\"Gets the column indexes categorical features after one-hot-encoding.\"\"\"\n",
|
|
" cols = []\n",
|
|
" for col_parent in self.categorical_feature_names:\n",
|
|
" temp = [self.encoded_feature_names.index(\n",
|
|
" col) for col in self.encoded_feature_names if col.startswith(col_parent)]\n",
|
|
" cols.append(temp)\n",
|
|
" return cols\n",
|
|
"\n",
|
|
" def get_indexes_of_features_to_vary(self, features_to_vary='all'):\n",
|
|
" \"\"\"Gets indexes from feature names of one-hot-encoded data.\"\"\"\n",
|
|
" if features_to_vary == \"all\":\n",
|
|
" return [i for i in range(len(self.encoded_feature_names))]\n",
|
|
" else:\n",
|
|
" return [colidx for colidx, col in enumerate(self.encoded_feature_names) if col.startswith(tuple(features_to_vary))]\n",
|
|
"\n",
|
|
" def from_dummies(self, data, prefix_sep='_'):\n",
|
|
" \"\"\"Gets the original data from dummy encoded data with k levels.\"\"\"\n",
|
|
" out = data.copy()\n",
|
|
" for l in self.categorical_feature_names:\n",
|
|
" cols, labs = [[c.replace(\n",
|
|
" x, \"\") for c in data.columns if l+prefix_sep in c] for x in [\"\", l+prefix_sep]]\n",
|
|
" out[l] = pd.Categorical(\n",
|
|
" np.array(labs)[np.argmax(data[cols].values, axis=1)])\n",
|
|
" out.drop(cols, axis=1, inplace=True)\n",
|
|
" return out\n",
|
|
"\n",
|
|
" def get_decimal_precisions(self):\n",
|
|
" \"\"\"\"Gets the precision of continuous features in the data.\"\"\"\n",
|
|
" precisions = [0]*len(self.feature_names)\n",
|
|
" for ix, col in enumerate(self.continuous_feature_names):\n",
|
|
" precisions[ix] = 0\n",
|
|
" for instance in self.data_df[col].tolist():\n",
|
|
" if isinstance(instance, int):\n",
|
|
" precisions[ix] = 0\n",
|
|
" break\n",
|
|
" else:\n",
|
|
" if float(str(instance).split('.')[1]) > 0:\n",
|
|
" precisions[ix] = len(str(instance).split('.')[1])\n",
|
|
" break\n",
|
|
" return precisions\n",
|
|
"\n",
|
|
" def get_decoded_data(self, data):\n",
|
|
" \"\"\"Gets the original data from dummy encoded data.\"\"\"\n",
|
|
" if isinstance(data, np.ndarray):\n",
|
|
" index = [i for i in range(0, len(data))]\n",
|
|
" data = pd.DataFrame(data=data, index=index,\n",
|
|
" columns=self.encoded_feature_names)\n",
|
|
" return self.from_dummies(data)\n",
|
|
"\n",
|
|
" def prepare_df_for_encoding(self):\n",
|
|
" \"\"\"Facilitates prepare_query_instance() function.\"\"\"\n",
|
|
" levels = []\n",
|
|
" colnames = self.categorical_feature_names\n",
|
|
" for cat_feature in colnames:\n",
|
|
" levels.append(self.data_df[cat_feature].cat.categories.tolist())\n",
|
|
"\n",
|
|
" df = pd.DataFrame({colnames[0]: levels[0]})\n",
|
|
" for col in range(1, len(colnames)):\n",
|
|
" temp_df = pd.DataFrame({colnames[col]: levels[col]})\n",
|
|
" df = pd.concat([df, temp_df], axis=1, sort=False)\n",
|
|
"\n",
|
|
" colnames = self.continuous_feature_names\n",
|
|
" for col in range(0, len(colnames)):\n",
|
|
" temp_df = pd.DataFrame({colnames[col]: []})\n",
|
|
" df = pd.concat([df, temp_df], axis=1, sort=False)\n",
|
|
"\n",
|
|
" return df\n",
|
|
"\n",
|
|
" def prepare_query_instance(self, query_instance, encode):\n",
|
|
" \"\"\"Prepares user defined test input for DiCE.\"\"\"\n",
|
|
"\n",
|
|
" if isinstance(query_instance, list):\n",
|
|
" query_instance = {'row1': query_instance}\n",
|
|
" test = pd.DataFrame.from_dict(\n",
|
|
" query_instance, orient='index', columns=self.feature_names)\n",
|
|
"\n",
|
|
" elif isinstance(query_instance, dict):\n",
|
|
" query_instance = dict(zip(query_instance.keys(), [[q] for q in query_instance.values()]))\n",
|
|
" test = pd.DataFrame(query_instance, columns=self.feature_names)\n",
|
|
"\n",
|
|
" test = test.reset_index(drop=True)\n",
|
|
"\n",
|
|
" if encode is False:\n",
|
|
" return self.normalize_data(test)\n",
|
|
" else:\n",
|
|
" temp = self.prepare_df_for_encoding()\n",
|
|
"\n",
|
|
" temp = temp.append(test, ignore_index=True, sort=False)\n",
|
|
" temp = self.one_hot_encode_data(temp)\n",
|
|
" temp = self.normalize_data(temp)\n",
|
|
"\n",
|
|
" return temp.tail(test.shape[0]).reset_index(drop=True)\n",
|
|
"\n",
|
|
" def get_dev_data(self, model_interface, desired_class, filter_threshold=0.5):\n",
|
|
" \"\"\"Constructs dev data by extracting part of the test data for which finding counterfactuals make sense.\"\"\"\n",
|
|
"\n",
|
|
" # create TensorFLow session if one is not already created\n",
|
|
" if tf.get_default_session() is not None:\n",
|
|
" self.data_sess = tf.get_default_session()\n",
|
|
" else:\n",
|
|
" self.data_sess = tf.InteractiveSession()\n",
|
|
"\n",
|
|
" # loading trained model\n",
|
|
" model_interface.load_model()\n",
|
|
"\n",
|
|
" # get the permitted range of change for each feature\n",
|
|
" minx, maxx = self.get_minx_maxx(normalized=True)\n",
|
|
"\n",
|
|
" # get the transformed data: continuous features are normalized to fall in the range [0,1], and categorical features are one-hot encoded\n",
|
|
" data_df_transformed = self.normalize_data(self.one_hot_encoded_data)\n",
|
|
"\n",
|
|
" # split data - nomralization considers only train df and there is no leakage due to transformation before train-test splitting\n",
|
|
" _, test = self.split_data(data_df_transformed)\n",
|
|
" test = test.drop_duplicates(\n",
|
|
" subset=self.encoded_feature_names).reset_index(drop=True)\n",
|
|
"\n",
|
|
" # finding target predicted probabilities\n",
|
|
" input_tensor = tf.Variable(minx, dtype=tf.float32)\n",
|
|
" output_tensor = model_interface.get_output(\n",
|
|
" input_tensor) # model(input_tensor)\n",
|
|
" temp_data = test[self.encoded_feature_names].values.astype(np.float32)\n",
|
|
" dev_preds = [self.data_sess.run(output_tensor, feed_dict={\n",
|
|
" input_tensor: np.array([dt])}) for dt in temp_data]\n",
|
|
" dev_preds = [dev_preds[i][0][0] for i in range(len(dev_preds))]\n",
|
|
"\n",
|
|
" # filtering examples which have predicted value >/< threshold\n",
|
|
" dev_data = test[self.encoded_feature_names]\n",
|
|
" if desired_class == 0:\n",
|
|
" idxs = [i for i in range(len(dev_preds))\n",
|
|
" if dev_preds[i] > filter_threshold]\n",
|
|
" else:\n",
|
|
" idxs = [i for i in range(len(dev_preds))\n",
|
|
" if dev_preds[i] < filter_threshold]\n",
|
|
" dev_data = dev_data.iloc[idxs]\n",
|
|
" dev_preds = [dev_preds[i] for i in idxs]\n",
|
|
"\n",
|
|
" # convert from one-hot encoded vals to user interpretable fromat\n",
|
|
" dev_data = self.from_dummies(dev_data)\n",
|
|
" dev_data = self.de_normalize_data(dev_data)\n",
|
|
" return dev_data, dev_preds # values.tolist()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"base_dir= '../../data/datasets/adult/'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" age workclass education marital_status occupation race sex \\\n",
|
|
"0 39 Government Bachelors Single White-Collar White Male \n",
|
|
"1 50 Self-Employed Bachelors Married White-Collar White Male \n",
|
|
"2 38 Private HS-grad Divorced Blue-Collar White Male \n",
|
|
"3 53 Private School Married Blue-Collar Other Male \n",
|
|
"4 28 Private Bachelors Married Professional Other Female \n",
|
|
"\n",
|
|
" hours_per_week income \n",
|
|
"0 40 0 \n",
|
|
"1 13 0 \n",
|
|
"2 40 0 \n",
|
|
"3 40 0 \n",
|
|
"4 40 0 \n",
|
|
" age workclass education marital_status occupation race sex \\\n",
|
|
"0 39 Government Bachelors Single White-Collar White Male \n",
|
|
"1 50 Self-Employed Bachelors Married White-Collar White Male \n",
|
|
"2 38 Private HS-grad Divorced Blue-Collar White Male \n",
|
|
"3 53 Private School Married Blue-Collar Other Male \n",
|
|
"4 28 Private Bachelors Married Professional Other Female \n",
|
|
"\n",
|
|
" hours_per_week income \n",
|
|
"0 40 0 \n",
|
|
"1 13 0 \n",
|
|
"2 40 0 \n",
|
|
"3 40 0 \n",
|
|
"4 40 0 \n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"continuous_features=['age', 'hours_per_week']\n",
|
|
"outcome_name='income'\n",
|
|
"data='adult'\n",
|
|
"\n",
|
|
"dataset = pd.read_csv(base_dir + 'adult.csv')\n",
|
|
"params= {'dataframe':dataset.copy(), 'continuous_features':continuous_features, 'outcome_name':outcome_name}\n",
|
|
"d = DataLoader(params)\n",
|
|
"\n",
|
|
"# d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')\n",
|
|
"data_df= d.data_df.copy()\n",
|
|
"data_df= data_df.sample(n=len(data_df), random_state=100)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"encoded_data = d.normalize_data(d.one_hot_encode_data(data_df))\n",
|
|
"# Need to rearrange columns such that the Income comes at the last\n",
|
|
"cols = list(encoded_data.columns)\n",
|
|
"cols = cols[:2] + cols[3:] + [cols[2]]\n",
|
|
"encoded_data = encoded_data[cols]\n",
|
|
"columns= encoded_data.columns.tolist()\n",
|
|
"dataset = encoded_data.to_numpy()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['age',\n",
|
|
" 'hours_per_week',\n",
|
|
" 'workclass_Government',\n",
|
|
" 'workclass_Other/Unknown',\n",
|
|
" 'workclass_Private',\n",
|
|
" 'workclass_Self-Employed',\n",
|
|
" 'education_Assoc',\n",
|
|
" 'education_Bachelors',\n",
|
|
" 'education_Doctorate',\n",
|
|
" 'education_HS-grad',\n",
|
|
" 'education_Masters',\n",
|
|
" 'education_Prof-school',\n",
|
|
" 'education_School',\n",
|
|
" 'education_Some-college',\n",
|
|
" 'marital_status_Divorced',\n",
|
|
" 'marital_status_Married',\n",
|
|
" 'marital_status_Separated',\n",
|
|
" 'marital_status_Single',\n",
|
|
" 'marital_status_Widowed',\n",
|
|
" 'occupation_Blue-Collar',\n",
|
|
" 'occupation_Other/Unknown',\n",
|
|
" 'occupation_Professional',\n",
|
|
" 'occupation_Sales',\n",
|
|
" 'occupation_Service',\n",
|
|
" 'occupation_White-Collar',\n",
|
|
" 'race_Other',\n",
|
|
" 'race_White',\n",
|
|
" 'sex_Female',\n",
|
|
" 'sex_Male',\n",
|
|
" 'income']"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"columns"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Version 1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#train_data_vae= train_data_vae[ train_data_vae['income']==0 ]\n",
|
|
"#train_data_vae.drop('income', axis=1, inplace=True)\n",
|
|
"\n",
|
|
"idx= columns.index('sex_Male')\n",
|
|
"data_male= dataset[dataset[:, idx]==1]\n",
|
|
"idx= columns.index('sex_Female')\n",
|
|
"data_female= dataset[dataset[:, idx]==1]\n",
|
|
"print('Gender Stats: ', data_male.shape, data_female.shape)\n",
|
|
"\n",
|
|
"# Get low and high income groups from the last dimension in the data\n",
|
|
"data_male_low_inc= data_male[data_male[:, -1] == 0]\n",
|
|
"data_male_high_inc= data_male[data_male[:, -1] == 1]\n",
|
|
"data_female_low_inc= data_female[data_female[:, -1] == 0]\n",
|
|
"data_female_high_inc= data_female[data_female[:, -1] == 1]\n",
|
|
"\n",
|
|
"# print('Male: ', data_male_low_inc.shape, data_male_high_inc.shape)\n",
|
|
"# print('Female: ', data_female_low_inc.shape, data_female_high_inc.shape)\n",
|
|
"\n",
|
|
"# Male: Train: (1000, 64), Test: (60, 940) \n",
|
|
"# Female: Train: (64, 1000), Test: (940, 60) \n",
|
|
"# Val created from Train by selecting 60 from 1000 and 4 from 64\n",
|
|
"\n",
|
|
"#Male Dataset\n",
|
|
"train_high_inc= data_male_high_inc[:1000]\n",
|
|
"test_high_inc= data_male_high_inc[-61:-1]\n",
|
|
"\n",
|
|
"train_low_inc= data_male_low_inc[:64]\n",
|
|
"test_low_inc= data_male_low_inc[-941:-1]\n",
|
|
"\n",
|
|
"val_high_inc= train_high_inc[940:]\n",
|
|
"train_high_inc= train_high_inc[:940]\n",
|
|
"val_low_inc= train_low_inc[60:]\n",
|
|
"train_low_inc= train_low_inc[:60]\n",
|
|
"\n",
|
|
"train_male= np.concatenate((train_high_inc, train_low_inc), axis=0)\n",
|
|
"val_male= np.concatenate((val_high_inc, val_low_inc), axis=0)\n",
|
|
"test_male= np.concatenate((test_high_inc, test_low_inc), axis=0)\n",
|
|
"\n",
|
|
"#Female Dataset\n",
|
|
"train_high_inc= data_female_high_inc[:64]\n",
|
|
"test_high_inc= data_female_high_inc[-941:-1]\n",
|
|
"\n",
|
|
"train_low_inc= data_female_low_inc[:1000]\n",
|
|
"test_low_inc= data_female_low_inc[-61:-1]\n",
|
|
"\n",
|
|
"val_high_inc= train_high_inc[60:]\n",
|
|
"train_high_inc= train_high_inc[:60]\n",
|
|
"\n",
|
|
"val_low_inc= train_low_inc[940:]\n",
|
|
"train_low_inc= train_low_inc[:940]\n",
|
|
"\n",
|
|
"train_female= np.concatenate((train_high_inc, train_low_inc), axis=0)\n",
|
|
"val_female= np.concatenate((val_high_inc, val_low_inc), axis=0)\n",
|
|
"test_female= np.concatenate((test_high_inc, test_low_inc), axis=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print('Male Final: ', train_male.shape, val_male.shape, test_male.shape)\n",
|
|
"print('Female Final: ', train_male.shape, val_male.shape, test_male.shape)\n",
|
|
"print('Sanity Check Male')\n",
|
|
"print('Train', np.sum(train_male[:, -1]==0), np.sum(train_male[:, -1]==1))\n",
|
|
"print('Val', np.sum(val_male[:, -1]==0), np.sum(val_male[:, -1]==1))\n",
|
|
"print('Test', np.sum(test_male[:, -1]==0), np.sum(test_male[:, -1]==1))\n",
|
|
"\n",
|
|
"print('Sanity Check Female')\n",
|
|
"print('Train', np.sum(train_female[:, -1]==0), np.sum(train_female[:, -1]==1))\n",
|
|
"print('Val', np.sum(val_female[:, -1]==0), np.sum(val_female[:, -1]==1))\n",
|
|
"print('Test', np.sum(test_female[:, -1]==0), np.sum(test_female[:, -1]==1))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#Male\n",
|
|
"torch.save(torch.tensor(train_male[:, :-1]), base_dir + 'male_train_data.pt')\n",
|
|
"torch.save(torch.tensor(train_male[:, -1]), base_dir + 'male_train_label.pt')\n",
|
|
"\n",
|
|
"torch.save(torch.tensor(val_male[:, :-1]), base_dir + 'male_val_data.pt')\n",
|
|
"torch.save(torch.tensor(val_male[:, -1]), base_dir + 'male_val_label.pt')\n",
|
|
"\n",
|
|
"torch.save(torch.tensor(test_male[:, :-1]), base_dir + 'male_test_data.pt')\n",
|
|
"torch.save(torch.tensor(test_male[:, -1]), base_dir + 'male_test_label.pt')\n",
|
|
"\n",
|
|
"#Female\n",
|
|
"torch.save(torch.tensor(train_female[:, :-1]), base_dir + 'female_train_data.pt')\n",
|
|
"torch.save(torch.tensor(train_female[:, -1]), base_dir + 'female_train_label.pt')\n",
|
|
"\n",
|
|
"torch.save(torch.tensor(val_female[:, :-1]), base_dir + 'female_val_data.pt')\n",
|
|
"torch.save(torch.tensor(val_female[:, -1]), base_dir + 'female_val_label.pt')\n",
|
|
"\n",
|
|
"torch.save(torch.tensor(test_female[:, :-1]), base_dir + 'female_test_data.pt')\n",
|
|
"torch.save(torch.tensor(test_female[:, -1]), base_dir + 'female_test_label.pt')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Version 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Gender Stats: (21790, 30) (10771, 30)\n",
|
|
"Male: (15128, 30) (6662, 30)\n",
|
|
"Female: (9592, 30) (1179, 30)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#train_data_vae= train_data_vae[ train_data_vae['income']==0 ]\n",
|
|
"#train_data_vae.drop('income', axis=1, inplace=True)\n",
|
|
"\n",
|
|
"idx= columns.index('sex_Male')\n",
|
|
"data_male= dataset[dataset[:, idx]==1]\n",
|
|
"idx= columns.index('sex_Female')\n",
|
|
"data_female= dataset[dataset[:, idx]==1]\n",
|
|
"print('Gender Stats: ', data_male.shape, data_female.shape)\n",
|
|
"\n",
|
|
"# Get low and high income groups from the last dimension in the data\n",
|
|
"data_male_low_inc= data_male[data_male[:, -1] == 0]\n",
|
|
"data_male_high_inc= data_male[data_male[:, -1] == 1]\n",
|
|
"data_female_low_inc= data_female[data_female[:, -1] == 0]\n",
|
|
"data_female_high_inc= data_female[data_female[:, -1] == 1]\n",
|
|
"\n",
|
|
"print('Male: ', data_male_low_inc.shape, data_male_high_inc.shape)\n",
|
|
"print('Female: ', data_female_low_inc.shape, data_female_high_inc.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Domain: 0\n",
|
|
"Train Size for each group: 405 Val Size for each group: 45\n",
|
|
"Data: train Gender: male Income : low 122\n",
|
|
"Data: train Gender: male Income : high 283\n",
|
|
"Data: train Gender: female Income : low 284\n",
|
|
"Data: train Gender: female Income : high 121\n",
|
|
"torch.Size([810, 30]) torch.Size([810])\n",
|
|
"Data: val Gender: male Income : low 14\n",
|
|
"Data: val Gender: male Income : high 31\n",
|
|
"Data: val Gender: female Income : low 32\n",
|
|
"Data: val Gender: female Income : high 13\n",
|
|
"torch.Size([90, 30]) torch.Size([90])\n",
|
|
"\n",
|
|
"\n",
|
|
"Domain: 1\n",
|
|
"Train Size for each group: 405 Val Size for each group: 45\n",
|
|
"Data: train Gender: male Income : low 162\n",
|
|
"Data: train Gender: male Income : high 243\n",
|
|
"Data: train Gender: female Income : low 243\n",
|
|
"Data: train Gender: female Income : high 162\n",
|
|
"torch.Size([810, 30]) torch.Size([810])\n",
|
|
"Data: val Gender: male Income : low 18\n",
|
|
"Data: val Gender: male Income : high 27\n",
|
|
"Data: val Gender: female Income : low 27\n",
|
|
"Data: val Gender: female Income : high 18\n",
|
|
"torch.Size([90, 30]) torch.Size([90])\n",
|
|
"\n",
|
|
"\n",
|
|
"Domain: 2\n",
|
|
"Train Size for each group: 90 Val Size for each group: 10\n",
|
|
"Data: test Gender: male Income : low 81\n",
|
|
"Data: test Gender: male Income : high 9\n",
|
|
"Data: test Gender: female Income : low 9\n",
|
|
"Data: test Gender: female Income : high 81\n",
|
|
"torch.Size([180, 30]) torch.Size([180])\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#Two groups (Male, Female) equally represented in each domain\n",
|
|
"total_domains= 3 \n",
|
|
"val_frac=0.1\n",
|
|
"domain_spur_prob= [0.7, 0.6, 0.1]\n",
|
|
"\n",
|
|
"start_idx={'male_low':0, 'male_high':0, 'female_low':0, 'female_high':0}\n",
|
|
"\n",
|
|
"for idx in range(total_domains): \n",
|
|
" \n",
|
|
" if idx in [0, 1]:\n",
|
|
" group_size= 450\n",
|
|
" else:\n",
|
|
" group_size= 100\n",
|
|
" \n",
|
|
" domain_size= 2*group_size \n",
|
|
" val_size= int(val_frac*group_size)\n",
|
|
" train_size= group_size - val_size\n",
|
|
" print('Domain: ', idx)\n",
|
|
" print('Train Size for each group: ', train_size, 'Val Size for each group: ', val_size)\n",
|
|
" for data_case in ['train', 'val', 'test']: \n",
|
|
" \n",
|
|
" if data_case in ['test'] and idx in [0, 1]:\n",
|
|
" continue\n",
|
|
" \n",
|
|
" if data_case in ['train', 'val'] and idx in [2]:\n",
|
|
" continue\n",
|
|
" \n",
|
|
" group_data={'male_low':data_male_low_inc, 'male_high':data_male_high_inc, 'female_low':data_female_low_inc, 'female_high':data_female_high_inc}\n",
|
|
" curr_data={}\n",
|
|
" for gender in ['male', 'female']:\n",
|
|
" for income_case in ['low', 'high']:\n",
|
|
" \n",
|
|
" if gender == 'male':\n",
|
|
" prob= domain_spur_prob[idx]\n",
|
|
" else:\n",
|
|
" prob= 1.0 - domain_spur_prob[idx]\n",
|
|
"\n",
|
|
" if data_case in ['train', 'test']:\n",
|
|
" data_size= train_size\n",
|
|
" else:\n",
|
|
" data_size= val_size\n",
|
|
" \n",
|
|
" if income_case == 'high':\n",
|
|
" inc_size= int(prob*data_size)\n",
|
|
" else:\n",
|
|
" inc_size= data_size - int(prob*data_size)\n",
|
|
"\n",
|
|
" offset= start_idx[gender+'_'+income_case] \n",
|
|
" curr_data[gender+'_'+income_case]= group_data[gender+'_'+income_case][offset: offset + inc_size]\n",
|
|
" start_idx[gender+'_'+income_case]+= inc_size\n",
|
|
" \n",
|
|
" print('Data: ', data_case, 'Gender: ', gender, ' Income :', income_case, inc_size)\n",
|
|
" \n",
|
|
" save_data= []\n",
|
|
" spur_corr= []\n",
|
|
" for key in curr_data.keys():\n",
|
|
" save_data.append(curr_data[key])\n",
|
|
" \n",
|
|
" if 'female' in key:\n",
|
|
" spur_corr.append(np.zeros(curr_data[key].shape[0]))\n",
|
|
" else:\n",
|
|
" spur_corr.append(np.ones(curr_data[key].shape[0]))\n",
|
|
" \n",
|
|
"# for i in range(4):\n",
|
|
"# print(save_data[i].shape, spur_corr[i].shape)\n",
|
|
" \n",
|
|
" save_data= torch.tensor(np.vstack(save_data))\n",
|
|
" spur_corr= torch.tensor(np.hstack(spur_corr))\n",
|
|
" print(save_data.shape, spur_corr.shape)\n",
|
|
" \n",
|
|
" torch.save(save_data[:, :-1], base_dir + 'd' + str(idx+1) + '_' + data_case + '_data.pt')\n",
|
|
" torch.save(save_data[:, -1], base_dir + 'd' + str(idx+1) + '_' + data_case + '_label.pt')\n",
|
|
" torch.save(spur_corr, base_dir + 'd' + str(idx+1) + '_' + data_case + '_spur.pt')\n",
|
|
" \n",
|
|
" print('\\n')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Domain 1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"# Male: Train: (1000, 64), Test: (60, 940) \n",
|
|
"# Female: Train: (64, 1000), Test: (940, 60) \n",
|
|
"# Val created from Train by selecting 60 from 1000 and 4 from 64\n",
|
|
"\n",
|
|
"#Male Dataset\n",
|
|
"train_high_inc= data_male_high_inc[:1000]\n",
|
|
"train_low_inc= data_male_low_inc[:64]\n",
|
|
"\n",
|
|
"val_high_inc= train_high_inc[940:]\n",
|
|
"train_high_inc= train_high_inc[:940]\n",
|
|
"\n",
|
|
"val_low_inc= train_low_inc[60:]\n",
|
|
"train_low_inc= train_low_inc[:60]\n",
|
|
"\n",
|
|
"train_male= np.concatenate((train_high_inc, train_low_inc), axis=0)\n",
|
|
"val_male= np.concatenate((val_high_inc, val_low_inc), axis=0)\n",
|
|
"\n",
|
|
"#Female Dataset\n",
|
|
"train_high_inc= data_female_high_inc[:64]\n",
|
|
"train_low_inc= data_female_low_inc[:1000]\n",
|
|
"\n",
|
|
"val_high_inc= train_high_inc[60:]\n",
|
|
"train_high_inc= train_high_inc[:60]\n",
|
|
"\n",
|
|
"val_low_inc= train_low_inc[940:]\n",
|
|
"train_low_inc= train_low_inc[:940]\n",
|
|
"\n",
|
|
"train_female= np.concatenate((train_high_inc, train_low_inc), axis=0)\n",
|
|
"val_female= np.concatenate((val_high_inc, val_low_inc), axis=0)\n",
|
|
"\n",
|
|
"\n",
|
|
"train_data_1= np.concatenate((train_male, train_female), axis=0) \n",
|
|
"val_data_1= np.concatenate((val_male, val_female), axis=0) \n",
|
|
"\n",
|
|
"\n",
|
|
"spur_male= np.zeros(train_male.shape[0])\n",
|
|
"spur_female= np.ones(train_female.shape[0])\n",
|
|
"train_spur_1= np.concatenate((spur_male, spur_female), axis=0)\n",
|
|
"\n",
|
|
"spur_male= np.zeros(val_male.shape[0])\n",
|
|
"spur_female= np.ones(val_female.shape[0])\n",
|
|
"val_spur_1= np.concatenate((spur_male, spur_female), axis=0)\n",
|
|
"\n",
|
|
"print('Male Final: ', train_male.shape, val_male.shape)\n",
|
|
"print('Female Final: ', train_female.shape, val_female.shape)\n",
|
|
"print('Spur Feature Final: ', train_spur_1.shape, val_spur_1.shape)\n",
|
|
"\n",
|
|
"print('Sanity Check Male')\n",
|
|
"print('Train', np.sum(train_male[:, -1]==0), np.sum(train_male[:, -1]==1))\n",
|
|
"print('Val', np.sum(val_male[:, -1]==0), np.sum(val_male[:, -1]==1))\n",
|
|
"\n",
|
|
"print('Sanity Check Female')\n",
|
|
"print('Train', np.sum(train_female[:, -1]==0), np.sum(train_female[:, -1]==1))\n",
|
|
"print('Val', np.sum(val_female[:, -1]==0), np.sum(val_female[:, -1]==1))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Domain 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#train_data_vae= train_data_vae[ train_data_vae['income']==0 ]\n",
|
|
"#train_data_vae.drop('income', axis=1, inplace=True)\n",
|
|
"\n",
|
|
"idx= columns.index('sex_Male')\n",
|
|
"data_male= dataset[dataset[:, idx]==1]\n",
|
|
"idx= columns.index('sex_Female')\n",
|
|
"data_female= dataset[dataset[:, idx]==1]\n",
|
|
"print('Gender Stats: ', data_male.shape, data_female.shape)\n",
|
|
"\n",
|
|
"# Get low and high income groups from the last dimension in the data\n",
|
|
"data_male_low_inc= data_male[data_male[:, -1] == 0]\n",
|
|
"data_male_high_inc= data_male[data_male[:, -1] == 1]\n",
|
|
"data_female_low_inc= data_female[data_female[:, -1] == 0]\n",
|
|
"data_female_high_inc= data_female[data_female[:, -1] == 1]\n",
|
|
"\n",
|
|
"# print('Male: ', data_male_low_inc.shape, data_male_high_inc.shape)\n",
|
|
"# print('Female: ', data_female_low_inc.shape, data_female_high_inc.shape)\n",
|
|
"\n",
|
|
"# Male: Train: (1000, 190), Test: (160, 840) \n",
|
|
"# Female: Train: (190, 1000), Test: (840, 160) \n",
|
|
"# Val created from Train by selecting 60 from 1000 and 4 from 64\n",
|
|
"\n",
|
|
"#Male Dataset\n",
|
|
"train_high_inc= data_male_high_inc[:1000]\n",
|
|
"train_low_inc= data_male_low_inc[:190]\n",
|
|
"\n",
|
|
"val_high_inc= train_high_inc[840:]\n",
|
|
"train_high_inc= train_high_inc[:840]\n",
|
|
"\n",
|
|
"val_low_inc= train_low_inc[160:]\n",
|
|
"train_low_inc= train_low_inc[:160]\n",
|
|
"\n",
|
|
"train_male= np.concatenate((train_high_inc, train_low_inc), axis=0)\n",
|
|
"val_male= np.concatenate((val_high_inc, val_low_inc), axis=0)\n",
|
|
"\n",
|
|
"#Female Dataset\n",
|
|
"train_high_inc= data_female_high_inc[:190]\n",
|
|
"train_low_inc= data_female_low_inc[:1000]\n",
|
|
"\n",
|
|
"val_high_inc= train_high_inc[160:]\n",
|
|
"train_high_inc= train_high_inc[:160]\n",
|
|
"\n",
|
|
"val_low_inc= train_low_inc[840:]\n",
|
|
"train_low_inc= train_low_inc[:840]\n",
|
|
"\n",
|
|
"train_female= np.concatenate((train_high_inc, train_low_inc), axis=0)\n",
|
|
"val_female= np.concatenate((val_high_inc, val_low_inc), axis=0)\n",
|
|
"\n",
|
|
"\n",
|
|
"train_data_2= np.concatenate((train_male, train_female), axis=0) \n",
|
|
"val_data_2= np.concatenate((val_male, val_female), axis=0) \n",
|
|
"\n",
|
|
"spur_male= np.zeros(train_male.shape[0])\n",
|
|
"spur_female= np.ones(train_female.shape[0])\n",
|
|
"train_spur_2= np.concatenate((spur_male, spur_female), axis=0)\n",
|
|
"\n",
|
|
"spur_male= np.zeros(val_male.shape[0])\n",
|
|
"spur_female= np.ones(val_female.shape[0])\n",
|
|
"val_spur_2= np.concatenate((spur_male, spur_female), axis=0)\n",
|
|
"\n",
|
|
"print('Male Final: ', train_male.shape, val_male.shape)\n",
|
|
"print('Female Final: ', train_female.shape, val_female.shape)\n",
|
|
"print('Spur Feature Final: ', train_spur_2.shape, val_spur_2.shape)\n",
|
|
"\n",
|
|
"print('Sanity Check Male')\n",
|
|
"print('Train', np.sum(train_male[:, -1]==0), np.sum(train_male[:, -1]==1))\n",
|
|
"print('Val', np.sum(val_male[:, -1]==0), np.sum(val_male[:, -1]==1))\n",
|
|
"\n",
|
|
"print('Sanity Check Female')\n",
|
|
"print('Train', np.sum(train_female[:, -1]==0), np.sum(train_female[:, -1]==1))\n",
|
|
"print('Val', np.sum(val_female[:, -1]==0), np.sum(val_female[:, -1]==1))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Domain 3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#train_data_vae= train_data_vae[ train_data_vae['income']==0 ]\n",
|
|
"#train_data_vae.drop('income', axis=1, inplace=True)\n",
|
|
"\n",
|
|
"idx= columns.index('sex_Male')\n",
|
|
"data_male= dataset[dataset[:, idx]==1]\n",
|
|
"idx= columns.index('sex_Female')\n",
|
|
"data_female= dataset[dataset[:, idx]==1]\n",
|
|
"print('Gender Stats: ', data_male.shape, data_female.shape)\n",
|
|
"\n",
|
|
"# Get low and high income groups from the last dimension in the data\n",
|
|
"data_male_low_inc= data_male[data_male[:, -1] == 0]\n",
|
|
"data_male_high_inc= data_male[data_male[:, -1] == 1]\n",
|
|
"data_female_low_inc= data_female[data_female[:, -1] == 0]\n",
|
|
"data_female_high_inc= data_female[data_female[:, -1] == 1]\n",
|
|
"\n",
|
|
"# print('Male: ', data_male_low_inc.shape, data_male_high_inc.shape)\n",
|
|
"# print('Female: ', data_female_low_inc.shape, data_female_high_inc.shape)\n",
|
|
"\n",
|
|
"# Male: Train: (1000, 64), Test: (60, 940) \n",
|
|
"# Female: Train: (64, 1000), Test: (940, 60) \n",
|
|
"# Val created from Train by selecting 60 from 1000 and 4 from 64\n",
|
|
"\n",
|
|
"#Male Dataset\n",
|
|
"test_high_inc= data_male_high_inc[-61:-1]\n",
|
|
"test_low_inc= data_male_low_inc[-941:-1]\n",
|
|
"test_male= np.concatenate((test_high_inc, test_low_inc), axis=0)\n",
|
|
"\n",
|
|
"#Female Dataset\n",
|
|
"test_high_inc= data_female_high_inc[-941:-1]\n",
|
|
"test_low_inc= data_female_low_inc[-61:-1]\n",
|
|
"test_female= np.concatenate((test_high_inc, test_low_inc), axis=0)\n",
|
|
"\n",
|
|
"test_data= np.concatenate((test_male, test_female), axis=0) \n",
|
|
"\n",
|
|
"spur_male= np.zeros(test_male.shape[0])\n",
|
|
"spur_female= np.ones(test_female.shape[0])\n",
|
|
"test_spur= np.concatenate((spur_male, spur_female), axis=0)\n",
|
|
"\n",
|
|
"print('Male Final: ', test_male.shape)\n",
|
|
"print('Female Final: ', test_female.shape)\n",
|
|
"print('Spur Feature Final: ', test_spur.shape)\n",
|
|
"\n",
|
|
"print('Sanity Check Male')\n",
|
|
"print('Test', np.sum(test_male[:, -1]==0), np.sum(test_male[:, -1]==1))\n",
|
|
"\n",
|
|
"print('Sanity Check Female')\n",
|
|
"print('Test', np.sum(test_female[:, -1]==0), np.sum(test_female[:, -1]==1))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#Domain 1\n",
|
|
"torch.save(torch.tensor(train_data_1[:, :-1]), base_dir + 'd1_train_data.pt')\n",
|
|
"torch.save(torch.tensor(train_data_1[:, -1]), base_dir + 'd1_train_label.pt')\n",
|
|
"torch.save(torch.tensor(train_spur_1), base_dir + 'd1_train_spur.pt')\n",
|
|
"\n",
|
|
"torch.save(torch.tensor(val_data_1[:, :-1]), base_dir + 'd1_val_data.pt')\n",
|
|
"torch.save(torch.tensor(val_data_1[:, -1]), base_dir + 'd1_val_label.pt')\n",
|
|
"torch.save(torch.tensor(val_spur_1), base_dir + 'd1_val_spur.pt')\n",
|
|
"\n",
|
|
"#Domain 2\n",
|
|
"torch.save(torch.tensor(train_data_2[:, :-1]), base_dir + 'd2_train_data.pt')\n",
|
|
"torch.save(torch.tensor(train_data_2[:, -1]), base_dir + 'd2_train_label.pt')\n",
|
|
"torch.save(torch.tensor(train_spur_2), base_dir + 'd2_train_spur.pt')\n",
|
|
"\n",
|
|
"torch.save(torch.tensor(val_data_2[:, :-1]), base_dir + 'd2_val_data.pt')\n",
|
|
"torch.save(torch.tensor(val_data_2[:, -1]), base_dir + 'd2_val_label.pt')\n",
|
|
"torch.save(torch.tensor(val_spur_2), base_dir + 'd2_val_spur.pt')\n",
|
|
"\n",
|
|
"#Domain 3\n",
|
|
"torch.save(torch.tensor(test_data[:, :-1]), base_dir + 'd3_test_data.pt')\n",
|
|
"torch.save(torch.tensor(test_data[:, -1]), base_dir + 'd3_test_label.pt')\n",
|
|
"torch.save(torch.tensor(test_spur), base_dir + 'd3_test_spur.pt')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "envs",
|
|
"language": "python",
|
|
"name": "envs"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|