From 3da805dfb1f5cbbbd390f65addf8c4c0b47d8267 Mon Sep 17 00:00:00 2001 From: divyat09 Date: Thu, 15 Jul 2021 07:19:58 +0000 Subject: [PATCH] azure amlt scripts --- .gitignore | 4 +- azure_scripts/setup_data_mnist.yaml | 20 +- chestxray_download.txt | 16 + docs/notebooks/HParam_Plots.ipynb | 1470 +++++++++++++++++++++++++++ 4 files changed, 1498 insertions(+), 12 deletions(-) create mode 100644 chestxray_download.txt create mode 100644 docs/notebooks/HParam_Plots.ipynb diff --git a/.gitignore b/.gitignore index 6b22f6f..58b5cd5 100644 --- a/.gitignore +++ b/.gitignore @@ -9,14 +9,14 @@ matchdg-env/ results/ #extra-files -*.sh +#*.sh #philly_tools .ptignore .amltignore pt/ amlt/ -*.yaml +#*.yaml .ptconfig .amltconfig diff --git a/azure_scripts/setup_data_mnist.yaml b/azure_scripts/setup_data_mnist.yaml index 58a8cee..cfca4e7 100644 --- a/azure_scripts/setup_data_mnist.yaml +++ b/azure_scripts/setup_data_mnist.yaml @@ -32,15 +32,15 @@ jobs: command: - echo "--debug" && python data/data_gen.py rot_mnist resnet18 -- name: fashion_mnist_resnet18 - # one gpu - sku: G1 - command: - - echo "--debug" && python data/data_gen.py fashion_mnist resnet18 +#- name: fashion_mnist_resnet18 +# # one gpu +# sku: G1 +# command: +# - echo "--debug" && python data/data_gen.py fashion_mnist resnet18 -- name: rot_mnist_lenet - # one gpu - sku: G1 - command: - - echo "--debug" && python data/data_gen.py rot_mnist lenet +#- name: rot_mnist_lenet +# # one gpu +# sku: G1 +# command: +# - echo "--debug" && python data/data_gen.py rot_mnist lenet diff --git a/chestxray_download.txt b/chestxray_download.txt new file mode 100644 index 0000000..f7def5b --- /dev/null +++ b/chestxray_download.txt @@ -0,0 +1,16 @@ +NIH Dataset: + +curl -o nih.zip "https://storage.googleapis.com/kaggle-data-sets/5839%2F18613%2Fbundle%2Farchive.zip?GoogleAccessId=gcp-kaggle-com@kaggle-161607.iam.gserviceaccount.com&Expires=1600359450&Signature=MPds%2FPBnAPNGFXy1cnmRVhHaHsTRggstPA44ZCE0onI35vc4UMwdPSyQS%2Fypf5B%2FhmOsf6%2B6oxy0%2BKL8HCBh8BtFrwMyfY7dVczTmPkBEGPALf7roGbuWFB6oUVrXAVHFpJwCKEwMCSrxkpFIccLxXII%2B84aG4xrqwzu1LQq%2BRyE3W7Rg22ib1tiyX%2FsZjGk8%2BmHqlA7gg2Y9pr4s7xZgTpnpUv0NPiVjLcsWHgWznx2fuWZm8Ox%2Faj6CzZa6dbpYg%2FNWIHpCJ%2BzfPRCZQuGVaoSfKjoPZK9ei3W1FrZ2MDHBzPREQh1OCngT3v2%2Fn%2BXFj5tQ5b4OfvL1YB%2FMou5Cg%3D%3D" + +CheXpert: + +curl -o chexpert.zip http://download.cs.stanford.edu/deep/CheXpert-v1.0-small.zip + +OpenI: + +curl -o openi.tgz https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz + +Kaggle: + +curl -o kaggle.zip "https://storage.googleapis.com/kaggle-competitions-data/kaggle-v2/10338/862042/bundle/archive.zip?GoogleAccessId=web-data@kaggle161607.iam.gserviceaccount.com&Expires=1600347314&Signature=hilOASDiejHlo7KgvJR%2FqzaPg3eKcnBKauYVS%2FM6CIoVUl6mjgDdiDFwXJYOmeuK%2F1WfLO32JEjsc8XB6h7SQWhsMJ6Xs%2F1P7oMKNURjcYkZ2OQYXSV5gFDWVqZ%2Bna4t4B2y%2Bz6Gp9GpGt5HEjc4leOGlMizwLQEhQmlZWSpBqFzgTjLF9eVbNc2ekln5SCsLFWLz0YGFeAgkulq5qgh2Rfu%2BD5QafmPgTc3iMMJf%2BQcVJ0dgqHjcROmANWTnvdWcMjweZMBwXOgYHOomCHHRAgXnWvaXC5AxZsKXmmsbWe%2BsuCDJ4bIwAzm%2BC27XJwnIaeaOudn6BL%2FuLtf1lvv7A%3D%3D&response-content-disposition=attachment%3B+filename%3Drsna-pneumonia-detection-challenge.zip" + diff --git a/docs/notebooks/HParam_Plots.ipynb b/docs/notebooks/HParam_Plots.ipynb new file mode 100644 index 0000000..cd65385 --- /dev/null +++ b/docs/notebooks/HParam_Plots.ipynb @@ -0,0 +1,1470 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_dir= '/home/t-dimaha/RobustDG/robustdg/amlt/'\n", + "list_dir= ['mdg_htune_photo', 'mdg_htune_art_painting', 'mdg_htune_cartoon', 'mdg_htune_sketch']\n", + "list_model= ['alexnet', 'resnet18', 'resnet50']\n", + "\n", + "for model_name in list_model:\n", + " for dir_name in list_dir:\n", + " \n", + " count=0\n", + " res={}\n", + " curr_dir = base_dir + dir_name\n", + " for case in ['source', 'target']:\n", + " res[case]={}\n", + " for metric in ['acc', 'std']:\n", + " res[case][metric]= 0.0\n", + "\n", + " for subdir, dirs, files in os.walk(curr_dir):\n", + " \n", + " if model_name not in subdir:\n", + " continue\n", + " \n", + " if 'weigh_0.0005' not in subdir:\n", + " continue\n", + " \n", + "# if 'lr_0.01' not in subdir:\n", + "# continue\n", + " \n", + "# # To only obtain ERM results\n", + "# if '_penal_0_' in subdir:\n", + "# continue\n", + " \n", + " for file in files:\n", + " if 'stdout' in file :\n", + " f=open(os.path.join(subdir, file), 'r')\n", + " data= f.readlines()\n", + "\n", + " for idx in range(-40, -30):\n", + " if 'Source' in data[idx] and 'Final' in data[idx]:\n", + " case = 'source'\n", + " mean= float(data[idx].split(')')[-1][1:].split(' ')[0])\n", + " std= float(data[idx].split(')')[-1][1:].split(' ')[1].rstrip('\\n'))\n", + " count+=1\n", + " if mean > res[case]['acc'] or ( mean == res[case]['acc'] and std < res[case]['std'] ): \n", + " res[case]['acc']= mean\n", + " res[case]['std']= std\n", + " res[case]['name']= os.path.join(subdir, file) \n", + " \n", + " elif 'Target' in data[idx] and 'Final' in data[idx] :\n", + " case = 'target'\n", + " mean= float(data[idx].split(')')[-1][1:].split(' ')[0])\n", + " std= float(data[idx].split(')')[-1][1:].split(' ')[1].rstrip('\\n'))\n", + " count+=1\n", + " if mean > res[case]['acc'] or ( mean == res[case]['acc'] and std < res[case]['std'] ): \n", + " res[case]['acc']= mean\n", + " res[case]['std']= std\n", + " res[case]['name']= os.path.join(subdir, file) \n", + " \n", + " \n", + " \n", + " print(model_name, dir_name)\n", + " print(count/2)\n", + " print(res)\n", + " print('\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_dir= '/home/t-dimaha/RobustDG/robustdg/amlt/'\n", + "list_dir= ['rand_htune_photo', 'rand_htune_art_painting', 'rand_htune_cartoon', 'rand_htune_sketch', 'mdghyb_htune_sketch']\n", + "list_model= ['resnet18', 'resnet50']\n", + "\n", + "for model_name in list_model:\n", + " for dir_name in list_dir:\n", + " \n", + " count=0\n", + " res={}\n", + " curr_dir = base_dir + dir_name\n", + " for case in ['source', 'target']:\n", + " res[case]={}\n", + " for metric in ['acc', 'std']:\n", + " res[case][metric]= 0.0\n", + "\n", + " for subdir, dirs, files in os.walk(curr_dir):\n", + " \n", + " if model_name not in subdir:\n", + " continue\n", + " \n", + "# if 'weigh' not in subdir:\n", + "# continue\n", + " \n", + "# if 'lr_0.01' not in subdir:\n", + "# continue\n", + " \n", + "# # To only obtain ERM results\n", + "# if '_penal_0_' in subdir:\n", + "# continue\n", + " \n", + " for file in files:\n", + " if 'stdout' in file :\n", + " f=open(os.path.join(subdir, file), 'r')\n", + " data= f.readlines()\n", + " \n", + " for idx in range(-40, -30):\n", + " if 'Source' in data[idx] and 'Final' in data[idx]:\n", + " case = 'source'\n", + " mean= float(data[idx].split(')')[-1][1:].split(' ')[0])\n", + " std= float(data[idx].split(')')[-1][1:].split(' ')[1].rstrip('\\n'))\n", + " count+=1\n", + " if mean > res[case]['acc'] or ( mean == res[case]['acc'] and std < res[case]['std'] ): \n", + " res[case]['acc']= mean\n", + " res[case]['std']= std\n", + " res[case]['name']= os.path.join(subdir, file) \n", + " \n", + " elif 'Target' in data[idx] and 'Final' in data[idx]:\n", + " case = 'target'\n", + " mean= float(data[idx].split(')')[-1][1:].split(' ')[0])\n", + " std= float(data[idx].split(')')[-1][1:].split(' ')[1].rstrip('\\n'))\n", + " count+=1\n", + " if mean > res[case]['acc'] or ( mean == res[case]['acc'] and std < res[case]['std'] ): \n", + " res[case]['acc']= mean\n", + " res[case]['std']= std\n", + " res[case]['name']= os.path.join(subdir, file) \n", + " \n", + " \n", + " print(model_name, dir_name)\n", + " print(count/2)\n", + " print(res)\n", + " print('\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_dir= '/home/t-dimaha/RobustDG/robustdg/amlt/'\n", + "list_dir= ['erm_htune_photo', 'erm_htune_art_painting', 'erm_htune_cartoon', 'erm_htune_sketch']\n", + "list_model= ['alexnet', 'resnet18']\n", + "\n", + "for model_name in list_model:\n", + " for dir_name in list_dir:\n", + " \n", + " count=0\n", + " res={}\n", + " curr_dir = base_dir + dir_name\n", + " for case in ['source', 'target']:\n", + " res[case]={}\n", + " for metric in ['acc', 'std']:\n", + " res[case][metric]= 0.0\n", + "\n", + " for subdir, dirs, files in os.walk(curr_dir):\n", + " \n", + " if model_name not in subdir:\n", + " continue\n", + " \n", + "# if 'weigh' not in subdir:\n", + "# continue\n", + " \n", + "# if 'lr_0.01' not in subdir:\n", + "# continue\n", + " \n", + "# # To only obtain ERM results\n", + "# if '_penal_0_' in subdir:\n", + "# continue\n", + " \n", + " for file in files:\n", + " if 'stdout' in file :\n", + " f=open(os.path.join(subdir, file), 'r')\n", + " data= f.readlines()\n", + " \n", + " for idx in range(-40, -30):\n", + " if 'Source' in data[idx] and 'Final' in data[idx]:\n", + " case = 'source'\n", + " mean= float(data[idx].split(')')[-1][1:].split(' ')[0])\n", + " std= float(data[idx].split(')')[-1][1:].split(' ')[1].rstrip('\\n'))\n", + " count+=1\n", + " if mean > res[case]['acc'] or ( mean == res[case]['acc'] and std < res[case]['std'] ): \n", + " res[case]['acc']= mean\n", + " res[case]['std']= std\n", + " res[case]['name']= os.path.join(subdir, file) \n", + " \n", + " elif 'Target' in data[idx] and 'Final' in data[idx]:\n", + " case = 'target'\n", + " mean= float(data[idx].split(')')[-1][1:].split(' ')[0])\n", + " std= float(data[idx].split(')')[-1][1:].split(' ')[1].rstrip('\\n'))\n", + " count+=1\n", + " if mean > res[case]['acc'] or ( mean == res[case]['acc'] and std < res[case]['std'] ): \n", + " res[case]['acc']= mean\n", + " res[case]['std']= std\n", + " res[case]['name']= os.path.join(subdir, file) \n", + " \n", + " \n", + " print(model_name, dir_name)\n", + " print(count/2)\n", + " print(res)\n", + " print('\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_dir= '/home/t-dimaha/RobustDG/robustdg/amlt/'\n", + "list_dir= ['hyb_htune_photo', 'hyb_htune_art_painting', 'hyb_htune_cartoon', 'hyb_htune_sketch']\n", + "list_model= ['alexnet']\n", + "\n", + "for model_name in list_model:\n", + " for dir_name in list_dir:\n", + " \n", + " count=0\n", + " res={}\n", + " curr_dir = base_dir + dir_name\n", + " for case in ['source', 'target']:\n", + " res[case]={}\n", + " for metric in ['acc', 'std']:\n", + " res[case][metric]= 0.0\n", + "\n", + " for subdir, dirs, files in os.walk(curr_dir):\n", + " \n", + " if model_name not in subdir:\n", + " continue\n", + " \n", + "# if 'weigh' not in subdir:\n", + "# continue\n", + " \n", + "# if 'lr_0.01' not in subdir:\n", + "# continue\n", + " \n", + "# # To only obtain ERM results\n", + "# if '_penal_0_' in subdir:\n", + "# continue\n", + " \n", + " for file in files:\n", + " if 'stdout' in file :\n", + " f=open(os.path.join(subdir, file), 'r')\n", + " data= f.readlines()\n", + " \n", + " for idx in range(-40, -30):\n", + " if 'Source' in data[idx] and 'Final' in data[idx]:\n", + " case = 'source'\n", + " mean= float(data[idx].split(')')[-1][1:].split(' ')[0])\n", + " std= float(data[idx].split(')')[-1][1:].split(' ')[1].rstrip('\\n'))\n", + " count+=1\n", + " if mean > res[case]['acc'] or ( mean == res[case]['acc'] and std < res[case]['std'] ): \n", + " res[case]['acc']= mean\n", + " res[case]['std']= std\n", + " res[case]['name']= os.path.join(subdir, file) \n", + " \n", + " elif 'Target' in data[idx] and 'Final' in data[idx]:\n", + " case = 'target'\n", + " mean= float(data[idx].split(')')[-1][1:].split(' ')[0])\n", + " std= float(data[idx].split(')')[-1][1:].split(' ')[1].rstrip('\\n'))\n", + " count+=1\n", + " if mean > res[case]['acc'] or ( mean == res[case]['acc'] and std < res[case]['std'] ): \n", + " res[case]['acc']= mean\n", + " res[case]['std']= std\n", + " res[case]['name']= os.path.join(subdir, file) \n", + " \n", + " \n", + " print(model_name, dir_name)\n", + " print(count/2)\n", + " print(res)\n", + " print('\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_dir= '/home/t-dimaha/RobustDG/robustdg/pt/'\n", + "list_dir= ['irm_mnist/', 'irm_fmnist/']\n", + "\n", + "for dir_name in list_dir:\n", + "\n", + " count=0\n", + " res={}\n", + " curr_dir = base_dir + dir_name\n", + " for case in ['source', 'target']:\n", + " res[case]={}\n", + " for metric in ['acc', 'std', 'name']:\n", + " res[case][metric]= 0.0\n", + "\n", + " for subdir, dirs, files in os.walk(curr_dir):\n", + "\n", + " for file in files:\n", + " if 'stdout' in file :\n", + " f=open(os.path.join(subdir, file), 'r')\n", + " data= f.readlines()\n", + "\n", + " for case in ['source', 'target']:\n", + " if case == 'source':\n", + " idx=-5\n", + " else:\n", + " idx=-4\n", + "\n", + " if 'Final' in data[idx]:\n", + " mean= float(data[idx].split(')')[-1][1:].split(' ')[0])\n", + " std= float(data[idx].split(')')[-1][1:].split(' ')[1].rstrip('\\n'))\n", + " count+=1\n", + "\n", + " if mean > res[case]['acc'] or ( mean == res[case]['acc'] and std < res[case]['std'] ): \n", + " res[case]['acc']= mean\n", + " res[case]['std']= std\n", + " res[case]['name']= os.path.join(subdir, file)\n", + " print( dir_name)\n", + " print(count/2)\n", + " print(res)\n", + " print('\\n')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# IRM Loss Plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10,4))\n", + "fig.suptitle('Example of a Single Legend Shared Across Multiple Subplots')\n", + "\n", + "# The data\n", + "x = [1, 2, 3]\n", + "y1 = [1, 2, 3]\n", + "y2 = [3, 1, 3]\n", + "y3 = [1, 3, 1]\n", + "y4 = [2, 2, 3]\n", + "\n", + "# Labels to use in the legend for each line\n", + "line_labels = [\"Line A\", \"Line B\", \"Line C\", \"Line D\"]\n", + "\n", + "# Create the sub-plots, assigning a different color for each line.\n", + "# Also store the line objects created\n", + "l1 = ax1.plot(x, y1, color=\"red\")[0]\n", + "ax1.set_title('Blah')\n", + "ax1.set_xlabel('No')\n", + "l2 = ax2.plot(x, y2, color=\"green\")[0]\n", + "l3 = ax3.plot(x, y3, color=\"blue\")[0]\n", + "l4 = ax3.plot(x, y4, color=\"orange\")[0] # A second line in the third subplot\n", + "\n", + "# Create the legend\n", + "fig.legend([l1, l2, l3, l4], # The line objects\n", + " labels=line_labels, # The labels for each line\n", + " loc=\"center right\", # Position of legend\n", + " borderaxespad=0.1, # Small spacing around legend box\n", + " title=\"Legend Title\" # Title for the legend\n", + " )\n", + "\n", + "# Adjust the scaling factor to fit your legend text completely outside the plot\n", + "# (smaller value results in more space being made for the legend)\n", + "plt.subplots_adjust(right=0.85)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13,5))\n", + "fontsize=22\n", + "ax1.tick_params(labelsize=20)\n", + "ax2.tick_params(labelsize=20)\n", + "\n", + "base_dir='results/rot_mnist/plots/'\n", + "legend_arr_mnist=[]\n", + "for case in ['ERM', 'IRM']: \n", + " if case == 'IRM': \n", + " file= base_dir + 'irm_1_5.txt'\n", + " elif case == 'ERM':\n", + " file= base_dir + 'irm_0_-1.txt'\n", + " \n", + " count=0\n", + " y=[]\n", + " f=open(file ,'r')\n", + " data= f.readlines()\n", + " for line in data:\n", + " if 'Train Loss Basic' in line:\n", + " y.append( float( line.split(':')[-1].split(' ')[-1].rstrip('\\n') ) / 625 ) \n", + " count+=1\n", + " if count == 25:\n", + " break\n", + " \n", + " ax1.plot(range(1, 1+len(y)), y, label=case)\n", + "# print(y)\n", + "ax1.set_xlabel('Training Epochs', fontsize=fontsize)\n", + "ax1.set_ylabel('IRM Loss Penalty', fontsize=fontsize)\n", + "ax1.set_title('Rot MNIST', fontsize=fontsize)\n", + "ax1.axvline(x=10, ls='--', color='purple', label='ERM: Perfect Train Acc')\n", + "ax1.axvline(x=11, ls='--', color='red', label='IRM: Perfect Train Acc')\n", + "\n", + "\n", + "legend_arr_fmnist=[]\n", + "base_dir='results/fashion_mnist/plots/'\n", + "for case in ['ERM', 'IRM']: \n", + " if case == 'IRM': \n", + " file= base_dir + 'irm_0.05_-1.txt'\n", + " elif case == 'ERM':\n", + " file= base_dir + 'irm_0_-1.txt'\n", + " \n", + " count=0\n", + " y=[]\n", + " f=open(file ,'r')\n", + " data= f.readlines()\n", + " for line in data:\n", + " if 'Train Loss Basic' in line:\n", + " y.append( float( line.split(':')[-1].split(' ')[-1].rstrip('\\n') ) / (5*625) ) \n", + " count+=1\n", + " if count == 25:\n", + " break\n", + " \n", + " ax2.plot(range(1, 1+len(y)), y, label=case)\n", + " \n", + "ax2.set_xlabel('Training Epochs', fontsize=fontsize)\n", + "# ax2.set_ylabel('IRM Loss Penalty', fontsize=fontsize)\n", + "ax2.set_title('Fashion MNIST', fontsize=fontsize)\n", + "ax2.axvline(x=19, ls='--', color='purple', label='ERM: Perfect Train Acc')\n", + "ax2.axvline(x=21, ls='--', color='red', label='IRM: Perfect Train Acc')\n", + "\n", + "lines, labels = fig.axes[-1].get_legend_handles_labels() \n", + "lgd= fig.legend(lines, labels, loc=\"lower center\", bbox_to_anchor=(0.5, -0.3), fontsize=fontsize, ncol=2)\n", + "plt.savefig('results/irm_plot.jpg', bbox_extra_artists=(lgd,), bbox_inches='tight', dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for case in [ 'ERM', 'IRM']:\n", + " if case == 'irm': \n", + " file= '/home/t-dimaha/RobustDG/robustdg/pt/irm_fmnist/search_irm_fmnist_penal_0.05_thres_-1/stdout.txt'\n", + " if case == 'irm: p1': \n", + " file= '/home/t-dimaha/RobustDG/robustdg/pt/irm_fmnist/search_irm_fmnist_penal_1_thres_-1/stdout.txt'\n", + " if case == 'irm: p5': \n", + " file= '/home/t-dimaha/RobustDG/robustdg/pt/irm_fmnist/search_irm_fmnist_penal_5_thres_-1/stdout.txt'\n", + " elif case == 'erm':\n", + " file= '/home/t-dimaha/RobustDG/robustdg/pt/irm_fmnist/search_irm_fmnist_penal_0_thres_-1/stdout.txt'\n", + " \n", + " count=0\n", + " y=[]\n", + " f=open(file ,'r')\n", + " data= f.readlines()\n", + " for line in data:\n", + " if 'Train Loss Basic' in line:\n", + " y.append( float( line.split(':')[-1].split(' ')[-1].rstrip('\\n') ) ) \n", + " count+=1\n", + " if count == 25:\n", + " break\n", + " \n", + " plt.plot(range(1, 1+len(y)), y, label=case)\n", + " print(y)\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('IRM Penalty')\n", + "plt.title('IRM Penalty during training process')\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MatchDG Plots" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13,5))\n", + "fontsize=22\n", + "ax1.tick_params(labelsize=20)\n", + "ax2.tick_params(labelsize=20)\n", + "\n", + "base_dir='results/rot_mnist/plots/'\n", + "legend_arr_mnist=[]\n", + "for case in ['ERM', 'MatchDG']: \n", + " if case == 'MatchDG': \n", + " file= base_dir + 'matchdg.txt'\n", + " elif case == 'ERM':\n", + " file= base_dir + 'erm.txt'\n", + " \n", + " count=0\n", + " y=[]\n", + " f=open(file ,'r')\n", + " data= f.readlines()\n", + " for line in data:\n", + " if 'Train Loss Basic' in line:\n", + " y.append( float( line.split(':')[-1].split(' ')[-1].rstrip('\\n') ) / 625 ) \n", + " count+=1\n", + " if count == 25:\n", + " break\n", + " \n", + " ax1.plot(range(1, 1+len(y)), y, label=case)\n", + "# print(y)\n", + "ax1.set_xlabel('Training Epochs', fontsize=fontsize)\n", + "ax1.set_ylabel('MatchDG Loss Penalty', fontsize=fontsize)\n", + "ax1.set_title('Rot MNIST', fontsize=fontsize)\n", + "ax1.axvline(x=10, ls='--', color='purple', label='ERM: Perfect Train Acc')\n", + "ax1.axvline(x=8, ls='--', color='red', label='MatchDG: Perfect Train Acc')\n", + "\n", + "\n", + "base_dir='results/fashion_mnist/plots/'\n", + "legend_arr_mnist=[]\n", + "for case in ['ERM', 'MatchDG']: \n", + " if case == 'MatchDG': \n", + " file= base_dir + 'matchdg.txt'\n", + " elif case == 'ERM':\n", + " file= base_dir + 'erm.txt'\n", + " \n", + " count=0\n", + " y=[]\n", + " f=open(file ,'r')\n", + " data= f.readlines()\n", + " for line in data:\n", + " if 'Train Loss Basic' in line:\n", + " y.append( float( line.split(':')[-1].split(' ')[-1].rstrip('\\n') ) / (5*625) ) \n", + " count+=1\n", + " if count == 25:\n", + " break\n", + " \n", + " ax2.plot(range(1, 1+len(y)), y, label=case)\n", + " \n", + "ax2.set_xlabel('Training Epochs', fontsize=fontsize)\n", + "#ax2.set_ylabel('MatchDG Loss Penalty', fontsize=fontsize)\n", + "ax2.set_title('Fashion MNIST', fontsize=fontsize)\n", + "ax2.axvline(x=19, ls='--', color='purple', label='ERM: Perfect Train Acc')\n", + "ax2.axvline(x=15, ls='--', color='red', label='MatchDG: Perfect Train Acc')\n", + "\n", + "lines, labels = fig.axes[-1].get_legend_handles_labels() \n", + "lgd= fig.legend(lines, labels, loc=\"lower center\", bbox_to_anchor=(0.5, -0.3), fontsize=fontsize, ncol=2)\n", + "plt.savefig('results/matchdg_plot.jpg', bbox_extra_artists=(lgd,), bbox_inches='tight', dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_dir='results/rot_mnist/plots/'\n", + "for case in ['ERM', 'MatchDG']: \n", + " if case == 'MatchDG': \n", + " file= base_dir + 'matchdg.txt'\n", + " elif case == 'ERM':\n", + " file= base_dir + 'erm.txt'\n", + " \n", + " count=0\n", + " y=[]\n", + " f=open(file ,'r')\n", + " data= f.readlines()\n", + " for line in data:\n", + " if 'Train Loss Basic' in line:\n", + " y.append( float( line.split(':')[-1].split(' ')[-1].rstrip('\\n') ) ) \n", + " count+=1\n", + " if count == 25:\n", + " break\n", + " \n", + " plt.plot(range(1, 1+len(y)), y, label=case)\n", + "# print(y)\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Match Function Penalty')\n", + "plt.axvline(x=10, ls='--', color='purple', label='ERM: Perfect Train Acc')\n", + "plt.axvline(x=8, ls='--', color='red', label='MatchDG: Perfect Train Acc')\n", + "# plt.axvline(x=19, ls='--', color='purple', label='ERM: Perfect Train Acc')\n", + "# plt.axvline(x=15, ls='--', color='red', label='MatchDG: Perfect Train Acc')\n", + "plt.title('Rot MNIST: Match Function Penalty during training process')\n", + "plt.legend()\n", + "plt.savefig(base_dir + 'matchdg_plot.jpg')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Eval Acc MNIST (Table 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_path= \"results/fashion_mnist/erm_match/logit_match/train_['15', '30', '45', '60', '75']\"\n", + "for idx in range(3):\n", + " \n", + " file= base_path + \"/Val_Acc_0.1_1.0_1.0_5_0_\" + str(idx) +\"_l2_resnet18.npy\"\n", + " arr= np.load(file) \n", + " print(arr.shape)\n", + " index= np.argmax(arr)\n", + " \n", + " file= base_path + \"/Test_Acc_0.1_1.0_1.0_5_0_\" + str(idx) +\"_l2_resnet18.npy\"\n", + " arr= np.load(file) \n", + " print(arr.shape)\n", + " acc[idx]= arr[index]\n", + " \n", + "print('Mean: ', np.mean(acc), ' Std :', np.std(acc))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "arr=[0, 1]\n", + "np.pad(arr, (10, 10),'edge')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Misc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "loss = torch.nn.CrossEntropyLoss(reduction='none')\n", + "input_ = torch.randn(3, 5, requires_grad=True)\n", + "target = torch.empty(3, dtype=torch.long).random_(5)\n", + "output = loss(input_, target)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + " type(np.random.randint(1, 10, 2)[0].item())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x={'a':2, 'b': 3, 'd':4}\n", + "max(x.values())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Slab Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Common imports\n", + "import os\n", + "import sys\n", + "import numpy as np\n", + "import argparse\n", + "import copy\n", + "import random\n", + "import json\n", + "import pickle\n", + "\n", + "#Pytorch\n", + "import torch\n", + "from torch.autograd import grad\n", + "from torch import nn, optim\n", + "from torch.nn import functional as F\n", + "from torchvision import datasets, transforms\n", + "from torchvision.utils import save_image\n", + "from torch.autograd import Variable\n", + "import torch.utils.data as data_utils\n", + "\n", + "#Sklearn\n", + "from sklearn.manifold import TSNE\n", + "\n", + "#robustdg\n", + "from utils.helper import *\n", + "from utils.match_function import *\n", + "\n", + "#slab\n", + "from utils.slab_data import *\n", + "import utils.scripts.utils as slab_utils\n", + "import utils.scripts.lms_utils as slab_lms_utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_1, temp1, _, _= get_data(1000, 0.0, 0.1, 7, 'train', 0, 0)\n", + "data_2, temp1, _, _= get_data(1000, 0.1, 0.1, 7, 'train', 0, 0)\n", + "data_3, temp1, _, _= get_data(1000, 1.0, 0.1, 7, 'test', 0, 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (ax_1, ax_2, ax_3) = plt.subplots(1,3,figsize=(17, 5))\n", + "\n", + "W = data_1['W']\n", + "X, Y = data_1['X'], data_1['Y']\n", + "X = X.numpy().dot(W.T)\n", + "ax_1.scatter(X[:,0], X[:,1], c=Y, cmap='coolwarm', s=4, alpha=0.8) \n", + "ax_1.set_xlabel('Linear Feature (x1)', fontsize=20)\n", + "ax_1.set_ylabel('Slab Feature (x2)', fontsize=20)\n", + "ax_1.set_title('Source Domain 1', fontsize=20)\n", + "\n", + "W = data_2['W']\n", + "X, Y = data_2['X'], data_2['Y']\n", + "X = X.numpy().dot(W.T)\n", + "ax_2.scatter(X[:,0], X[:,1], c=Y, cmap='coolwarm', s=4, alpha=0.8) \n", + "ax_2.set_xlabel('Linear Feature (x1)', fontsize=20)\n", + "ax_2.set_ylabel('Slab Feature (x2)', fontsize=20)\n", + "ax_2.set_title('Source Domain 2', fontsize=20)\n", + "\n", + "W = data_3['W']\n", + "X, Y = data_3['X'], data_3['Y']\n", + "X = X.numpy().dot(W.T)\n", + "ax_3.scatter(X[:,0], X[:,1], c=Y, cmap='coolwarm', s=4, alpha=0.8) \n", + "ax_3.set_xlabel('Linear Feature (x1)', fontsize=20)\n", + "ax_3.set_ylabel('Slab Feature (x2)', fontsize=20)\n", + "ax_3.set_title('Target Domain', fontsize=20)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig('results/slab_dataset.jpg')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a= np.ones(2,3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "torch.randint(0, 1, (5, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "l=[[0, 1], [2, 3], [9, 8], [0, 2]]\n", + "random.shuffle(l)\n", + "print(l)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "x=np.array([1, 2, 3])\n", + "y=x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x[0]= 10\n", + "print(x, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch \n", + "import copy\n", + "a= torch.tensor([1, 2])\n", + "b= copy.deepcopy(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a[0]=3\n", + "print(a, b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a=[1, 2, 3, 5, 6]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random.choice(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a=torch.tensor([7, 9, 3, 6, 8, 4, 9])\n", + "b, c= torch.sort(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(b)\n", + "print(c)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "torch.argmin(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a[c[:3]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "(a ==9).nonzero()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a=torch.rand(2,3,4)\n", + "b= a.flatten(start_dim=0, end_dim=1)\n", + "print(a.shape, b.shape)\n", + "\n", + "c= torch.stack(torch.split(b, a.shape[1]))\n", + "\n", + "print(a==c)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "b" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "c" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "n= np.array(range(1, 100))\n", + "c1= 181 \n", + "c2= -180\n", + "c3= -40\n", + "c4= -4\n", + "x= c1*(n**np.log2(5)) + c2*(n**2)+ c3*(np.log2(n))*(n**2) + c4*((np.log2(n))**2)*(n**2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "181*np.log2(5) - 448" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "5*109+64*9" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n= np.array(range(1, 1000))\n", + "y= 181*np.log2(5)*(n**(np.log2(5)-1)) -400*n - 88*(n)*(np.log2(n)) - 8*n*(np.log2(n)**2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "181*np.log2(5)*(2**(np.log2(5)-1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.log(n)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n= np.array(range(1, 100))\n", + "z= 10*n**(np.log2(5)) - (n**2)*(np.log2(n))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "n=[1/100*np.array(range(1, 10))]\n", + "print(np.cos(n))\n", + "print(np.sin(n))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n= np.array(range(1, 10))\n", + "a=[]\n", + "for item in n:\n", + " a.append( np.math.factorial(item) )\n", + "a=np.array(a)\n", + "\n", + "a= np.log2(a)\n", + "b= 0.5*x*np.log2(n)\n", + "print(a >= b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(a, b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(a)\n", + "print(b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n= np.array(range(1, 100))\n", + "a= 2**(np.log2(n)**2)\n", + "b= n**2\n", + "print(a >= b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(n, a, '--')\n", + "plt.plot(n, b, '.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n=10\n", + "x=np.array(range(1, n+1))\n", + "f=[]\n", + "for item in x:\n", + " f.append(item*(n+1-item))\n", + "f=np.array(f)\n", + "print(f >= n)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "k=222\n", + "n=5023" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s=0\n", + "for i in range(k, n+1):\n", + " s+= i\n", + "b= 0.5*(n**2 + n - k*(k-1))\n", + "print(s, b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f1='data/datasets/mnist/rot_mnist_lenet/train/seed_0_domain_15_data.pt'\n", + "f2='data/datasets/mnist_2/rot_mnist_lenet/train/seed_0_domain_15_data.pt'\n", + "\n", + "a= torch.load(f1)\n", + "b= torch.load(f2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "torch.sum(torch.eq(a, b))/(1000*32*32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn, optim\n", + " \n", + "from opacus.dp_model_inspector import DPModelInspector\n", + "from opacus.utils import module_modification" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model= torchvision.models.resnet18(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Parameter containing:\n", + "tensor([[-0.0185, -0.0705, -0.0518, ..., -0.0390, 0.1735, -0.0410],\n", + " [-0.0818, -0.0944, 0.0174, ..., 0.2028, -0.0248, 0.0372],\n", + " [-0.0332, -0.0566, -0.0242, ..., -0.0344, -0.0227, 0.0197],\n", + " ...,\n", + " [-0.0103, 0.0033, -0.0359, ..., -0.0279, -0.0115, 0.0128],\n", + " [-0.0359, -0.0353, -0.0296, ..., -0.0330, -0.0110, -0.0513],\n", + " [ 0.0021, -0.0248, -0.0829, ..., 0.0417, -0.0500, 0.0663]],\n", + " requires_grad=True)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fc.weight" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "class GroupNorm(torch.nn.GroupNorm):\n", + " def __init__(self, num_channels, num_groups=32, **kwargs):\n", + " super().__init__(num_groups, num_channels, **kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "model= torchvision.models.resnet18(0, norm_layer=GroupNorm)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "opt= optim.SGD([ {'params': filter(lambda p: p.requires_grad, model.parameters()) }, \n", + " ], lr= 0.001, weight_decay= 0.0005, momentum= 0.9, nesterov=True )" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Linear(in_features=512, out_features=1000, bias=True)" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fc" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [], + "source": [ + "a=torch.rand(3)" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [], + "source": [ + "b=a\n", + "c=a.clone().detach()" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([0.7193, 0.8302, 0.5062])\n", + "tensor([0.7193, 0.8302, 0.5062])\n", + "tensor([0.7193, 0.8302, 0.5062])\n" + ] + } + ], + "source": [ + "print(a)\n", + "print(b)\n", + "print(c)" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "b[2]=-10\n", + "b[1]=-10" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([ 0.7193, -10.0000, -10.0000])\n", + "tensor([ 0.7193, -10.0000, -10.0000])\n", + "tensor([0.7193, 0.8302, 0.5062])\n" + ] + } + ], + "source": [ + "print(a)\n", + "print(b)\n", + "print(c)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(2)" + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(b == -10).nonzero()[1,0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "matchdg-env", + "language": "python", + "name": "matchdg-env" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}