Add retraining experiment and refactors code slightly

This commit is contained in:
msalvaris 2017-05-17 08:52:25 +00:00
Родитель 2be16c7963
Коммит 85f1cf203a
2 изменённых файлов: 1454 добавлений и 82 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,578 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hoaphumanoid/anaconda3/envs/strata/lib/python3.6/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
" \"This module will be removed in 0.20.\", DeprecationWarning)\n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"from sklearn.pipeline import Pipeline, FeatureUnion\n",
"\n",
"from sklearn.model_selection import cross_val_score, cross_val_predict\n",
"from xgboost import XGBClassifier\n",
"import numpy as np\n",
"import itertools\n",
"import seaborn\n",
"from sklearn.metrics import roc_auc_score\n",
"from experiments.libs import loaders\n",
"from sklearn.model_selection import StratifiedKFold\n",
"\n",
"from sklearn.svm import LinearSVC\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.metrics import roc_auc_score\n",
"from xgboost import XGBModel\n",
"\n",
"from lightgbm import LGBMClassifier\n",
"\n",
"import mne\n",
"from scipy.io import loadmat\n",
"\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"filepath = '/fileshare/BCI_Comp_III_Wads_2004/Subject_A_Train.mat'\n",
"srate = 240 # Hz\n",
"# Filtered 0.1 - 60Hz"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_event_frame(flash_series, stimtype_series):\n",
" diff_seq = flash_series.diff()\n",
" # First event will be missing so fill with 1\n",
" return pd.DataFrame({'flash_onset':diff_seq.fillna(1), \n",
" 'stim_type': stimtype_series})"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def epoch_generator(eeg_data, event_df, duration=0.8, srate=240):\n",
" \"\"\" Yields epoch as well as classification label\n",
" \"\"\"\n",
" offset = int(np.round(duration * srate))\n",
" for idx in event_df[event_df['flash_onset']==1].index:\n",
" data = eeg_data[idx:idx+offset, :] - eeg_data[idx, :] # Removing offset\n",
" yield data, event_df['stim_type'].at[idx]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def dataset_epoch_generator(filename, srate=240):\n",
" dataset_dict = loadmat(filename)\n",
" num_runs = dataset_dict['Flashing'].shape[0]\n",
" \n",
" for run in range(num_runs):\n",
" eeg_array = dataset_dict['Signal'][run,:,:]\n",
" flash_series = pd.Series(dataset_dict['Flashing'][run, :])\n",
" stimtype_series = pd.Series(dataset_dict['StimulusType'][run, :])\n",
" event_df = create_event_frame(flash_series, stimtype_series)\n",
" # Transforming from time by channel to channel by time\n",
" eeg_array = mne.filter.filter_data(eeg_array.astype(np.float64).T, srate, None, 18, verbose=False).T\n",
" for epoch in epoch_generator(eeg_array, event_df, duration=0.8, srate=srate):\n",
" yield epoch"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"egen = dataset_epoch_generator(filepath, srate=srate)\n",
"data_list = list(egen)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"15300"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(data_list)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X = np.array([d[0][::6].ravel() for d in data_list])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y = np.array([d[1] for d in data_list])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(15300, 2048)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(15300,)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.shape"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"pipeline_steps = [('scale', StandardScaler())]\n",
"continuous_pipeline = Pipeline(steps=pipeline_steps)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"featurisers = [('continuous', continuous_pipeline)]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"xgb_clf_pipeline = Pipeline(steps=[('features', FeatureUnion(featurisers)),\n",
" ('clf', XGBClassifier(max_depth=2, \n",
" learning_rate=0.1, \n",
" scale_pos_weight=2,\n",
" n_estimators=100,\n",
" gamma=0.1,\n",
" subsample=1))]) "
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"lgbm_clf_pipeline = Pipeline(steps=[('features', FeatureUnion(featurisers)),\n",
" ('clf', LGBMClassifier(max_depth=2, \n",
" learning_rate=0.1, \n",
" scale_pos_weight=2,\n",
" n_estimators=100,\n",
" subsample=1))]) "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ123456789_'\n",
"stim_code_translation_dict = dict((l, (i%6+1, int(np.floor(i/6)+7)) ) for i, l in enumerate(letters))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def event_frame_from(stim_code, target_letter, stim_code_translation_dict):\n",
" diff_seq = stim_code.diff()\n",
" diff_seq = diff_seq.fillna(stim_code)\n",
" \n",
" target_codes = stim_code_translation_dict[target_letter]\n",
" target_index = diff_seq[diff_seq.isin(target_codes)].index\n",
" stimtype_series = pd.Series(0, index=stim_code.index)\n",
" stimtype_series[target_index]=1\n",
" diff_seq = (diff_seq/diff_seq.abs()).fillna(0)\n",
" \n",
" return pd.DataFrame({'flash_onset':diff_seq, \n",
" 'stim_type': stimtype_series})"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"test_filename = '/fileshare/BCI_Comp_III_Wads_2004/Subject_A_Test.mat'\n",
"labels_filename = '/fileshare/BCI_Comp_III_Wads_2004/true_labels_a.txt'"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def read_labels(filename):\n",
" return open(filename).readline().strip()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def test_dataset_epoch_generator(data_filename, labels_filename, stim_code_translation_dict, srate=240):\n",
" dataset_dict = loadmat(data_filename)\n",
" labels = read_labels(labels_filename)\n",
" \n",
" for run, target in enumerate(labels):\n",
" eeg_array = dataset_dict['Signal'][run,:,:]\n",
" stimcode_series = pd.Series(dataset_dict['StimulusCode'][run, :])\n",
" \n",
" event_df = event_frame_from(stimcode_series, target, stim_code_translation_dict)\n",
" \n",
" # Transforming from time by channel to channel by time\n",
" eeg_array = mne.filter.filter_data(eeg_array.astype(np.float64).T, srate, None, 18, verbose=False).T\n",
" for epoch in epoch_generator(eeg_array, event_df, duration=0.8, srate=srate):\n",
" yield epoch"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"data_gen = test_dataset_epoch_generator(test_filename, labels_filename, stim_code_translation_dict)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"data_list = list(data_gen)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X_test = np.array([d[0][::6].ravel() for d in data_list])"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y_test = np.array([d[1] for d in data_list])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X_train = np.concatenate([X, X_test[:9000]])"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y_train = np.concatenate([y, y_test[:9000]])"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3min 43s, sys: 16.5 s, total: 3min 59s\n",
"Wall time: 13.9 s\n"
]
},
{
"data": {
"text/plain": [
"Pipeline(steps=[('features', FeatureUnion(n_jobs=1,\n",
" transformer_list=[('continuous', Pipeline(steps=[('scale', StandardScaler(copy=True, with_mean=True, with_std=True))]))],\n",
" transformer_weights=None)), ('clf', XGBClassifier(base_score=0.5, colsample_bylevel=1, colsample_bytree=1,\n",
" gamma=0...logistic', reg_alpha=0, reg_lambda=1,\n",
" scale_pos_weight=2, seed=0, silent=True, subsample=1))])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"xgb_clf_pipeline.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.67662137777777787"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred = xgb_clf_pipeline.predict_proba(X_test[9000:])\n",
"roc_auc_score(y_test[9000:], y_pred[:, 1])"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 35.4 s, sys: 1.2 s, total: 36.6 s\n",
"Wall time: 6.91 s\n"
]
},
{
"data": {
"text/plain": [
"Pipeline(steps=[('features', FeatureUnion(n_jobs=1,\n",
" transformer_list=[('continuous', Pipeline(steps=[('scale', StandardScaler(copy=True, with_mean=True, with_std=True))]))],\n",
" transformer_weights=None)), ('clf', LGBMClassifier(boosting_type='gbdt', colsample_bytree=1, drop_rate=0.1,\n",
" is_un... subsample_for_bin=50000, subsample_freq=1, uniform_drop=False,\n",
" xgboost_dart_mode=False))])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"lgbm_clf_pipeline.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.67720866666666668"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred = lgbm_clf_pipeline.predict_proba(X_test[9000:])\n",
"roc_auc_score(y_test[9000:], y_pred[:, 1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"for X_chunk, y_chunk in zip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}