fast_retraining/experiments/03_football_GPU.ipynb

1354 строки
38 KiB
Plaintext
Исходник Постоянная ссылка Ответственный История

Этот файл содержит неоднозначные символы Юникода!

Этот файл содержит неоднозначные символы Юникода, которые могут быть перепутаны с другими в текущей локали. Если это намеренно, можете спокойно проигнорировать это предупреждение. Используйте кнопку Экранировать, чтобы подсветить эти символы.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Experiment 3: Football match prediction (GPU version)\n",
"\n",
"In this experiment we are going to use the [Kaggle football dataset](https://www.kaggle.com/hugomathien/soccer). The dataset has information from +25,000 matches, +10,000 players from 11 European Countries with their lead championship during seasons 2008 to 2016. It also contains players attributes sourced from EA Sports' FIFA video game series. The problem we address is to try to predict if a match is going to end as win, draw or defeat. \n",
"\n",
"Part of the code use in this notebook is this [kaggle kernel](https://www.kaggle.com/airback/match-outcome-prediction-in-football).\n",
"\n",
"The details of the machine we used and the version of the libraries can be found in [experiment 01](01_airline.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"System version: 3.5.2 |Anaconda custom (64-bit)| (default, Jul 2 2016, 17:53:06) \n",
"[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]\n",
"XGBoost version: 0.6\n",
"LightGBM version: 0.2\n",
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"import os,sys\n",
"import pandas as pd\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import itertools\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import train_test_split\n",
"import xgboost as xgb\n",
"import lightgbm as lgb\n",
"from libs.loaders import load_football\n",
"from libs.football import get_fifa_data, create_feables\n",
"from libs.timer import Timer\n",
"from libs.conversion import convert_cols_categorical_to_numeric\n",
"from libs.metrics import classification_metrics_multilabel\n",
"import pickle\n",
"import pkg_resources\n",
"import json\n",
"\n",
"\n",
"print(\"System version: {}\".format(sys.version))\n",
"print(\"XGBoost version: {}\".format(pkg_resources.get_distribution('xgboost').version))\n",
"print(\"LightGBM version: {}\".format(pkg_resources.get_distribution('lightgbm').version))\n",
"\n",
"%matplotlib inline\n",
"% load_ext autoreload\n",
"% autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"### Data loading and management\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:libs.loaders:MOUNT_POINT not found in environment. Defaulting to /fileshare\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(11, 2)\n",
"(25979, 115)\n",
"(11, 3)\n",
"(299, 5)\n",
"(183978, 42)\n",
"CPU times: user 4 s, sys: 864 ms, total: 4.86 s\n",
"Wall time: 20.2 s\n"
]
}
],
"source": [
"%%time\n",
"countries, matches, leagues, teams, players = load_football()\n",
"print(countries.shape)\n",
"print(matches.shape)\n",
"print(leagues.shape)\n",
"print(teams.shape)\n",
"print(players.shape)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>country_id</th>\n",
" <th>name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Belgium Jupiler League</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1729</td>\n",
" <td>1729</td>\n",
" <td>England Premier League</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4769</td>\n",
" <td>4769</td>\n",
" <td>France Ligue 1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>7809</td>\n",
" <td>7809</td>\n",
" <td>Germany 1. Bundesliga</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>10257</td>\n",
" <td>10257</td>\n",
" <td>Italy Serie A</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>13274</td>\n",
" <td>13274</td>\n",
" <td>Netherlands Eredivisie</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>15722</td>\n",
" <td>15722</td>\n",
" <td>Poland Ekstraklasa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>17642</td>\n",
" <td>17642</td>\n",
" <td>Portugal Liga ZON Sagres</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>19694</td>\n",
" <td>19694</td>\n",
" <td>Scotland Premier League</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>21518</td>\n",
" <td>21518</td>\n",
" <td>Spain LIGA BBVA</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>24558</td>\n",
" <td>24558</td>\n",
" <td>Switzerland Super League</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id country_id name\n",
"0 1 1 Belgium Jupiler League\n",
"1 1729 1729 England Premier League\n",
"2 4769 4769 France Ligue 1\n",
"3 7809 7809 Germany 1. Bundesliga\n",
"4 10257 10257 Italy Serie A\n",
"5 13274 13274 Netherlands Eredivisie\n",
"6 15722 15722 Poland Ekstraklasa\n",
"7 17642 17642 Portugal Liga ZON Sagres\n",
"8 19694 19694 Scotland Premier League\n",
"9 21518 21518 Spain LIGA BBVA\n",
"10 24558 24558 Switzerland Super League"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"leagues"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>country_id</th>\n",
" <th>league_id</th>\n",
" <th>season</th>\n",
" <th>stage</th>\n",
" <th>date</th>\n",
" <th>match_api_id</th>\n",
" <th>home_team_api_id</th>\n",
" <th>away_team_api_id</th>\n",
" <th>home_team_goal</th>\n",
" <th>...</th>\n",
" <th>SJA</th>\n",
" <th>VCH</th>\n",
" <th>VCD</th>\n",
" <th>VCA</th>\n",
" <th>GBH</th>\n",
" <th>GBD</th>\n",
" <th>GBA</th>\n",
" <th>BSH</th>\n",
" <th>BSD</th>\n",
" <th>BSA</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2008/2009</td>\n",
" <td>1</td>\n",
" <td>2008-08-17 00:00:00</td>\n",
" <td>492473</td>\n",
" <td>9987</td>\n",
" <td>9993</td>\n",
" <td>1</td>\n",
" <td>...</td>\n",
" <td>4.00</td>\n",
" <td>1.65</td>\n",
" <td>3.40</td>\n",
" <td>4.50</td>\n",
" <td>1.78</td>\n",
" <td>3.25</td>\n",
" <td>4.00</td>\n",
" <td>1.73</td>\n",
" <td>3.40</td>\n",
" <td>4.20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2008/2009</td>\n",
" <td>1</td>\n",
" <td>2008-08-16 00:00:00</td>\n",
" <td>492474</td>\n",
" <td>10000</td>\n",
" <td>9994</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>3.80</td>\n",
" <td>2.00</td>\n",
" <td>3.25</td>\n",
" <td>3.25</td>\n",
" <td>1.85</td>\n",
" <td>3.25</td>\n",
" <td>3.75</td>\n",
" <td>1.91</td>\n",
" <td>3.25</td>\n",
" <td>3.60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2008/2009</td>\n",
" <td>1</td>\n",
" <td>2008-08-16 00:00:00</td>\n",
" <td>492475</td>\n",
" <td>9984</td>\n",
" <td>8635</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>2.50</td>\n",
" <td>2.35</td>\n",
" <td>3.25</td>\n",
" <td>2.65</td>\n",
" <td>2.50</td>\n",
" <td>3.20</td>\n",
" <td>2.50</td>\n",
" <td>2.30</td>\n",
" <td>3.20</td>\n",
" <td>2.75</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2008/2009</td>\n",
" <td>1</td>\n",
" <td>2008-08-17 00:00:00</td>\n",
" <td>492476</td>\n",
" <td>9991</td>\n",
" <td>9998</td>\n",
" <td>5</td>\n",
" <td>...</td>\n",
" <td>7.50</td>\n",
" <td>1.45</td>\n",
" <td>3.75</td>\n",
" <td>6.50</td>\n",
" <td>1.50</td>\n",
" <td>3.75</td>\n",
" <td>5.50</td>\n",
" <td>1.44</td>\n",
" <td>3.75</td>\n",
" <td>6.50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2008/2009</td>\n",
" <td>1</td>\n",
" <td>2008-08-16 00:00:00</td>\n",
" <td>492477</td>\n",
" <td>7947</td>\n",
" <td>9985</td>\n",
" <td>1</td>\n",
" <td>...</td>\n",
" <td>1.73</td>\n",
" <td>4.50</td>\n",
" <td>3.40</td>\n",
" <td>1.65</td>\n",
" <td>4.50</td>\n",
" <td>3.50</td>\n",
" <td>1.65</td>\n",
" <td>4.75</td>\n",
" <td>3.30</td>\n",
" <td>1.67</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 115 columns</p>\n",
"</div>"
],
"text/plain": [
" id country_id league_id season stage date \\\n",
"0 1 1 1 2008/2009 1 2008-08-17 00:00:00 \n",
"1 2 1 1 2008/2009 1 2008-08-16 00:00:00 \n",
"2 3 1 1 2008/2009 1 2008-08-16 00:00:00 \n",
"3 4 1 1 2008/2009 1 2008-08-17 00:00:00 \n",
"4 5 1 1 2008/2009 1 2008-08-16 00:00:00 \n",
"\n",
" match_api_id home_team_api_id away_team_api_id home_team_goal ... \\\n",
"0 492473 9987 9993 1 ... \n",
"1 492474 10000 9994 0 ... \n",
"2 492475 9984 8635 0 ... \n",
"3 492476 9991 9998 5 ... \n",
"4 492477 7947 9985 1 ... \n",
"\n",
" SJA VCH VCD VCA GBH GBD GBA BSH BSD BSA \n",
"0 4.00 1.65 3.40 4.50 1.78 3.25 4.00 1.73 3.40 4.20 \n",
"1 3.80 2.00 3.25 3.25 1.85 3.25 3.75 1.91 3.25 3.60 \n",
"2 2.50 2.35 3.25 2.65 2.50 3.20 2.50 2.30 3.20 2.75 \n",
"3 7.50 1.45 3.75 6.50 1.50 3.75 5.50 1.44 3.75 6.50 \n",
"4 1.73 4.50 3.40 1.65 4.50 3.50 1.65 4.75 3.30 1.67 \n",
"\n",
"[5 rows x 115 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"matches.head()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(21374, 115)\n"
]
}
],
"source": [
"#Reduce match data to fulfill run time requirements\n",
"cols = [\"country_id\", \"league_id\", \"season\", \"stage\", \"date\", \"match_api_id\", \"home_team_api_id\", \n",
" \"away_team_api_id\", \"home_team_goal\", \"away_team_goal\", \"home_player_1\", \"home_player_2\",\n",
" \"home_player_3\", \"home_player_4\", \"home_player_5\", \"home_player_6\", \"home_player_7\", \n",
" \"home_player_8\", \"home_player_9\", \"home_player_10\", \"home_player_11\", \"away_player_1\",\n",
" \"away_player_2\", \"away_player_3\", \"away_player_4\", \"away_player_5\", \"away_player_6\",\n",
" \"away_player_7\", \"away_player_8\", \"away_player_9\", \"away_player_10\", \"away_player_11\"]\n",
"match_data = matches.dropna(subset = cols)\n",
"print(match_data.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Now, using the information from the matches and players, we are going to create features based on the FIFA attributes. This computation is heavy, so we are going to save it the first time we create it. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(21374, 23)\n",
"CPU times: user 29min 27s, sys: 1min 6s, total: 30min 33s\n",
"Wall time: 31min 14s\n"
]
}
],
"source": [
"%%time\n",
"fifa_data_filename = 'fifa_data.pk'\n",
"if os.path.isfile(fifa_data_filename):\n",
" fifa_data = pd.read_pickle(fifa_data_filename)\n",
"else:\n",
" fifa_data = get_fifa_data(match_data, players)\n",
" fifa_data.to_pickle(fifa_data_filename)\n",
"print(fifa_data.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Finally, we are going to compute the features and labels. The labels are related to the result of the team playing at home, they are: `Win`, `Draw`, `Defeat`. "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating match features...\n",
"Generating match labels...\n",
"Generating bookkeeper data...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda/envs/strata/lib/python3.5/site-packages/pandas/core/indexing.py:297: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n",
" self.obj[key] = _infer_fill_value(value)\n",
"/anaconda/envs/strata/lib/python3.5/site-packages/pandas/core/indexing.py:477: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n",
" self.obj[item] = s\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(19673, 48)\n",
"CPU times: user 10min 44s, sys: 52.2 s, total: 11min 37s\n",
"Wall time: 11min 53s\n"
]
}
],
"source": [
"%%time\n",
"bk_cols = ['B365', 'BW', 'IW', 'LB', 'PS', 'WH', 'SJ', 'VC', 'GB', 'BS']\n",
"bk_cols_selected = ['B365', 'BW'] \n",
"feables = create_feables(match_data, fifa_data, bk_cols_selected, get_overall = True)\n",
"print(feables.shape)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>match_api_id</th>\n",
" <th>home_team_goals_difference</th>\n",
" <th>away_team_goals_difference</th>\n",
" <th>games_won_home_team</th>\n",
" <th>games_won_away_team</th>\n",
" <th>games_against_won</th>\n",
" <th>games_against_lost</th>\n",
" <th>season</th>\n",
" <th>League_1.0</th>\n",
" <th>League_1729.0</th>\n",
" <th>...</th>\n",
" <th>away_player_9_overall_rating</th>\n",
" <th>away_player_10_overall_rating</th>\n",
" <th>away_player_11_overall_rating</th>\n",
" <th>B365_Win</th>\n",
" <th>B365_Draw</th>\n",
" <th>B365_Defeat</th>\n",
" <th>BW_Win</th>\n",
" <th>BW_Draw</th>\n",
" <th>BW_Defeat</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>493017.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2008.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>70.0</td>\n",
" <td>68.0</td>\n",
" <td>63.0</td>\n",
" <td>0.313804</td>\n",
" <td>0.276886</td>\n",
" <td>0.409310</td>\n",
" <td>0.307825</td>\n",
" <td>0.279410</td>\n",
" <td>0.412765</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>493025.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2008.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>67.0</td>\n",
" <td>73.0</td>\n",
" <td>68.0</td>\n",
" <td>0.327179</td>\n",
" <td>0.286281</td>\n",
" <td>0.386540</td>\n",
" <td>0.290493</td>\n",
" <td>0.300176</td>\n",
" <td>0.409331</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>493027.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2008.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>55.0</td>\n",
" <td>58.0</td>\n",
" <td>64.0</td>\n",
" <td>0.672897</td>\n",
" <td>0.209346</td>\n",
" <td>0.117757</td>\n",
" <td>0.672269</td>\n",
" <td>0.226891</td>\n",
" <td>0.100840</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>493034.0</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2008.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>74.0</td>\n",
" <td>70.0</td>\n",
" <td>69.0</td>\n",
" <td>0.207407</td>\n",
" <td>0.259259</td>\n",
" <td>0.533333</td>\n",
" <td>0.192717</td>\n",
" <td>0.274476</td>\n",
" <td>0.532807</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>493040.0</td>\n",
" <td>-2.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2008.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>60.0</td>\n",
" <td>63.0</td>\n",
" <td>65.0</td>\n",
" <td>0.535211</td>\n",
" <td>0.267606</td>\n",
" <td>0.197183</td>\n",
" <td>0.565759</td>\n",
" <td>0.254990</td>\n",
" <td>0.179250</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 48 columns</p>\n",
"</div>"
],
"text/plain": [
" match_api_id home_team_goals_difference away_team_goals_difference \\\n",
"0 493017.0 0.0 0.0 \n",
"1 493025.0 0.0 0.0 \n",
"2 493027.0 0.0 0.0 \n",
"3 493034.0 1.0 2.0 \n",
"4 493040.0 -2.0 0.0 \n",
"\n",
" games_won_home_team games_won_away_team games_against_won \\\n",
"0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 \n",
"3 1.0 1.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"\n",
" games_against_lost season League_1.0 League_1729.0 ... \\\n",
"0 0.0 2008.0 1 0 ... \n",
"1 0.0 2008.0 1 0 ... \n",
"2 0.0 2008.0 1 0 ... \n",
"3 0.0 2008.0 1 0 ... \n",
"4 0.0 2008.0 1 0 ... \n",
"\n",
" away_player_9_overall_rating away_player_10_overall_rating \\\n",
"0 70.0 68.0 \n",
"1 67.0 73.0 \n",
"2 55.0 58.0 \n",
"3 74.0 70.0 \n",
"4 60.0 63.0 \n",
"\n",
" away_player_11_overall_rating B365_Win B365_Draw B365_Defeat BW_Win \\\n",
"0 63.0 0.313804 0.276886 0.409310 0.307825 \n",
"1 68.0 0.327179 0.286281 0.386540 0.290493 \n",
"2 64.0 0.672897 0.209346 0.117757 0.672269 \n",
"3 69.0 0.207407 0.259259 0.533333 0.192717 \n",
"4 65.0 0.535211 0.267606 0.197183 0.565759 \n",
"\n",
" BW_Draw BW_Defeat label \n",
"0 0.279410 0.412765 0 \n",
"1 0.300176 0.409331 1 \n",
"2 0.226891 0.100840 0 \n",
"3 0.274476 0.532807 0 \n",
"4 0.254990 0.179250 2 \n",
"\n",
"[5 rows x 48 columns]"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"feables = convert_cols_categorical_to_numeric(feables)\n",
"feables.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Let's now split features and labels."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(19673, 46)\n",
"(19673,)\n"
]
}
],
"source": [
"features = feables[feables.columns.difference(['match_api_id', 'label'])]\n",
"labs = feables['label']\n",
"print(features.shape)\n",
"print(labs.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Once we have the features and labels defined, let's create the train and test set."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 16 ms, sys: 4 ms, total: 20 ms\n",
"Wall time: 17.8 ms\n"
]
}
],
"source": [
"%%time\n",
"X_train, X_test, y_train, y_test = train_test_split(features, labs, test_size=0.2, random_state=42, stratify=labs)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"dtrain = xgb.DMatrix(data=X_train, label=y_train)\n",
"dtest = xgb.DMatrix(data=X_test, label=y_test)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"lgb_train = lgb.Dataset(X_train.values, y_train.values, free_raw_data=False)\n",
"lgb_test = lgb.Dataset(X_test.values, y_test, reference=lgb_train, free_raw_data=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"### XGBoost analysis\n",
"Once we have done the feature engineering step, we can start to train with each of the libraries. We will start with XGBoost. \n",
"\n",
"We are going to save the training and test time, as well as some metrics. "
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"results_dict = dict()\n",
"num_rounds = 300\n",
"labels = [0,1,2]"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"params = {'max_depth':3, \n",
" 'objective': 'multi:softprob', \n",
" 'num_class': len(labels),\n",
" 'min_child_weight':5, \n",
" 'learning_rate':0.1, \n",
" 'colsample_bytree':0.8, \n",
" 'scale_pos_weight':2, \n",
" 'gamma':0.1, \n",
" 'reg_lamda':1, \n",
" 'subsample':1,\n",
" 'tree_method':'exact', \n",
" 'updater':'grow_gpu'\n",
" }\n"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"with Timer() as t_train:\n",
" xgb_clf_pipeline = xgb.train(params, dtrain, num_boost_round=num_rounds)\n",
" \n",
"with Timer() as t_test:\n",
" y_prob_xgb = xgb_clf_pipeline.predict(dtest)"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def quantitize_multilable_prediction(y_pred):\n",
" return np.argmax(y_pred, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"y_pred_xgb = quantitize_multilable_prediction(y_prob_xgb)"
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"report_xgb = classification_metrics_multilabel(y_test, y_pred_xgb, labels)"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"results_dict['xgb']={\n",
" 'train_time': t_train.interval,\n",
" 'test_time': t_test.interval,\n",
" 'performance': report_xgb \n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"del xgb_clf_pipeline"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"\n",
"Now let's try with XGBoost histogram."
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"params = {'max_depth':0, \n",
" 'max_leaves':2**3, \n",
" 'objective': 'multi:softprob', \n",
" 'num_class': len(labels),\n",
" 'min_child_weight':5, \n",
" 'learning_rate':0.1, \n",
" 'colsample_bytree':0.80, \n",
" 'scale_pos_weight':2, \n",
" 'gamma':0.1, \n",
" 'reg_lamda':1, \n",
" 'subsample':1,\n",
" 'tree_method':'hist', \n",
" 'grow_policy':'lossguide', \n",
" }\n"
]
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"with Timer() as t_train:\n",
" xgb_hist_clf_pipeline = xgb.train(params, dtrain, num_boost_round=num_rounds)\n",
" \n",
"with Timer() as t_test:\n",
" y_prob_xgb_hist = xgb_hist_clf_pipeline.predict(dtest)"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"y_pred_xgb_hist = quantitize_multilable_prediction(y_prob_xgb_hist)"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"report_xgb_hist = classification_metrics_multilabel(y_test, y_pred_xgb_hist, labels)"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"results_dict['xgb_hist']={\n",
" 'train_time': t_train.interval,\n",
" 'test_time': t_test.interval,\n",
" 'performance': report_xgb_hist\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"del xgb_hist_clf_pipeline"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"### LightGBM analysis\n",
"\n",
"Now let's compare with LightGBM."
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"params = {'num_leaves': 2**3,\n",
" 'learning_rate': 0.1,\n",
" 'colsample_bytree': 0.80,\n",
" 'scale_pos_weight': 2,\n",
" 'min_split_gain': 0.1,\n",
" 'min_child_weight': 5,\n",
" 'reg_lambda': 1,\n",
" 'subsample': 1,\n",
" 'objective':'multiclass',\n",
" 'num_class': len(labels),\n",
" 'task': 'train'\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"with Timer() as t_train:\n",
" lgbm_clf_pipeline = lgb.train(params, lgb_train, num_boost_round=num_rounds)\n",
" \n",
"with Timer() as t_test:\n",
" y_prob_lgbm = lgbm_clf_pipeline.predict(X_test.values)"
]
},
{
"cell_type": "code",
"execution_count": 137,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"y_pred_lgbm = quantitize_multilable_prediction(y_prob_lgbm)"
]
},
{
"cell_type": "code",
"execution_count": 138,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"report_lgbm = classification_metrics_multilabel(y_test, y_pred_lgbm, labels)"
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"results_dict['lgbm']={\n",
" 'train_time': t_train.interval,\n",
" 'test_time': t_test.interval,\n",
" 'performance': report_lgbm \n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"del lgbm_clf_pipeline"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Finally, the results."
]
},
{
"cell_type": "code",
"execution_count": 141,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"lgbm\": {\n",
" \"performance\": {\n",
" \"Accuracy\": 0.5344345616264294,\n",
" \"F1\": 0.4704311590503636,\n",
" \"Precision\": 0.48847893806298454,\n",
" \"Recall\": 0.5344345616264294\n",
" },\n",
" \"test_time\": 0.029374134999670787,\n",
" \"train_time\": 0.976751588001207\n",
" },\n",
" \"xgb\": {\n",
" \"performance\": {\n",
" \"Accuracy\": 0.5359593392630242,\n",
" \"F1\": 0.4704659043141339,\n",
" \"Precision\": 0.4825747269523364,\n",
" \"Recall\": 0.5359593392630242\n",
" },\n",
" \"test_time\": 0.006567717999132583,\n",
" \"train_time\": 7.09927419500309\n",
" },\n",
" \"xgb_hist\": {\n",
" \"performance\": {\n",
" \"Accuracy\": 0.537992376111817,\n",
" \"F1\": 0.4723094570741036,\n",
" \"Precision\": 0.4944404394401915,\n",
" \"Recall\": 0.537992376111817\n",
" },\n",
" \"test_time\": 0.007724854996922659,\n",
" \"train_time\": 4.588017762001982\n",
" }\n",
"}\n"
]
}
],
"source": [
"# Results\n",
"print(json.dumps(results_dict, indent=4, sort_keys=True))"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"As it can be seen, in the case of multilabel LightGBM is faster than XGBoost in both versions. The performance metrics are really poor, so we wouldn't recommend to bet based on this algorithm :-)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.5",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}