This commit is contained in:
Frozenmad 2019-07-17 11:03:47 +08:00
Родитель babe10102a
Коммит b2ca990bfe
3 изменённых файлов: 499 добавлений и 496 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -1,262 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Copyright (c) Microsoft Corporation. All rights reserved.*\n",
"\n",
"*Licensed under the MIT License.*\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Understand your NLP models\n",
"\n",
"This is a torturial on how to utilize Interpreter class to explain certain hidden layers in your NLP models. We provide the explanation by measuring the information of input words ${\\bf x}_1,...,{\\bf x}_n$ that is encoded in hidden state ${\\bf s} = \\Phi({\\bf x})$. The method is from paper [*Towards a Deep and Unified Understanding of Deep Neural Models in NLP*](https://www.microsoft.com/en-us/research/publication/towards-a-deep-and-unified-understanding-of-deep-neural-models-in-nlp/) that is accepted by **ICML 2019**. In this torturial, we provide a simple example for you to start quickly."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import sys\n",
"\n",
"sys.path.append(\"../../\")\n",
"\n",
"from utils_nlp.interpreter.Interpreter import calculate_regularization, Interpreter"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 0 Prepare necessary components\n",
"Suppose the $\\Phi$ we need to explain is a simple linear function:\n",
"$$\\Phi(x)=10 \\times x[0] + 20 \\times x[1] + 5 \\times x[2] - 20 \\times x[3] - 10 \\times x[4]$$\n",
"From the definition of $\\Phi$ we can know that, the weights of the 2nd and the 4th elements in input $x$ are the biggest (in abs form), which means that they contributes the most to the results. Therefore, a reasonable explanation should show a similar pattern."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cpu\" if not torch.cuda.is_available() else \"cuda\")\n",
"\n",
"# Suppose our input is x, and the sentence is simply \"1 2 3 4 5\"\n",
"x = torch.randn(5, 256) / 100\n",
"x = x.to(device)\n",
"words = [\"1\", \"2\", \"3\", \"4\", \"5\"]\n",
"\n",
"# Suppose our hidden state s = Phi(x), where\n",
"# Phi = 10 * word[0] + 20 * word[1] + 5 * word[2] - 20 * word[3] - 10 * word[4]\n",
"def Phi(x):\n",
" W = torch.tensor([10.0, 20.0, 5.0, -20.0, -10.0]).to(device)\n",
" return W @ x\n",
"\n",
"\n",
"# Suppose this is our dataset used for training our models\n",
"dataset = [torch.randn(5, 256) / 100 for _ in range(100)]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1 Using Mutual Information to understand $\\Phi$\n",
"\n",
"We briefly introduce our algorithms here. You can also refer to our paper [here](https://www.microsoft.com/en-us/research/publication/towards-a-deep-and-unified-understanding-of-deep-neural-models-in-nlp/) for more details on algorithm.\n",
"\n",
"### 1.1 Multi-level Quantification\n",
"\n",
"Suppose the input random variable is $\\bf X$ and the hidden random variable ${\\bf S} = \\Phi({\\bf X})$. We can provide a global/corpus-level explanation by evaluating the mutual information of $\\bf X$ and $\\bf S$:\n",
"\n",
"$$MI({\\bf X};{\\bf S})=H({\\bf S}) - H({\\bf H}|{\\bf S})$$\n",
"\n",
"Where $MI(\\cdot;\\cdot)$ is the mututal information. $H(\\cdot)$ stands for entropy. Because $H({\\bf S})$ is a constant only related to input dataset $\\bf S$, the only thing we need to consider is $H({\\bf H}|{\\bf S})$. This conditional entropy can be seen as the global/corpus-level information loss when r.v. $\\bf X$ is processed by $\\Phi$. By definition:\n",
"\n",
"$$H({\\bf X}|{\\bf S}) = \\int_{{\\bf s}\\in {\\bf S}}p({\\bf S})H({\\bf X}|{\\bf s})d{\\bf s}$$\n",
"\n",
"Then, we can decompose the corpus-level information loss to sentence-level:\n",
"\n",
"$$H({\\bf X}|{\\bf s}) = \\int_{{\\bf x'}\\in {\\bf X}}p({\\bf x}'|{\\bf s})H({\\bf x}'|{\\bf s})d{\\bf x}'$$\n",
"\n",
"If we make a assumption that the inputs of $\\Phi$ are independent, we can further decompose the sentence-level information loss to word level:\n",
"\n",
"$$H({\\bf X}|{\\bf s}) = \\sum_i H({\\bf X}_i|{\\bf s})$$\n",
"$$H({\\bf X}_i|{\\bf s}) = \\int_{{\\bf x'}_i\\in {\\bf X}_i}p({\\bf x}_i'|{\\bf s})H({\\bf x}_i'|{\\bf s})d{\\bf x}_i'$$\n",
"\n",
"Note that $H({\\bf X}_i|{\\bf s})$ stands for the information loss when word ${\\bf x}_i$ reaches hidden state $s$. Therefore, we can use this value as our explanation. Higher value stands for the information of corrsponding word is largely lost, which means that this word is less important to $\\bf s$, and vice versa.\n",
"\n",
"### 1.2 Perturbation-based Approximation\n",
"\n",
"In order to calculate $H({\\bf X}_i|{\\bf s})$, we propose a perturbation-besed method. Let $\\tilde{\\bf x}_{i}={\\bf x}_{i} +{\\boldsymbol \\epsilon}_{i}$ denote an input with a certain noise $\\boldsymbol{\\epsilon}_{i}$. We assume that the noise term is a random variable that follows a Gaussian distribution, ${\\boldsymbol{\\epsilon}_{i}}\\in \\mathbb{R}^{K}$ and ${\\boldsymbol \\epsilon}_i\\sim{\\mathcal N}({\\bf0},{\\boldsymbol\\Sigma}_{i}=\\sigma_{i}^2{\\bf I})$. \n",
"In order to approximate $H({\\bf X}_i|{\\bf s})$, we first learn an optimal distribution of ${\\boldsymbol{\\epsilon}} = [{\\boldsymbol{\\epsilon}}_1^T, {\\boldsymbol \\epsilon}_2^T, ..., {\\boldsymbol \\epsilon}_n^T]^T$ with respect to the hidden state \n",
"${\\bf s}$ with the following loss:\n",
"\n",
"$$L({\\boldsymbol \\sigma})=\\mathbb{E}_{{\\boldsymbol \\epsilon}}\\Vert\\Phi(\\tilde{\\bf x})-{\\bf s}\\Vert^2-\\lambda\\sum_{i=1}^n H(\\tilde{\\bf X}_{i}|{\\bf s})|_{{\\boldsymbol\\epsilon}_{i}\\sim{\\mathcal N}({\\bf 0},\\sigma_{i}^2{\\bf I})}$$\n",
"\n",
"where $\\lambda>0$ is a hyper-parameter, ${\\boldsymbol \\sigma}=[\\sigma_1,...,\\sigma_n]$, and $\\tilde{\\bf x} = {\\bf x} + \\boldsymbol{\\epsilon}$. The first term on the left corresponds to the maximum likelihood estimation (MLE) of the distribution of $\\tilde{\\bf x}_{i}$ that maximizes $\\sum_{i}\\sum_{\\tilde{\\bf x}_{i}}\\log p(\\tilde{\\bf x}_{i}|{\\bf s})$, if we consider $\\sum_{i}\\log p(\\tilde{\\bf x}_{i}|{\\bf s})\\propto -\\Vert\\Phi(\\tilde{\\bf x})-{\\bf s}\\Vert^2$. In other words, the first term learns a distribution that generates all potential inputs corresponding to the hidden state ${\\bf s}$. The second term on the right encourages a high conditional entropy $H(\\tilde{\\bf X}_{i}|{\\bf s})$, which corresponds to the maximum entropy principle. In other words, the noise $\\boldsymbol \\epsilon$ needs to enumerate all perturbation directions to reach the representation limit of ${\\bf s}$. By minimizing the loss above, we can get the optimal ${\\sigma}_i$, then we can get the $H(\\tilde{\\bf X}_i|{\\bf s})$:\n",
"\n",
"$$H(\\tilde{\\bf X}_{i}|{\\bf s})=\\frac{K}{2}\\log(2\\pi e)+K\\log\\sigma_{i}$$\n",
"\n",
"Then, we can use $H(\\tilde{\\bf X}_i|{\\bf s})$ to approximate $H({\\bf X}_i|{\\bf s})$. Again, you can refer to our paper [here](https://www.microsoft.com/en-us/research/publication/towards-a-deep-and-unified-understanding-of-deep-neural-models-in-nlp/) for more details on algorithm."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2 Create an Interpreter instance\n",
"\n",
"In the following, we'll show you how to calculate the $\\sigma_i$ using functions this library. To explain a certain $\\bf x$ and certain $\\Phi$, we need to create an Interpreter instance, and pass your $\\bf x$, $\\Phi$ and regularization term (which is the standard variance of the hidden state r.v. $\\bf S$) to it. We also provide a simple function to calculate the regularization term that is needed in this method."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Interpreter()"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# calculate the regularization term.\n",
"regularization = calculate_regularization(dataset, Phi, device=device)\n",
"\n",
"# create the interpreter instance\n",
"# we recommend you to set hyper-parameter *scale* to 10 * Std[word_embedding_weight], 10 * 0.1 in this example\n",
"interpreter = Interpreter(\n",
" x=x, Phi=Phi, regularization=regularization, scale=10 * 0.1, words=words\n",
")\n",
"interpreter.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3 Train the Interpreter\n",
"\n",
"Then, we need to train our interpreter (by minimizing the loss [here](#1.2-Perturbation-based-Approximation)) to let it find the information loss in each input word ${\\bf x}_i$ when they reach hidden state $\\bf s$. You can control the iteration and learning rate when training."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5000/5000 [00:08<00:00, 601.76it/s]\n"
]
}
],
"source": [
"# Train the interpreter by optimizing the loss\n",
"interpreter.optimize(iteration=5000, lr=0.5, show_progress=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4 Show and visualize results\n",
"\n",
"After training, we can show the sigma (directly speaking, it is the range that every word can change without changing $\\bf s$ too much) we have got. Sigma somewhat stands for the information loss of word ${\\bf x}_i$ when it reaches $\\bf s$."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.00315241, 0.00157062, 0.00633752, 0.00157459, 0.00313454],\n",
" dtype=float32)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Show the sigma we get\n",
"interpreter.get_sigma()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAACvCAYAAACowErMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAC7pJREFUeJzt3V2MHXUZx/Hfb/tCSWkt2moKJSKJIQFiaLM2mhpikGgLDXjJhXqDIUYuSrggEhIiemli8EpDikYCCsqLF2gQjBBBgbp948VSA1hjA8mWIC+VKG+PF2fO7pyzc/aF8zLPnP1+kkPnzPznP888e3Z/e2ZOiyNCAABkM1F3AQAAVCGgAAApEVAAgJQIKABASgQUACAlAgoAkBIBBQBIiYACAKREQAEAUlq5lMGr122INZvOGFYty8LJE2/UXUKjbT3343WX0HgHj07XXULjrd20vu4SGu2/J17Ru2+97oXGLSmg1mw6Q5Pfv/3DVwU9sfd3dZfQaH/+w566S2i80y/5Ud0lNN7kVbvqLqHRpm76xqLGcYkPAJASAQUASImAAgCkREABAFIioAAAKRFQAICUCCgAQEoEFAAgJQIKAJASAQUASImAAgCkREABAFIioAAAKRFQAICUCCgAQEoEFAAgJQIKAJASAQUASImAAgCkREABAFIioAAAKRFQAICUCCgAQEoEFAAgJQIKAJASAQUASImAAgCkREABAFIioAAAKRFQAICUCCgAQEoEFAAgJQIKAJASAQUASImAAgCkREABAFIioAAAKRFQAICUFgwo21fbnrI99c6b/x5FTQAALBxQEXFrRExGxOTq9aePoiYAALjEBwDIiYACAKREQAEAUiKgAAApEVAAgJQIKABASgQUACAlAgoAkBIBBQBIiYACAKREQAEAUiKgAAApEVAAgJQIKABASgQUACAlAgoAkBIBBQBIiYACAKREQAEAUiKgAAApEVAAgJQIKABASgQUACAlAgoAkBIBBQBIiYACAKREQAEAUiKgAAApEVAAgJQIKABASgQUACAlAgoAkBIBBQBIiYACAKREQAEAUiKgAAApEVAAgJQIKABASgQUACAlR8TiB9snJP1zeOX0baOkV+suouHoYX/oX//oYf+y9/CTEbFpoUFLCqjsbE9FxGTddTQZPewP/esfPezfuPSQS3wAgJQIKABASuMWULfWXcAYoIf9oX/9o4f9G4sejtU9KADA+Bi3d1AAgDExFgFl+6e2p20/W3ctTWT7LNuP2D5i+znbe+quqWlsr7G9z/bhooc3111TE9leYfug7QfqrqWJbB+z/YztQ7an6q6nX2Nxic/2RZJOSro9Ii6ou56msb1Z0uaIOGB7naT9kr4aEX+rubTGsG1JayPipO1Vkh6XtCcinqy5tEaxfZ2kSUnrI2J33fU0je1jkiYjIvPfgVq0sXgHFRF/kvRa3XU0VUS8EhEHiuW3JB2RdGa9VTVLtJwsnq4qHs3/7W+EbG+RdJmkvXXXghzGIqAwOLbPlrRV0lP1VtI8xeWpQ5KmJT0cEfRwaW6RdL2kD+oupMFC0kO299u+uu5i+kVAYYbt0yTdK+naiHiz7nqaJiLej4gLJW2RtN02l5sXyfZuSdMRsb/uWhpuR0Rsk7RL0jXF7Y/GIqAgSSrum9wr6c6IuK/ueposIl6X9KiknTWX0iQ7JF1e3EO5S9LFtu+ot6TmiYiXiz+nJd0vaXu9FfWHgEL7Bv9tko5ExA/rrqeJbG+yvaFYPlXSJZKer7eq5oiIGyJiS0ScLelKSX+MiK/VXFaj2F5bfMhJttdK+rKkRn+yeSwCyvYvJT0h6Vzbx21fVXdNDbND0tfV+q31UPG4tO6iGmazpEdsPy3pr2rdg+Kj0hilT0h63PZhSfsk/TYiHqy5pr6MxcfMAQDjZyzeQQEAxg8BBQBIiYACAKREQAEAUiKgAAApEVAAgJQIKABASgQUACAlAgoAkBIBBQBIiYACAKREQAEAUiKgAAApEVAAgJQIKABASgQUACAlAgoAkBIBBQBIiYACAKREQAEAUiKgAAApEVAAgJQIKABASgQUACCllUsZvOKjn4547+2KLZZctUfVelcuzjtH1dg5i10791o/s8qLm3Zg+7UWeu1XMbRj7dL289xtlYVJnqdP7rFPe0OvL211D9w9bM6Tqv161W7FfIeoKqurxrn79x7ba96YM6D32Ip1XkQNVds6jlc9R9U8Vf2ZXRU91s+dYCnn2Lm9ulEL71c9sufrc+7IRc7fY828Txeqvjxq0QV3f8cvYu4l1jLfz8eeZVWP7T3D3C9ye82B/Yd+HxE75z24lhhQ8e7bOuWz35Imijde7VeILblYN+HS+onZ5YmKdZ4orS/NUd6nvG5m/UTnsReaq3zsYp3b32yWXGyf3d2l6bvGtpcn5o6d6DV25vDd8849xsQSxlbN213D7Paqc+icd3b/hfebe47uqr2rBnvh/Yof2K0vWbGs0nLpy9tre3uO9qUBOzrmn1lWlHpWfdzy9pllVRyr67hVNbS/MSe6lmfm6rHdpTmq5q9a7ti/VPdM/9R5nlXHmq3dpfWzP2bba21porSsYluv5YmquXotq+t7Yp55y8uz/ZipvOscOuPCmv3Z0vnf8nnNrpv5sevZZ+VAqTzunGO0t/Y6RumLVbHOPeqZv8ZyLRXzlmrv2Kdi3p7nU/Ua6TqeJJ26csNGLQKX+AAAKRFQAICUCCgAQEoEFAAgJQIKAJASAQUASImAAgCkREABAFIioAAAKRFQAICUCCgAQEoEFAAgJQIKAJASAQUASImAAgCkREABAFIioAAAKRFQAICUHBGLH2w/KGlR/6veEdko6dW6ixgz9HSw6Ofg0dPBG3VPX42InQsNWlJAZWN7KiIm665jnNDTwaKfg0dPBy9rT7nEBwBIiYACAKTU9IC6te4CxhA9HSz6OXj0dPBS9rTR96AAAOOr6e+gAABjKk1A2d5p+6jtF2x/p2L7KbbvLrY/Zfvs0rYbivVHbX+lWHeW7UdsH7H9nO09ozubHIbQ0zW299k+XPT05tGdTQ6D7mlp2wrbB20/MPyzyGMY/bR9zPYztg/ZnhrNmeQxpJ5usH2P7eeLn6mfH8nJRETtD0krJL0o6RxJqyUdlnRe15hvS/pJsXylpLuL5fOK8adI+lQxzwpJmyVtK8ask/T37jnH+TGknlrSacWYVZKekvS5us+1yT0t7XedpF9IeqDu82x6PyUdk7Sx7vMbs57+XNI3i+XVkjaM4nyyvIPaLumFiHgpIt6RdJekK7rGXKFWkyTpHklfsu1i/V0R8b+I+IekFyRtj4hXIuKAJEXEW5KOSDpzBOeSxTB6GhFxshi/qngsp5uYA++pJNneIukySXtHcA6ZDKWfy9zAe2p7vaSLJN0mSRHxTkS8PoJzSRNQZ0r6V+n5cc0Nk5kxEfGepDckfWwx+xZvYbeq9Rv/cjGUnhaXog5Jmpb0cETQ0x5jlvA6vUXS9ZI+GHzJqQ2rnyHpIdv7bV89hLozG0ZPz5F0QtLPisvQe22vHU75nbIElCvWdf9m3mvMvPvaPk3SvZKujYg3P3SFzTOUnkbE+xFxoaQtav12dUFfVTbLwHtqe7ek6YjY329xDTSs7/sdEbFN0i5J19i+6MOX2DjD6OlKSdsk/Tgitkr6j6Q597aGIUtAHZd0Vun5Fkkv9xpje6Wkj0h6bb59ba9SK5zujIj7hlJ5XkPpaVvxFv9RSQv+e1pjZBg93SHpctvH1Locc7HtO4ZRfEJDeY1GRPvPaUn3a3ld+htGT49LOl66WnKPWoE1fHXf1Ctuuq2U9JJaN+baN/bO7xpzjTpv7P2qWD5fnTf2XtLsDf3bJd1S9/mNUU83qbg5KulUSY9J2l33uTa5p137flHL60MSw3iNrpW0rhizVtJfJO2s+1yb3NNi22OSzi2WvyvpByM5n7obWmrapWp90u5FSTcW674n6fJieY2kX6t1426fpHNK+95Y7HdU0q5i3RfUenv6tKRDxePSus+z4T39jKSDRU+flXRT3efY9J52zb2sAmoY/VTrfsnh4vFce87l9BjGa1TShZKmiu/930g6fRTnwr8kAQBIKcs9KAAAOhBQAICUCCgAQEoEFAAgJQIKAJASAQUASImAAgCkREABAFL6P7iFr3AGyDfZAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Visualize the information loss of our sigma\n",
"interpreter.visualize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the second and forth words are important to ${\\bf s} = \\Phi({\\bf x})$, which is reasonable because the weights of them are larger."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Различия файлов скрыты, потому что одна или несколько строк слишком длинны