Added random seed option to wikigold util function.

This commit is contained in:
Hong Lu 2019-06-07 17:32:49 -04:00
Родитель 049ddf6442
Коммит 26fcc3cbe4
2 изменённых файлов: 90 добавлений и 85 удалений

Просмотреть файл

@ -27,7 +27,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"metadata": {
"scrolled": true
},
@ -59,34 +59,33 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# path configurations\n",
"data_dir = \"./data\"\n",
"data_file = \"./data/wikigold.conll.txt\"\n",
"cache_dir=\".\"\n",
"DATA_DIR = \"./data\"\n",
"DATA_FILE = \"./data/wikigold.conll.txt\"\n",
"CACHE_DIR=\".\"\n",
"\n",
"# set random seeds\n",
"random_seed = 42\n",
"random.seed(random_seed)\n",
"torch.manual_seed(random_seed)\n",
"RANDOM_SEED = 100\n",
"torch.manual_seed(RANDOM_SEED)\n",
"\n",
"# model configurations\n",
"language = Language.ENGLISHCASED\n",
"do_lower_case = False\n",
"max_seq_length = 200\n",
"LANGUAGE = Language.ENGLISHCASED\n",
"DO_LOWER_CASE = False\n",
"MAX_SEQ_LENGTH = 200\n",
"\n",
"# training configurations\n",
"device=\"gpu\"\n",
"batch_size = 16\n",
"num_train_epochs = 5\n",
"DEVICE=\"gpu\"\n",
"BATCH_SIZE = 16\n",
"NUM_TRAIN_EPOCHS = 5\n",
"\n",
"# optimizer configuration\n",
"learning_rate = 3e-5"
"LEARNING_RATE = 3e-5"
]
},
{
@ -107,12 +106,12 @@
"\n",
"The helper function `get_unique_labels` returns the unique entity labels in the dataset. There are 5 unique labels in the original dataset: 'O' (non-entity), 'I-LOC' (location), 'I-MISC' (miscellaneous), 'I-PER' (person), and 'I-ORG' (organization). An 'X' label is added for the trailing word pieces generated by BERT, because BERT uses WordPiece tokenizer. \n",
"\n",
"The maximum number of words in a sentence is 144, so we set `max_seq_length` to 200 above, because the number of tokens will grow after WordPiece tokenization."
"The maximum number of words in a sentence is 144, so we set MAX_SEQ_LENGTH to 200 above, because the number of tokens will grow after WordPiece tokenization."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"metadata": {
"scrolled": true
},
@ -122,24 +121,26 @@
"output_type": "stream",
"text": [
"Maximum sequence length in training data is: 144\n",
"Maximum sequence length in testing data is: 89\n",
"Maximum sequence length in testing data is: 81\n",
"\n",
"Unique entity labels: \n",
"['O', 'I-LOC', 'I-MISC', 'I-PER', 'I-ORG', 'X']\n",
"\n",
"Sample sentence: \n",
"-DOCSTART-\n",
"Two , Samsung based , electronic cash registers were reconstructed in order to expand their functions and adapt them for networking .\n",
"\n",
"Sample sentence labels: \n",
"['O']\n",
"['O', 'O', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n",
"\n"
]
}
],
"source": [
"download(data_dir)\n",
"wikigold_text = read_data(data_file)\n",
"train_text, train_labels, test_text, test_labels = get_train_test_data(wikigold_text, test_percentage=0.5)\n",
"download(DATA_DIR)\n",
"wikigold_text = read_data(DATA_FILE)\n",
"train_text, train_labels, test_text, test_labels = get_train_test_data(wikigold_text, \n",
" test_percentage=0.5, \n",
" random_seed=RANDOM_SEED)\n",
"label_list = get_unique_labels()\n",
"print('\\nUnique entity labels: \\n{}\\n'.format(label_list))\n",
"print('Sample sentence: \\n{}\\n'.format(train_text[0]))\n",
@ -166,7 +167,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"metadata": {
"scrolled": true
},
@ -184,15 +185,15 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"tokenizer = Tokenizer(language=language, \n",
" to_lower=do_lower_case, \n",
" cache_dir=cache_dir)"
"tokenizer = Tokenizer(language=LANGUAGE, \n",
" to_lower=DO_LOWER_CASE, \n",
" cache_dir=CACHE_DIR)"
]
},
{
@ -205,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 9,
"metadata": {
"scrolled": false
},
@ -214,13 +215,13 @@
"train_token_ids, train_input_mask, train_trailing_token_mask, train_label_ids = \\\n",
" tokenizer.preprocess_ner_tokens(text=train_text,\n",
" label_map=label_map,\n",
" max_len=max_seq_length,\n",
" max_len=MAX_SEQ_LENGTH,\n",
" labels=train_labels,\n",
" trailing_piece_tag=\"X\")\n",
"test_token_ids, test_input_mask, test_trailing_token_mask, test_label_ids = \\\n",
" tokenizer.preprocess_ner_tokens(text=test_text,\n",
" label_map=label_map,\n",
" max_len=max_seq_length,\n",
" max_len=MAX_SEQ_LENGTH,\n",
" labels=test_labels,\n",
" trailing_piece_tag=\"X\")"
]
@ -238,7 +239,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 10,
"metadata": {
"scrolled": true
},
@ -248,16 +249,16 @@
"output_type": "stream",
"text": [
"Sample token ids:\n",
"[118, 141, 9244, 9272, 12426, 1942, 118, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
"[1960, 117, 20799, 1359, 117, 4828, 5948, 21187, 1127, 15755, 1107, 1546, 1106, 7380, 1147, 4226, 1105, 16677, 1172, 1111, 16074, 119, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
"\n",
"Sample attention mask:\n",
"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
"\n",
"Sample trailing token mask:\n",
"[True, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]\n",
"[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]\n",
"\n",
"Sample label ids:\n",
"[0, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
"[0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
"\n"
]
}
@ -287,15 +288,15 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"token_classifier = BERTTokenClassifier(language=language,\n",
"token_classifier = BERTTokenClassifier(language=LANGUAGE,\n",
" num_labels=len(label_list),\n",
" cache_dir=cache_dir)"
" cache_dir=CACHE_DIR)"
]
},
{
@ -307,7 +308,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 12,
"metadata": {
"scrolled": false
},
@ -333,10 +334,10 @@
"output_type": "stream",
"text": [
"\n",
"Iteration: 41%|████▏ | 24/58 [00:31<00:44, 1.30s/it]\u001b[A\n",
"Iteration: 41%|████▏ | 24/58 [00:49<00:44, 1.30s/it]\u001b[A\n",
"Iteration: 83%|████████▎ | 48/58 [01:02<00:12, 1.30s/it]\u001b[A\n",
"Epoch: 20%|██ | 1/5 [01:14<04:59, 74.89s/it]9s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:30<00:45, 1.31s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:49<00:45, 1.31s/it]\u001b[A\n",
"Iteration: 81%|████████ | 47/58 [01:01<00:14, 1.31s/it]\u001b[A\n",
"Epoch: 20%|██ | 1/5 [01:15<05:01, 75.44s/it]0s/it]\u001b[A\n",
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
]
},
@ -344,7 +345,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss: 0.4810219132180872\n"
"Train loss: 0.4387937460480065\n"
]
},
{
@ -352,11 +353,11 @@
"output_type": "stream",
"text": [
"\n",
"Iteration: 41%|████▏ | 24/58 [00:31<00:44, 1.30s/it]\u001b[A\n",
"Iteration: 41%|████▏ | 24/58 [00:45<00:44, 1.30s/it]\u001b[A\n",
"Iteration: 83%|████████▎ | 48/58 [01:02<00:13, 1.30s/it]\u001b[A\n",
"Iteration: 83%|████████▎ | 48/58 [01:15<00:13, 1.30s/it]\u001b[A\n",
"Epoch: 40%|████ | 2/5 [02:30<03:44, 74.96s/it]0s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:30<00:45, 1.31s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:44<00:45, 1.31s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.31s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:14<00:15, 1.31s/it]\u001b[A\n",
"Epoch: 40%|████ | 2/5 [02:31<03:46, 75.53s/it]1s/it]\u001b[A\n",
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
]
},
@ -364,7 +365,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss: 0.1079183994182225\n"
"Train loss: 0.11529478796854101\n"
]
},
{
@ -373,9 +374,9 @@
"text": [
"\n",
"Iteration: 40%|███▉ | 23/58 [00:30<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:49<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:48<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.32s/it]\u001b[A\n",
"Epoch: 60%|██████ | 3/5 [03:46<02:30, 75.28s/it]1s/it]\u001b[A\n",
"Epoch: 60%|██████ | 3/5 [03:47<02:31, 75.73s/it]1s/it]\u001b[A\n",
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
]
},
@ -383,7 +384,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss: 0.04469434472186298\n"
"Train loss: 0.048355102719706965\n"
]
},
{
@ -392,10 +393,10 @@
"text": [
"\n",
"Iteration: 40%|███▉ | 23/58 [00:30<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:43<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:42<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.32s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:13<00:15, 1.32s/it]\u001b[A\n",
"Epoch: 80%|████████ | 4/5 [05:02<01:15, 75.57s/it]1s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:12<00:15, 1.32s/it]\u001b[A\n",
"Epoch: 80%|████████ | 4/5 [05:03<01:15, 75.92s/it]2s/it]\u001b[A\n",
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
]
},
@ -403,7 +404,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss: 0.023166427043555624\n"
"Train loss: 0.030773249510996813\n"
]
},
{
@ -412,16 +413,17 @@
"text": [
"\n",
"Iteration: 40%|███▉ | 23/58 [00:30<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:47<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:46<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.32s/it]\u001b[A\n",
"Epoch: 100%|██████████| 5/5 [06:18<00:00, 75.79s/it]2s/it]\u001b[A"
"Iteration: 79%|███████▉ | 46/58 [01:16<00:15, 1.32s/it]\u001b[A\n",
"Epoch: 100%|██████████| 5/5 [06:20<00:00, 76.06s/it]2s/it]\u001b[A"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss: 0.014027320994626218\n"
"Train loss: 0.018075179055750627\n"
]
},
{
@ -436,9 +438,9 @@
"token_classifier.fit(token_ids=train_token_ids, \n",
" input_mask=train_input_mask, \n",
" labels=train_label_ids,\n",
" num_epochs=num_train_epochs, \n",
" batch_size=batch_size, \n",
" learning_rate=learning_rate)"
" num_epochs=NUM_TRAIN_EPOCHS, \n",
" batch_size=BATCH_SIZE, \n",
" learning_rate=LEARNING_RATE)"
]
},
{
@ -450,7 +452,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 13,
"metadata": {
"scrolled": true
},
@ -474,14 +476,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration: 100%|██████████| 58/58 [00:25<00:00, 2.30it/s]"
"Iteration: 100%|██████████| 58/58 [00:25<00:00, 2.29it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluation loss: 0.13430038492741256\n"
"Evaluation loss: 0.11602881579691994\n"
]
},
{
@ -496,7 +498,7 @@
"pred_label_ids = token_classifier.predict(token_ids=test_token_ids, \n",
" input_mask=test_input_mask, \n",
" labels=test_label_ids, \n",
" batch_size=batch_size)"
" batch_size=BATCH_SIZE)"
]
},
{
@ -509,7 +511,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 14,
"metadata": {
"scrolled": true
},
@ -520,14 +522,14 @@
"text": [
" precision recall f1-score support\n",
"\n",
" LOC 0.91 0.84 0.88 558\n",
" X 0.97 0.98 0.98 2060\n",
" PER 0.90 0.95 0.93 569\n",
" ORG 0.70 0.82 0.76 553\n",
" MISC 0.74 0.71 0.72 399\n",
" MISC 0.73 0.71 0.72 396\n",
" ORG 0.72 0.79 0.76 538\n",
" X 0.96 0.97 0.97 1983\n",
" PER 0.93 0.93 0.93 550\n",
" LOC 0.84 0.89 0.86 543\n",
"\n",
"micro avg 0.89 0.91 0.90 4139\n",
"macro avg 0.90 0.91 0.90 4139\n",
"micro avg 0.88 0.90 0.89 4010\n",
"macro avg 0.88 0.90 0.89 4010\n",
"\n"
]
}
@ -551,7 +553,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 15,
"metadata": {
"scrolled": true
},
@ -562,13 +564,13 @@
"text": [
" precision recall f1-score support\n",
"\n",
" LOC 0.90 0.85 0.88 507\n",
" PER 0.89 0.94 0.92 460\n",
" ORG 0.65 0.81 0.72 440\n",
" MISC 0.67 0.68 0.68 328\n",
" ORG 0.69 0.79 0.74 442\n",
" PER 0.92 0.92 0.92 455\n",
" MISC 0.71 0.68 0.69 344\n",
" LOC 0.83 0.87 0.85 503\n",
"\n",
"micro avg 0.77 0.83 0.80 1735\n",
"macro avg 0.79 0.83 0.81 1735\n",
"micro avg 0.77 0.83 0.79 1744\n",
"macro avg 0.79 0.83 0.81 1744\n",
"\n"
]
}

Просмотреть файл

@ -32,7 +32,7 @@ def read_data(data_file):
return text
def get_train_test_data(text, test_percentage=0.5):
def get_train_test_data(text, test_percentage=0.5, random_seed=None):
"""
Get the training and testing data based on test_percentage.
@ -41,6 +41,7 @@ def get_train_test_data(text, test_percentage=0.5):
test_percentage (float, optional): Percentage of data ot use for
testing. Since this is a small dataset, the default testing
percentage is set to 0.5
random_seed (float, optional): Random seed used to shuffle the data.
Returns:
tuple: A tuple containing four lists:
@ -51,11 +52,13 @@ def get_train_test_data(text, test_percentage=0.5):
test_labels_list: List of lists. Each sublist contains the
entity labels of the word in the testing sentence.
"""
# Input data are separated by empty lines
text_split = text.split("\n\n")
# Remove empty line at EOF
text_split = text_split[:-1]
if random_seed:
random.seed(random_seed)
random.shuffle(text_split)
sentence_count = len(text_split)