428 строки
2.5 MiB
Plaintext
428 строки
2.5 MiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import copy\n",
|
||
|
"import torch\n",
|
||
|
"import torchvision\n",
|
||
|
"import torchvision.transforms as transforms\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"import torch.optim as optim\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"import backwardcompatibilityml.loss as bcloss\n",
|
||
|
"import backwardcompatibilityml.scores as scores\n",
|
||
|
"from backwardcompatibilityml.helpers import training\n",
|
||
|
"from backwardcompatibilityml.widget.compatibility_analysis import CompatibilityAnalysis\n",
|
||
|
"\n",
|
||
|
"# Turn off warnings so that the widget screen\n",
|
||
|
"# real estate does not decrease.\n",
|
||
|
"import warnings\n",
|
||
|
"warnings.filterwarnings(\"ignore\")\n",
|
||
|
"\n",
|
||
|
"%matplotlib inline"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<torch._C.Generator at 0x7f0a4d7979b0>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"n_epochs = 3\n",
|
||
|
"batch_size_train = 64\n",
|
||
|
"batch_size_test = 1000\n",
|
||
|
"learning_rate = 0.01\n",
|
||
|
"momentum = 0.5\n",
|
||
|
"log_interval = 10\n",
|
||
|
"\n",
|
||
|
"random_seed = 1\n",
|
||
|
"torch.backends.cudnn.enabled = False\n",
|
||
|
"torch.manual_seed(random_seed)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_loader = list(torch.utils.data.DataLoader(\n",
|
||
|
" torchvision.datasets.MNIST('datasets/', train=True, download=True,\n",
|
||
|
" transform=torchvision.transforms.Compose([\n",
|
||
|
" torchvision.transforms.ToTensor(),\n",
|
||
|
" torchvision.transforms.Normalize(\n",
|
||
|
" (0.1307,), (0.3081,))\n",
|
||
|
" ])),\n",
|
||
|
" batch_size=batch_size_train, shuffle=True))\n",
|
||
|
"\n",
|
||
|
"test_loader = list(torch.utils.data.DataLoader(\n",
|
||
|
" torchvision.datasets.MNIST('datasets/', train=False, download=True,\n",
|
||
|
" transform=torchvision.transforms.Compose([\n",
|
||
|
" torchvision.transforms.ToTensor(),\n",
|
||
|
" torchvision.transforms.Normalize(\n",
|
||
|
" (0.1307,), (0.3081,))\n",
|
||
|
" ])),\n",
|
||
|
" batch_size=batch_size_test, shuffle=True))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_loader_a = train_loader[:int(len(train_loader)/2)]\n",
|
||
|
"train_loader_b = train_loader[int(len(train_loader)/2):]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(torch.Size([64, 1, 28, 28]), torch.Size([64]))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"train_loader_a[0][0].size(), train_loader_a[0][1].size()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAELCAYAAAD+9XA2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAebElEQVR4nO3de5CUxbnH8d8jgoAgiGAEwSsxKoiomAoIiopyEBERJCpiMN4vORXkHNQcCtTSiGBiNBE4GhW1MEKiRFG8xitHQ0XUCApRURRvICCIy02kzx8zvL79Zmd2Lj2zs7vfT9VW9bP9Tr+9s1377Nv9Tr/mnBMAAMXaobY7AACoH0goAIAgSCgAgCBIKACAIEgoAIAgSCgAgCDqdUIxs33MzJnZjrVw7mVm1q/c50UYjB0UqiGPnaITipmdYWbzzazKzFamy5eamYXoYKmY2Texr21mtjEWj8izrelmdn3AvpmZ/Y+ZfWxmX5vZg2a2S6j2KwVjpyRj51gzW2hma81stZnNNrM9Q7VfKRg74cdOus2zzOyj9Pv6VzNrk8/ri0ooZjZG0q2SJkvaQ9IPJF0s6ShJTTK8plEx5wzFOddi+5ekjyUNin1vxvbjauO/DEnnSBqp1PvYQVIzSb+vhX6UDGOnZN6R1N8511qpsfOepKm10I+SYeyUhpl1kfS/Sv3t+YGkDZKm5NWIc66gL0mtJFVJGlrDcdOVGtBz08f3k3SQpBckrZX0tqRTYse/IOn8WDxK0rxY7JQaPO+lX3+7JEvXNZJ0s6RVkj6QdFn6+B1r6OMySf3S5b6SPpF0paQvJN2f7EOsH50lXSjpW0lbJH0jaU6szf+S9JakdZJmSmqa43v7F0n/HYt7SdokqXmhv69K+mLslG7sJM6zk6QbJb1T279zxk7ljx1Jv5b0QCzeP91+y1x/P8VcofRUasA+ksOxZ0m6QVJLSfMlzZH0tKTdJf1C0gwz+1Ee5z5Z0pGSukkaLql/+vsXpOsOk9RD0rA82ozbQ1IbSXsr9YvLyDl3h6QZkia51H8Zg2LVwyX9h6R9030dtb0iPSXRO0vTlijvJOmHefwMlYyxo9KNHTPby8zWStqo1B+XSYX9KBWJsaOSjZ0ukv4ZO8dSpRLKAbn+AMUklLaSVjnntm7/hpm9ku7wRjM7OnbsI865/3PObZPUXVILSROdc1ucc89JekzSmXmce6Jzbq1z7mNJz6fblFJv5O+cc8udc2uU+u+sENskTXDObXbObSywDUm6zTn3Wbovc2L9lHOutXNuXobXPSnp/PTiXiul/muRpOZF9KWSMHZqVujYkXPuY5ea8moraZykJUX0o9IwdmpW6NhpodRVTdw6pRJyTopJKKsltY3P9TnneqUH8upE28tj5Q6Slqd/ydt9JCmfhcMvYuUNSr0RUduJdgvxpXNuU4GvjcvUz5rcLelPSl2Gv63U4JVSl8T1AWOnZoWOnUj6D8q9kh6ppfWcUmDs1KzQsfONpOTNP7tIWp/riYtJKK9K2ixpcA7Hxrc0/kxSJzOLn3svSZ+my1Xy/xPfI48+fS6pU6LdQiS3YPb6ZGbJPgXdstk5t805N8E5t49zrqNSSeVTff8e1XWMnczHh7ajUlM89eUuQcZO5uOL9bakQ2Pn20+p6cV3c22g4ITinFsr6VpJU8xsmJm1NLMdzKy7pJ2zvHS+UllzrJk1NrO+kgZJejBd/6ak08ysuZl1lnReHt2aJek/zayjme0q6ao8f6xM/impi5l1N7Omkq5J1K+QtF+gc8nM2pjZ/unbhw+W9FtJ1yX+u6qzGDue0GPnNDP7Ufr9bKfU2HkjfbVS5zF2PEHHjlJrMoPMrI+Z7SzpOkkPO+fKcoUi59wkSVdIGqvUD7dCqdvOrpT0SobXbFHqFzlAqbsipkg6xzm3fZ73FqUWglYodbk+o7p2MrhT0lNK/SJel/Rwfj9R9Zxz7yr15j6r1F0eyTnIuyQdnJ7H/WsubabvO++Tobqtvr875QlJd6cX4eoNxk4k9NjZU6k1uPWSFio1Lz+kkL5XKsZOJOjYcc69rdSdbDMkrVRq7eTSfPq8/bY3AACKUq+3XgEAlA8JBQAQBAkFABAECQUAEAQJBQAQRF6fnjUzbgmrQM65St+ym3FTmVY559rVdieyYexUrGrHDlcoQMNV6BYhQLVjh4QCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIIq8HbAHwde7cOSoff/zxXt3gwYO9eMCAAVHZOf+5UQcccEBUfv/990N2ERXq4IMPjso9e/b06u64446Mr9thB/86YNu2bVH5rLPO8upmzpxZTBfzxhUKACAIEgoAIAgSCgAgiDq9htKhQwcv7tq1qxfvuuuuUfmkk07y6vr16+fF7du3j8oLFizw6h566CEvfuyxx6LyokWL8ugx6pqBAwd68VVXXeXFhxxySFRu2bJl1rbic91JQ4YMicqTJ0/Op4uoUJdccokXH3jggV7cp0+fqBwfR1L2sZIUP/b222/36rZu3erFyb9loXGFAgAIgoQCAAjCkrcvZj3YLPeDA+nSpYsXX3zxxVH57LPP9uqWLVvmxV999VVUfvzxx3M+5+GHH+7F3bp18+L99tsvKj/xxBNe3ejRo6Py8uXLcz5nMZxzVpYTFag2xk1So0aNonLv3r29ul/96lde3KtXr6jcrFkzr86sNG91/Jzz588vyTmqscA516NcJytEJYydbPbZZx8vHjp0aFQeP368V9eiRQsvzmdaKy7bbcNJ69at8+JTTz01Ks+bN6+g86dVO3a4QgEABEFCAQAEQUIBAARREbcNx+epzzvvPK9u3LhxXhyfP5wzZ45Xd84555Sgd1Ljxo29OH4L8uzZs726Nm3aROXjjjuuJP1B/uLjKDm3XYwXX3wxKs+dO9erS66/XHPNNRnbGTZsWFQu4xoKivTUU095cXx9NR9r16714vvvvz/jsfvuu68Xn3zyyRmPbdWqlRfvvPPOBfQud1yhAACCIKEAAIIgoQAAgqiVNZTkZ0t+85vfROUTTzzRq3vmmWe8OD7XvH79+hL07t99++23XlxVVZXx2Pj8ZnzrF8n/XAzKa++9947KmzZt8uq+++47L45vH//www97dcltxeNtxbdPkaRbb701Y3+2bNnixXfddVfGY1G7evT4/uMWEyZM8Or22GOPIOe47LLLvHjWrFkZj40/BkHKvoaSdO2110bl5PpPCFyhAACCIKEAAIKolSmvn//8514cn+aKb60iZX9yWbnstNNOXjxx4sSonNyKY/HixVGZKa7KER9zyR1ZN2/e7MXZdpA+88wzvfiKK66Iyskte5Li01znn3++V7dkyZKsr0X5xKe4JOkf//hHVC50uxRJWr16tRfHPyKR/AhENl9++aUXx7d4ik/tVifbreshcIUCAAiChAIACIKEAgAIomxrKCNHjozKl19+uVcXv0WuEtZMkrc1J59y9sMf/jAqf/31117dBRdcULqOIYjkEzmTa2SdO3eOyg888IBXd8QRR+R8nvjtx5I0aNCgqPzuu+/m3A5K65hjjvHiu+++24vj6yb5rKFMmzbNi59++mkvzmfdJO61117z4kcffTQqJ28/TsrncSWF4AoFABAECQUAEAQJBQAQRNnWUA455JDvT7qjf9rkVhjlkNxaPL4lfXwrGEnq1KlTxnZeffVVL/70008D9A6hxcfc6aef7tWNGTPGiw877LAg51y5cqUXf/TRR0HaRfHij+598MEHvbq2bdvm3E5yW6bf//73UTm+zYkkbdiwIY8eZpbcgr5169Y5v7ZDhw5B+pAJVygAgCBIKACAICriiY3xp5wln3j2wQcfFNxufLff+C7F0r/fuhyfkrv66qu9uuQT0rg1uO6J37KZvE00m+RtolOnTs1Y/5Of/MSr69WrlxdPmTIlKiefTIryGjp0aFTOZ4orKT7FJUlXXnllwW3l6uijj/biESNG5Pza+M7
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 6 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAELCAYAAAD+9XA2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAebElEQVR4nO3de5CUxbnH8d8jgoAgiGAEwSsxKoiomAoIiopyEBERJCpiMN4vORXkHNQcCtTSiGBiNBE4GhW1MEKiRFG8xitHQ0XUCApRURRvICCIy02kzx8zvL79Zmd2Lj2zs7vfT9VW9bP9Tr+9s1377Nv9Tr/mnBMAAMXaobY7AACoH0goAIAgSCgAgCBIKACAIEgoAIAgSCgAgCDqdUIxs33MzJnZjrVw7mVm1q/c50UYjB0UqiGPnaITipmdYWbzzazKzFamy5eamYXoYKmY2Texr21mtjEWj8izrelmdn3AvpmZ/Y+ZfWxmX5vZg2a2S6j2KwVjpyRj51gzW2hma81stZnNNrM9Q7VfKRg74cdOus2zzOyj9Pv6VzNrk8/ri0ooZjZG0q2SJkvaQ9IPJF0s6ShJTTK8plEx5wzFOddi+5ekjyUNin1vxvbjauO/DEnnSBqp1PvYQVIzSb+vhX6UDGOnZN6R1N8511qpsfOepKm10I+SYeyUhpl1kfS/Sv3t+YGkDZKm5NWIc66gL0mtJFVJGlrDcdOVGtBz08f3k3SQpBckrZX0tqRTYse/IOn8WDxK0rxY7JQaPO+lX3+7JEvXNZJ0s6RVkj6QdFn6+B1r6OMySf3S5b6SPpF0paQvJN2f7EOsH50lXSjpW0lbJH0jaU6szf+S9JakdZJmSmqa43v7F0n/HYt7SdokqXmhv69K+mLslG7sJM6zk6QbJb1T279zxk7ljx1Jv5b0QCzeP91+y1x/P8VcofRUasA+ksOxZ0m6QVJLSfMlzZH0tKTdJf1C0gwz+1Ee5z5Z0pGSukkaLql/+vsXpOsOk9RD0rA82ozbQ1IbSXsr9YvLyDl3h6QZkia51H8Zg2LVwyX9h6R9030dtb0iPSXRO0vTlijvJOmHefwMlYyxo9KNHTPby8zWStqo1B+XSYX9KBWJsaOSjZ0ukv4ZO8dSpRLKAbn+AMUklLaSVjnntm7/hpm9ku7wRjM7OnbsI865/3PObZPUXVILSROdc1ucc89JekzSmXmce6Jzbq1z7mNJz6fblFJv5O+cc8udc2uU+u+sENskTXDObXbObSywDUm6zTn3Wbovc2L9lHOutXNuXobXPSnp/PTiXiul/muRpOZF9KWSMHZqVujYkXPuY5ea8moraZykJUX0o9IwdmpW6NhpodRVTdw6pRJyTopJKKsltY3P9TnneqUH8upE28tj5Q6Slqd/ydt9JCmfhcMvYuUNSr0RUduJdgvxpXNuU4GvjcvUz5rcLelPSl2Gv63U4JVSl8T1AWOnZoWOnUj6D8q9kh6ppfWcUmDs1KzQsfONpOTNP7tIWp/riYtJKK9K2ixpcA7Hxrc0/kxSJzOLn3svSZ+my1Xy/xPfI48+fS6pU6LdQiS3YPb6ZGbJPgXdstk5t805N8E5t49zrqNSSeVTff8e1XWMnczHh7ajUlM89eUuQcZO5uOL9bakQ2Pn20+p6cV3c22g4ITinFsr6VpJU8xsmJm1NLMdzKy7pJ2zvHS+UllzrJk1NrO+kgZJejBd/6ak08ysuZl1lnReHt2aJek/zayjme0q6ao8f6xM/impi5l1N7Omkq5J1K+QtF+gc8nM2pjZ/unbhw+W9FtJ1yX+u6qzGDue0GPnNDP7Ufr9bKfU2HkjfbVS5zF2PEHHjlJrMoPMrI+Z7SzpOkkPO+fKcoUi59wkSVdIGqvUD7dCqdvOrpT0SobXbFHqFzlAqbsipkg6xzm3fZ73FqUWglYodbk+o7p2MrhT0lNK/SJel/Rwfj9R9Zxz7yr15j6r1F0eyTnIuyQdnJ7H/WsubabvO++Tobqtvr875QlJd6cX4eoNxk4k9NjZU6k1uPWSFio1Lz+kkL5XKsZOJOjYcc69rdSdbDMkrVRq7eTSfPq8/bY3AACKUq+3XgEAlA8JBQAQBAkFABAECQUAEAQJBQAQRF6fnjUzbgmrQM65St+ym3FTmVY559rVdieyYexUrGrHDlcoQMNV6BYhQLVjh4QCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIgoQCAAiChAIACIKEAgAIIq8HbAHwde7cOSoff/zxXt3gwYO9eMCAAVHZOf+5UQcccEBUfv/990N2ERXq4IMPjso9e/b06u64446Mr9thB/86YNu2bVH5rLPO8upmzpxZTBfzxhUKACAIEgoAIAgSCgAgiDq9htKhQwcv7tq1qxfvuuuuUfmkk07y6vr16+fF7du3j8oLFizw6h566CEvfuyxx6LyokWL8ugx6pqBAwd68VVXXeXFhxxySFRu2bJl1rbic91JQ4YMicqTJ0/Op4uoUJdccokXH3jggV7cp0+fqBwfR1L2sZIUP/b222/36rZu3erFyb9loXGFAgAIgoQCAAjCkrcvZj3YLPeDA+nSpYsXX3zxxVH57LPP9uqWLVvmxV999VVUfvzxx3M+5+GHH+7F3bp18+L99tsvKj/xxBNe3ejRo6Py8uXLcz5nMZxzVpYTFag2xk1So0aNonLv3r29ul/96lde3KtXr6jcrFkzr86sNG91/Jzz588vyTmqscA516NcJytEJYydbPbZZx8vHjp0aFQeP368V9eiRQsvzmdaKy7bbcNJ69at8+JTTz01Ks+bN6+g86dVO3a4QgEABEFCAQAEQUIBAARREbcNx+epzzvvPK9u3LhxXhyfP5wzZ45Xd84555Sgd1Ljxo29OH4L8uzZs726Nm3aROXjjjuuJP1B/uLjKDm3XYwXX3wxKs+dO9erS66/XHPNNRnbGTZsWFQu4xoKivTUU095cXx9NR9r16714vvvvz/jsfvuu68Xn3zyyRmPbdWqlRfvvPPOBfQud1yhAACCIKEAAIIgoQAAgqiVNZTkZ0t+85vfROUTTzzRq3vmmWe8OD7XvH79+hL07t99++23XlxVVZXx2Pj8ZnzrF8n/XAzKa++9947KmzZt8uq+++47L45vH//www97dcltxeNtxbdPkaRbb701Y3+2bNnixXfddVfGY1G7evT4/uMWEyZM8Or22GOPIOe47LLLvHjWrFkZj40/BkHKvoaSdO2110bl5PpPCFyhAACCIKEAAIKolSmvn//8514cn+aKb60iZX9yWbnstNNOXjxx4sSonNyKY/HixVGZKa7KER9zyR1ZN2/e7MXZdpA+88wzvfiKK66Iyskte5Li01znn3++V7dkyZKsr0X5xKe4JOkf//hHVC50uxRJWr16tRfHPyKR/AhENl9++aUXx7d4ik/tVifbreshcIUCAAiChAIACIKEAgAIomxrKCNHjozKl19+uVcXv0WuEtZMkrc1J59y9sMf/jAqf/31117dBRdcULqOIYjkEzmTa2SdO3eOyg888IBXd8QRR+R8nvjtx5I0aNCgqPzuu+/m3A5K65hjjvHiu+++24vj6yb5rKFMmzbNi59++mkvzmfdJO61117z4kcffTQqJ28/TsrncSWF4AoFABAECQUAEAQJBQAQRNnWUA455JDvT7qjf9rkVhjlkNxaPL4lfXwrGEnq1KlTxnZeffVVL/70008D9A6hxcfc6aef7tWNGTPGiw877LAg51y5cqUXf/TRR0HaRfHij+598MEHvbq2bdvm3E5yW6bf//73UTm+zYkkbdiwIY8eZpbcgr5169Y5v7ZDhw5B+pAJVygAgCBIKACAICriiY3xp5wln3j2wQcfFNxufLff+C7F0r/fuhyfkrv66qu9uuQT0rg1uO6J37KZvE00m+RtolOnTs1Y/5Of/MSr69WrlxdPmTIlKiefTIryGjp0aFTOZ4orKT7FJUlXXnllwW3l6uijj/biESNG5Pza+M7
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 6 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"fig = plt.figure()\n",
|
||
|
"for i in range(6):\n",
|
||
|
" plt.subplot(2,3,i+1)\n",
|
||
|
" plt.tight_layout()\n",
|
||
|
" plt.imshow(train_loader_a[0][0][i][0], cmap='gray', interpolation='none')\n",
|
||
|
" plt.title(\"Ground Truth: {}\".format(train_loader_a[0][1][i]))\n",
|
||
|
" plt.xticks([])\n",
|
||
|
" plt.yticks([])\n",
|
||
|
"fig"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Net(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(Net, self).__init__()\n",
|
||
|
" self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
|
||
|
" self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
|
||
|
" self.conv2_drop = nn.Dropout2d()\n",
|
||
|
" self.fc1 = nn.Linear(320, 50)\n",
|
||
|
" self.fc2 = nn.Linear(50, 10)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
|
||
|
" x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
|
||
|
" x = x.view(-1, 320)\n",
|
||
|
" x = F.relu(self.fc1(x))\n",
|
||
|
" x = F.dropout(x, training=self.training)\n",
|
||
|
" x = self.fc2(x)\n",
|
||
|
" return x, F.softmax(x, dim=1), F.log_softmax(x, dim=1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"h1 = Net().cuda()\n",
|
||
|
"optimizer = optim.SGD(h1.parameters(), lr=learning_rate, momentum=momentum)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_counter, test_counter, train_losses, test_losses = training.train(\n",
|
||
|
" n_epochs, h1, optimizer, F.nll_loss, train_loader_a, test_loader,\n",
|
||
|
" batch_size_train, batch_size_test, device=\"cuda\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"Text(0, 0.5, 'negative log likelihood loss')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2dd5xU5dXHv2cpCyzSi8CiFAFpK2UVwQZG7IbX2AVL1NdYscUWNDEm5k18E2sSsbxKolgSCxqxYYsogoIiVaVLC026dJ73j3Mf5u7szO7d3Zmd3Z3z/Xzmc8vM3Hvu3dnnd895nucccc5hGIZhZC85mTbAMAzDyCwmBIZhGFmOCYFhGEaWY0JgGIaR5ZgQGIZhZDm1M21AWWnRooXr0KFDps0wDMOoVkybNm2tc65loveqnRB06NCBqVOnZtoMwzCMaoWILEn2noWGDMMwshwTAsMwjCzHhMAwDCPLqXZ9BIZh1Cx27drFsmXL2L59e6ZNqRHUq1eP/Px86tSpE/k7JgSGYWSUZcuWsd9++9GhQwdEJNPmVGucc6xbt45ly5bRsWPHyN+z0JBhGBll+/btNG/e3EQgBYgIzZs3L7N3ZUJgGEbGMRFIHeW5l9khBGPHQocOkJPD7gM78+glU9i0KdNGGYZhVA1qvhCMHQuXXw5LloBzvPDd4Vzx1ADuOmdupi0zDKMKsG7dOvr06UOfPn3Yf//9adeu3b7tnTt3lvjdqVOnMnLkyDKdr0OHDqxdu7YiJqecmt9ZPGoU/PADAA74EzcB8Ne3OnHjMsjPz6BthmFknObNmzN9+nQA7rrrLho2bMjPf/7zfe/v3r2b2rUTN5WFhYUUFhZWip3ppOZ7BN99t2/13xzDl/RjFL9lL8I99+j+xYthw4bYV9asgQ8/rFQrDcOoQlx88cVcccUVDBgwgFtuuYXPPvuMgQMH0rdvXwYNGsQ333wDwIcffsipp54KqIhccsklDB48mE6dOvHQQw9FPt/ixYs59thjKSgo4Ec/+hHfBe3WP//5T3r16sUhhxzC0UcfDcDs2bM57LDD6NOnDwUFBcybN6/C11vzPYIDDtCwEOoNtGQ1o7iH7xseyONPXMDChfDOO3D44fDJJyACw4fD++/DihXQqlWG7TeMLOL66yF4OE8ZffrAAw+U/XvLli1j0qRJ1KpVi02bNjFx4kRq167Nu+++yy9+8QteeumlYt/5+uuv+eCDD9i8eTPdunXjyiuvjDSe/9prr+Wiiy7ioosu4sknn2TkyJGMGzeOu+++m7fffpt27dqxIXhaHT16NNdddx3Dhw9n586d7Nmzp+wXF0fN9wjuuQcaNOBruvE6p3EVf6V+gxzu+J886tWDGTPg3HNh8mR46il4/XWYMAH27IFXXsm08YZhZIqzzjqLWrVqAbBx40bOOussevXqxQ033MDs2bMTfueUU04hNzeXFi1a0KpVK1atWhXpXJ9++innn38+ABdccAEff/wxAEcccQQXX3wxjz/++L4Gf+DAgfzud7/jD3/4A0uWLKF+/foVvdQs8AiGDwdgwQ3vcsCaJVyV/y/4/WO0Hf4TFp0LjRpBnTqwbBnceis0aQLdu8Pu3fDii/Czn2XYfsPIIsrz5J4u8vLy9q3feeedDBkyhFdeeYXFixczePDghN/Jzc3dt16rVi12795dIRtGjx7NlClTGD9+PP3792fatGmcf/75DBgwgPHjx3PyySfz6KOPcuyxx1boPDXfIwAYPpxTVj/Foj0H0mrptH3i0KIF1K2r4aC//EX7CRYsgPvvh7POgg8+gCrWuW8YRgbYuHEj7dq1A2DMmDEpP/6gQYN4/vnnARg7dixHHXUUAAsWLGDAgAHcfffdtGzZkqVLl7Jw4UI6derEyJEjGTZsGDNmzKjw+bNDCAJySrjaggL405/gppvghBPgzDM1PPTqq/p+CsJwhmFUU2655RZuv/12+vbtW+GnfICCggLy8/PJz8/nxhtv5OGHH+app56ioKCAp59+mgcffBCAm2++md69e9OrVy8GDRrEIYccwj/+8Q969epFnz59mDVrFhdeeGGF7RHnXIUPUpkUFha6yihM4xwcdBC0awe9e8MTT2ifwcknp/3UhpFVzJ07l+7du2fajBpFonsqItOccwnHumaVR1AWRNQrmDgRHn0U8vLgF7+AvXszbZlhGEZqMSEogWuvhWuu0eFsDz4IX30VCxUZhmHUFEwISiA/Hx5+GHr1gvPOg65d4a67zCswDKNmYUIQkdq14Ze/1HkHl18On36q/QiGYRjVHROCMnDuuXDRRfD00zBoEPzqV5m2yDAMo+KYEJSBWrVgzBhYvVqF4F//yrRFhmEYFceEoBw0bgxHHw2zZsGOHZm2xjCMilCRNNSgiecmTZqU8L0xY8ZwzTXXpNrklFPzU0ykiX79NA3FrFnQv3+mrTEMo7yUloa6ND788EMaNmzIoEGD0mVi2jGPoJz066fLL77IrB2GkXWEKg7SoYNup5hp06ZxzDHH0L9/f0444QRWrlwJwEMPPUSPHj0oKCjg3HPPZfHixYwePZr777+fPn36MHHixEjHv+++++jVqxe9evXigSDB0tatWznllFM45JBD6NWrFy+88AIAt912275zlkWgyoJ5BOWkUycNEZkQGEYl4isOBsWmWLJEt2FfDrGK4pzj2muv5dVXX6Vly5a88MILjBo1iieffJLf//73LFq0iNzcXDZs2ECTJk244ooryuRFTJs2jaeeeoopU6bgnGPAgAEcc8wxLFy4kLZt2zJ+/HhA8xutW7eOV155ha+//hoR2ZeKOtWYR1BORKBvXxMCw6hUQhUH9/HDD7o/RezYsYNZs2YxdOhQ+vTpw29/+1uWLVsGaI6g4cOH88wzzyStWlYaH3/8Maeffjp5eXk0bNiQn/zkJ0ycOJHevXszYcIEbr31ViZOnEjjxo1p3Lgx9erV49JLL+Xll1+mQYMGKbvOMCYEFaBfP51tvGtXpi0xjCwhVHEw0v5y4JyjZ8+eTJ8+nenTpzNz5kzeeecdAMaPH8/VV1/NF198waGHHpqSBHSerl278sUXX9C7d2/uuOMO7r77bmrXrs1nn33GmWeeyeuvv86JJ56YsvOFMSGoAP366aihr7/OtCWGkSUccEDZ9peD3Nxc1qxZw6effgrArl27mD17Nnv37mXp0qUMGTKEP/zhD2zcuJEtW7aw3377sXnz5sjHP+qooxg3bhw//PADW7du5ZVXXuGoo45ixYoVNGjQgBEjRnDzzTfzxRdfsGXLFjZu3MjJJ5/M/fffz1dffZWy6wxjfQQVINxh3Lt3Zm0xjKzgnnuK9hEANGjAvgLkKSAnJ4cXX3yRkSNHsnHjRnbv3s31119P165dGTFiBBs3bsQ5x8iRI2nSpAmnnXYaZ555Jq+++ioPP/zwvloCnjFjxjBu3Lh925MnT+biiy/msMMOA+Cyyy6jb9++vP3229x8883k5ORQp04dHnnkETZv3sywYcPYvn07zjnuu+++lF1nGEtDXQH27NEKZ5ddpknpDMMoO2VOQz12rPYJfPedegL33JOyjuKaQlnTUJtHUAFq1dIO4+efh5494cILoV69TFtlGDWc4cOt4U8x1kdQQe67D9q319rGTZtCt25ax2D79kxbZhiGEQ0Tggpy2GHw+efw/vtw9dXQsSO89JLWOzYMIxrVLURdlSnPvUybEIhIexH5QETmiMhsEbkuwWdERB4SkfkiMkNE+qXLnnQiAkOGwB//qOUsc3NhwoRMW2UY1YN69eqxbt06E4MU4Jxj3bp11CtjjDqdfQS7gZucc1+IyH7ANBGZ4JybE/rMSUCX4DUAeCRYVlvq14cjjzQhMIyo5Ofns2zZMtasWZNpU2oE9erVIz8/v0zfSZsQOOdWAiuD9c0iMhdoB4SFYBjwd6ePApNFpImItAm+W20ZOhRuuw1WroQ2bTJtjWFUberUqUPHjh0zbUZWUyl9BCLSAegLTIl7qx2wNLS9LNgX//3LRWSqiEytDk8NQ4fq8t13M2uHYRhGFNI
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"fig = plt.figure()\n",
|
||
|
"plt.plot(train_counter, train_losses, color='blue')\n",
|
||
|
"plt.scatter(test_counter, test_losses, color='red')\n",
|
||
|
"plt.legend(['Train Loss', 'Test Loss'], loc='upper right')\n",
|
||
|
"plt.xlabel('number of training examples seen')\n",
|
||
|
"plt.ylabel('negative log likelihood loss')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"with torch.no_grad():\n",
|
||
|
" _, _, output = h1(test_loader[0][0].to(\"cuda\"))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAELCAYAAAD+9XA2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAe7ElEQVR4nO3debCU1ZnH8d/jsCk7aFAUwR0FCSAKBrcQjBsYBEVwGXUsStCMlURFSYESUcao0bECyEjcSgw4RrQI7mgig6AZYIjGBURlX0Q2QTFsZ/7o5vU9r7f7dvc93bcv9/upoup57nn77cO9h/vwnvP2ec05JwAAqmq/6u4AAGDfQEEBAARBQQEABEFBAQAEQUEBAARBQQEABFHjC4qZPWFmd6Xj081sUYHnmWhmo8L2DuWMsYNCMG4yK0lBMbOlZrbdzLaZ2br0D6RR6Pdxzv2Pc+64HPpztZnNTrx2qHNuTOg+VfDeE9Pfh71//mlmW4v9vjUVY8d7b8ZOjhg33nsPMrNFZrbFzL4wsyfNrEkx3quUVyh9nXONJHWV1E3SyOQBZlanhP2pFulB1GjvH0lTJD1b3f0qc4wdMXYKwLhJeVtST+dcU0lHSqoj6a5ivFHJp7ycc6skvSypoySZmTOzG8zsE0mfpL/Wx8wWmtlmM5tjZp32vt7MupjZAjPbambPSGoQazvLzFbG8jZmNs3M1pvZBjMbZ2bHS5oo6dT0/142p4+NLmPT+RAzW2JmG81supm1jrU5MxtqZp+k+zjezCzf74WZNZQ0QNKT+b62NmLsfIexk7vaPm6ccyucc1/GvrRb0tH5fA9zVfKCYmZtJJ0v6f9iX+4nqbukE8ysi6THJF0nqaWk/5I03czqm1k9SS9IekpSC6X+dzYgw/v8i6QZkpZJaifpUElTnXMfSRoqaW76f3rNKnhtL0n/IWmgpEPS55iaOKyPpJMldUofd076tYenf+CH5/DtGCBpvaRZORxb6zF2PIydHDFuJDM7zcy2SNqa7v9/Zjq2SpxzRf8jaamkbZI2K/WNmiBp/3Sbk9QrduzDksYkXr9I0pmSzpC0WpLF2uZIuisdnyVpZTo+Val/cHUq6M/VkmYnvvZE7DyPSro31tZI0k5J7WJ9Pi3W/t+Sbivg+/KGpNGl+BnU1D+MHcYO4ybouDlU0mhJxxbj+17K+cN+zrmZGdpWxOK2kq4ys3+Pfa2epNZKfVNXufR3Jm1ZhnO2kbTMObergL62lrRgb+Kc22ZmG5T6YSxNf3lt7PhvlBoAOUv/b+IsSUMK6F9tw9iJYezkjHGT4JxbZWavKHX107WAfmZVLrcNx39YKyTd7ZxrFvtzgHNuiqQ1kg5NzB1musxbIelwq3jRrbItllcrNcgkRfPVLSWtquwvkocrJb3tnPss4DlrI8YOClEbx81edSQdVYTzlk1BiZskaaiZdbeUhmZ2gZk1ljRX0i5JN5pZXTPrL+mUDOf5m1KD4Z70ORqYWc902zpJh6XnRysyRdI1ZtbZzOpLGivpXefc0kB/R0n6V6UueREOYweF2KfHjZldvnd9xczaSrpbqSnT4MquoDjn5il1KT9O0iZJS5Saf5Rzboek/ul8o6RLJU3LcJ7dkvoqdTfDckkr08dL0puSPpC01sy+rOC1MyWNkvScUgPkKEmDcul/eoFsWyULZKdKOkzc8hkUYweFqAXj5gRJc8zsa6VuIV6kIk2Xmj81CABAYcruCgUAUDNRUAAAQVBQAABBUFAAAEFQUAAAQeT1SXkz45awMuScy3tzwVJi3JStL51zB1V3J7Jh7JStCscOVyhA7ZVpCxGgMhWOHQoKACAICgoAIAgKCgAgCAoKACAICgoAIAgKCgAgCAoKACAICgoAIAgKCgAgCAoKACAICgoAIAgKCgAgiLx2GwaA2uyiiy7y8scee8zLmzZtmvG1ixcv9vJp06ZF8SOPPOK1LV26tMAeVi+uUAAAQVBQAABBmHO5P7+Gh92UJx6whQLNd851q+5OZFMOY2f06NFRfPvtt3ttGzdu9PLdu3dnPE+DBg28vHHjxlG8ZMkSr+3YY4/Nt5ulVuHY4QoFABAEBQUAEAQFBQAQBLcNA0DMYYcd5uU33HBDFA8fPtxrmzhxopdv27Yt43kPP/xwLx81alQUX3zxxV7bwQcf7OVr167N0uPywRUKACAICgoAIIh9esqrZcuWUXzSSSd5bSeccIKX33LLLVHcunVrry15a/X27dujuF+/fl7b66+/XlhnUTJ9+vTx8gsvvNDLr7322oyvXbBggZf/5je/ieIZM2Z4ba1atfLyevXqRfGKFSty6yxK7ptvvvHye++9N4rvv//+gs+7fPlyL49/Uv6yyy7z2pjyAgDUahQUAEAQFBQAQBD79BrKIYccEsXnnXee13bjjTdmfF1yzSSZx7dQSG6ngOrRo0ePrO2DBg2K4uQayQEHHODl2bYj6tKli5dPmTIlil966SWv7cc//rGXt2jRIoonTZrktQ0bNizje6K0ktup3HfffUV5n/hOxcnfI8lbjBcuXFiUPoTGFQoAIAgKCgAgCAoKACCIsl9Dad++vZfHP0NwwQUXeG3NmjXz8iZNmkRxu3btwndO0rfffluU86JygwcPjuLJkyd7bfk8lqEq4usvAwYMyPl1559/fjG6gzKSXJsbOXKkl8c/a7J+/Xqvbfr06cXrWBFxhQIACIKCAgAIouynvEaMGOHlV1xxRc6vNfvuQYbJWwGT22QsXbo0it9++22vLX5rqCQ1b9485z6geIo1bbRmzZooTm61kpxmLVT8aX2S/4S+xYsXB3kPlF58mv2RRx7x2i655BIvjz+l8dJLLy1ux0qEKxQAQBAUFABAEBQUAEAQZb+GktyiomnTplGc3K7gzTff9PL4/PfMmTNzfs+OHTt6+f777+/l8bUZlE7yNsxevXpF8X77+f832rNnT87nTc51Z9sGJfkYhNGjR0dxPms6ySf7sW5SM8QfiSFJDzzwgJf37t07ihs2bOi1Jbe+Hz9+fBQnt7avqbhCAQAEQUEBAARBQQEABFH2ayizZ8/OmhdD8hHAybWaUm3rAV/yMyDxR+wm10yy/YyqsnX8/Pnzvfz666+P4s8//zzn8zCGyle3bt28/Ne//nUUJx9JEP/ciSTNmTMnipOfoSvF767qxhUKACAICgoAIIiyn/KqDrfcckvW9g0bNkTxO++8U+zuIG3FihVevn379ihO3lKcdN1110XxU089FbZjqNGST/CM384rSXXr1o3i5EcGklOX8R3PTzzxRK/tmGOO8fLHH388/86WOa5QAABBUFAAAEFQUAAAQbCGknbkkUdG8U9+8hOvLTlPOmHChCjeunVrcTuGSHK9atCgQVGcvJ3zxRdf9PL4Iwl27NhRhN6hpurbt6+X16tXL+OxlW271KFDhyhOrsUkPfbYY1H8pz/9yWubNm2al8+aNSuKV61alfW81YkrFABAEBQUAEAQFBQAQBCWzxYQZrbP7hdx5ZVXRvGTTz6Z9djOnTtH8XvvvVe0PuXKOVfW++nvy+Pmd7/7XRT/4he/yPl1Q4cO9fLkdjAlMt85163yw6pPKcZOjx49vLxOndyXlvN5JHT8Mc+Sv+4XfyxHRdavXx/F7777rtc2ZsyYKJ43b17O/amiCscOVygAgCAoKACAIGrtlFfjxo29PH4Z2b59e68tuUto/EmBu3btKkLv8sOUV/X5+OOPo/joo4/OeuyaNWuiuE2bNkXrUx6Y8qpGLVq0iOLkRxUGDx7s5eedd14UJ29r3rRpUxTfdNNNXltl0/dVwJQXAKB4KCgAgCAoKACAIGrtGkq7du28/NNPP43i5PYKyVsDX3755aL1qxCsoZROcm578uTJUVzZv6UHH3wwiit7REKJsIZSQ8TXW5K3H7/00ktRHH+kgyQdd9xxXr5t27ZQXWINBQBQPBQUAEAQFBQAQBC1dvv6bHPY8+fP9/I33nij2N1BDXHwwQcX/NoFCxYE7Alqk40bN0Zx8jEO8a3uk2t8yc9GLVy4sAi9+w5XKAC
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 6 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAELCAYAAAD+9XA2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAe7ElEQVR4nO3debCU1ZnH8d/jsCk7aFAUwR0FCSAKBrcQjBsYBEVwGXUsStCMlURFSYESUcao0bECyEjcSgw4RrQI7mgig6AZYIjGBURlX0Q2QTFsZ/7o5vU9r7f7dvc93bcv9/upoup57nn77cO9h/vwnvP2ec05JwAAqmq/6u4AAGDfQEEBAARBQQEABEFBAQAEQUEBAARBQQEABFHjC4qZPWFmd6Xj081sUYHnmWhmo8L2DuWMsYNCMG4yK0lBMbOlZrbdzLaZ2br0D6RR6Pdxzv2Pc+64HPpztZnNTrx2qHNuTOg+VfDeE9Pfh71//mlmW4v9vjUVY8d7b8ZOjhg33nsPMrNFZrbFzL4wsyfNrEkx3quUVyh9nXONJHWV1E3SyOQBZlanhP2pFulB1GjvH0lTJD1b3f0qc4wdMXYKwLhJeVtST+dcU0lHSqoj6a5ivFHJp7ycc6skvSypoySZmTOzG8zsE0mfpL/Wx8wWmtlmM5tjZp32vt7MupjZAjPbambPSGoQazvLzFbG8jZmNs3M1pvZBjMbZ2bHS5oo6dT0/142p4+NLmPT+RAzW2JmG81supm1jrU5MxtqZp+k+zjezCzf74WZNZQ0QNKT+b62NmLsfIexk7vaPm6ccyucc1/GvrRb0tH5fA9zVfKCYmZtJJ0v6f9iX+4nqbukE8ysi6THJF0nqaWk/5I03czqm1k9SS9IekpSC6X+dzYgw/v8i6QZkpZJaifpUElTnXMfSRoqaW76f3rNKnhtL0n/IWmgpEPS55iaOKyPpJMldUofd076tYenf+CH5/DtGCBpvaRZORxb6zF2PIydHDFuJDM7zcy2SNqa7v9/Zjq2SpxzRf8jaamkbZI2K/WNmiBp/3Sbk9QrduzDksYkXr9I0pmSzpC0WpLF2uZIuisdnyVpZTo+Val/cHUq6M/VkmYnvvZE7DyPSro31tZI0k5J7WJ9Pi3W/t+Sbivg+/KGpNGl+BnU1D+MHcYO4ybouDlU0mhJxxbj+17K+cN+zrmZGdpWxOK2kq4ys3+Pfa2epNZKfVNXufR3Jm1ZhnO2kbTMObergL62lrRgb+Kc22ZmG5T6YSxNf3lt7PhvlBoAOUv/b+IsSUMK6F9tw9iJYezkjHGT4JxbZWavKHX107WAfmZVLrcNx39YKyTd7ZxrFvtzgHNuiqQ1kg5NzB1musxbIelwq3jRrbItllcrNcgkRfPVLSWtquwvkocrJb3tnPss4DlrI8YOClEbx81edSQdVYTzlk1BiZskaaiZdbeUhmZ2gZk1ljRX0i5JN5pZXTPrL+mUDOf5m1KD4Z70ORqYWc902zpJh6XnRysyRdI1ZtbZzOpLGivpXefc0kB/R0n6V6UueREOYweF2KfHjZldvnd9xczaSrpbqSnT4MquoDjn5il1KT9O0iZJS5Saf5Rzboek/ul8o6RLJU3LcJ7dkvoqdTfDckkr08dL0puSPpC01sy+rOC1MyWNkvScUgPkKEmDcul/eoFsWyULZKdKOkzc8hkUYweFqAXj5gRJc8zsa6VuIV6kIk2Xmj81CABAYcruCgUAUDNRUAAAQVBQAABBUFAAAEFQUAAAQeT1SXkz45awMuScy3tzwVJi3JStL51zB1V3J7Jh7JStCscOVyhA7ZVpCxGgMhWOHQoKACAICgoAIAgKCgAgCAoKACAICgoAIAgKCgAgCAoKACAICgoAIAgKCgAgCAoKACAICgoAIAgKCgAgiLx2GwaA2uyiiy7y8scee8zLmzZtmvG1ixcv9vJp06ZF8SOPPOK1LV26tMAeVi+uUAAAQVBQAABBmHO5P7+Gh92UJx6whQLNd851q+5OZFMOY2f06NFRfPvtt3ttGzdu9PLdu3dnPE+DBg28vHHjxlG8ZMkSr+3YY4/Nt5ulVuHY4QoFABAEBQUAEAQFBQAQBLcNA0DMYYcd5uU33HBDFA8fPtxrmzhxopdv27Yt43kPP/xwLx81alQUX3zxxV7bwQcf7OVr167N0uPywRUKACAICgoAIIh9esqrZcuWUXzSSSd5bSeccIKX33LLLVHcunVrry15a/X27dujuF+/fl7b66+/XlhnUTJ9+vTx8gsvvNDLr7322oyvXbBggZf/5je/ieIZM2Z4ba1atfLyevXqRfGKFSty6yxK7ptvvvHye++9N4rvv//+gs+7fPlyL49/Uv6yyy7z2pjyAgDUahQUAEAQFBQAQBD79BrKIYccEsXnnXee13bjjTdmfF1yzSSZx7dQSG6ngOrRo0ePrO2DBg2K4uQayQEHHODl2bYj6tKli5dPmTIlil966SWv7cc//rGXt2jRIoonTZrktQ0bNizje6K0ktup3HfffUV5n/hOxcnfI8lbjBcuXFiUPoTGFQoAIAgKCgAgCAoKACCIsl9Dad++vZfHP0NwwQUXeG3NmjXz8iZNmkRxu3btwndO0rfffluU86JygwcPjuLJkyd7bfk8lqEq4usvAwYMyPl1559/fjG6gzKSXJsbOXKkl8c/a7J+/Xqvbfr06cXrWBFxhQIACIKCAgAIouynvEaMGOHlV1xxRc6vNfvuQYbJWwGT22QsXbo0it9++22vLX5rqCQ1b9485z6geIo1bbRmzZooTm61kpxmLVT8aX2S/4S+xYsXB3kPlF58mv2RRx7x2i655BIvjz+l8dJLLy1ux0qEKxQAQBAUFABAEBQUAEAQZb+GktyiomnTplGc3K7gzTff9PL4/PfMmTNzfs+OHTt6+f777+/l8bUZlE7yNsxevXpF8X77+f832rNnT87nTc51Z9sGJfkYhNGjR0dxPms6ySf7sW5SM8QfiSFJDzzwgJf37t07ihs2bOi1Jbe+Hz9+fBQnt7avqbhCAQAEQUEBAARBQQEABFH2ayizZ8/OmhdD8hHAybWaUm3rAV/yMyDxR+wm10yy/YyqsnX8/Pnzvfz666+P4s8//zzn8zCGyle3bt28/Ne//nUUJx9JEP/ciSTNmTMnipOfoSvF767qxhUKACAICgoAIIiyn/KqDrfcckvW9g0bNkTxO++8U+zuIG3FihVevn379ihO3lKcdN1110XxU089FbZjqNGST/CM384rSXXr1o3i5EcGklOX8R3PTzzxRK/tmGOO8fLHH388/86WOa5QAABBUFAAAEFQUAAAQbCGknbkkUdG8U9+8hOvLTlPOmHChCjeunVrcTuGSHK9atCgQVGcvJ3zxRdf9PL4Iwl27NhRhN6hpurbt6+X16tXL+OxlW271KFDhyhOrsUkPfbYY1H8pz/9yWubNm2al8+aNSuKV61alfW81YkrFABAEBQUAEAQFBQAQBCWzxYQZrbP7hdx5ZVXRvGTTz6Z9djOnTtH8XvvvVe0PuXKOVfW++nvy+Pmd7/7XRT/4he/yPl1Q4cO9fLkdjAlMt85163yw6pPKcZOjx49vLxOndyXlvN5JHT8Mc+Sv+4XfyxHRdavXx/F7777rtc2ZsyYKJ43b17O/amiCscOVygAgCAoKACAIGrtlFfjxo29PH4Z2b59e68tuUto/EmBu3btKkLv8sOUV/X5+OOPo/joo4/OeuyaNWuiuE2bNkXrUx6Y8qpGLVq0iOLkRxUGDx7s5eedd14UJ29r3rRpUxTfdNNNXltl0/dVwJQXAKB4KCgAgCAoKACAIGrtGkq7du28/NNPP43i5PYKyVsDX3755aL1qxCsoZROcm578uTJUVzZv6UHH3wwiit7REKJsIZSQ8TXW5K3H7/00ktRHH+kgyQdd9xxXr5t27ZQXWINBQBQPBQUAEAQFBQAQBC1dvv6bHPY8+fP9/I33nij2N1BDXHwwQcX/NoFCxYE7Alqk40bN0Zx8jEO8a3uk2t8yc9GLVy4sAi9+w5XKAC
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 6 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"fig = plt.figure()\n",
|
||
|
"for i in range(6):\n",
|
||
|
" plt.subplot(2,3,i+1)\n",
|
||
|
" plt.tight_layout()\n",
|
||
|
" plt.imshow(test_loader[0][0][i][0], cmap='gray', interpolation='none')\n",
|
||
|
" plt.title(\"Prediction: {}\".format(\n",
|
||
|
" output.data.max(1, keepdim=True)[1][i].item()))\n",
|
||
|
" plt.xticks([])\n",
|
||
|
" plt.yticks([])\n",
|
||
|
"fig"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<div class=\"container\">\n",
|
||
|
" <style>\n",
|
||
|
" #viz {\n",
|
||
|
" font-family: \"Helvetica Neue\", Helvetica, Arial, sans-serif;\n",
|
||
|
" position: relative;\n",
|
||
|
" width: 960px;\n",
|
||
|
"}\n",
|
||
|
" \n",
|
||
|
".axis text {\n",
|
||
|
" font: 10px sans-serif;\n",
|
||
|
"}\n",
|
||
|
" \n",
|
||
|
".axis path,\n",
|
||
|
".axis line {\n",
|
||
|
" fill: none;\n",
|
||
|
" stroke: #000;\n",
|
||
|
" shape-rendering: crispEdges;\n",
|
||
|
"}\n",
|
||
|
" \n",
|
||
|
".bar {\n",
|
||
|
" fill: rgba(113, 113, 255, 0.8);\n",
|
||
|
"}\n",
|
||
|
" \n",
|
||
|
".bar:hover {\n",
|
||
|
" fill: rgba(226, 75, 158, 0.8);\n",
|
||
|
"}\n",
|
||
|
" \n",
|
||
|
"label {\n",
|
||
|
" position: absolute;\n",
|
||
|
" top: 10px;\n",
|
||
|
" right: 10px;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".plot {\n",
|
||
|
" display: inline-block;\n",
|
||
|
" margin-left: 20px;\n",
|
||
|
" margin-right:20px;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".table {\n",
|
||
|
" width: 600px;\n",
|
||
|
" text-align: center;\n",
|
||
|
" background-color: lightblue;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".data-selector .control-group {\n",
|
||
|
" display: inline-block;\n",
|
||
|
" margin-right: 50px;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".data-selector .control-subgroup {\n",
|
||
|
" display: inline-block;\n",
|
||
|
" margin-right: 20px;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".data-selector .control {\n",
|
||
|
" display: inline-block;\n",
|
||
|
" margin-right: 5px;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
" </style>\n",
|
||
|
" <div id=\"widget\">\n",
|
||
|
" </div>\n",
|
||
|
" <script>\n",
|
||
|
" var data = null;\n",
|
||
|
" window.API_SERVICE_ENVIRONMENT = {\"environment_type\": \"local\", \"base_url\": \"\", \"port\": 5000};\n",
|
||
|
" window.WIDGET_STATE = {\n",
|
||
|
" data: data,\n",
|
||
|
" sweepStatus: null,\n",
|
||
|
" selectedDataPoint: null,\n",
|
||
|
" training: true,\n",
|
||
|
" testing: true,\n",
|
||
|
" newError: true,\n",
|
||
|
" strictImitation: true,\n",
|
||
|
" error: null,\n",
|
||
|
" loading: false\n",
|
||
|
" };\n",
|
||
|
"\n",
|
||
|
" !function(n){var t={};function e(c){if(t[c])return t[c].exports;var I=t[c]={i:c,l:!1,exports:{}};return n[c].call(I.exports,I,I.exports,e),I.l=!0,I.exports}e.m=n,e.c=t,e.d=function(n,t,c){e.o(n,t)||Object.defineProperty(n,t,{enumerable:!0,get:c})},e.r=function(n){\"undefined\"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(n,Symbol.toStringTag,{value:\"Module\"}),Object.defineProperty(n,\"__esModule\",{value:!0})},e.t=function(n,t){if(1&t&&(n=e(n)),8&t)return n;if(4&t&&\"object\"==typeof n&&n&&n.__esModule)return n;var c=Object.create(null);if(e.r(c),Object.defineProperty(c,\"default\",{enumerable:!0,value:n}),2&t&&\"string\"!=typeof n)for(var I in n)e.d(c,I,function(t){return n[t]}.bind(null,I));return c},e.n=function(n){var t=n&&n.__esModule?function(){return n.default}:function(){return n};return e.d(t,\"a\",t),t},e.o=function(n,t){return Object.prototype.hasOwnProperty.call(n,t)},e.p=\"/\",e(e.s=20)}([function(module,exports,__webpack_require__){\"use strict\";eval(\"\\n\\nif (true) {\\n module.exports = __webpack_require__(22);\\n} else {}\\n//# sourceURL=[module]\\n//# sourceMappingURL=data:application/json;charset=utf-8;base64,eyJ2ZXJzaW9uIjozLCJzb3VyY2VzIjpbIndlYnBhY2s6Ly8vLi9ub2RlX21vZHVsZXMvcmVhY3QvaW5kZXguanM/Y2E3OCJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBYTs7QUFFYixJQUFJLElBQXFDO0FBQ3pDLG1CQUFtQixtQkFBTyxDQUFDLEVBQStCO0FBQzFELENBQUMsTUFBTSxFQUVOIiwiZmlsZSI6IjAuanMiLCJzb3VyY2VzQ29udGVudCI6WyIndXNlIHN0cmljdCc7XG5cbmlmIChwcm9jZXNzLmVudi5OT0RFX0VOViA9PT0gJ3Byb2R1Y3Rpb24nKSB7XG4gIG1vZHVsZS5leHBvcnRzID0gcmVxdWlyZSgnLi9janMvcmVhY3QucHJvZHVjdGlvbi5taW4uanMnKTtcbn0gZWxzZSB7XG4gIG1vZHVsZS5leHBvcnRzID0gcmVxdWlyZSgnLi9janMvcmVhY3QuZGV2ZWxvcG1lbnQuanMnKTtcbn1cbiJdLCJzb3VyY2VSb290IjoiIn0=\\n//# sourceURL=webpack-internal:///0\\n\")},function(module,exports,__webpack_require__){\"use strict\";eval(\"\\n\\nvar bind = __webpack_require__(10);\\n\\n/*global toString:true*/\\n\\n// utils is a library of generic helper functions non-specific to axios\\n\\nvar toString = Object.prototype.toString;\\n\\n/**\\n * Determine if a value is an Array\\n *\\n * @param {Object} val The value to test\\n * @returns {boolean} True if value is an Array, otherwise false\\n */\\nfunction isArray(val) {\\n return toString.call(val) === '[object Array]';\\n}\\n\\n/**\\n * Determine if a value is undefined\\n *\\n * @param {Object} val The value to test\\n * @returns {boolean} True if the value is undefined, otherwise false\\n */\\nfunction isUndefined(val) {\\n return typeof val === 'undefined';\\n}\\n\\n/**\\n * Determine if a value is a Buffer\\n *\\n * @param {Object} val The value to test\\n * @returns {boolean} True if value is a Buffer, otherwise false\\n */\\nfunction isBuffer(val) {\\n return val !== null && !isUndefined(val) && val.constructor !== null && !isUndefined(val.constructor)\\n && typeof val.constructor.isBuffer === 'function' && val.constructor.isBuffer(val);\\n}\\n\\n/**\\n * Determine if a value is an ArrayBuffer\\n *\\n * @param {Object} val The value to test\\n * @returns {boolean} True if value is an ArrayBuffer, otherwise false\\n */\\nfunction isArrayBuffer(val) {\\n return toString.call(val) === '[object ArrayBuffer]';\\n}\\n\\n/**\\n * Determine if a value is a FormData\\n *\\n * @param {Object} val The value to test\\n * @returns {boolean} True if value is an FormData, otherwise false\\n */\\nfunction isFormData(val) {\\n return (typeof FormData !== 'undefined') && (val instanceof FormData);\\n}\\n\\n/**\\n * Determine if a value is a view on an ArrayBuffer\\n *\\n * @param {Object} val The value to test\\n * @returns {boolean} True if value is a view on an ArrayBuffer, otherwise false\\n */\\nfunction isArrayBufferView(val) {\\n var result;\\n if ((typeof ArrayBuffer !== 'undefined') && (ArrayBuffer.isView)) {\\n result = ArrayBuffer.isView(val);\\n } else {\\n result = (val) && (val.buffer) && (val.buffer instanceof ArrayBuffer);\\n }\\n return result;\\n}\\n\\n/**\\n * Determine if a value is a String\\n *\\n * @param {Object} val The value to test\\n * @returns {boolea
|
||
|
" </script>\n",
|
||
|
"</div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"h2 = copy.deepcopy(h1)\n",
|
||
|
"\n",
|
||
|
"analysis = CompatibilityAnalysis(\"sweeps-mnist\", n_epochs, h1, h2, train_loader_b, test_loader,\n",
|
||
|
" batch_size_train, batch_size_test,\n",
|
||
|
" OptimizerClass=optim.SGD,\n",
|
||
|
" optimizer_kwargs={\"lr\": learning_rate, \"momentum\": momentum},\n",
|
||
|
" NewErrorLossClass=bcloss.BCCrossEntropyLoss,\n",
|
||
|
" StrictImitationLossClass=bcloss.StrictImitationCrossEntropyLoss,\n",
|
||
|
" lambda_c_stepsize=0.25, device=\"cuda\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.6.9"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|