Added random seed option to wikigold util function.
This commit is contained in:
Родитель
049ddf6442
Коммит
26fcc3cbe4
|
@ -27,7 +27,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -59,34 +59,33 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# path configurations\n",
|
||||
"data_dir = \"./data\"\n",
|
||||
"data_file = \"./data/wikigold.conll.txt\"\n",
|
||||
"cache_dir=\".\"\n",
|
||||
"DATA_DIR = \"./data\"\n",
|
||||
"DATA_FILE = \"./data/wikigold.conll.txt\"\n",
|
||||
"CACHE_DIR=\".\"\n",
|
||||
"\n",
|
||||
"# set random seeds\n",
|
||||
"random_seed = 42\n",
|
||||
"random.seed(random_seed)\n",
|
||||
"torch.manual_seed(random_seed)\n",
|
||||
"RANDOM_SEED = 100\n",
|
||||
"torch.manual_seed(RANDOM_SEED)\n",
|
||||
"\n",
|
||||
"# model configurations\n",
|
||||
"language = Language.ENGLISHCASED\n",
|
||||
"do_lower_case = False\n",
|
||||
"max_seq_length = 200\n",
|
||||
"LANGUAGE = Language.ENGLISHCASED\n",
|
||||
"DO_LOWER_CASE = False\n",
|
||||
"MAX_SEQ_LENGTH = 200\n",
|
||||
"\n",
|
||||
"# training configurations\n",
|
||||
"device=\"gpu\"\n",
|
||||
"batch_size = 16\n",
|
||||
"num_train_epochs = 5\n",
|
||||
"DEVICE=\"gpu\"\n",
|
||||
"BATCH_SIZE = 16\n",
|
||||
"NUM_TRAIN_EPOCHS = 5\n",
|
||||
"\n",
|
||||
"# optimizer configuration\n",
|
||||
"learning_rate = 3e-5"
|
||||
"LEARNING_RATE = 3e-5"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -107,12 +106,12 @@
|
|||
"\n",
|
||||
"The helper function `get_unique_labels` returns the unique entity labels in the dataset. There are 5 unique labels in the original dataset: 'O' (non-entity), 'I-LOC' (location), 'I-MISC' (miscellaneous), 'I-PER' (person), and 'I-ORG' (organization). An 'X' label is added for the trailing word pieces generated by BERT, because BERT uses WordPiece tokenizer. \n",
|
||||
"\n",
|
||||
"The maximum number of words in a sentence is 144, so we set `max_seq_length` to 200 above, because the number of tokens will grow after WordPiece tokenization."
|
||||
"The maximum number of words in a sentence is 144, so we set MAX_SEQ_LENGTH to 200 above, because the number of tokens will grow after WordPiece tokenization."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -122,24 +121,26 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"Maximum sequence length in training data is: 144\n",
|
||||
"Maximum sequence length in testing data is: 89\n",
|
||||
"Maximum sequence length in testing data is: 81\n",
|
||||
"\n",
|
||||
"Unique entity labels: \n",
|
||||
"['O', 'I-LOC', 'I-MISC', 'I-PER', 'I-ORG', 'X']\n",
|
||||
"\n",
|
||||
"Sample sentence: \n",
|
||||
"-DOCSTART-\n",
|
||||
"Two , Samsung based , electronic cash registers were reconstructed in order to expand their functions and adapt them for networking .\n",
|
||||
"\n",
|
||||
"Sample sentence labels: \n",
|
||||
"['O']\n",
|
||||
"['O', 'O', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"download(data_dir)\n",
|
||||
"wikigold_text = read_data(data_file)\n",
|
||||
"train_text, train_labels, test_text, test_labels = get_train_test_data(wikigold_text, test_percentage=0.5)\n",
|
||||
"download(DATA_DIR)\n",
|
||||
"wikigold_text = read_data(DATA_FILE)\n",
|
||||
"train_text, train_labels, test_text, test_labels = get_train_test_data(wikigold_text, \n",
|
||||
" test_percentage=0.5, \n",
|
||||
" random_seed=RANDOM_SEED)\n",
|
||||
"label_list = get_unique_labels()\n",
|
||||
"print('\\nUnique entity labels: \\n{}\\n'.format(label_list))\n",
|
||||
"print('Sample sentence: \\n{}\\n'.format(train_text[0]))\n",
|
||||
|
@ -166,7 +167,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -184,15 +185,15 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = Tokenizer(language=language, \n",
|
||||
" to_lower=do_lower_case, \n",
|
||||
" cache_dir=cache_dir)"
|
||||
"tokenizer = Tokenizer(language=LANGUAGE, \n",
|
||||
" to_lower=DO_LOWER_CASE, \n",
|
||||
" cache_dir=CACHE_DIR)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -205,7 +206,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
|
@ -214,13 +215,13 @@
|
|||
"train_token_ids, train_input_mask, train_trailing_token_mask, train_label_ids = \\\n",
|
||||
" tokenizer.preprocess_ner_tokens(text=train_text,\n",
|
||||
" label_map=label_map,\n",
|
||||
" max_len=max_seq_length,\n",
|
||||
" max_len=MAX_SEQ_LENGTH,\n",
|
||||
" labels=train_labels,\n",
|
||||
" trailing_piece_tag=\"X\")\n",
|
||||
"test_token_ids, test_input_mask, test_trailing_token_mask, test_label_ids = \\\n",
|
||||
" tokenizer.preprocess_ner_tokens(text=test_text,\n",
|
||||
" label_map=label_map,\n",
|
||||
" max_len=max_seq_length,\n",
|
||||
" max_len=MAX_SEQ_LENGTH,\n",
|
||||
" labels=test_labels,\n",
|
||||
" trailing_piece_tag=\"X\")"
|
||||
]
|
||||
|
@ -238,7 +239,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -248,16 +249,16 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample token ids:\n",
|
||||
"[118, 141, 9244, 9272, 12426, 1942, 118, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
|
||||
"[1960, 117, 20799, 1359, 117, 4828, 5948, 21187, 1127, 15755, 1107, 1546, 1106, 7380, 1147, 4226, 1105, 16677, 1172, 1111, 16074, 119, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
|
||||
"\n",
|
||||
"Sample attention mask:\n",
|
||||
"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
|
||||
"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
|
||||
"\n",
|
||||
"Sample trailing token mask:\n",
|
||||
"[True, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]\n",
|
||||
"[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]\n",
|
||||
"\n",
|
||||
"Sample label ids:\n",
|
||||
"[0, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
|
||||
"[0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
|
@ -287,15 +288,15 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"token_classifier = BERTTokenClassifier(language=language,\n",
|
||||
"token_classifier = BERTTokenClassifier(language=LANGUAGE,\n",
|
||||
" num_labels=len(label_list),\n",
|
||||
" cache_dir=cache_dir)"
|
||||
" cache_dir=CACHE_DIR)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -307,7 +308,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
|
@ -333,10 +334,10 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Iteration: 41%|████▏ | 24/58 [00:31<00:44, 1.30s/it]\u001b[A\n",
|
||||
"Iteration: 41%|████▏ | 24/58 [00:49<00:44, 1.30s/it]\u001b[A\n",
|
||||
"Iteration: 83%|████████▎ | 48/58 [01:02<00:12, 1.30s/it]\u001b[A\n",
|
||||
"Epoch: 20%|██ | 1/5 [01:14<04:59, 74.89s/it]9s/it]\u001b[A\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:30<00:45, 1.31s/it]\u001b[A\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:49<00:45, 1.31s/it]\u001b[A\n",
|
||||
"Iteration: 81%|████████ | 47/58 [01:01<00:14, 1.31s/it]\u001b[A\n",
|
||||
"Epoch: 20%|██ | 1/5 [01:15<05:01, 75.44s/it]0s/it]\u001b[A\n",
|
||||
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
|
||||
]
|
||||
},
|
||||
|
@ -344,7 +345,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train loss: 0.4810219132180872\n"
|
||||
"Train loss: 0.4387937460480065\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -352,11 +353,11 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Iteration: 41%|████▏ | 24/58 [00:31<00:44, 1.30s/it]\u001b[A\n",
|
||||
"Iteration: 41%|████▏ | 24/58 [00:45<00:44, 1.30s/it]\u001b[A\n",
|
||||
"Iteration: 83%|████████▎ | 48/58 [01:02<00:13, 1.30s/it]\u001b[A\n",
|
||||
"Iteration: 83%|████████▎ | 48/58 [01:15<00:13, 1.30s/it]\u001b[A\n",
|
||||
"Epoch: 40%|████ | 2/5 [02:30<03:44, 74.96s/it]0s/it]\u001b[A\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:30<00:45, 1.31s/it]\u001b[A\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:44<00:45, 1.31s/it]\u001b[A\n",
|
||||
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.31s/it]\u001b[A\n",
|
||||
"Iteration: 79%|███████▉ | 46/58 [01:14<00:15, 1.31s/it]\u001b[A\n",
|
||||
"Epoch: 40%|████ | 2/5 [02:31<03:46, 75.53s/it]1s/it]\u001b[A\n",
|
||||
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
|
||||
]
|
||||
},
|
||||
|
@ -364,7 +365,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train loss: 0.1079183994182225\n"
|
||||
"Train loss: 0.11529478796854101\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -373,9 +374,9 @@
|
|||
"text": [
|
||||
"\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:30<00:46, 1.32s/it]\u001b[A\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:49<00:46, 1.32s/it]\u001b[A\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:48<00:46, 1.32s/it]\u001b[A\n",
|
||||
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.32s/it]\u001b[A\n",
|
||||
"Epoch: 60%|██████ | 3/5 [03:46<02:30, 75.28s/it]1s/it]\u001b[A\n",
|
||||
"Epoch: 60%|██████ | 3/5 [03:47<02:31, 75.73s/it]1s/it]\u001b[A\n",
|
||||
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
|
||||
]
|
||||
},
|
||||
|
@ -383,7 +384,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train loss: 0.04469434472186298\n"
|
||||
"Train loss: 0.048355102719706965\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -392,10 +393,10 @@
|
|||
"text": [
|
||||
"\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:30<00:46, 1.32s/it]\u001b[A\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:43<00:46, 1.32s/it]\u001b[A\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:42<00:46, 1.32s/it]\u001b[A\n",
|
||||
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.32s/it]\u001b[A\n",
|
||||
"Iteration: 79%|███████▉ | 46/58 [01:13<00:15, 1.32s/it]\u001b[A\n",
|
||||
"Epoch: 80%|████████ | 4/5 [05:02<01:15, 75.57s/it]1s/it]\u001b[A\n",
|
||||
"Iteration: 79%|███████▉ | 46/58 [01:12<00:15, 1.32s/it]\u001b[A\n",
|
||||
"Epoch: 80%|████████ | 4/5 [05:03<01:15, 75.92s/it]2s/it]\u001b[A\n",
|
||||
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
|
||||
]
|
||||
},
|
||||
|
@ -403,7 +404,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train loss: 0.023166427043555624\n"
|
||||
"Train loss: 0.030773249510996813\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -412,16 +413,17 @@
|
|||
"text": [
|
||||
"\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:30<00:46, 1.32s/it]\u001b[A\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:47<00:46, 1.32s/it]\u001b[A\n",
|
||||
"Iteration: 40%|███▉ | 23/58 [00:46<00:46, 1.32s/it]\u001b[A\n",
|
||||
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.32s/it]\u001b[A\n",
|
||||
"Epoch: 100%|██████████| 5/5 [06:18<00:00, 75.79s/it]2s/it]\u001b[A"
|
||||
"Iteration: 79%|███████▉ | 46/58 [01:16<00:15, 1.32s/it]\u001b[A\n",
|
||||
"Epoch: 100%|██████████| 5/5 [06:20<00:00, 76.06s/it]2s/it]\u001b[A"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train loss: 0.014027320994626218\n"
|
||||
"Train loss: 0.018075179055750627\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -436,9 +438,9 @@
|
|||
"token_classifier.fit(token_ids=train_token_ids, \n",
|
||||
" input_mask=train_input_mask, \n",
|
||||
" labels=train_label_ids,\n",
|
||||
" num_epochs=num_train_epochs, \n",
|
||||
" batch_size=batch_size, \n",
|
||||
" learning_rate=learning_rate)"
|
||||
" num_epochs=NUM_TRAIN_EPOCHS, \n",
|
||||
" batch_size=BATCH_SIZE, \n",
|
||||
" learning_rate=LEARNING_RATE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -450,7 +452,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -474,14 +476,14 @@
|
|||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Iteration: 100%|██████████| 58/58 [00:25<00:00, 2.30it/s]"
|
||||
"Iteration: 100%|██████████| 58/58 [00:25<00:00, 2.29it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluation loss: 0.13430038492741256\n"
|
||||
"Evaluation loss: 0.11602881579691994\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -496,7 +498,7 @@
|
|||
"pred_label_ids = token_classifier.predict(token_ids=test_token_ids, \n",
|
||||
" input_mask=test_input_mask, \n",
|
||||
" labels=test_label_ids, \n",
|
||||
" batch_size=batch_size)"
|
||||
" batch_size=BATCH_SIZE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -509,7 +511,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -520,14 +522,14 @@
|
|||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" LOC 0.91 0.84 0.88 558\n",
|
||||
" X 0.97 0.98 0.98 2060\n",
|
||||
" PER 0.90 0.95 0.93 569\n",
|
||||
" ORG 0.70 0.82 0.76 553\n",
|
||||
" MISC 0.74 0.71 0.72 399\n",
|
||||
" MISC 0.73 0.71 0.72 396\n",
|
||||
" ORG 0.72 0.79 0.76 538\n",
|
||||
" X 0.96 0.97 0.97 1983\n",
|
||||
" PER 0.93 0.93 0.93 550\n",
|
||||
" LOC 0.84 0.89 0.86 543\n",
|
||||
"\n",
|
||||
"micro avg 0.89 0.91 0.90 4139\n",
|
||||
"macro avg 0.90 0.91 0.90 4139\n",
|
||||
"micro avg 0.88 0.90 0.89 4010\n",
|
||||
"macro avg 0.88 0.90 0.89 4010\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
|
@ -551,7 +553,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -562,13 +564,13 @@
|
|||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" LOC 0.90 0.85 0.88 507\n",
|
||||
" PER 0.89 0.94 0.92 460\n",
|
||||
" ORG 0.65 0.81 0.72 440\n",
|
||||
" MISC 0.67 0.68 0.68 328\n",
|
||||
" ORG 0.69 0.79 0.74 442\n",
|
||||
" PER 0.92 0.92 0.92 455\n",
|
||||
" MISC 0.71 0.68 0.69 344\n",
|
||||
" LOC 0.83 0.87 0.85 503\n",
|
||||
"\n",
|
||||
"micro avg 0.77 0.83 0.80 1735\n",
|
||||
"macro avg 0.79 0.83 0.81 1735\n",
|
||||
"micro avg 0.77 0.83 0.79 1744\n",
|
||||
"macro avg 0.79 0.83 0.81 1744\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ def read_data(data_file):
|
|||
return text
|
||||
|
||||
|
||||
def get_train_test_data(text, test_percentage=0.5):
|
||||
def get_train_test_data(text, test_percentage=0.5, random_seed=None):
|
||||
"""
|
||||
Get the training and testing data based on test_percentage.
|
||||
|
||||
|
@ -41,6 +41,7 @@ def get_train_test_data(text, test_percentage=0.5):
|
|||
test_percentage (float, optional): Percentage of data ot use for
|
||||
testing. Since this is a small dataset, the default testing
|
||||
percentage is set to 0.5
|
||||
random_seed (float, optional): Random seed used to shuffle the data.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing four lists:
|
||||
|
@ -51,11 +52,13 @@ def get_train_test_data(text, test_percentage=0.5):
|
|||
test_labels_list: List of lists. Each sublist contains the
|
||||
entity labels of the word in the testing sentence.
|
||||
"""
|
||||
|
||||
# Input data are separated by empty lines
|
||||
text_split = text.split("\n\n")
|
||||
# Remove empty line at EOF
|
||||
text_split = text_split[:-1]
|
||||
|
||||
if random_seed:
|
||||
random.seed(random_seed)
|
||||
random.shuffle(text_split)
|
||||
|
||||
sentence_count = len(text_split)
|
||||
|
|
Загрузка…
Ссылка в новой задаче