robustdg/docs/notebooks/robustdg_getting_started.ipynb

314 строки
9.3 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Getting started with RobustDG: Generalization and Privacy Attacks on Rotated MNIST dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Domain Generalization (DG) is the task of learning a predictive model that can generalize to different data distributions. Intuitively, models trained by just aggregating the data from different domain might overfit to the domains observed during training. Many DG methods have been proposed to improve the generalization of models for OOD data.\n",
"\n",
"Here we present a simple application of the RobustDG library to build a model on a modified MNIST dataset and then evaluate its out-of-distribution accuracy and robustness to privacy attacks. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset: Rotated MNIST"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Rotated MNIST consists of various data domains, each corresponding to a specific rotation. It provides a very easy way to genereate out of distribution (OOD) data samples. For example, the model would be shown data containing rotations between 15 to 75 degrees during training; while at the test time it has to classify digits rotated by 90 degrees. Hence, different rotations/domains lead to a difference between the train and the test distributions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"images/Merged.png\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training ML models that can generalize to new domains \n",
"\n",
"Below we provides commands to train different methods, we use the pre trained models for the this notebook. You may run these commands to first train the models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare Data for Rot MNIST & Fashion MNIST\n",
"\n",
"From the directory `data`, run the following command\n",
"\n",
"<code> python data_gen.py resnet18 </code>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Baseline: Empirical risk minimization\n",
"We first train a model using ERM that simply pools data from different domains and builds a model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<code> python train.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --epochs 25 </code>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### MatchDG: Domain generalization via causal matching\n",
"\n",
"The MatchDG model regularize the ERM training objective by matching data samples across domains that were generated from the same base object. More details are in the [Arxiv paper](https://arxiv.org/abs/2006.07500).\n",
"\n",
"Train the MatchDG model on Rotated MNIST by executing the following command\n",
"\n",
"MatchDG operates in two phases; in the first phase it learns a matching function and in the second phase it learns a classifier regularized as per the matching function learnt in the first phase"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Phase 1: Learning Match Function\n",
"\n",
"<code> python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 100 --batch_size 256 --pos_metric cos\n",
"</code>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Phase 2: Learning Classifier regularised on the Match Function\n",
"\n",
"<code> python train.py --dataset rot_mnist --method_name matchdg_erm --match_case -1 --penalty_ws 0.1 --epochs 25 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 </code>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluating the trained model\n",
"After training the model; we can evaluate the model on various test metrics like test accuracy on the unseen domain; match function metrics, etc.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Out-of-distribution accuracy\n",
"\n",
"We evaluate both the ERM and MatchDG method on OOD accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### ERM OOD accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%%bash\n",
"cd ../..\n",
"python test.py --test_metric acc --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### MatchDG OOD accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%%bash\n",
"cd ../..\n",
"python test.py --test_metric acc --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The results indicate that MatchDG (96.1) outperforms ERM (93.9) on OOD accuracy by approximately 2 percent "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### T-SNE Plots"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to OOD accuracy, we provide metrics to evaluate the representations learnt by the methods above. We provide T-SNE plots and match function metrics ( check the match_eval.py module under the evaluations directory )\n",
"\n",
"Here, we evalute the representations learnt with contrastive learning (Phase 1) using T-SNE plots"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"cd ../..\n",
"python test.py --test_metric t_sne --dataset rot_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --pos_metric cos"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above command stores the TSNE embeddings in json format. Now we generate the TSNE plots using the saved files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json \n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"save_path= \"../../results/rot_mnist/matchdg_ctr/logit_match/train_['15', '30', '45', '60', '75']/Model_0.01_5_1_0_resnet18_label.json\"\n",
"with open(save_path) as f:\n",
" data = json.load(f)\n",
"\n",
"for key in data.keys():\n",
" arr= np.array(data[key])\n",
" plt.plot( arr[:, 0], arr[:, 1], '.', label=key )\n",
"plt.title('TSNE plot of representations: Legend denotes class labels')\n",
"plt.legend()\n",
"plt.savefig('images/t_sne.png', dpi=100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src='images/t_sne.png'>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Robustness to membership inference privacy attack"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also test the models on the membership inference attacks (MIA). MIA relies on the generalization gap of the ML models, the models that overfit leak out information about the dataset the model was trained on. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### ERM MIA accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"cd ../..\n",
"python test.py --test_metric mia --mia_logit 1 --mia_sample_size 2000 --batch_size 64 --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### MatchDG MIA accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"cd ../..\n",
"python test.py --test_metric mia --mia_logit 1 --mia_sample_size 2000 --batch_size 64 --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We evaluate ERM and MatchDG on MIA accuracy and observe a similar pattern to that of OOD accuracy; ERM gives test MIA accuracy of 57.2 percent, which gets imporved to 53.1 percent for MatchDG. MatchDG is able to improve the generalization, which makes it difficult for MIA to detemine samples derived from the train versus test distribution"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}