Initial commit with code.
This commit is contained in:
Родитель
aa3bffc42b
Коммит
5254314c7a
|
@ -1,104 +1,114 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*,cover
|
||||
.hypothesis/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# IPython Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# dotenv
|
||||
.env
|
||||
|
||||
# virtualenv
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# Mac
|
||||
.DS_Store
|
||||
|
||||
# model cache
|
||||
*.pt
|
||||
|
||||
# dataset
|
||||
rt-polaritydata
|
||||
trees
|
||||
*.tar
|
||||
*.zip
|
||||
|
||||
# pycharm
|
||||
.idea
|
||||
|
||||
.idea/
|
||||
output/
|
||||
*.7z
|
||||
.DS_Store
|
||||
data/
|
||||
data/glove/
|
||||
data/processed/enwiki*
|
||||
data/raw/
|
||||
data_processing/extractor.txt
|
||||
|
|
165
README.md
165
README.md
|
@ -1,3 +1,168 @@
|
|||
Modeling the Relationship Between Comments and Edits in Document Revisions
|
||||
======
|
||||
|
||||
This is a pytorch implementation of modeling the relationship between comments and edits for wikipedia revision data, as described in our EMNLP 2019 paper:
|
||||
|
||||
**Modeling the Relationship between User Comments and Edits in Document Revision**, Xuchao Zhang, Dheeraj Rajagopal, Michael Gamon, Sujay Kumar Jauhar and ChangTien Lu, 2019 Conference on Empirical Methods in Natural Language Processing (EMNLP), Hongkong, China, Nov 3-7, 2019.
|
||||
|
||||
Two distinct but related tasks are proposed in this work:
|
||||
- **Comment Ranking**: ranking a list of comments based on their relevance to a specific edit
|
||||
- **Edit Anchoring**: anchoring a comment to the most relevant sentences in a document
|
||||
|
||||
## Requirements
|
||||
- python 3.6
|
||||
- [tqdm](https://github.com/noamraph/tqdm)
|
||||
- [pytorch 0.4.0](https://pytorch.org/)
|
||||
- [numpy v1.13+](http://www.numpy.org/)
|
||||
- [scipy](https://www.scipy.org/)
|
||||
- [scikit-learn](http://scikit-learn.org/stable/)
|
||||
- [spacy v2.0.11](https://spacy.io/)
|
||||
- and some basic packages.
|
||||
|
||||
|
||||
## Usage
|
||||
### Data Preparation
|
||||
- Download the raw data from Wikipedia and generate the preprocessed data file ```wikicmnt.json``` by the steps in our [Data Extractor](./data_processing/README.md). And leave the generated data file ```wikicmnt.json``` in the dataset folder such as ```./data/processed/```.
|
||||
|
||||
<!-- For demo purpose, we include a small dump file ```enwiki-20181001-pages-meta-history24.xml-p33948436p33952815.bz2``` (17MB) in the ```./dataset/``` folder. -->
|
||||
|
||||
- Download the glove embeddings from ```https://nlp.stanford.edu/projects/glove/```. And copy these files into the folder you specified in the parameter such as ```./dataset/glove/```.
|
||||
|
||||
### Training
|
||||
In this section, we try to train the models for both comment ranking and edit anchoring tasks individually or jointly. Before training a model, you need to check whether the ```wikicmnt.json``` file existing in the ```--data_path```.
|
||||
|
||||
##### Comment Ranking Training
|
||||
To train the model for comment ranking task, run the following command:
|
||||
```
|
||||
python3 main.py --cr_train --data_path="./data/processed/"
|
||||
```
|
||||
|
||||
##### Edit Anchoring Training
|
||||
To train the model for edit anchoring task, run the following command:
|
||||
```
|
||||
python3 main.py --ea_train --data_path="./data/processed/"
|
||||
```
|
||||
##### Jointly train on both Comment Ranking & Edit Anchoring tasks
|
||||
Train a multi-task model with default parameters is simply combining the parameters of individual task together:
|
||||
```
|
||||
python3 main.py --cr_train --ea_train --data_path="./dataset/processed/"
|
||||
```
|
||||
The common parameters you can change:
|
||||
```
|
||||
--data_path="./dataset/"
|
||||
--glove_path="./dataset/glove/"
|
||||
--epoch=20
|
||||
--batch_size=10
|
||||
--word_embd_size=100
|
||||
```
|
||||
|
||||
The best model is saved to default folder ```./checkpoint/```.
|
||||
|
||||
### Test
|
||||
Test saved model. The following default metrics are presented: P@1, P@3, MRR and NDCG
|
||||
```
|
||||
python main.py --test -checkpoint="./checkpoint/saved_best_model.pt"
|
||||
```
|
||||
|
||||
### Full Usage Options
|
||||
A full options of our code are listed
|
||||
```
|
||||
usage: main.py [-h] [--data_path DATA_PATH]
|
||||
[--checkpoint_path CHECKPOINT_PATH] [--glove_path GLOVE_PATH]
|
||||
[--log_interval LOG_INTERVAL] [--test_interval TEST_INTERVAL]
|
||||
[--save_best SAVE_BEST] [--lr LR] [--ngpu NGPU]
|
||||
[--word_embd_size WORD_EMBD_SIZE]
|
||||
[--max_ctx_length MAX_CTX_LENGTH]
|
||||
[--max_diff_length MAX_DIFF_LENGTH]
|
||||
[--max_cmnt_length MAX_CMNT_LENGTH] [--ctx_mode CTX_MODE]
|
||||
[--rnn_model] [--epoch EPOCH] [--start_epoch START_EPOCH]
|
||||
[--batch_size BATCH_SIZE] [--cr_train] [--ea_train]
|
||||
[--no_action] [--no_attention] [--no_hadamard]
|
||||
[--src_train SRC_TRAIN] [--train_ratio TRAIN_RATIO]
|
||||
[--val_ratio VAL_RATIO] [--val_size VAL_SIZE]
|
||||
[--manualSeed MANUALSEED] [--test] [--case_study]
|
||||
[--resume PATH] [--seed SEED] [--gpu GPU]
|
||||
[--checkpoint CHECKPOINT] [--rank_num RANK_NUM]
|
||||
[--anchor_num ANCHOR_NUM] [--use_target_only] [--predict]
|
||||
[--pred_cmnt PRED_CMNT] [--pred_ctx PRED_CTX]
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--data_path DATA_PATH
|
||||
the data directory
|
||||
--checkpoint_path CHECKPOINT_PATH
|
||||
the checkpoint directory
|
||||
--glove_path GLOVE_PATH
|
||||
the glove directory
|
||||
--log_interval LOG_INTERVAL
|
||||
how many steps to wait before logging training status
|
||||
[default: 100]
|
||||
--test_interval TEST_INTERVAL
|
||||
how many steps to wait before testing [default: 1000]
|
||||
--save_best SAVE_BEST
|
||||
whether to save when get best performance
|
||||
--lr LR learning rate, default=0.5
|
||||
--ngpu NGPU number of GPUs to use
|
||||
--word_embd_size WORD_EMBD_SIZE
|
||||
word embedding size
|
||||
--max_ctx_length MAX_CTX_LENGTH
|
||||
the maximum words in the context [default: 300]
|
||||
--max_diff_length MAX_DIFF_LENGTH
|
||||
the maximum words in the revision difference [default:
|
||||
200]
|
||||
--max_cmnt_length MAX_CMNT_LENGTH
|
||||
the maximum words in the comment [default: 30]
|
||||
--ctx_mode CTX_MODE whether to use change context in training [default:
|
||||
True]
|
||||
--rnn_model use rnn baseline model
|
||||
--epoch EPOCH number of epoch, default=10
|
||||
--start_epoch START_EPOCH
|
||||
resume epoch count, default=1
|
||||
--batch_size BATCH_SIZE
|
||||
input batch size
|
||||
--cr_train whether to training the comment rank task
|
||||
--ea_train whether to training the revision anchoring task
|
||||
--no_action whether to use action encoding to train the model
|
||||
--no_attention whether to use mutual attention to train the model
|
||||
--no_hadamard whether to use hadamard product to train the model
|
||||
--src_train SRC_TRAIN
|
||||
whether to training the comment rank task without
|
||||
before-editing version
|
||||
--train_ratio TRAIN_RATIO
|
||||
ratio of training data in the entire data [default:
|
||||
0.7]
|
||||
--val_ratio VAL_RATIO
|
||||
ratio of validation data in the entire data [default:
|
||||
0.1]
|
||||
--val_size VAL_SIZE force the size of validation dataset, the parameter
|
||||
will disgard the setting of val_ratio [default: -1]
|
||||
--manualSeed MANUALSEED
|
||||
manual seed
|
||||
--test use test model
|
||||
--case_study use case study mode
|
||||
--resume PATH path saved params
|
||||
--seed SEED random seed
|
||||
--gpu GPU gpu to use for iterate data, -1 mean cpu [default: -1]
|
||||
--checkpoint CHECKPOINT
|
||||
filename of model checkpoint [default: None]
|
||||
--rank_num RANK_NUM the number of ranking comments
|
||||
--anchor_num ANCHOR_NUM
|
||||
the number of ranking comments
|
||||
--use_target_only use target context only in model
|
||||
--predict predict the sentence given
|
||||
--pred_cmnt PRED_CMNT
|
||||
the comment of prediction
|
||||
--pred_ctx PRED_CTX the context of prediction
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Author
|
||||
|
||||
If you have any troubles or questions, please contact [Xuchao Zhang]().
|
||||
|
||||
August, 2018
|
||||
|
||||
|
||||
# Contributing
|
||||
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*,cover
|
||||
.hypothesis/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# IPython Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# dotenv
|
||||
.env
|
||||
|
||||
# virtualenv
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# Mac
|
||||
.DS_Store
|
||||
|
||||
# model cache
|
||||
*.pt
|
||||
|
||||
# dataset
|
||||
rt-polaritydata
|
||||
trees
|
||||
*.tar
|
||||
*.zip
|
||||
|
||||
# pycharm
|
||||
.idea
|
||||
|
||||
data/
|
||||
dataset/
|
||||
output/
|
||||
wikidata/
|
||||
*.7z
|
||||
*.bz2
|
||||
sumdata_api.py
|
|
@ -0,0 +1,91 @@
|
|||
Wikipedia Revision Data Extractor
|
||||
======
|
||||
|
||||
This is a python implementation of the data extractor tool set for wikipedia revision data, as described in our EMNLP 2019 paper:
|
||||
|
||||
**Modeling the Relationship between User Comments and Edits in Document Revision**, Xuchao Zhang, Dheeraj Rajagopal, Michael Gamon, Sujay Kumar Jauhar and ChangTien Lu, 2019 Conference on Empirical Methods in Natural Language Processing (EMNLP), Hongkong, China, Nov 3-7, 2019.
|
||||
|
||||
|
||||
We provide three tools to extract and preprocess the wikipedia revision history data from scratch:
|
||||
- download entire enwiki revision dumps from wikipedia
|
||||
- extract the revision data for comment modeling task from wiki dump files
|
||||
- extract the summeration task dataset from wiki dump
|
||||
|
||||
Note: The collected wikipedia revision data can be used as the input for the proposed models in our EMNLP paper or used individually for other tasks.
|
||||
|
||||
## Requirements
|
||||
- python 3.6
|
||||
- [tqdm](https://github.com/noamraph/tqdm)
|
||||
- [numpy v1.13+](http://www.numpy.org/)
|
||||
- [scipy](https://www.scipy.org/)
|
||||
- [scikit-learn](http://scikit-learn.org/stable/)
|
||||
- [spacy v2.0.11](https://spacy.io/)
|
||||
- and some basic packages.
|
||||
|
||||
|
||||
## Usage
|
||||
### Download Wiki dump files
|
||||
First, choose a dump such as ```https://dumps.wikimedia.org/enwiki/20190801``` (the latest version of wiki dump when our code is released). You can check all the information related to dump files from this page such as the list of files generated in this dump. Then download a machine-readable dump status file ```dumpstatus.json``` from the Wikipedia dump page. Next copy the status file into the default data path, e.g., ```./data/```.
|
||||
|
||||
###### Important Note:
|
||||
* Check the dump files must contain the complete page edit history: ```All pages with complete page edit history (.bz2)```. The edit history is sometimes skipped by some specific versions.
|
||||
* Always choose the recent dumps since Wikipedia cleans the old dumps and make the old one deprecated.
|
||||
|
||||
|
||||
Finally, run our wiki dump download script to download dump files as follows:
|
||||
```
|
||||
python wiki_dump_download.py --data_path="./data/raw/" --compress_type="bz2" --threads=3
|
||||
```
|
||||
You need to specify the data path and compress type (by default choose bz2 ). Since the download process will be extremely slow, you can use multiple threads to download the dump files. However, Wikipedia only allows three http connections to download simultaneously for each IP address. The maximum threads I recommend is three unless you can assign different IP address for each threads.
|
||||
|
||||
At the beginning of download process, all the files are listed with unique Ids as follows:
|
||||
|
||||
```
|
||||
All files to download ...
|
||||
1 enwiki-20190801-pages-meta-history1.xml-p1043p2036.bz2
|
||||
2 enwiki-20190801-pages-meta-history1.xml-p10675p11499.bz2
|
||||
3 enwiki-20190801-pages-meta-history1.xml-p10p1042.bz2
|
||||
4 enwiki-20190801-pages-meta-history1.xml-p11500p12310.bz2
|
||||
...
|
||||
...
|
||||
648 enwiki-20190801-pages-meta-history9.xml-p2326206p2336422.bz2
|
||||
```
|
||||
Usually, the entire download process takes one to two days to be done. You can download each file individually by specifying the ```--start``` and ```--end``` parameters. You can also use ```--verify``` parameter to verify the completeness of your dump files.
|
||||
|
||||
|
||||
### Revision Data Preprocessing
|
||||
|
||||
For preprocessing the revision data, we provide both single-thread and multi-thread versions.
|
||||
|
||||
To preprocess a single dump file, we specify the file index of the dump file by the parameter ```--single_thread``` as follows:
|
||||
```
|
||||
python3 wikicmnt_extractor.py --single_thread=5
|
||||
```
|
||||
Here the number 5 in the example means the 5th dump file in the data folder of dump files.
|
||||
|
||||
To preprocess multiple dump files,
|
||||
```
|
||||
python3 wikicmnt_extractor.py --threads=10
|
||||
```
|
||||
|
||||
You need to specify some common parameters:
|
||||
```
|
||||
--data_path="../data/raw/"
|
||||
--output_path="../data/processed/"
|
||||
--sample_ratio=0.1
|
||||
--ctx_window=5
|
||||
--min_page_tokens=50
|
||||
--max_page_tokens=2000
|
||||
--min_cmnt_length=8
|
||||
```
|
||||
Last, if you use the single thread mode to generate the processed files one by one, you need to merge the outputs of all the dump files together by running the following command:
|
||||
```
|
||||
python3 wikicmnt_extractor.py --merge_only
|
||||
```
|
||||
The output of the command is a processed file ```wikicmnt.json``` which includes all the processed data.
|
||||
|
||||
## Author
|
||||
|
||||
If you have any troubles or questions, please contact [Xuchao Zhang](xuczhang@gmail.com).
|
||||
|
||||
August, 2019
|
|
@ -0,0 +1,186 @@
|
|||
import argparse
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import urllib.request
|
||||
from datetime import datetime
|
||||
|
||||
parser = argparse.ArgumentParser(description='WikiDump Downloader')
|
||||
parser.add_argument('--data-path', type=str, default="./data/", help='the data directory')
|
||||
parser.add_argument('--compress-type', type=str, default='bz2',
|
||||
help='the compressed file type to download: 7z or bz2 [default: bz2]')
|
||||
parser.add_argument('--threads', type=int, default=3, help='number of threads [default: 3]')
|
||||
parser.add_argument('--start', type=int, default=1, help='the first file to download [default: 0]')
|
||||
parser.add_argument('--end', type=int, default=-1, help='the last file to download [default: -1]')
|
||||
parser.add_argument('--verify', action='store_true', default=False, help='verify the dump files in the specific path')
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='(%(threadName)s) %(message)s',
|
||||
)
|
||||
|
||||
|
||||
def download(dump_status_file, data_path, compress_type, start, end, thread_num):
|
||||
url_list = []
|
||||
file_list = []
|
||||
with open(dump_status_file) as json_data:
|
||||
# Two dump types: compressed by 7z (metahistory7zdump) or bz2 (metahistorybz2dump)
|
||||
history_dump = json.load(json_data)['jobs']['metahistory' + compress_type + 'dump']
|
||||
dump_dict = history_dump['files']
|
||||
dump_files = sorted(list(dump_dict.keys()))
|
||||
|
||||
if args.end > 0 and args.end <= len(dump_files):
|
||||
dump_files = dump_files[start - 1:end]
|
||||
else:
|
||||
dump_files = dump_files[start - 1:]
|
||||
|
||||
# print all files to be downloaded.
|
||||
print("All files to download ...")
|
||||
for i, file in enumerate(dump_files):
|
||||
print(i + args.start, file)
|
||||
|
||||
file_num = 0
|
||||
for dump_file in dump_files:
|
||||
file_name = data_path + dump_file
|
||||
file_list.append(file_name)
|
||||
|
||||
# url example: https://dumps.wikimedia.org/enwiki/20180501/enwiki-20180501-pages-meta-history1.xml-p10p2123.7z
|
||||
url = "https://dumps.wikimedia.org" + dump_dict[dump_file]['url']
|
||||
url_list.append(url)
|
||||
file_num += 1
|
||||
|
||||
print('Total file ', file_num, ' to be downloaded ...')
|
||||
json_data.close()
|
||||
|
||||
task = WikiDumpTask(file_list, url_list)
|
||||
threads = []
|
||||
for i in range(thread_num):
|
||||
t = threading.Thread(target=worker, args=(i, task))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
logging.debug('Waiting for worker threads')
|
||||
main_thread = threading.currentThread()
|
||||
for t in threading.enumerate():
|
||||
if t is not main_thread:
|
||||
t.join()
|
||||
|
||||
|
||||
def existFile(data_path, cur_file):
|
||||
exist_file_list = glob.glob(data_path + "*." + args.compress_type)
|
||||
exist_file_names = [os.path.basename(i) for i in exist_file_list]
|
||||
cur_file_name = os.path.basename(cur_file)
|
||||
if cur_file_name in exist_file_names:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def md5(file):
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(file, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(40960000), b""):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def verify(dump_status_file, compress_type, data_path):
|
||||
print("Verify the file in folder:", data_path)
|
||||
pass_files, miss_files, crash_files = [], [], []
|
||||
with open(dump_status_file) as json_data:
|
||||
# Two dump types: compressed by 7z (metahistory7zdump) or bz2 (metahistorybz2dump)
|
||||
history_dump = json.load(json_data)['jobs']['metahistory' + compress_type + 'dump']
|
||||
dump_dict = history_dump['files']
|
||||
for i, (file, value) in enumerate(dump_dict.items()):
|
||||
gt_md5 = value['md5']
|
||||
print("#", i, " ", file, ' ', value['md5'], sep='')
|
||||
if existFile(data_path, file):
|
||||
file_md5 = md5(data_path + file)
|
||||
if file_md5 == gt_md5:
|
||||
pass_files.append(file)
|
||||
else:
|
||||
crash_files.append(file)
|
||||
else:
|
||||
miss_files.append(file)
|
||||
|
||||
print(len(pass_files), "files passed, ", len(miss_files), "files missed, ", len(crash_files), "files crashed.")
|
||||
|
||||
if len(miss_files):
|
||||
print("==== Missed Files ====")
|
||||
print(miss_files)
|
||||
|
||||
if len(crash_files):
|
||||
print("==== Crashed Files ====")
|
||||
print(crash_files)
|
||||
|
||||
|
||||
def main():
|
||||
dump_status_file = args.data_path + "dumpstatus.json"
|
||||
|
||||
if args.verify:
|
||||
verify(dump_status_file, args.compress_type, args.data_path)
|
||||
else:
|
||||
download(dump_status_file, args.data_path, args.compress_type, args.start, args.end, args.threads)
|
||||
|
||||
|
||||
'''
|
||||
WikiDumpTask class contains a list of dump files to be downloaded .
|
||||
The assign_task function will be called by workers to grab a task.
|
||||
'''
|
||||
|
||||
|
||||
class WikiDumpTask(object):
|
||||
def __init__(self, file_list, url_list):
|
||||
self.lock = threading.Lock()
|
||||
self.url_list = url_list
|
||||
self.file_list = file_list
|
||||
self.total_num = len(url_list)
|
||||
|
||||
def assign_task(self):
|
||||
logging.debug('Assign tasks ... Waiting for lock')
|
||||
self.lock.acquire()
|
||||
url = None
|
||||
file_name = None
|
||||
cur_progress = None
|
||||
try:
|
||||
# logging.debug('Acquired lock')
|
||||
if len(self.url_list) > 0:
|
||||
url = self.url_list.pop(0)
|
||||
file_name = self.file_list.pop(0)
|
||||
cur_progress = self.total_num - len(self.url_list)
|
||||
finally:
|
||||
self.lock.release()
|
||||
return url, file_name, cur_progress, self.total_num
|
||||
|
||||
|
||||
'''
|
||||
worker is main function for each thread.
|
||||
'''
|
||||
|
||||
|
||||
def worker(work_id, tasks):
|
||||
logging.debug('Starting.')
|
||||
|
||||
# grab one task from task_list
|
||||
while 1:
|
||||
url, file_name, cur_progress, total_num = tasks.assign_task()
|
||||
if not url:
|
||||
break
|
||||
logging.debug('Assigned task (' + str(cur_progress) + '/' + str(total_num) + '): ' + str(url))
|
||||
if not existFile(args.data_path, file_name):
|
||||
urllib.request.urlretrieve(url, file_name)
|
||||
logging.debug("File Downloaded: " + url)
|
||||
else:
|
||||
logging.debug("File Exists, Skip: " + url)
|
||||
logging.debug('Exiting.')
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_time = datetime.now()
|
||||
main()
|
||||
time_elapsed = datetime.now() - start_time
|
||||
print('Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed))
|
|
@ -0,0 +1,118 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append('../')
|
||||
from wiki_util import *
|
||||
|
||||
|
||||
def extract_json(raw_file, ctx10_file, output_file, ctx_window=10, negative_edit_num=10):
|
||||
# load tokenized comment and neg_comments from ctx10_file
|
||||
print("Loading tokenized comment and neg_comments from ctx10_file")
|
||||
cmnt_dict = {}
|
||||
with open(ctx10_file, 'r', encoding='utf-8') as f:
|
||||
for idx, json_line in enumerate(tqdm(f)):
|
||||
article = json.loads(json_line.strip('\n'))
|
||||
rev_id = article["revision_id"]
|
||||
comment = article["comment"]
|
||||
neg_cmnts = article["neg_cmnts"]
|
||||
cmnt_dict[rev_id] = (comment, neg_cmnts)
|
||||
|
||||
json_file = open(output_file, "w", buffering=1, encoding='utf-8')
|
||||
with open(raw_file, 'r', encoding='utf-8') as f:
|
||||
for idx, json_line in enumerate(tqdm(f)):
|
||||
|
||||
article = json.loads(json_line.strip('\n'))
|
||||
|
||||
'''
|
||||
Json file format:
|
||||
=================
|
||||
revision_id: The revision ID
|
||||
parent_id: The parent revision ID
|
||||
timestamp: Timestamp
|
||||
diff_url: The wikipedia link to show the difference between previous and current version.
|
||||
page_title: The title of page.
|
||||
comment: Revision comment.
|
||||
src_token: List of tokens in before-editing version
|
||||
src_action: Action flags for each token in before-editing version. E.g., 0 represents no action; -1 represents removed token.
|
||||
tgt_token: List of tokens in after-editing version
|
||||
tgt_action: Action flags for each token in after-editing version. E.g., 0 represents no action; 1 represents added token.
|
||||
neg_cmnts: Negative samples of user comments in the same page.
|
||||
pos_edits: Edit sentences for comments.
|
||||
neg_edits: Negative edit sentences for comments.
|
||||
'''
|
||||
rev_id = article["revision_id"]
|
||||
parent_id = article["parent_id"]
|
||||
timestamp = article["timestamp"]
|
||||
diff_url = article["diff_url"]
|
||||
page_title = article["page_title"]
|
||||
|
||||
# comment = word_tokenize(article["comment"])
|
||||
# neg_comments = [word_tokenize(cmnt) for cmnt in article['neg_comments']]
|
||||
# lookup comment and neg_comments from dictionary
|
||||
comment, neg_cmnts = cmnt_dict[rev_id]
|
||||
|
||||
src_text = article["src_text"]
|
||||
tgt_text = article["tgt_text"]
|
||||
|
||||
src_sents = article["src_sents"]
|
||||
src_tokens = article["src_tokens"]
|
||||
|
||||
tgt_sents = article["tgt_sents"]
|
||||
tgt_tokens = article["tgt_tokens"]
|
||||
|
||||
src_token_diff = article["src_token_diff"]
|
||||
tgt_token_diff = article["tgt_token_diff"]
|
||||
|
||||
# src_sents, src_tokens = tokenizeText(src_text)
|
||||
# tgt_sents, tgt_tokens = tokenizeText(tgt_text)
|
||||
|
||||
# extract the offset of the changed tokens in both src and tgt
|
||||
# src_token_diff, tgt_token_diff = diffRevision(src_tokens, tgt_tokens)
|
||||
|
||||
src_ctx_tokens, src_action = extContext(src_tokens, src_token_diff, ctx_window)
|
||||
tgt_ctx_tokens, tgt_action = extContext(tgt_tokens, tgt_token_diff, ctx_window)
|
||||
|
||||
# src_sent_diff = findSentDiff(src_sents, src_tokens, src_token_diff)
|
||||
tgt_sent_diff = findSentDiff(tgt_sents, tgt_tokens, tgt_token_diff)
|
||||
|
||||
# generate the positive edits
|
||||
pos_edits = [tgt_sents[i] for i in tgt_sent_diff]
|
||||
|
||||
# generate negative edits
|
||||
neg_edits_idx = [i for i in range(len(tgt_sents)) if i not in tgt_sent_diff]
|
||||
if negative_edit_num > len(neg_edits_idx):
|
||||
sampled_neg_edits_idx = neg_edits_idx
|
||||
else:
|
||||
sampled_neg_edits_idx = random.sample(neg_edits_idx, negative_edit_num)
|
||||
neg_edits = [tgt_sents[i] for i in sampled_neg_edits_idx]
|
||||
|
||||
if (len(src_token_diff) > 0 or len(tgt_token_diff) > 0):
|
||||
json_dict = {"revision_id": rev_id, "parent_id": parent_id, "timestamp": timestamp, \
|
||||
"diff_url": diff_url, "page_title": page_title, \
|
||||
"comment": comment, "src_token": src_ctx_tokens, "src_action": src_action, \
|
||||
"tgt_token": tgt_ctx_tokens, "tgt_action": tgt_action, \
|
||||
"neg_cmnts": neg_cmnts, "neg_edits": neg_edits, "pos_edits": pos_edits
|
||||
}
|
||||
|
||||
json_str = json.dumps(json_dict,
|
||||
indent=None, sort_keys=False,
|
||||
separators=(',', ': '), ensure_ascii=False)
|
||||
json_file.write(json_str + '\n')
|
||||
|
||||
|
||||
def main():
|
||||
root_path = "../dataset/raw/"
|
||||
|
||||
ctx_window = int(sys.argv[1])
|
||||
raw_file = root_path + "enwiki-sample_output_raw.json"
|
||||
ctx10_file = root_path + "wikicmnt_ctx10.json"
|
||||
output_file = root_path + "wikicmnt_ctx" + str(ctx_window) + ".json"
|
||||
|
||||
extract_json(raw_file, ctx10_file, output_file, ctx_window=ctx_window)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_time = datetime.now()
|
||||
main()
|
||||
time_elapsed = datetime.now() - start_time
|
|
@ -0,0 +1,155 @@
|
|||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
sys.path.append('../')
|
||||
from wiki_util import *
|
||||
from wikicmnt_extractor_st import randSampleRev
|
||||
|
||||
# arguments
|
||||
parser = argparse.ArgumentParser(description='Wiki Extractor')
|
||||
parser.add_argument('--data_path', type=str, default="../data/raw/", help='the data directory')
|
||||
parser.add_argument('--output_path', type=str, default="../data/processed/", help='the sample output path')
|
||||
parser.add_argument('--min_page_tokens', type=int, default=50,
|
||||
help='the minimum size of tokens in page to extract [default: 100]')
|
||||
parser.add_argument('--max_page_tokens', type=int, default=2000,
|
||||
help='the maximum size of tokens in page to extract [default: 1000]')
|
||||
parser.add_argument('--min_cmnt_length', type=int, default=8,
|
||||
help='the minimum words contained in the comments [default: 8]')
|
||||
parser.add_argument('--ctx_window', type=int, default=5, help='the window size of context [default: 5]')
|
||||
parser.add_argument('--sample_ratio', type=float, default=0.01, help='the ratio of sampling [default: 0.001]')
|
||||
parser.add_argument('--threads', type=int, default=3, help='the number of sampling threads [default: 5]')
|
||||
parser.add_argument('--single_thread', type=int, default=0,
|
||||
help='the dump file index when using single thread mode. If the index equals to zero, '
|
||||
'use multi-thread mode to proprocess all the dump files [default: 0]')
|
||||
parser.add_argument('--user_stat', type=bool, default=False, help='whether to do user statistics')
|
||||
parser.add_argument('--merge_only', action='store_true', default=False, help='merge the results only')
|
||||
parser.add_argument('--neg_cmnt_num', type=int, default=10,
|
||||
help='how many negative comments sampled for ranking problem [default: 10]')
|
||||
parser.add_argument('--count_revision_only', type=bool, default=False,
|
||||
help='count the revision only without sampling anything [default: False]')
|
||||
args = parser.parse_args()
|
||||
|
||||
# create sample output folder if it doesn't exist
|
||||
if not os.path.exists(args.output_path):
|
||||
os.makedirs(args.output_path)
|
||||
|
||||
# logging configuration
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='(%(threadName)s) %(message)s',
|
||||
)
|
||||
|
||||
'''
|
||||
WikiSampleTask class contains a list of dump files to be sampled.
|
||||
The assign_task function will be called by workers to grab a task.
|
||||
'''
|
||||
|
||||
|
||||
class WikiSampleTask(object):
|
||||
def __init__(self, dump_list):
|
||||
self.lock = threading.Lock()
|
||||
self.dump_list = dump_list
|
||||
self.total_num = len(dump_list)
|
||||
|
||||
def assign_task(self):
|
||||
logging.debug('Assign tasks ... Waiting for lock')
|
||||
self.lock.acquire()
|
||||
dump_name = None
|
||||
cur_progress = None
|
||||
try:
|
||||
# logging.debug('Acquired lock')
|
||||
if len(self.dump_list) > 0:
|
||||
dump_name = self.dump_list.pop(0)
|
||||
cur_progress = self.total_num - len(self.dump_list)
|
||||
finally:
|
||||
self.lock.release()
|
||||
return dump_name, cur_progress, self.total_num
|
||||
|
||||
|
||||
'''
|
||||
worker is main function for each thread.
|
||||
'''
|
||||
|
||||
|
||||
def worker(work_id, tasks):
|
||||
logging.debug('Starting.')
|
||||
output_file = args.data_path + 'sample/enwiki_sample_' + str(work_id) + '.json'
|
||||
# grab one task from task_list
|
||||
while 1:
|
||||
dump_file, cur_progress, total_num = tasks.assign_task()
|
||||
if not dump_file:
|
||||
break
|
||||
logging.debug('Assigned task (' + str(cur_progress) + '/' + str(total_num) + '): ' + str(dump_file))
|
||||
# start to sample the dump file
|
||||
output_file = args.output_path + 'enwiki-sample-' + os.path.basename(dump_file)[27:-4] + '.json'
|
||||
randSampleRev(work_id, dump_file, output_file, args.sample_ratio, args.min_cmnt_length, args.ctx_window,
|
||||
args.neg_cmnt_num)
|
||||
logging.debug('Exiting.')
|
||||
|
||||
|
||||
def initLogger(file_idx):
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG) # or whatever
|
||||
|
||||
handler = logging.FileHandler('extractor.txt', 'a', 'utf-8') # or whatever
|
||||
# handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter = logging.Formatter('%(asctime)s' + '[' + str(file_idx) + '] - %(message)s') # or whatever
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
||||
|
||||
def main():
|
||||
# single file mode
|
||||
if args.single_thread:
|
||||
|
||||
logger = initLogger(args.single_thread)
|
||||
|
||||
data_path = args.data_path
|
||||
output_path = args.output_path
|
||||
|
||||
dump_list = sorted(glob.glob(data_path + "*.bz2"))
|
||||
print(dump_list)
|
||||
dump_file = dump_list[args.single_thread - 1]
|
||||
output_file = output_path + 'enwiki-sample-' + os.path.basename(dump_file)[27:-4] + '.json'
|
||||
|
||||
# create sample output folder if it doesn't exist
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
print("start to preprocess dump file:", str(dump_file))
|
||||
logger.info("[" + str(args.single_thread) + "] Start to sample dump file " + dump_file)
|
||||
randSampleRev(args.single_thread, dump_file, output_file, args.sample_ratio, args.min_cmnt_length,
|
||||
args.ctx_window, args.neg_cmnt_num)
|
||||
return
|
||||
|
||||
if not args.merge_only:
|
||||
dump_list = glob.glob(args.data_path + "*.bz2")
|
||||
# # testing
|
||||
# dump_list = dump_list[:5]
|
||||
dump_num = len(dump_list)
|
||||
logging.debug("Samping revisions from " + str(dump_num) + " dump files")
|
||||
|
||||
task = WikiSampleTask(dump_list)
|
||||
threads = []
|
||||
for i in range(args.threads):
|
||||
t = threading.Thread(target=worker, args=(i, task))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
logging.debug('Waiting for worker threads')
|
||||
main_thread = threading.currentThread()
|
||||
for t in threading.enumerate():
|
||||
if t is not main_thread:
|
||||
t.join()
|
||||
|
||||
logging.debug('Merging the sample outputs from each dump file')
|
||||
|
||||
# merge the result
|
||||
mergeOutputs(args.output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_time = datetime.now()
|
||||
main()
|
||||
time_elapsed = datetime.now() - start_time
|
||||
logging.debug('Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed))
|
|
@ -0,0 +1,182 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append('../')
|
||||
from wiki_util import *
|
||||
|
||||
|
||||
def printSample(task_id, sample_count, revision_count, page_title, sect_title, comment, diff_url, parent_tokens,
|
||||
target_tokens, origin_diff, target_diff, delimitor='^'):
|
||||
# print the sampled revision in excel format
|
||||
revision_info = '[' + str(len(parent_tokens)) + '|' + str(len(target_tokens)) + ']'
|
||||
origin_diff_tokens = '[' + (
|
||||
str(origin_diff[0]) if len(origin_diff) == 1 else ','.join([str(i) for i in origin_diff])) + ']'
|
||||
target_diff_tokens = '[' + (
|
||||
str(target_diff[0]) if len(target_diff) == 1 else ','.join([str(i) for i in target_diff])) + ']'
|
||||
# print(sample_count, '/', revision_count, delimitor, page_title, delimitor, sect_title, delimitor, comment, delimitor,
|
||||
# diff_url, delimitor, revision_info, delimitor, origin_diff_tokens, delimitor, target_diff_tokens, sep='')
|
||||
logging.info("[" + str(task_id) + "] " + str(sample_count) + '/' + str(
|
||||
revision_count) + delimitor + page_title + delimitor + sect_title +
|
||||
delimitor + comment + delimitor + diff_url + delimitor + revision_info + delimitor + origin_diff_tokens + delimitor + target_diff_tokens)
|
||||
|
||||
|
||||
def randSampleRev(task_id, dump_file, output_file, sample_ratio, min_cmnt_length, ctx_window, negative_cmnt_num=10,
|
||||
negative_edit_num=10, count_revision_only=False, MIN_COMMENT_SIZE=20):
|
||||
# if os.path.exists(output_file):
|
||||
# print("Output file already exists. Please remove first so I don't destroy your stuff please")
|
||||
# sys.exit(1)
|
||||
|
||||
json_file = open(output_file, "w", buffering=1, encoding='utf-8')
|
||||
|
||||
start_time = datetime.now()
|
||||
wiki_file = bz2.open(dump_file, "rt", encoding='utf-8')
|
||||
# template = '{"revision_id": %s, "diff_url": %s, "parent_id": %s, "timestamp": %s, "page_title": %s, "user_name": %s, \
|
||||
# "user_id": %s, "user_ip": %s, "comment": %s, "parent_text": %s, "text": %s, "diff_tokens": %s}\n'
|
||||
sample_count = 0
|
||||
revision_count = 0
|
||||
|
||||
# time_elapsed = datetime.now() - start_time
|
||||
# print("=== ", i, "/", file_num, ' === ', revision_count, " revisions extracted.", ' Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed), sep='')
|
||||
|
||||
sample_parent_id = None
|
||||
sample_parent_text = None
|
||||
page_comment_list = []
|
||||
prev_page_title = ''
|
||||
|
||||
try:
|
||||
for page_title, revision in split_records(wiki_file):
|
||||
|
||||
revision_count += 1
|
||||
if count_revision_only:
|
||||
if revision_count % 1000 == 0:
|
||||
print("= revision", revision_count, "=")
|
||||
continue
|
||||
|
||||
# fields
|
||||
rev_id, parent_id, timestamp, username, userid, userip, comment, text = extract_data(revision)
|
||||
|
||||
# if the length of comment is less than MIN_COMMENT SIZE (default 20), skip the revision directly.
|
||||
if len(comment) < MIN_COMMENT_SIZE:
|
||||
continue
|
||||
|
||||
# clean the comment text
|
||||
comment = cleanCmntText(comment)
|
||||
# extract the section title and the comment without section info
|
||||
sect_title, comment = extractSectionTitle(comment)
|
||||
|
||||
# skip the revision if no section title included or the length is too short.
|
||||
if not sect_title or len(comment) < MIN_COMMENT_SIZE:
|
||||
continue
|
||||
|
||||
# store the comments
|
||||
if prev_page_title != page_title:
|
||||
prev_page_title = page_title
|
||||
page_comment_list.clear()
|
||||
|
||||
_, comment_tokens = tokenizeText(comment)
|
||||
if checkComment(comment, comment_tokens, min_cmnt_length):
|
||||
# page_comment_list.append(comment)
|
||||
page_comment_list.append(comment)
|
||||
|
||||
# write the sampled revision to json file
|
||||
# Skip the line if it is not satisfied with some criteria
|
||||
if sample_parent_id == parent_id:
|
||||
# do sample
|
||||
diff_url = 'https://en.wikipedia.org/w/index.php?title=' + page_title.replace(" ",
|
||||
'%20') + '&type=revision&diff=' + rev_id + '&oldid=' + parent_id
|
||||
# print(diff_url)
|
||||
|
||||
# check whether the comment is appropriate by some criteria
|
||||
# check_comment(comment, length_only=True)
|
||||
try:
|
||||
src_text = extractSectionText(sample_parent_text, sect_title)
|
||||
tgt_text = extractSectionText(text, sect_title)
|
||||
except:
|
||||
print("ERROR-RegularExpression:", sample_parent_text, text, " Skip!!")
|
||||
# skip the revision if any exception happens
|
||||
continue
|
||||
|
||||
# clean the wiki text
|
||||
src_text = cleanWikiText(src_text)
|
||||
tgt_text = cleanWikiText(tgt_text)
|
||||
|
||||
if (src_text and tgt_text) and (len(src_text) < 1000000 and len(tgt_text) < 1000000):
|
||||
|
||||
# tokenization
|
||||
src_sents, src_tokens = tokenizeText(src_text)
|
||||
tgt_sents, tgt_tokens = tokenizeText(tgt_text)
|
||||
|
||||
# extract the offset of the changed tokens in both src and tgt
|
||||
src_token_diff, tgt_token_diff = diffRevision(src_tokens, tgt_tokens)
|
||||
|
||||
if len(src_token_diff) == 0 and len(tgt_token_diff) == 0:
|
||||
continue
|
||||
|
||||
if (len(src_token_diff) > 0 and src_token_diff[0] < 0) or (
|
||||
len(tgt_token_diff) > 0 and tgt_token_diff[0] < 0):
|
||||
continue
|
||||
|
||||
if src_sents == None or tgt_sents == None:
|
||||
continue
|
||||
|
||||
src_ctx_tokens, src_action = extContext(src_tokens, src_token_diff, ctx_window)
|
||||
tgt_ctx_tokens, tgt_action = extContext(tgt_tokens, tgt_token_diff, ctx_window)
|
||||
|
||||
# src_sent_diff = findSentDiff(src_sents, src_tokens, src_token_diff)
|
||||
tgt_sent_diff = findSentDiff(tgt_sents, tgt_tokens, tgt_token_diff)
|
||||
|
||||
# randomly sample the negative comments
|
||||
if negative_cmnt_num > len(page_comment_list) - 1:
|
||||
neg_cmnt_idx = range(len(page_comment_list) - 1)
|
||||
else:
|
||||
neg_cmnt_idx = random.sample(range(len(page_comment_list) - 1), negative_cmnt_num)
|
||||
neg_comments = [page_comment_list[i] for i in neg_cmnt_idx]
|
||||
|
||||
# generate the positive edits
|
||||
pos_edits = [tgt_sents[i] for i in tgt_sent_diff]
|
||||
|
||||
# generate negative edits
|
||||
neg_edits_idx = [i for i in range(len(tgt_sents)) if i not in tgt_sent_diff]
|
||||
if negative_edit_num > len(neg_edits_idx):
|
||||
sampled_neg_edits_idx = neg_edits_idx
|
||||
else:
|
||||
sampled_neg_edits_idx = random.sample(neg_edits_idx, negative_edit_num)
|
||||
neg_edits = [tgt_sents[i] for i in sampled_neg_edits_idx]
|
||||
|
||||
if (len(src_token_diff) > 0 or len(tgt_token_diff) > 0):
|
||||
json_dict = {"revision_id": rev_id, "parent_id": parent_id, "timestamp": timestamp, \
|
||||
"diff_url": diff_url, "page_title": page_title, \
|
||||
"comment": comment, "src_token": src_ctx_tokens, "src_action": src_action, \
|
||||
"tgt_token": tgt_ctx_tokens, "tgt_action": tgt_action, \
|
||||
"neg_cmnts": neg_comments, "neg_edits": neg_edits, "pos_edits": pos_edits
|
||||
}
|
||||
|
||||
json_str = json.dumps(json_dict,
|
||||
indent=None, sort_keys=False,
|
||||
separators=(',', ': '), ensure_ascii=False)
|
||||
json_file.write(json_str + '\n')
|
||||
sample_count += 1
|
||||
|
||||
printSample(task_id, sample_count, revision_count, page_title, sect_title, comment, diff_url,
|
||||
src_tokens, tgt_tokens, src_token_diff, tgt_token_diff)
|
||||
|
||||
# filterRevision(comment, diff_list):
|
||||
# if sample_parent_id != parent_id:
|
||||
# logging.debug("ALERT: Parent Id Missing" + str(sample_parent_id) + "/" + str(parent_id) + "--- DATA MIGHT BE CORRUPTED!!")
|
||||
|
||||
# decide to sample next
|
||||
if sampleNext(sample_ratio):
|
||||
sample_parent_id = rev_id
|
||||
sample_parent_text = text
|
||||
else:
|
||||
sample_parent_id = None
|
||||
sample_parent_text = None
|
||||
|
||||
# if revision_count % 1000 == 0:
|
||||
# print("Finished ", str(revision_count))
|
||||
finally:
|
||||
time_elapsed = datetime.now() - start_time
|
||||
logging.debug("=== " + str(sample_count) + " revisions sampled in total " + str(revision_count) + " revisions. " \
|
||||
+ 'Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed) + ' ===')
|
||||
json_file.close()
|
|
@ -0,0 +1,120 @@
|
|||
#!/usr/bin/env python
|
||||
import json
|
||||
import sys
|
||||
sys.path.append('../')
|
||||
from wiki_util import *
|
||||
from process_data import word_tokenize
|
||||
|
||||
def extract_json(input_file, output_file, ctx_window=10, negative_edit_num=10):
|
||||
|
||||
json_file = open(output_file, "w", buffering=1, encoding='utf-8')
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for idx, json_line in enumerate(tqdm(f)):
|
||||
|
||||
article = json.loads(json_line.strip('\n'))
|
||||
|
||||
'''
|
||||
Json file format:
|
||||
=================
|
||||
revision_id: The revision ID
|
||||
parent_id: The parent revision ID
|
||||
timestamp: Timestamp
|
||||
diff_url: The wikipedia link to show the difference between previous and current version.
|
||||
page_title: The title of page.
|
||||
comment: Revision comment.
|
||||
src_token: List of tokens in before-editing version
|
||||
src_action: Action flags for each token in before-editing version. E.g., 0 represents no action; -1 represents removed token.
|
||||
tgt_token: List of tokens in after-editing version
|
||||
tgt_action: Action flags for each token in after-editing version. E.g., 0 represents no action; 1 represents added token.
|
||||
neg_cmnts: Negative samples of user comments in the same page.
|
||||
pos_edits: Edit sentences for comments.
|
||||
neg_edits: Negative edit sentences for comments.
|
||||
'''
|
||||
rev_id = article["revision_id"]
|
||||
parent_id = article["parent_id"]
|
||||
timestamp = article["timestamp"]
|
||||
diff_url = article["diff_url"]
|
||||
page_title = article["page_title"]
|
||||
|
||||
# # tokenize comment
|
||||
# cmnt_tokens = word_tokenize(article['comment'])
|
||||
# cmnt_list.append(cmnt_tokens)
|
||||
#
|
||||
# # negative comments
|
||||
# neg_cmnts = [word_tokenize(cmnt) for cmnt in article['neg_comments']]
|
||||
# neg_cmnts_list.append(neg_cmnts)
|
||||
|
||||
comment = word_tokenize(article["comment"])
|
||||
neg_comments = [word_tokenize(cmnt) for cmnt in article['neg_comments']]
|
||||
|
||||
src_text = article["src_text"]
|
||||
tgt_text = article["tgt_text"]
|
||||
|
||||
src_sents = article["src_sents"]
|
||||
src_tokens = article["src_tokens"]
|
||||
|
||||
tgt_sents = article["tgt_sents"]
|
||||
tgt_tokens = article["tgt_tokens"]
|
||||
|
||||
src_token_diff = article["src_token_diff"]
|
||||
tgt_token_diff = article["tgt_token_diff"]
|
||||
|
||||
#src_sents, src_tokens = tokenizeText(src_text)
|
||||
#tgt_sents, tgt_tokens = tokenizeText(tgt_text)
|
||||
|
||||
# extract the offset of the changed tokens in both src and tgt
|
||||
#src_token_diff, tgt_token_diff = diffRevision(src_tokens, tgt_tokens)
|
||||
|
||||
src_ctx_tokens, src_action = extContext(src_tokens, src_token_diff, ctx_window)
|
||||
tgt_ctx_tokens, tgt_action = extContext(tgt_tokens, tgt_token_diff, ctx_window)
|
||||
|
||||
# src_sent_diff = findSentDiff(src_sents, src_tokens, src_token_diff)
|
||||
tgt_sent_diff = findSentDiff(tgt_sents, tgt_tokens, tgt_token_diff)
|
||||
|
||||
# generate the positive edits
|
||||
pos_edits = [tgt_sents[i] for i in tgt_sent_diff]
|
||||
|
||||
# generate negative edits
|
||||
neg_edits_idx = [i for i in range(len(tgt_sents)) if i not in tgt_sent_diff]
|
||||
if negative_edit_num > len(neg_edits_idx):
|
||||
sampled_neg_edits_idx = neg_edits_idx
|
||||
else:
|
||||
sampled_neg_edits_idx = random.sample(neg_edits_idx, negative_edit_num)
|
||||
neg_edits = [tgt_sents[i] for i in sampled_neg_edits_idx]
|
||||
|
||||
if (len(src_token_diff) > 0 or len(tgt_token_diff) > 0):
|
||||
json_dict = {"revision_id": rev_id, "parent_id": parent_id, "timestamp": timestamp, \
|
||||
"diff_url": diff_url, "page_title": page_title, \
|
||||
"comment": comment, "src_token": src_ctx_tokens, "src_action": src_action, \
|
||||
"tgt_token": tgt_ctx_tokens, "tgt_action": tgt_action, \
|
||||
"neg_cmnts": neg_comments, "neg_edits": neg_edits, "pos_edits": pos_edits
|
||||
}
|
||||
|
||||
json_str = json.dumps(json_dict,
|
||||
indent=None, sort_keys=False,
|
||||
separators=(',', ': '), ensure_ascii=False)
|
||||
json_file.write(json_str + '\n')
|
||||
|
||||
def main():
|
||||
|
||||
root_path = "../dataset/raw/"
|
||||
data_path = root_path + "split_files/"
|
||||
output_path = root_path + "output/"
|
||||
|
||||
file_idx = int(sys.argv[1])
|
||||
|
||||
dump_list = sorted(glob.glob(data_path + "*.json"))
|
||||
dump_file = dump_list[file_idx - 1]
|
||||
file_name = os.path.basename(dump_file)
|
||||
output_file = output_path + file_name[:-5] + '_output.json'
|
||||
|
||||
# original way
|
||||
# input_file = root_path + "enwiki-sample_output_raw.json"
|
||||
# output_file = root_path + "wikicmnt.json"
|
||||
|
||||
extract_json(dump_file, output_file, ctx_window=10)
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_time = datetime.now()
|
||||
main()
|
||||
time_elapsed = datetime.now() - start_time
|
|
@ -0,0 +1,75 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import bz2
|
||||
import glob
|
||||
from datetime import datetime
|
||||
import random
|
||||
import difflib
|
||||
import nltk
|
||||
import re
|
||||
import threading
|
||||
import logging
|
||||
import operator
|
||||
import html
|
||||
import spacy
|
||||
import en_core_web_sm
|
||||
import json
|
||||
from wiki_util import *
|
||||
import random
|
||||
from random import shuffle
|
||||
|
||||
def cal_total_line(input_file):
|
||||
total_line = 0
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for json_line in tqdm(f):
|
||||
total_line += 1
|
||||
return total_line
|
||||
|
||||
def rand_sample(input_file, output_file, sample_num, total_num):
|
||||
sample_indices = random.sample(range(1, total_num), sample_num)
|
||||
sample_file = open(output_file, "w", encoding='utf-8')
|
||||
sample_list = []
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for idx, json_line in enumerate(tqdm(f)):
|
||||
# if idx in sample_indices:
|
||||
# sample_list.append(json_line)
|
||||
|
||||
j = json.loads(json_line)
|
||||
tgt_sent_len = len(j['tgt_sents'])
|
||||
tgt_diff_sent_size = len(j['tgt_sent_diff'])
|
||||
if tgt_sent_len - tgt_diff_sent_size >= 10:
|
||||
sample_list.append(json_line)
|
||||
|
||||
# sample_list.append(json_line)
|
||||
# if idx == sample_num - 1:
|
||||
# break
|
||||
|
||||
print("Start to shuffle sample list ...")
|
||||
shuffle(sample_list)
|
||||
|
||||
print("Start to write output ...")
|
||||
for line in tqdm(sample_list):
|
||||
sample_file.write(line)
|
||||
|
||||
def main():
|
||||
|
||||
data_folder = "../data/"
|
||||
|
||||
#input_file = data_folder + "wikicmnt.json"
|
||||
input_file = data_folder + "wiki_comment_orig.json"
|
||||
output_file = data_folder + "wiki_comment.json"
|
||||
sample_num = 260000
|
||||
#sample_num = 500000
|
||||
#total_num = cal_total_line(input_file)
|
||||
total_num = 786886
|
||||
print("total line:", total_num)
|
||||
rand_sample(input_file, output_file, sample_num, total_num)
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_time = datetime.now()
|
||||
main()
|
||||
time_elapsed = datetime.now() - start_time
|
||||
logging.debug('Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed))
|
|
@ -0,0 +1,335 @@
|
|||
import math
|
||||
from statistics import mean
|
||||
|
||||
import sklearn
|
||||
import sklearn.metrics
|
||||
import torch
|
||||
|
||||
from process_data import _make_char_vector, _make_word_vector, gen_cmntrank_batches, gen_editanch_batches
|
||||
from process_data import to_var, make_vector
|
||||
from wiki_util import tokenizeText
|
||||
|
||||
|
||||
## general evaluation
|
||||
def predict(pred_cmnt, pred_ctx, w2i, c2i, model, max_ctx_length):
|
||||
print("prediction on single sample ...")
|
||||
model.eval()
|
||||
|
||||
_, pred_cmnt_words = tokenizeText(pred_cmnt)
|
||||
pred_cmnt_chars = [list(i) for i in pred_cmnt_words]
|
||||
_, pred_ctx_words = tokenizeText(pred_ctx)
|
||||
pred_ctx_chars = [list(i) for i in pred_ctx_words]
|
||||
|
||||
cmnt_sent_len = len(pred_cmnt_words)
|
||||
cmnt_word_len = int(mean([len(w) for w in pred_cmnt_chars]))
|
||||
ctx_sent_len = max_ctx_length
|
||||
ctx_word_len = int(mean([len(w) for w in pred_ctx_chars]))
|
||||
|
||||
cmnt_words, cmnt_chars, ctx_words, ctx_chars = [], [], [], []
|
||||
# c, cc, q, cq, a in batch
|
||||
|
||||
cmnt_words.append(_make_word_vector(pred_cmnt_words, w2i, cmnt_sent_len))
|
||||
cmnt_chars.append(_make_char_vector(pred_cmnt_chars, c2i, cmnt_sent_len, cmnt_word_len))
|
||||
ctx_words.append(_make_word_vector(pred_ctx_words, w2i, ctx_sent_len))
|
||||
ctx_chars.append(_make_char_vector(pred_ctx_chars, c2i, ctx_sent_len, ctx_word_len))
|
||||
|
||||
cmnt_words = to_var(torch.LongTensor(cmnt_words))
|
||||
cmnt_chars = to_var(torch.stack(cmnt_chars, 0))
|
||||
ctx_words = to_var(torch.LongTensor(ctx_words))
|
||||
ctx_chars = to_var(torch.stack(ctx_chars, 0))
|
||||
|
||||
logit, _ = model(ctx_words, ctx_chars, cmnt_words, cmnt_chars)
|
||||
a = torch.max(logit.cpu(), -1)
|
||||
print(logit)
|
||||
print(a)
|
||||
y_pred = a[1].data[0]
|
||||
y_prob = a[0].data[0]
|
||||
y_pred_2 = [int(i) for i in (torch.max(logit, -1)[1].view(1).data).tolist()][0]
|
||||
print(y_pred_2)
|
||||
# y_pred = y_pred[0]
|
||||
print(y_pred, y_prob)
|
||||
return y_pred
|
||||
|
||||
|
||||
def compute_rank_score(pos_score, neg_scores):
|
||||
p1, p3, p5 = 0, 0, 0
|
||||
|
||||
pos_list = [0 if pos_score > neg_score else 1 for neg_score in neg_scores]
|
||||
pos = sum(pos_list)
|
||||
# precision @K
|
||||
if pos == 0: p1 = 1
|
||||
if pos < 3: p3 = 1
|
||||
if pos < 5: p5 = 1
|
||||
# MRR
|
||||
mrr = 1 / (pos + 1)
|
||||
# NDCG: DCG/IDCG (In our case, we set the rel=1 if relevent, otherwise rel=0; Then IDCG=1)
|
||||
ndcg = 1 / math.log2(pos + 2)
|
||||
|
||||
return p1, p3, p5, mrr, ndcg
|
||||
|
||||
|
||||
def get_rank(pos_score, neg_scores):
|
||||
pos_list = [0 if pos_score > neg_score else 1 for neg_score in neg_scores]
|
||||
pos = sum(pos_list)
|
||||
return pos + 1
|
||||
|
||||
|
||||
def isEditPredCorrect(pred, truth):
|
||||
for i in len(pred):
|
||||
if pred[i] != truth[i]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def eval_rank(score_pos, score_neg, cand_num):
|
||||
score_pos_list = score_pos.data.cpu().squeeze(1).numpy().tolist()
|
||||
score_neg_list = score_neg.data.cpu().squeeze(1).numpy().tolist()
|
||||
|
||||
correct_p1, correct_p3, correct_p5, total_mrr, total_ndcg = 0, 0, 0, 0, 0
|
||||
neg_num = cand_num - 1
|
||||
batch_num = int(len(score_neg) / neg_num)
|
||||
rank_list = []
|
||||
for i in range(batch_num):
|
||||
score_pos_i = score_pos_list[i * neg_num: (i + 1) * neg_num]
|
||||
score_neg_i = score_neg_list[i * neg_num: (i + 1) * neg_num]
|
||||
|
||||
p1, p3, p5, mrr, ndcg = compute_rank_score(score_pos_i[0], score_neg_i)
|
||||
rank = get_rank(score_pos_i[0], score_neg_i)
|
||||
rank_list.append(rank)
|
||||
correct_p1 += p1
|
||||
correct_p3 += p3
|
||||
correct_p5 += p5
|
||||
total_mrr += mrr
|
||||
total_ndcg += ndcg
|
||||
|
||||
return correct_p1, correct_p3, correct_p5, total_mrr, total_ndcg, rank_list
|
||||
|
||||
|
||||
# def eval_rank_orig(score_pos, score_neg, batch_size):
|
||||
# score_pos_list = score_pos.data.cpu().squeeze(1).numpy().tolist()
|
||||
# score_neg_list = score_neg.data.cpu().squeeze(1).numpy().tolist()
|
||||
#
|
||||
# total_p1, total_p3, total_p5, total_mrr, total_ndcg = 0, 0, 0, 0, 0
|
||||
# sample_num = int(len(score_neg) / batch_size)
|
||||
# for i in range(batch_size):
|
||||
# score_pos_i = score_pos_list[i * sample_num: (i+1) * sample_num]
|
||||
# score_neg_i = score_neg_list[i * sample_num: (i+1) * sample_num]
|
||||
# pos_list = [0 if score_pos_i[i] >= score_neg_i[i] else 1 for i in range(sample_num)]
|
||||
# pos = sum(pos_list)
|
||||
# sorted_neg = ["%.4f" % i for i in sorted(score_neg_i, reverse=True)]
|
||||
# #print(pos, "%.4f" % score_pos_i[0], "\t".join(sorted_neg), sep='\t')
|
||||
# if pos == 0:
|
||||
# total_p1 += 1
|
||||
#
|
||||
# if pos < 3:
|
||||
# total_p3 += 1
|
||||
#
|
||||
# if pos < 5:
|
||||
# total_p5 += 1
|
||||
#
|
||||
# # MRR
|
||||
# total_mrr += 1 / (pos + 1)
|
||||
#
|
||||
# # NDCG: DCG/IDCG (In our case, we set the rel=1 if relevent, otherwise rel=0; Then IDCG=1)
|
||||
# total_ndcg += 1 / math.log2(pos + 2)
|
||||
#
|
||||
# return total_p1, total_p3, total_p5, total_mrr, total_ndcg
|
||||
|
||||
## general evaluation
|
||||
def eval(dataset, val_df, w2i, model, args):
|
||||
# print(" evaluation on", val_df.shape[0], " samples ...")
|
||||
model.eval()
|
||||
corrects, avg_loss = 0, 0
|
||||
|
||||
cmnt_rank_p1, cmnt_rank_p3, cmnt_rank_p5, cmnt_rank_mrr, cmnt_rank_ndcg = 0, 0, 0, 0, 0
|
||||
ea_pred, ea_truth = [], []
|
||||
cr_total, ea_total = 0, 0
|
||||
pred_cr_list, pred_ea_list = [], []
|
||||
for batch in dataset.iterate_minibatches(val_df, args.batch_size):
|
||||
cmnt_sent_len = args.max_cmnt_length
|
||||
ctx_sent_len = args.max_ctx_length
|
||||
diff_sent_len = args.max_diff_length
|
||||
|
||||
###########################################################
|
||||
# Comment Ranking Task
|
||||
###########################################################
|
||||
# generate positive and negative batches
|
||||
pos_batch, neg_batch = gen_cmntrank_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
|
||||
args.rank_num)
|
||||
|
||||
pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action = \
|
||||
make_vector(pos_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
|
||||
make_vector(neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
score_pos, _ = model(pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action, cr_mode=True)
|
||||
score_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action, cr_mode=True)
|
||||
|
||||
cr_p1_corr, cr_p3_corr, cr_p5_corr, cr_mrr, cr_ndcg, pred_rank = eval_rank(score_pos, score_neg, args.rank_num)
|
||||
cmnt_rank_p1 += cr_p1_corr
|
||||
cmnt_rank_p3 += cr_p3_corr
|
||||
cmnt_rank_p5 += cr_p5_corr
|
||||
cmnt_rank_mrr += cr_mrr
|
||||
cmnt_rank_ndcg += cr_ndcg
|
||||
cr_total += int(len(score_pos) / (args.rank_num - 1))
|
||||
pred_cr_list += pred_rank
|
||||
|
||||
###########################################################
|
||||
# Edits Anchoring
|
||||
###########################################################
|
||||
# generate positive and negative batches
|
||||
ea_batch, ea_truth_cur = gen_editanch_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
|
||||
args.anchor_num)
|
||||
if len(pos_batch[0]) > 0:
|
||||
cmnt, src_token, src_action, tgt_token, tgt_action = \
|
||||
make_vector(ea_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
# neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
|
||||
# make_vector(neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
logit, _ = model(cmnt, src_token, src_action, tgt_token, tgt_action, cr_mode=False)
|
||||
# logit_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action, cr_mode=False)
|
||||
|
||||
ea_pred_cur = (torch.max(logit, 1)[1].view(logit.size(0)).data).tolist()
|
||||
# ea_truth_cur = [1] * logit_pos.size(0) + [0] * logit_neg.size(0)
|
||||
|
||||
ea_pred += ea_pred_cur
|
||||
ea_truth += ea_truth_cur
|
||||
|
||||
ea_total += int(len(score_pos) / (args.anchor_num - 1))
|
||||
|
||||
# # output the prediction results
|
||||
# with open(args.checkpoint_path + 'test_out.txt', 'w') as f:
|
||||
# for i in range(len(y_truth)):
|
||||
# line = cmnt_readable_all[i] + '\t' + ctx_readable_all[i] + '\t' + str(y_pred[i]) + '\t' + str(y_truth[i])
|
||||
# f.write(line + '\n')
|
||||
|
||||
# if args.test:
|
||||
# print(total_rank)
|
||||
# print("\t".join([str(i) for i in pred_cr_list]))
|
||||
# print("\t".join([str(i) for i in ea_pred]))
|
||||
|
||||
cr_p1_acc = cmnt_rank_p1 / cr_total
|
||||
cr_p3_acc = cmnt_rank_p3 / cr_total
|
||||
cr_p5_acc = cmnt_rank_p5 / cr_total
|
||||
cr_mrr = cmnt_rank_mrr / cr_total
|
||||
cr_ndcg = cmnt_rank_ndcg / cr_total
|
||||
|
||||
ea_acc = (sklearn.metrics.accuracy_score(ea_truth, ea_pred))
|
||||
ea_f1 = (sklearn.metrics.f1_score(ea_truth, ea_pred, pos_label=1))
|
||||
ea_prec = (sklearn.metrics.precision_score(ea_truth, ea_pred, pos_label=1))
|
||||
ea_recall = (sklearn.metrics.recall_score(ea_truth, ea_pred, pos_label=1))
|
||||
|
||||
print("\n*** Validation Results *** ")
|
||||
# print("[Task-CR] P@1:", "%.3f" % cr_p1_acc, "% P@3:", "%.3f" % cr_p3_acc, "% P@5:", "%.3f" % cr_p5_acc,\
|
||||
# '%', ' (', cmnt_rank_p1, '/', cr_total, ',', cmnt_rank_p3, '/', cr_total, ',', cmnt_rank_p5, '/', cr_total,')', sep='')
|
||||
# print("[Task-RA] P@1:", "%.3f" % ea_p1_acc, "% P@3:", "%.3f" % ea_p3_acc, "% P@5:", "%.3f" % ea_p5_acc,\
|
||||
# '%', ' (', edit_anch_p1, '/', ea_total, ',', edit_anch_p3, '/', ea_total, ',', edit_anch_p5, '/', ea_total,')', sep='')
|
||||
print("[Task-CR] P@1:", "%.3f" % cr_p1_acc, " P@3:", "%.3f" % cr_p3_acc, " P@5:", "%.3f" % cr_p5_acc, " MRR:",
|
||||
"%.3f" % cr_mrr, " NDCG:", "%.3f" % cr_ndcg, sep='')
|
||||
print("[Task-EA] ACC:", "%.3f" % ea_acc, " F1:", "%.3f" % ea_f1, " Precision:", "%.3f" % ea_prec, " Recall:",
|
||||
"%.3f" % ea_recall, sep='')
|
||||
return cr_p1_acc, ea_f1
|
||||
|
||||
|
||||
def dump_cmntrank_case(pos_batch, neg_batch, idx, rank_num, diff_url, rank, pos_score, neg_scores):
|
||||
neg_num = rank_num - 1
|
||||
pos_cmnt = pos_batch[0][idx * neg_num]
|
||||
neg_cmnts = neg_batch[0][idx * neg_num: (idx + 1) * neg_num]
|
||||
|
||||
before_edit = pos_batch[1][idx * neg_num]
|
||||
after_edit = pos_batch[3][idx * neg_num]
|
||||
|
||||
match = False
|
||||
for token in pos_cmnt:
|
||||
if token in before_edit + after_edit:
|
||||
match = True
|
||||
break
|
||||
|
||||
neg_match_words = []
|
||||
neg_match = False
|
||||
for neg_cmnt in neg_cmnts:
|
||||
for token in neg_cmnt:
|
||||
if len(token) <= 3:
|
||||
continue
|
||||
if token in before_edit + after_edit:
|
||||
neg_match = True
|
||||
neg_match_words.append(token)
|
||||
|
||||
if not match and neg_match:
|
||||
print("\n ====== cmntrank case (Not Matched) ======")
|
||||
print("Rank", rank)
|
||||
print(diff_url)
|
||||
print("pos_cmnt (", "{0:.3f}".format(pos_score), "): ", " ".join(pos_cmnt), sep='')
|
||||
|
||||
for i, neg_cmnt in enumerate(neg_cmnts):
|
||||
print("neg_cmnt ", i, " (", "{0:.3f}".format(neg_scores[i]), "): ", " ".join(neg_cmnt), sep='')
|
||||
pass
|
||||
print("neg_match_words:", " ".join(neg_match_words))
|
||||
|
||||
|
||||
def dump_editanch_case(comment, edit, pred, truth):
|
||||
print("\n ====== editanch case ======")
|
||||
print("pred/truth: ", pred, "/", truth)
|
||||
print("comment:", " ".join(comment))
|
||||
print("edit:", " ".join(edit))
|
||||
|
||||
|
||||
def case_study(dataset, val_df, w2i, model, args):
|
||||
model.eval()
|
||||
print("Start the case study")
|
||||
# for batch in dataset.iterate_minibatches(val_df[:500], args.batch_size):
|
||||
for batch in dataset.iterate_minibatches(val_df, args.batch_size):
|
||||
cmnt_sent_len = args.max_cmnt_length
|
||||
ctx_sent_len = args.max_ctx_length
|
||||
diff_sent_len = args.max_diff_length
|
||||
|
||||
###########################################################
|
||||
# Comment Ranking Task
|
||||
###########################################################
|
||||
# generate positive and negative batches
|
||||
|
||||
if args.cr_train:
|
||||
pos_batch, neg_batch = gen_cmntrank_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
|
||||
args.rank_num)
|
||||
|
||||
pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action = \
|
||||
make_vector(pos_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
|
||||
make_vector(neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
score_pos, _ = model(pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action, cr_mode=True)
|
||||
score_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action, cr_mode=True)
|
||||
|
||||
score_pos_list = score_pos.data.cpu().squeeze(1).numpy().tolist()
|
||||
score_neg_list = score_neg.data.cpu().squeeze(1).numpy().tolist()
|
||||
|
||||
neg_num = args.rank_num - 1
|
||||
batch_num = int(len(score_neg) / neg_num)
|
||||
for i in range(batch_num):
|
||||
score_pos_i = score_pos_list[i * neg_num: (i + 1) * neg_num]
|
||||
score_neg_i = score_neg_list[i * neg_num: (i + 1) * neg_num]
|
||||
|
||||
rank = get_rank(score_pos_i[0], score_neg_i)
|
||||
dump_cmntrank_case(pos_batch, neg_batch, i, args.rank_num, batch[8][i], rank, score_pos_i[0],
|
||||
score_neg_i)
|
||||
|
||||
###########################################################
|
||||
# Edits Anchoring
|
||||
###########################################################
|
||||
# generate positive and negative batches
|
||||
|
||||
if args.ea_train:
|
||||
ea_batch, ea_truth_cur = gen_editanch_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
|
||||
args.anchor_num)
|
||||
cmnt, src_token, src_action, tgt_token, tgt_action = \
|
||||
make_vector(ea_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
# neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
|
||||
# make_vector(neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
logit, _ = model(cmnt, src_token, src_action, tgt_token, tgt_action, cr_mode=False)
|
||||
# logit_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action, cr_mode=False)
|
||||
ea_pred_cur = (torch.max(logit, 1)[1].view(logit.size(0)).data).tolist()
|
||||
|
||||
for i in range(len(ea_truth_cur)):
|
||||
# if ea_pred_cur[i] == ea_truth_cur[i]:
|
||||
dump_editanch_case(ea_batch[0][i], ea_batch[3][i], ea_pred_cur[i], ea_truth_cur[i])
|
||||
|
||||
pass
|
||||
# ea_truth_cur = [1] * logit_pos.size(0) + [0] * logit_neg.size(0)
|
|
@ -0,0 +1,185 @@
|
|||
import argparse
|
||||
import time
|
||||
|
||||
from eval import case_study, predict
|
||||
from process_data import load_glove_weights
|
||||
from train import *
|
||||
from wikicmnt_dataset import Wiki_DataSet
|
||||
from wikicmnt_model import CmntModel
|
||||
|
||||
#############################################################################################
|
||||
# ArgumentParser
|
||||
#############################################################################################
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_path', type=str, default='./data/processed/', help='the data directory')
|
||||
parser.add_argument('--checkpoint_path', type=str, default='./checkpoint/', help='the checkpoint directory')
|
||||
parser.add_argument('--glove_path', type=str, default='./data/glove/', help='the glove directory')
|
||||
|
||||
# model
|
||||
parser.add_argument('--log_interval', type=int, default=100,
|
||||
help='how many steps to wait before logging training status [default: 100]')
|
||||
parser.add_argument('--test_interval', type=int, default=1,
|
||||
help='how many steps to wait before testing [default: 1000]')
|
||||
parser.add_argument('--save_best', type=bool, default=True, help='whether to save when get best performance')
|
||||
parser.add_argument('--lr', type=float, default=0.5, help='learning rate, default=0.5')
|
||||
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
|
||||
parser.add_argument('--word_embd_size', type=int, default=100, help='word embedding size')
|
||||
parser.add_argument('--max_ctx_length', type=int, default=300, help='the maximum words in the context [default: 300]')
|
||||
parser.add_argument('--max_diff_length', type=int, default=300,
|
||||
help='the maximum words in the revision difference [default: 200]')
|
||||
parser.add_argument('--max_cmnt_length', type=int, default=30, help='the maximum words in the comment [default: 30]')
|
||||
parser.add_argument('--ctx_mode', type=bool, default=True,
|
||||
help='whether to use change context in training [default: True]')
|
||||
|
||||
# training
|
||||
parser.add_argument('--epoch', type=int, default=10, help='number of epoch, default=10')
|
||||
parser.add_argument('--start_epoch', type=int, default=1, help='resume epoch count, default=1')
|
||||
parser.add_argument('--batch_size', type=int, default=10, help='input batch size')
|
||||
|
||||
parser.add_argument('--cr_train', action='store_true', default=False, help='whether to training the comment rank task')
|
||||
parser.add_argument('--ea_train', action='store_true', default=False,
|
||||
help='whether to training the revision anchoring task')
|
||||
|
||||
# ablation testing
|
||||
parser.add_argument('--no_action', action='store_true', default=False,
|
||||
help='whether to use action encoding to train the model')
|
||||
parser.add_argument('--no_attention', action='store_true', default=False,
|
||||
help='whether to use mutual attention to train the model')
|
||||
parser.add_argument('--no_hadamard', action='store_true', default=False,
|
||||
help='whether to use hadamard product to train the model')
|
||||
|
||||
parser.add_argument('--src_train', type=bool, default=False,
|
||||
help='whether to training the comment rank task without before-editing version')
|
||||
parser.add_argument('--train_ratio', type=int, default=0.7,
|
||||
help='ratio of training data in the entire data [default: 0.7]')
|
||||
parser.add_argument('--val_ratio', type=int, default=0.1,
|
||||
help='ratio of validation data in the entire data [default: 0.1]')
|
||||
parser.add_argument('--val_size', type=int, default=10000,
|
||||
help='force the size of validation dataset, the parameter will disgard the setting of val_ratio [default: -1]')
|
||||
|
||||
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
||||
parser.add_argument('--test', action='store_true', default=False, help='use test model')
|
||||
parser.add_argument('--case_study', action='store_true', default=False, help='use case study mode')
|
||||
parser.add_argument('--resume', default='./checkpoints/model_best.tar', type=str, metavar='PATH',
|
||||
help='path saved params')
|
||||
parser.add_argument('--seed', type=int, default=1111, help='random seed')
|
||||
|
||||
# device
|
||||
parser.add_argument('--gpu', type=int, default=-1, help='gpu to use for iterate data, -1 mean cpu [default: -1]')
|
||||
|
||||
parser.add_argument('--checkpoint', type=str, default=None, help='filename of model checkpoint [default: None]')
|
||||
|
||||
parser.add_argument('--rank_num', type=int, default=5, help='the number of ranking comments')
|
||||
parser.add_argument('--anchor_num', type=int, default=5, help='the number of ranking comments')
|
||||
parser.add_argument('--use_target_only', action='store_true', default=False, help='use target context only in model')
|
||||
|
||||
# single case prediction
|
||||
parser.add_argument('--predict', action='store_true', default=False, help='predict the sentence given')
|
||||
parser.add_argument('--pred_cmnt', type=str, default=None, help='the comment of prediction')
|
||||
parser.add_argument('--pred_ctx', type=str, default=None, help='the context of prediction')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.gpu >= 0:
|
||||
torch.cuda.set_device(args.gpu)
|
||||
|
||||
print("\nParameters:")
|
||||
for attr, value in sorted(args.__dict__.items()):
|
||||
print("\t{}={}".format(attr.upper(), value))
|
||||
|
||||
# Set the random seed manually for reproducibility.
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
print(os.name)
|
||||
sys.stdout.flush()
|
||||
|
||||
# load data
|
||||
dataset = Wiki_DataSet(args)
|
||||
train_df, val_df, test_df, vocab_json = dataset.load_data(train_ratio=args.train_ratio, val_ratio=args.val_ratio)
|
||||
w2i = vocab_json['word2idx']
|
||||
|
||||
print('----')
|
||||
print('n_train', train_df.shape[0])
|
||||
print('n_val', val_df.shape[0])
|
||||
print('n_test', test_df.shape[0])
|
||||
print('vocab_size:', len(w2i))
|
||||
|
||||
# load glove
|
||||
glove_embd_w = torch.from_numpy(load_glove_weights(args.glove_path, args.word_embd_size, len(w2i), w2i)).type(
|
||||
torch.FloatTensor)
|
||||
# save_pickle(glove_embd_w, './pickle/glove_embd_w.pickle')
|
||||
|
||||
args.vocab_size_w = len(w2i)
|
||||
args.pre_embd_w = glove_embd_w
|
||||
args.filters = [[1, 5]]
|
||||
args.out_chs = 100
|
||||
|
||||
# generate save directory
|
||||
base_name = os.path.basename(os.path.normpath(args.data_path))
|
||||
if args.cr_train and not args.ea_train:
|
||||
task_str = "cr"
|
||||
elif args.ea_train and not args.cr_train:
|
||||
task_str = "ea"
|
||||
elif args.cr_train and args.ea_train:
|
||||
task_str = "mt"
|
||||
folder_name = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + "_" + base_name + "_" + task_str
|
||||
if args.no_action:
|
||||
folder_name += "_noaction"
|
||||
if args.word_embd_size == 300:
|
||||
folder_name += "_d300"
|
||||
args.save_dir = os.path.join(args.checkpoint_path, folder_name)
|
||||
print("Save to ", args.save_dir)
|
||||
sys.stdout.flush()
|
||||
|
||||
# initialize model
|
||||
model = CmntModel(args)
|
||||
|
||||
if args.checkpoint is not None:
|
||||
print('\nLoading model from {}...'.format(args.checkpoint))
|
||||
model.load_state_dict(torch.load(args.checkpoint))
|
||||
|
||||
if torch.cuda.is_available() and os.name != 'nt':
|
||||
print('use cuda')
|
||||
model.cuda()
|
||||
# model = torch.nn.DataParallel(model, device_ids=[0])
|
||||
|
||||
# optimizer = torch.optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()), lr=0.5)
|
||||
# optimizer = torch.optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()))
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
|
||||
# optimizer = torch.optim.Adamax(filter(lambda p: p.requires_grad, model.parameters()))
|
||||
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
checkpoint = torch.load(args.resume)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
# best_prec1 = checkpoint['best_prec1']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
print(model)
|
||||
print('parameters-----')
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
print(name, param.data.size())
|
||||
|
||||
if args.predict:
|
||||
print('Prediction mode')
|
||||
print('#Comment:', args.pred_cmnt)
|
||||
print('#Context:', args.pred_ctx)
|
||||
predict(args.pred_cmnt, args.pred_ctx, w2i, model, args.max_ctx_length)
|
||||
elif args.test:
|
||||
print('Test mode')
|
||||
eval(dataset, test_df, w2i, model, args)
|
||||
elif args.case_study:
|
||||
start_time = time.time()
|
||||
case_study(dataset, test_df, w2i, model, args)
|
||||
else:
|
||||
print('Train mode')
|
||||
start_time = time.time()
|
||||
train(args, model, dataset, train_df, val_df, optimizer, w2i, \
|
||||
n_epoch=args.epoch, start_epoch=args.start_epoch, batch_size=args.batch_size)
|
||||
print("Training duration: %s seconds" % (time.time() - start_time))
|
||||
print('finish')
|
|
@ -0,0 +1,361 @@
|
|||
import json
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import spacy
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
# TODO global
|
||||
NULL = "-NULL-"
|
||||
UNK = "-UNK-"
|
||||
ENT = "-ENT-"
|
||||
|
||||
# initialize the spacy
|
||||
nlp = spacy.load('en')
|
||||
|
||||
|
||||
def word_tokenize(text):
|
||||
doc = nlp(text)
|
||||
tokens = [token.string.strip() for token in doc]
|
||||
return tokens
|
||||
|
||||
|
||||
def save_pickle(d, path):
|
||||
print('save pickle to', path)
|
||||
with open(path, mode='wb') as f:
|
||||
pickle.dump(d, f)
|
||||
|
||||
|
||||
def load_pickle(path):
|
||||
print('load', path)
|
||||
with open(path, mode='rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def lower_list(str_list):
|
||||
return [str_var.lower() for str_var in str_list]
|
||||
|
||||
|
||||
def load_processed_json(fpath_data, fpath_shared):
|
||||
data = json.load(open(fpath_data))
|
||||
shared = json.load(open(fpath_shared))
|
||||
return data, shared
|
||||
|
||||
|
||||
def load_glove_weights(glove_dir, embd_dim, vocab_size, word_index):
|
||||
embeddings_index = {}
|
||||
if embd_dim < 300:
|
||||
glove_version = 'glove.6B.'
|
||||
else:
|
||||
glove_version = 'glove.840B.'
|
||||
|
||||
with open(os.path.join(glove_dir, glove_version + str(embd_dim) + 'd.txt'), encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
values = line.split()
|
||||
word = values[0]
|
||||
vector = np.array(values[1:], dtype='float32')
|
||||
embeddings_index[word] = vector
|
||||
except:
|
||||
continue
|
||||
|
||||
print('Found %s word vectors in glove.' % len(embeddings_index))
|
||||
embedding_matrix = np.zeros((vocab_size, embd_dim))
|
||||
print('embed_matrix.shape', embedding_matrix.shape)
|
||||
found_ct = 0
|
||||
for word, i in word_index.items():
|
||||
embedding_vector = embeddings_index.get(word)
|
||||
# words not found in embedding index will be all-zeros.
|
||||
if embedding_vector is not None:
|
||||
embedding_matrix[i] = embedding_vector
|
||||
found_ct += 1
|
||||
print(found_ct, 'words are found in glove')
|
||||
|
||||
return embedding_matrix
|
||||
|
||||
|
||||
def to_var(x):
|
||||
# if torch.cuda.is_available():
|
||||
# x = x.cuda()
|
||||
# return Variable(x)
|
||||
x = Variable(x)
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda()
|
||||
return x
|
||||
|
||||
|
||||
def to_np(x):
|
||||
return x.data.cpu().numpy()
|
||||
|
||||
|
||||
def _make_action_vector(actions, seq_len):
|
||||
index_vec = [action for action in actions]
|
||||
pad_len = max(0, seq_len - len(index_vec))
|
||||
index_vec += [-1] * pad_len
|
||||
index_vec = index_vec[:seq_len]
|
||||
return index_vec
|
||||
|
||||
|
||||
def _make_word_vector(sentence, w2i, seq_len):
|
||||
index_vec = [w2i[w] if w in w2i else w2i[UNK] for w in sentence]
|
||||
pad_len = max(0, seq_len - len(index_vec))
|
||||
index_vec += [w2i[NULL]] * pad_len
|
||||
index_vec = index_vec[:seq_len]
|
||||
return index_vec
|
||||
|
||||
|
||||
def _make_char_vector(data, c2i, sent_len, word_len):
|
||||
tmp = torch.ones(sent_len, word_len).type(torch.LongTensor) # TODO use fills
|
||||
for i, word in enumerate(data):
|
||||
for j, ch in enumerate(word):
|
||||
tmp[i][j] = c2i[ch] if ch in c2i else c2i[UNK]
|
||||
if j == word_len - 1:
|
||||
break
|
||||
if i == sent_len - 1:
|
||||
break
|
||||
return tmp
|
||||
|
||||
|
||||
def make_diff(diffs_raw):
|
||||
return diffs_raw
|
||||
|
||||
|
||||
def make_vector_one_sample(pred_cmnt, pred_ctx, w2i, c2i, ctx_sent_len, ctx_word_len, query_sent_len, query_word_len):
|
||||
cmnt_words, cmnt_chars, ctx_words, ctx_chars, ans, diffs = [], [], [], [], [], []
|
||||
# c, cc, q, cq, a in batch
|
||||
|
||||
cmnt_words.append(_make_word_vector(batch[0][i], w2i, ctx_sent_len))
|
||||
cmnt_chars.append(_make_char_vector(batch[1][i], c2i, ctx_sent_len, ctx_word_len))
|
||||
ctx_words.append(_make_word_vector(batch[2][i], w2i, query_sent_len))
|
||||
ctx_chars.append(_make_char_vector(batch[3][i], c2i, query_sent_len, query_word_len))
|
||||
ans.append(batch[4][i])
|
||||
# append the diffs
|
||||
diffs_raw = batch[5][i]
|
||||
diffs_ex = [-1] * query_sent_len
|
||||
for diff_idx in diffs_raw:
|
||||
diffs_ex[diff_idx - 1] = 1
|
||||
diffs.append(diffs_ex)
|
||||
|
||||
cmnt_words = to_var(torch.LongTensor(cmnt_words))
|
||||
cmnt_chars = to_var(torch.stack(cmnt_chars, 0))
|
||||
ctx_words = to_var(torch.LongTensor(ctx_words))
|
||||
ctx_chars = to_var(torch.stack(ctx_chars, 0))
|
||||
ans = to_var(torch.LongTensor(ans))
|
||||
diffs = to_var(torch.FloatTensor(diffs))
|
||||
return cmnt_words, cmnt_chars, ctx_words, ctx_chars
|
||||
|
||||
|
||||
'''
|
||||
Generate the word vector for each batch
|
||||
'''
|
||||
|
||||
|
||||
def make_vector(batch, w2i, cmnt_sent_len, ctx_sent_len):
|
||||
cmnt_words, src_token, src_action, tgt_token, tgt_action = [], [], [], [], []
|
||||
# batch_cmnt, batch_neg_cmnt, batch_origin, batch_target
|
||||
|
||||
for i in range(len(batch[0])):
|
||||
cmnt_words.append(_make_word_vector(batch[0][i], w2i, cmnt_sent_len))
|
||||
|
||||
src_token.append(_make_word_vector(batch[1][i], w2i, ctx_sent_len))
|
||||
src_action.append(_make_action_vector(batch[2][i], ctx_sent_len))
|
||||
|
||||
tgt_token.append(_make_word_vector(batch[3][i], w2i, ctx_sent_len))
|
||||
tgt_action.append(_make_action_vector(batch[4][i], ctx_sent_len))
|
||||
|
||||
cmnt_words = to_var(torch.LongTensor(cmnt_words))
|
||||
# neg_cmnt_words = to_var(torch.LongTensor(neg_cmnt_words))
|
||||
|
||||
src_token = to_var(torch.LongTensor(src_token))
|
||||
src_action = to_var(torch.LongTensor(src_action))
|
||||
|
||||
tgt_token = to_var(torch.LongTensor(tgt_token))
|
||||
tgt_action = to_var(torch.LongTensor(tgt_action))
|
||||
|
||||
return cmnt_words, src_token, src_action, tgt_token, tgt_action
|
||||
|
||||
|
||||
'''
|
||||
generate the batches for training and evaluation
|
||||
type definition:
|
||||
type 1: for the comment rank task
|
||||
type 2: for the diff anchoring task
|
||||
type 3: use the target diff only
|
||||
'''
|
||||
|
||||
|
||||
def gen_cmntrank_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len, rank_num):
|
||||
'''
|
||||
Batch Content:
|
||||
0,1. batch comment, batch neg_cmnt
|
||||
2,3. batch src_tokens, batch src_actions
|
||||
4,5. batch tgt_tokens, batch tgt_actions
|
||||
6,7. batch pos_edits, batch neg_edits
|
||||
'''
|
||||
|
||||
pos_cmnts, pos_src_tokens, pos_src_actions, pos_tgt_tokens, pos_tgt_actions = [], [], [], [], []
|
||||
neg_cmnts, neg_src_tokens, neg_src_actions, neg_tgt_tokens, neg_tgt_actions = [], [], [], [], []
|
||||
sample_index_list = []
|
||||
for i in range(len(batch[0])):
|
||||
|
||||
sample_size = 0
|
||||
|
||||
cmnt = batch[0][i]
|
||||
neg_cmnt = batch[1][i]
|
||||
|
||||
src_tokens = batch[2][i]
|
||||
src_actions = batch[3][i]
|
||||
|
||||
tgt_tokens = batch[4][i]
|
||||
tgt_actions = batch[5][i]
|
||||
|
||||
if rank_num - 1 > len(neg_cmnt):
|
||||
continue
|
||||
|
||||
neg_sample_indices = random.sample(range(len(neg_cmnt)), rank_num - 1)
|
||||
|
||||
for neg_idx in neg_sample_indices:
|
||||
pos_cmnts.append(cmnt)
|
||||
pos_src_tokens.append(src_tokens)
|
||||
pos_src_actions.append(src_actions)
|
||||
pos_tgt_tokens.append(tgt_tokens)
|
||||
pos_tgt_actions.append(tgt_actions)
|
||||
|
||||
neg_cmnts.append(neg_cmnt[neg_idx])
|
||||
neg_src_tokens.append(src_tokens)
|
||||
neg_src_actions.append(src_actions)
|
||||
neg_tgt_tokens.append(tgt_tokens)
|
||||
neg_tgt_actions.append(tgt_actions)
|
||||
sample_size += 1
|
||||
|
||||
sample_index_list.append(sample_size)
|
||||
return (pos_cmnts, pos_src_tokens, pos_src_actions, pos_tgt_tokens, pos_tgt_actions), \
|
||||
(neg_cmnts, neg_src_tokens, neg_src_actions, neg_tgt_tokens, neg_tgt_actions)
|
||||
|
||||
|
||||
def gen_editanch_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len, anchor_num):
|
||||
'''
|
||||
Batch Content:
|
||||
0,1. batch comment, batch neg_cmnt
|
||||
2,3. batch src_tokens, batch src_actions
|
||||
4,5. batch tgt_tokens, batch tgt_actions
|
||||
6,7. batch pos_edits, batch neg_edits
|
||||
'''
|
||||
|
||||
cmnts, src_tokens, src_actions, tgt_tokens, tgt_actions, ea_truth = [], [], [], [], [], []
|
||||
|
||||
for i in range(len(batch[0])):
|
||||
|
||||
cmnt = batch[0][i]
|
||||
|
||||
pos_edits = batch[6][i]
|
||||
neg_edits = batch[7][i]
|
||||
|
||||
if len(pos_edits) > anchor_num:
|
||||
pos_edits = pos_edits[:anchor_num]
|
||||
|
||||
if anchor_num - len(pos_edits) < 0:
|
||||
neg_sample_indices = []
|
||||
elif anchor_num - len(pos_edits) > len(neg_edits):
|
||||
neg_sample_indices = range(len(neg_edits))
|
||||
else:
|
||||
neg_sample_indices = random.sample(range(len(neg_edits)), anchor_num - len(pos_edits))
|
||||
|
||||
for pos_edit in pos_edits:
|
||||
cmnts.append(cmnt)
|
||||
src_tokens.append([])
|
||||
src_actions.append([])
|
||||
tgt_tokens.append(pos_edit)
|
||||
tgt_actions.append([1] * len(pos_edit))
|
||||
ea_truth.append(1)
|
||||
|
||||
for neg_idx in neg_sample_indices:
|
||||
cmnts.append(cmnt)
|
||||
src_tokens.append([])
|
||||
src_actions.append([])
|
||||
tgt_tokens.append(neg_edits[neg_idx])
|
||||
tgt_actions.append([1] * len(neg_edits[neg_idx]))
|
||||
ea_truth.append(0)
|
||||
|
||||
return (cmnts, src_tokens, src_actions, tgt_tokens, tgt_actions), ea_truth
|
||||
|
||||
|
||||
def find_cont_diffs(tokens, token_diff):
|
||||
# split the token_diff into the consecutive parts
|
||||
token_cont_list = []
|
||||
|
||||
if len(token_diff) == 0:
|
||||
return token_cont_list
|
||||
|
||||
if len(token_diff) == 1:
|
||||
token_cont_list.append(token_diff)
|
||||
return token_cont_list
|
||||
|
||||
start_idx, cur_idx = 0, 1
|
||||
while cur_idx < len(token_diff):
|
||||
# if cur_idx == len(token_diff) - 1:
|
||||
# token_list.append(list(range(start_idx, cur_idx + 1)))
|
||||
# cur_idx += 1
|
||||
if token_diff[cur_idx] != token_diff[cur_idx - 1] + 1:
|
||||
token_cont_list.append(list(range(token_diff[start_idx], token_diff[cur_idx - 1] + 1)))
|
||||
start_idx = cur_idx
|
||||
cur_idx += 1
|
||||
else:
|
||||
cur_idx += 1
|
||||
|
||||
# handle the last list
|
||||
token_cont_list.append(list(range(token_diff[start_idx], token_diff[cur_idx - 1] + 1)))
|
||||
|
||||
return token_cont_list
|
||||
|
||||
|
||||
def find_diff_context(tokens, token_diff, context_length=50):
|
||||
cont_difflist = find_cont_diffs(tokens, token_diff)
|
||||
diff_context = set()
|
||||
for cont_diff in cont_difflist:
|
||||
# avoid the case when only one markup or space included in the context
|
||||
if len(cont_diff) == 1 and len(tokens[cont_diff[0]]) <= 1:
|
||||
continue
|
||||
diff_context_cur = find_diff_context_int(tokens, cont_diff, context_length)
|
||||
diff_context.update(diff_context_cur)
|
||||
|
||||
diff_context = [x for x in diff_context if x not in token_diff]
|
||||
return sorted(list(diff_context))
|
||||
|
||||
|
||||
'''
|
||||
Find context difference
|
||||
The function requires token_diff is consecutive.
|
||||
'''
|
||||
|
||||
|
||||
def find_diff_context_int(tokens, token_diff, context_length):
|
||||
if len(token_diff) == 0:
|
||||
return []
|
||||
|
||||
# if len(token_diff) == 1:
|
||||
# diff_context = [token_diff[0]]
|
||||
# else:
|
||||
# diff_context = range(token_diff[0], token_diff[-1] + 1)
|
||||
|
||||
# diff_context = [x for x in diff_context if x not in token_diff]
|
||||
|
||||
start_idx = token_diff[0]
|
||||
end_idx = token_diff[-1]
|
||||
|
||||
context_start = start_idx - context_length
|
||||
context_start = context_start if context_start > 0 else 0
|
||||
context_end = end_idx + context_length + 1
|
||||
context_end = context_end if context_end < len(tokens) else len(tokens)
|
||||
|
||||
diff_context = list(range(context_start, start_idx)) + list(range(end_idx + 1, context_end))
|
||||
diff_words = [tokens[i] for i in diff_context]
|
||||
|
||||
# if len(diff_context) > context_length:
|
||||
# diff_context = diff_context[:int(context_length/2)] + diff_context[-int(context_length/2):]
|
||||
# else:
|
||||
# remain_length = context_length - len(diff_context)
|
||||
return diff_context
|
|
@ -0,0 +1,148 @@
|
|||
import datetime
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from eval import eval
|
||||
from process_data import to_var, make_vector, gen_cmntrank_batches, gen_editanch_batches
|
||||
|
||||
|
||||
def rank_loss(score_pos, score_neg):
|
||||
# normalize context attention
|
||||
# ctx_att_norm = F.normalize(ctx_att, p=2, dim=1)
|
||||
batch_size = len(score_pos)
|
||||
# y = Variable(torch.FloatTensor([1] * batch_size))
|
||||
margin = to_var(torch.FloatTensor(score_pos.size()).fill_(1))
|
||||
# loss = F.margin_ranking_loss(score_pos, -score_neg, y, margin=10.0)
|
||||
|
||||
loss_list = margin - score_pos + score_neg
|
||||
loss_list = loss_list.clamp(min=0)
|
||||
loss = loss_list.sum()
|
||||
|
||||
return loss, loss_list
|
||||
|
||||
|
||||
def cal_batch_loss(loss_list, batch_size, index_list):
|
||||
loss_list = loss_list.data.squeeze(1).numpy().tolist()
|
||||
loss_step = int(len(loss_list) / batch_size)
|
||||
# return [sum(loss_list[i * loss_step: (i + 1) * loss_step]) / loss_step for i in range(batch_size)]
|
||||
start = 0
|
||||
batch_loss_list = []
|
||||
for i in range(batch_size):
|
||||
cur_bs = index_list[i]
|
||||
end = start + cur_bs
|
||||
if cur_bs == 0:
|
||||
batch_loss_list.append(0)
|
||||
else:
|
||||
batch_loss_list.append(sum(loss_list[start:end]) / cur_bs)
|
||||
start = end
|
||||
return batch_loss_list
|
||||
|
||||
|
||||
def train(args, model, dataset, train_df, val_df, optimizer, w2i, n_epoch, start_epoch, batch_size):
|
||||
print('----Train---')
|
||||
label = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||
model.train()
|
||||
|
||||
cmnt_sent_len = args.max_cmnt_length
|
||||
diff_sent_len = args.max_diff_length
|
||||
ctx_sent_len = args.max_ctx_length
|
||||
train_size = len(train_df)
|
||||
# if args.use_cl:
|
||||
# sample_size = int(train_size * 0.5)
|
||||
# else:
|
||||
# sample_size = -1
|
||||
sample_size = int(train_size)
|
||||
cr_best_acc, ea_best_acc = 0, 0
|
||||
|
||||
for epoch in range(start_epoch, n_epoch + 1):
|
||||
print('============================== Epoch ', epoch, ' ==============================')
|
||||
|
||||
for step, batch in enumerate(dataset.iterate_minibatches(train_df, batch_size, epoch, n_epoch), start=1):
|
||||
batch_sample_weights = None
|
||||
total_loss = 0
|
||||
if args.cr_train:
|
||||
cr_pos_batch, cr_neg_batch = gen_cmntrank_batches(batch, w2i, cmnt_sent_len, diff_sent_len,
|
||||
ctx_sent_len, args.rank_num)
|
||||
if len(cr_pos_batch[0]) > 0:
|
||||
# TODO: tuning the code for more effective way: combine the positive sample and negative samples
|
||||
pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action = \
|
||||
make_vector(cr_pos_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
|
||||
make_vector(cr_neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
score_pos, _ = model(pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action,
|
||||
cr_mode=True)
|
||||
score_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action,
|
||||
cr_mode=True)
|
||||
|
||||
loss, _ = rank_loss(score_pos, score_neg)
|
||||
# batch_sample_weights = cal_batch_loss(loss_list, batch_size, index_list)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# train revision anchoring
|
||||
if args.ea_train:
|
||||
ea_batch, ea_truth = gen_editanch_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
|
||||
args.anchor_num)
|
||||
if len(ea_batch[0]) > 0:
|
||||
cmnt, src_token, src_action, tgt_token, tgt_action = \
|
||||
make_vector(ea_batch, w2i, cmnt_sent_len, ctx_sent_len)
|
||||
logit, _ = model(cmnt, src_token, src_action, tgt_token, tgt_action, cr_mode=False)
|
||||
target = to_var(torch.tensor(ea_truth))
|
||||
|
||||
loss = F.cross_entropy(logit, target)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# update the loss for each data sample
|
||||
# new_sample_weights += batch_sample_weights
|
||||
|
||||
if step % args.log_interval == 0:
|
||||
# corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
|
||||
# p1_corr, p3_corr, p5_corr, mrr, ndcg = eval_rank(score_pos, score_neg, args.batch_size)
|
||||
# p1_acc = p1_corr / args.batch_size * 100
|
||||
# p3_acc = p3_corr / args.batch_size * 100
|
||||
# p5_acc = p5_corr / args.batch_size * 100
|
||||
try:
|
||||
sys.stdout.write(
|
||||
'\rEpoch[{}] Batch[{}] - loss: {:.6f}\n'.format(epoch, step, loss.data.item()))
|
||||
sys.stdout.flush()
|
||||
except:
|
||||
print("Unexpected error:", sys.exc_info()[0])
|
||||
|
||||
if step % args.test_interval == 0:
|
||||
if args.val_size > 0:
|
||||
val_df = val_df[:args.val_size]
|
||||
cr_acc, ea_acc = eval(dataset, val_df, w2i, model, args)
|
||||
model.train() # change model back to training mode
|
||||
if args.cr_train and cr_acc > cr_best_acc:
|
||||
cr_best_acc = cr_acc
|
||||
if args.save_best:
|
||||
save(model, args.save_dir, 'best_cr', epoch, step, cr_best_acc, args.no_action)
|
||||
if args.ea_train and ea_acc > ea_best_acc:
|
||||
ea_best_acc = ea_acc
|
||||
if args.save_best:
|
||||
save(model, args.save_dir, 'best_ea', epoch, step, ea_best_acc, args.no_action)
|
||||
# sample_weights = new_sample_weights
|
||||
|
||||
|
||||
def save(model, save_dir, save_prefix, epoch, steps, best_result, no_action=False):
|
||||
if not os.path.isdir(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
save_prefix = os.path.join(save_dir, save_prefix)
|
||||
|
||||
# delete previously saved checkpoints
|
||||
exist_files = sorted(glob.glob(save_prefix + '*'))
|
||||
for file_name in exist_files:
|
||||
if os.path.exists(file_name):
|
||||
os.remove(file_name)
|
||||
|
||||
result_str = '%.3f' % best_result
|
||||
save_path = '{}_steps_{:02}_{:06}_{}.pt'.format(save_prefix, epoch, steps, result_str)
|
||||
print("Save best model", save_path)
|
||||
torch.save(model.state_dict(), save_path)
|
|
@ -0,0 +1,492 @@
|
|||
import difflib
|
||||
import glob
|
||||
import html
|
||||
import random
|
||||
import re
|
||||
|
||||
import spacy
|
||||
from tqdm import tqdm
|
||||
|
||||
# initialize the spacy
|
||||
nlp = spacy.load('en')
|
||||
|
||||
'''
|
||||
Extract the contents by delimitors
|
||||
'''
|
||||
def extract_with_delims(content, start_delim, end_delim, start_idx):
|
||||
delims_start = content.find(start_delim, start_idx)
|
||||
if delims_start == -1:
|
||||
return '', start_idx
|
||||
|
||||
delims_end = content.find(end_delim, start_idx)
|
||||
if delims_end == -1:
|
||||
return '', start_idx
|
||||
|
||||
if delims_end <= delims_start:
|
||||
return '', start_idx
|
||||
|
||||
delims_start += len(start_delim)
|
||||
|
||||
return content[delims_start:delims_end], delims_end
|
||||
|
||||
'''
|
||||
Extract the contents of revisions, e.g., revision_id, parent_id, user_name, comment, text
|
||||
'''
|
||||
def extract_data(revision_part):
|
||||
rev_id, next_idx = extract_with_delims(revision_part, "<id>", "</id>", 0)
|
||||
parent_id, next_idx = extract_with_delims(revision_part, "<parentid>", "</parentid>", next_idx)
|
||||
timestamp, next_idx = extract_with_delims(revision_part, "<timestamp>", "</timestamp>", next_idx)
|
||||
username, next_idx = extract_with_delims(revision_part, "<username>", "</username>", next_idx)
|
||||
userid, next_idx = extract_with_delims(revision_part, "<id>", "</id>", next_idx)
|
||||
# For annoymous user, the ip address will be used instead of the user name and id
|
||||
userip, next_idx = extract_with_delims(revision_part, "<ip>", "</ip>", next_idx)
|
||||
comment, next_idx = extract_with_delims(revision_part, "<comment>", "</comment>", next_idx)
|
||||
text, next_idx = extract_with_delims(revision_part, "<text xml:space=\"preserve\">", "</text>", next_idx)
|
||||
return (rev_id, parent_id, timestamp, username, userid, userip, comment, text)
|
||||
|
||||
'''
|
||||
Extract the revision text buffer, which has the format "<revision> ... </revision>".
|
||||
'''
|
||||
def split_records(wiki_file, chunk_size=150 * 1024):
|
||||
|
||||
text_buffer = ""
|
||||
cur_index = 0
|
||||
|
||||
while True:
|
||||
chunk = wiki_file.read(chunk_size)
|
||||
|
||||
if chunk:
|
||||
text_buffer += chunk
|
||||
|
||||
cur_index = 0
|
||||
REVISION_START = "<revision>"
|
||||
REVISION_END = "</revision>"
|
||||
PAGE_START = "<page>"
|
||||
PAGE_TITLE_START = "<title>"
|
||||
PAGE_TITLE_END = "</title>"
|
||||
while True:
|
||||
page_start_index = text_buffer.find(PAGE_START, cur_index)
|
||||
if page_start_index != -1:
|
||||
# update the current page title/ID
|
||||
page_title, _ = extract_with_delims(text_buffer, PAGE_TITLE_START, PAGE_TITLE_END, 0)
|
||||
if not page_title:
|
||||
# no complete page title
|
||||
break
|
||||
#logging.debug("Error: page information is cut. FIX THIS ISSUE!!!")
|
||||
|
||||
# find the revision start position
|
||||
revision_start_index = text_buffer.find(REVISION_START, cur_index)
|
||||
|
||||
# No revision in the buffer, continue loading data
|
||||
if revision_start_index == -1:
|
||||
break
|
||||
|
||||
# find the revision end position
|
||||
revision_end_index = text_buffer.find(REVISION_END, revision_start_index)
|
||||
|
||||
# No complete page in buffer
|
||||
if revision_end_index == -1:
|
||||
break
|
||||
|
||||
yield page_title, text_buffer[revision_start_index:revision_end_index + len(REVISION_END)]
|
||||
|
||||
cur_index = revision_end_index + len(REVISION_END)
|
||||
|
||||
# No more data
|
||||
if chunk == "":
|
||||
break
|
||||
|
||||
if cur_index == -1:
|
||||
text_buffer = ""
|
||||
else:
|
||||
text_buffer = text_buffer[cur_index:]
|
||||
|
||||
def sampleNext(sample_ratio):
|
||||
return random.random() < sample_ratio
|
||||
|
||||
def cleanCmntText(comment):
|
||||
filter_words = []
|
||||
comment = comment.replace("(edited with [[User:ProveIt_GT|ProveIt]]", "")
|
||||
#comment = re.sub("(edited with \[\[User\:ProveIt_GT\|ProveIt\]\]", "", comment)
|
||||
return comment
|
||||
|
||||
|
||||
def checkComment(comment, comment_tokens, min_comment_length):
|
||||
|
||||
if len(comment_tokens) < min_comment_length:
|
||||
return False
|
||||
|
||||
filter_words = ["[[Project:AWB|AWB]]", "[[Project:AutoWikiBrowser|AWB]]", "Undid revision"]
|
||||
if any(word in comment for word in filter_words):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
'''
|
||||
clean the wiki text
|
||||
E.g. "[[link name]] a" bds" ''markup''" to "link name a bds markup"
|
||||
'''
|
||||
def cleanWikiText(wiki_text):
|
||||
|
||||
# replace link: [[link_name]] and quotes
|
||||
wiki_text = re.sub("\[\[", "", wiki_text)
|
||||
wiki_text = re.sub("\]\]", "", wiki_text)
|
||||
wiki_text = re.sub("''", "", wiki_text)
|
||||
|
||||
# replace '<', '>', '&'
|
||||
# wiki_text = re.sub(""", "", wiki_text)
|
||||
# wiki_text = re.sub("<", "<", wiki_text)
|
||||
# wiki_text = re.sub(">", ">", wiki_text)
|
||||
# wiki_text = re.sub("&", "&", wiki_text)
|
||||
|
||||
# use html unescape to decode the html special characters
|
||||
wiki_text = html.unescape(wiki_text)
|
||||
|
||||
return wiki_text
|
||||
|
||||
def tokenizeText(text):
|
||||
doc = nlp(text)
|
||||
sentences = [sent.string.strip() for sent in doc.sents]
|
||||
sent_tokens = []
|
||||
tokens = []
|
||||
for sent in sentences:
|
||||
sent_doc = nlp(sent)
|
||||
token_one_sent = [token.string.strip() for token in sent_doc]
|
||||
sent_tokens.append(token_one_sent)
|
||||
tokens += token_one_sent
|
||||
return sent_tokens, tokens
|
||||
|
||||
# return the difference indices starting at 0.
|
||||
def diffRevision(parent_sent_list, sent_list):
|
||||
|
||||
# make diff
|
||||
origin_start_idx, origin_start_idx = -1, -1
|
||||
target_start_idx, target_end_idx = -1, -1
|
||||
origin_diff_list, target_diff_list = [], []
|
||||
for line in difflib.context_diff(parent_sent_list, sent_list, 'origin', 'target'):
|
||||
|
||||
#print(line)
|
||||
# parse the origin diff line range: e.g., --- 56,62 ----
|
||||
if line.startswith("*** ") and line.endswith(" ****\n"):
|
||||
target_start_idx, target_end_idx = -1, -1 # reset the target indices
|
||||
range_line = line[4:-6]
|
||||
if ',' not in range_line:
|
||||
origin_start_idx = int(range_line)
|
||||
origin_end_idx = origin_start_idx
|
||||
else:
|
||||
origin_start_idx, origin_end_idx = [int(i) for i in range_line.split(',')]
|
||||
origin_sent_idx = origin_start_idx
|
||||
continue
|
||||
|
||||
# parse the diff line range: e.g., --- 56,62 ----
|
||||
if line.startswith("--- ") and line.endswith(" ----\n"):
|
||||
origin_start_idx, origin_end_idx = -1, -1 # reset the origin indices
|
||||
range_line = line[4:-6]
|
||||
if ',' not in range_line:
|
||||
target_start_idx = int(range_line)
|
||||
target_end_idx = target_start_idx
|
||||
else:
|
||||
target_start_idx, target_end_idx = [int(i) for i in range_line.split(',')]
|
||||
target_sent_idx = target_start_idx
|
||||
continue
|
||||
|
||||
if origin_start_idx >= 0:
|
||||
if len(line.strip('\n')) == 0:
|
||||
continue
|
||||
elif line.startswith('-') or line.startswith('!'):
|
||||
origin_diff_list.append(origin_sent_idx - 1) # adding the index starting at 0
|
||||
origin_sent_idx += 1
|
||||
else:
|
||||
origin_sent_idx += 1
|
||||
|
||||
if target_start_idx >= 0:
|
||||
if len(line.strip('\n')) == 0:
|
||||
continue
|
||||
elif line.startswith('+') or line.startswith('!'):
|
||||
target_diff_list.append(target_sent_idx - 1) # adding the index starting at 0
|
||||
target_sent_idx += 1
|
||||
else:
|
||||
target_sent_idx += 1
|
||||
|
||||
#print("Extracted Diff:", diff_sent_list)
|
||||
return origin_diff_list, target_diff_list
|
||||
|
||||
def findSentDiff(sents, tokens, token_diff):
|
||||
diff_idx, sent_idx, token_offset = 0, 0, 0
|
||||
diff_sents = set()
|
||||
if len(token_diff) == 0 or len(sents) == 0:
|
||||
return list(diff_sents)
|
||||
|
||||
token_offset = len(sents[0])
|
||||
while diff_idx < len(token_diff) and sent_idx < len(sents):
|
||||
if token_offset >= token_diff[diff_idx]:
|
||||
cur_token = tokens[diff_idx]
|
||||
# avoid the case that one more sentence added because only one markup included.
|
||||
if len(cur_token) > 1:
|
||||
diff_sents.add(sent_idx)
|
||||
diff_idx += 1
|
||||
else:
|
||||
sent_idx += 1
|
||||
token_offset += len(sents[sent_idx])
|
||||
|
||||
return list(diff_sents)
|
||||
|
||||
|
||||
def extContext(tokens, token_diff, ctx_window):
|
||||
'''
|
||||
Extend the context into the token difference
|
||||
|
||||
:param token_diff: a list of tokens which represents the difference between previous version and current version.
|
||||
ctx_window: the size of context before and after the edits
|
||||
:return:
|
||||
|
||||
For example: token_diff = [2, 3, 4, 11, 12, 13, 14, 16] and ctx_window = 2
|
||||
The function will return ctx_tokens = [0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
|
||||
action = [0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0]
|
||||
'''
|
||||
|
||||
|
||||
ctx_set = set(token_diff)
|
||||
|
||||
for idx in token_diff:
|
||||
for i in range(idx - ctx_window, idx + ctx_window + 1):
|
||||
if i < 0 or i >= len(tokens):
|
||||
continue
|
||||
ctx_set.add(i)
|
||||
|
||||
action = []
|
||||
ctx_token = []
|
||||
ctx_token_idx = sorted(list(ctx_set))
|
||||
diff_set = set(token_diff)
|
||||
|
||||
for i in ctx_token_idx:
|
||||
ctx_token.append(tokens[i])
|
||||
if i in diff_set:
|
||||
action.append(1)
|
||||
else:
|
||||
action.append(0)
|
||||
|
||||
return ctx_token, action
|
||||
|
||||
|
||||
|
||||
# def findDiffSents(sents, diff_list):
|
||||
# token_idx = 0
|
||||
# start_sent_idx, end_sent_idx = -1, -1
|
||||
# for i, sent in enumerate(sents):
|
||||
# token_idx += len(sent)
|
||||
# if start_sent_idx < 0 and diff_list[0] < token_idx:
|
||||
# start_sent_idx = i
|
||||
# if end_sent_idx < 0 and diff_list[-1] < token_idx:
|
||||
# end_sent_idx = i
|
||||
# if start_sent_idx >= 0 and end_sent_idx >= 0:
|
||||
# break
|
||||
# return start_sent_idx, end_sent_idx
|
||||
|
||||
def extDiffSents(sents, start, end, max_tokens=200, max_sents_add=2):
|
||||
token_size = 0
|
||||
for i in range(start, end + 1):
|
||||
token_size += len(sents[i])
|
||||
|
||||
context_sents = sents[start:end + 1]
|
||||
sents_head_added, sents_tail_added = 0, 0
|
||||
while token_size <= max_tokens and len(context_sents) < len(sents)\
|
||||
and sents_head_added < max_sents_add and sents_tail_added < max_sents_add:
|
||||
if start > 0:
|
||||
start -= 1
|
||||
insert_sent = sents[start]
|
||||
context_sents.insert(0, insert_sent)
|
||||
token_size += len(insert_sent)
|
||||
sents_head_added += 1
|
||||
|
||||
if end < len(sents) - 1:
|
||||
end += 1
|
||||
insert_sent = sents[end]
|
||||
context_sents.append(insert_sent)
|
||||
token_size += len(insert_sent)
|
||||
sents_tail_added += 1
|
||||
|
||||
diff_offset = sum([len(sents[i]) for i in range(start)])
|
||||
return context_sents, diff_offset
|
||||
|
||||
def mapDiffContext(sents, start_sent, end_sent):
|
||||
|
||||
start_idx = -1
|
||||
end_idx = -1
|
||||
start_sent_str = " ".join(start_sent)
|
||||
end_sent_str = " ".join(end_sent)
|
||||
for i, sent in enumerate(sents):
|
||||
sent_str = " ".join(sent)
|
||||
if start_idx < 0 and start_sent_str in sent_str:
|
||||
start_idx = i
|
||||
|
||||
if end_sent_str in sent_str:
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
# if start_idx == -1:
|
||||
# start_idx = 0
|
||||
|
||||
# if end_idx == -1:
|
||||
# context_sents = sents[start_idx:]
|
||||
# else:
|
||||
# context_sents = sents[start_idx:end_idx + 1]
|
||||
if start_idx == -1 or end_idx == -1:
|
||||
return None, None
|
||||
|
||||
diff_offset = sum([len(sents[i]) for i in range(start_idx)])
|
||||
return sents[start_idx:end_idx + 1], diff_offset
|
||||
|
||||
|
||||
# def extDiffContextInt(target_context_sents, target_sents, target_start, target_end, token_size, max_tokens=200, max_sents_add=2):
|
||||
# sents_head_added, sents_tail_added = 0, 0
|
||||
|
||||
# return target_context_sents
|
||||
|
||||
def calTokenSize(sents):
|
||||
return sum([len(sent) for sent in sents])
|
||||
|
||||
def isSameSent(origin_sent, target_sent):
|
||||
|
||||
origin_size = len(origin_sent)
|
||||
target_size = len(target_sent)
|
||||
|
||||
if origin_size != target_size:
|
||||
return False
|
||||
|
||||
for i in range(origin_size):
|
||||
if origin_sent[i] != target_sent[i]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def stripContext(origin_sents, target_sents, max_token_size=200):
|
||||
|
||||
start_match = True
|
||||
end_match = True
|
||||
origin_token_size = calTokenSize(origin_sents)
|
||||
target_token_size = calTokenSize(target_sents)
|
||||
diff_offset = 0
|
||||
while origin_token_size > max_token_size or target_token_size > max_token_size:
|
||||
|
||||
if len(origin_sents) == 0 or len(target_sents) == 0:
|
||||
break
|
||||
|
||||
if start_match and isSameSent(origin_sents[0], target_sents[0]):
|
||||
# remove the sentence from both origin and target
|
||||
sent_size = len(origin_sents[0])
|
||||
origin_sents = origin_sents[1:]
|
||||
target_sents = target_sents[1:]
|
||||
diff_offset += sent_size
|
||||
origin_token_size -= sent_size
|
||||
target_token_size -= sent_size
|
||||
else:
|
||||
start_match = False
|
||||
|
||||
if len(origin_sents) == 0 or len(target_sents) == 0:
|
||||
break
|
||||
|
||||
if end_match and isSameSent(origin_sents[-1], target_sents[-1]):
|
||||
sent_size = len(origin_sents[-1])
|
||||
origin_sents = origin_sents[:-1]
|
||||
target_sents = target_sents[:-1]
|
||||
origin_token_size -= sent_size
|
||||
target_token_size -= sent_size
|
||||
else:
|
||||
end_match = False
|
||||
|
||||
if not start_match and not end_match:
|
||||
break
|
||||
|
||||
if origin_token_size > max_token_size or target_token_size > max_token_size:
|
||||
origin_sents, target_sents = None, None
|
||||
|
||||
return origin_sents, target_sents, diff_offset
|
||||
|
||||
|
||||
def extractDiffContext(origin_sents, target_sents, origin_diff, target_diff, max_tokens=200):
|
||||
|
||||
origin_context, target_context, diff_offset = stripContext(origin_sents, target_sents)
|
||||
if origin_context != None and target_context != None:
|
||||
origin_diff = [i - diff_offset for i in origin_diff]
|
||||
target_diff = [i - diff_offset for i in target_diff]
|
||||
|
||||
# fix the issue in the case that the appended dot belonging to current sentence is marked as the previous sentence. See example:
|
||||
# + .
|
||||
# + Warren
|
||||
# + ,
|
||||
# ...
|
||||
# + -
|
||||
# + syndicalists
|
||||
# .
|
||||
if len(target_diff) > 0 and target_diff[0] == -1:
|
||||
target_diff = target_diff[1:]
|
||||
|
||||
return origin_context, target_context, origin_diff, target_diff
|
||||
|
||||
|
||||
def filterRevision(comment, diff_list, max_sent_length):
|
||||
filter_pattern = "^.*\s*\W*(revert|undo|undid)(.*)$"
|
||||
filter_regex = re.compile(filter_pattern, re.IGNORECASE)
|
||||
|
||||
if len(diff_list) == 0 or len(diff_list) > max_sent_length:
|
||||
return True
|
||||
comment = comment.strip()
|
||||
if not comment.startswith('/*'):
|
||||
return True
|
||||
elif comment.startswith('/*') and comment.endswith('*/'):
|
||||
return True
|
||||
filter_match = filter_regex.match(comment)
|
||||
if filter_match:
|
||||
return True
|
||||
|
||||
return False
|
||||
#return filter_match
|
||||
|
||||
comment_pattern = "^/\*(.+)\*/(.*)$"
|
||||
comment_regex = re.compile(comment_pattern)
|
||||
def extractSectionTitle(comment):
|
||||
sect_title, sect_cmnt = '', comment
|
||||
comment_match = comment_regex.match(comment)
|
||||
if comment_match:
|
||||
sect_title = comment_match.group(1).strip()
|
||||
sect_cmnt = html.unescape(comment_match.group(2).strip()).strip()
|
||||
return sect_title, sect_cmnt
|
||||
|
||||
def extractSectionText(text, sect_title):
|
||||
sect_content = ''
|
||||
text_match = re.search('(=+)\s*' + sect_title + '\s*(=+)', text)
|
||||
|
||||
if text_match:
|
||||
sect_sign = text_match.group(1)
|
||||
sect_sign_end = text_match.group(2)
|
||||
|
||||
if sect_sign != sect_sign_end:
|
||||
print("ALERT: Section Data Corrupted!! Skip!!")
|
||||
return sect_content
|
||||
|
||||
sect_start = text_match.regs[2][1]
|
||||
remain_sect = text[sect_start:].strip().strip('\n')
|
||||
|
||||
# TODO: Fix the bug of comparing the ===Section Title=== after ==Section==
|
||||
next_sect_match = re.search(sect_sign + '.*' + sect_sign, remain_sect)
|
||||
if next_sect_match:
|
||||
sect_end = next_sect_match.regs[0][0]
|
||||
sect_content = remain_sect[:sect_end].strip().strip('\n')
|
||||
else:
|
||||
sect_content = remain_sect
|
||||
|
||||
return sect_content
|
||||
|
||||
|
||||
'''
|
||||
Merge the sample outputs of all the dump files
|
||||
'''
|
||||
def mergeOutputs(output_path):
|
||||
|
||||
# merge sample results
|
||||
print("Merging the sampled outputs from each files ...")
|
||||
sample_list = glob.glob(output_path + '*.json')
|
||||
sample_file = open(output_path + 'wikicmnt.json', "w", encoding='utf-8')
|
||||
for fn in tqdm(sample_list):
|
||||
with open(fn, 'r', encoding='utf-8') as fi:
|
||||
sample_file.write(fi.read())
|
|
@ -0,0 +1,309 @@
|
|||
import itertools
|
||||
import json
|
||||
import os
|
||||
import os.path
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
# from torchtext import data
|
||||
import pandas as pd
|
||||
import torch
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
from torch.autograd import Variable
|
||||
from tqdm import tqdm
|
||||
|
||||
from wiki_util import tokenizeText
|
||||
|
||||
|
||||
class Wiki_DataSet:
|
||||
|
||||
def __init__(self, args):
|
||||
"""Create an Wiki dataset instance. """
|
||||
|
||||
# self.word_embed_file = self.data_folder + 'embedding/wiki.ar.vec'
|
||||
# word_embed_file = data_folder + "embedding/Wiki-CBOW"
|
||||
self.data_dir = args.data_path
|
||||
|
||||
self.data_file = self.data_dir + "wikicmnt.json"
|
||||
|
||||
self.vocab_file = self.data_dir + 'vocabulary.json'
|
||||
self.train_df_file = self.data_dir + 'train_df.pkl'
|
||||
self.val_df_file = self.data_dir + 'val_df.pkl'
|
||||
self.test_df_file = self.data_dir + 'test_df.pkl'
|
||||
self.tf_file = self.data_dir + 'term_freq.json'
|
||||
self.weight_file = self.data_dir + 'train_weights.json'
|
||||
|
||||
self.glove_dir = args.glove_path
|
||||
self.glove_vec_size = args.word_embd_size
|
||||
self.lemmatizer = WordNetLemmatizer()
|
||||
self.class_num = -1
|
||||
self.rank_num = args.rank_num
|
||||
self.anchor_num = args.anchor_num
|
||||
self.max_ctx_length = int(args.max_ctx_length / 2)
|
||||
pass
|
||||
|
||||
'''
|
||||
Extract the data from raw json file
|
||||
'''
|
||||
|
||||
def extract_data_from_json(self):
|
||||
'''
|
||||
Parse the json file and return the
|
||||
:return:
|
||||
'''
|
||||
cmnt_list, neg_cmnts_list = [], []
|
||||
src_token_list, src_action_list = [], []
|
||||
tgt_token_list, tgt_action_list = [], []
|
||||
pos_edits_list, neg_edits_list = [], []
|
||||
diff_url_list = []
|
||||
|
||||
word_counter, lower_word_counter = Counter(), Counter()
|
||||
|
||||
print("Sample file:", self.data_file)
|
||||
|
||||
with open(self.data_file, 'r', encoding='utf-8') as f:
|
||||
for idx, json_line in enumerate(tqdm(f)):
|
||||
|
||||
# if idx % 100 == 0:
|
||||
# print("== processed ", idx)
|
||||
|
||||
article = json.loads(json_line.strip('\n'))
|
||||
# print(article['diff_url'])
|
||||
|
||||
'''
|
||||
Json file format:
|
||||
=================
|
||||
revision_id: The revision ID
|
||||
parent_id: The parent revision ID
|
||||
timestamp: Timestamp
|
||||
diff_url: The wikipedia link to show the difference between previous and current version.
|
||||
page_title: The title of page.
|
||||
comment: Revision comment.
|
||||
src_token: List of tokens in before-editing version
|
||||
src_action: Action flags for each token in before-editing version. E.g., 0 represents no action; -1 represents removed token.
|
||||
tgt_token: List of tokens in after-editing version
|
||||
tgt_action: Action flags for each token in after-editing version. E.g., 0 represents no action; 1 represents added token.
|
||||
neg_cmnts: Negative samples of user comments in the same page.
|
||||
pos_edits: Edit sentences for comments.
|
||||
neg_edits: Negative edit sentences for comments.
|
||||
'''
|
||||
try:
|
||||
|
||||
# comment
|
||||
comment = article['comment']
|
||||
_, cmnt_tokens = tokenizeText(comment)
|
||||
# cmnt_tokens = article['comment']
|
||||
cmnt_list.append(cmnt_tokens)
|
||||
|
||||
# negative comments
|
||||
# neg_cmnts = article['neg_cmnts']
|
||||
neg_cmnts = []
|
||||
for neg_cmnt in article['neg_cmnts']:
|
||||
_, tokens = tokenizeText(comment)
|
||||
neg_cmnts.append(tokens)
|
||||
neg_cmnts_list.append(neg_cmnts)
|
||||
|
||||
# source tokens and actions
|
||||
src_token = article['src_token']
|
||||
src_token_list.append(src_token)
|
||||
src_action_list.append(article['src_action'])
|
||||
|
||||
# target tokens and actions
|
||||
tgt_token = article['tgt_token']
|
||||
tgt_token_list.append(tgt_token)
|
||||
tgt_action_list.append(article['tgt_action'])
|
||||
|
||||
# positive and negative edits
|
||||
pos_edits = article['pos_edits']
|
||||
pos_edits_list.append(pos_edits)
|
||||
|
||||
neg_edits = article['neg_edits']
|
||||
neg_edits_list.append(neg_edits)
|
||||
|
||||
# diff url
|
||||
diff_url = article['diff_url']
|
||||
diff_url_list.append(diff_url)
|
||||
|
||||
# for counters
|
||||
for word in cmnt_tokens + src_token + tgt_token + \
|
||||
list(itertools.chain.from_iterable(neg_cmnts + pos_edits + neg_edits)):
|
||||
word_counter[word] += 1
|
||||
lower_word_counter[word.lower()] += 1
|
||||
|
||||
except:
|
||||
# ignore the index error
|
||||
print("ERROR: Index Error", article['revision_id'])
|
||||
continue
|
||||
|
||||
# if idx >= 100:
|
||||
# break
|
||||
|
||||
return cmnt_list, neg_cmnts_list, src_token_list, src_action_list, tgt_token_list, tgt_action_list, \
|
||||
pos_edits_list, neg_edits_list, word_counter, lower_word_counter, diff_url_list
|
||||
|
||||
'''
|
||||
Create dataset objects for wiki revision data.
|
||||
Arguments:
|
||||
args: arguments
|
||||
val_ratio: The ratio that will be used to get split validation dataset.
|
||||
shuffle: Whether to shuffle the data before split.
|
||||
'''
|
||||
|
||||
def load_data(self, train_ratio, val_ratio):
|
||||
|
||||
print("loading wiki data ...")
|
||||
|
||||
# check the existence of data files
|
||||
if os.path.isfile(self.train_df_file) and os.path.isfile(self.test_df_file) and os.path.isfile(self.vocab_file):
|
||||
print("dataframe file exists:", self.train_df_file)
|
||||
train_df = pd.read_pickle(self.train_df_file)
|
||||
val_df = pd.read_pickle(self.val_df_file)
|
||||
test_df = pd.read_pickle(self.test_df_file)
|
||||
vocab_json = json.load(open(self.vocab_file))
|
||||
else:
|
||||
cmnts, neg_cmnts, src_tokens, src_actions, \
|
||||
tgt_tokens, tgt_actions, pos_edits, neg_edits, \
|
||||
word_counter, lower_word_counter, diff_url = self.extract_data_from_json()
|
||||
|
||||
word2vec_dict = self.get_word2vec(word_counter)
|
||||
lower_word2vec_dict = self.get_word2vec(lower_word_counter)
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
'cmnt_words': cmnts, "neg_cmnts": neg_cmnts,
|
||||
'src_tokens': src_tokens, "src_actions": src_actions,
|
||||
'tgt_tokens': tgt_tokens, "tgt_actions": tgt_actions,
|
||||
'pos_edits': pos_edits, 'neg_edits': neg_edits, 'diff_url': diff_url
|
||||
}
|
||||
)
|
||||
|
||||
total_size = len(df)
|
||||
self.train_size = int(total_size * train_ratio)
|
||||
val_size = int(total_size * val_ratio)
|
||||
|
||||
# test_ratio = 0.2
|
||||
# train_df, test_df = train_test_split(df,
|
||||
# test_size=test_ratio, random_state=967898)
|
||||
train_df = df[:self.train_size]
|
||||
val_df = df[self.train_size:self.train_size + val_size]
|
||||
test_df = df[self.train_size + val_size:]
|
||||
|
||||
print("saving data into pickle ...")
|
||||
train_df.to_pickle(self.data_dir + 'train_df.pkl')
|
||||
test_df.to_pickle(self.data_dir + 'test_df.pkl')
|
||||
val_df.to_pickle(self.data_dir + 'val_df.pkl')
|
||||
|
||||
w2i = {w: i for i, w in enumerate(word_counter.keys(), 3)}
|
||||
NULL = "-NULL-"
|
||||
UNK = "-UNK-"
|
||||
ENT = "-ENT-"
|
||||
w2i[NULL] = 0
|
||||
w2i[UNK] = 1
|
||||
w2i[ENT] = 2
|
||||
|
||||
# save word2vec dictionary
|
||||
vocab_json = {'word2idx': w2i, 'word2vec': word2vec_dict, 'lower_word2vec': lower_word2vec_dict}
|
||||
json.dump(vocab_json, open(self.vocab_file, 'w', encoding='utf-8'))
|
||||
|
||||
return train_df, val_df, test_df, vocab_json
|
||||
|
||||
'''
|
||||
batch padding
|
||||
'''
|
||||
|
||||
def pad_batch(self, mini_batch, padding_size):
|
||||
mini_batch_size = len(mini_batch)
|
||||
# mean_sent_len = int(np.mean([len(x) for x in mini_batch]))
|
||||
main_matrix = np.zeros((mini_batch_size, padding_size), dtype=np.long)
|
||||
for i in range(main_matrix.shape[0]):
|
||||
for j in range(main_matrix.shape[1]):
|
||||
try:
|
||||
main_matrix[i, j] = mini_batch[i][j]
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
# transfer the tensor to LongTensor for some compatibility issues
|
||||
return Variable(torch.from_numpy(main_matrix).transpose(0, 1).type(torch.LongTensor))
|
||||
|
||||
'''
|
||||
Generate minibatches from data frame
|
||||
'''
|
||||
|
||||
def iterate_minibatches(self, df, batch_size, cur_epoch=-1, n_epoch=-1):
|
||||
|
||||
cmnt_words = df.cmnt_words.tolist()
|
||||
neg_cmnts = df.neg_cmnts.tolist()
|
||||
src_tokens = df.src_tokens.tolist()
|
||||
src_actions = df.src_actions.tolist()
|
||||
tgt_tokens = df.tgt_tokens.tolist()
|
||||
tgt_actions = df.tgt_actions.tolist()
|
||||
pos_edits = df.pos_edits.tolist()
|
||||
neg_edits = df.neg_edits.tolist()
|
||||
diff_urls = df.diff_url.tolist()
|
||||
|
||||
indices = np.arange(len(cmnt_words))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
cmnt_words = [cmnt_words[i] for i in indices]
|
||||
neg_cmnts = [neg_cmnts[i] for i in indices]
|
||||
src_tokens = [src_tokens[i] for i in indices]
|
||||
src_actions = [src_actions[i] for i in indices]
|
||||
tgt_tokens = [tgt_tokens[i] for i in indices]
|
||||
tgt_actions = [tgt_actions[i] for i in indices]
|
||||
pos_edits = [pos_edits[i] for i in indices]
|
||||
neg_edits = [neg_edits[i] for i in indices]
|
||||
diff_urls = [diff_urls[i] for i in indices]
|
||||
|
||||
for start_idx in range(0, len(cmnt_words) - batch_size + 1, batch_size):
|
||||
|
||||
# initialize batch variables
|
||||
batch_cmnt, batch_neg_cmnt, batch_src_tokens, batch_src_actions, \
|
||||
batch_tgt_tokens, batch_tgt_actions, batch_pos_edits, batch_neg_edits, batch_diffurls = [], [], [], [], [], [], [], [], []
|
||||
for i in range(start_idx, start_idx + batch_size):
|
||||
batch_cmnt.append(cmnt_words[i])
|
||||
batch_neg_cmnt.append(neg_cmnts[i])
|
||||
batch_src_tokens.append(src_tokens[i])
|
||||
batch_src_actions.append(src_actions[i])
|
||||
batch_tgt_tokens.append(tgt_tokens[i])
|
||||
batch_tgt_actions.append(tgt_actions[i])
|
||||
batch_pos_edits.append(pos_edits[i])
|
||||
batch_neg_edits.append(neg_edits[i])
|
||||
batch_diffurls.append(diff_urls[i])
|
||||
|
||||
yield batch_cmnt, batch_neg_cmnt, batch_src_tokens, batch_src_actions, batch_tgt_tokens, batch_tgt_actions, batch_pos_edits, batch_neg_edits, batch_diffurls
|
||||
|
||||
def get_datafile(self):
|
||||
return self.data_file
|
||||
|
||||
def get_tffile(self):
|
||||
return self.tf_file
|
||||
|
||||
def get_weight_file(self):
|
||||
return self.weight_file
|
||||
|
||||
def get_train_size(self):
|
||||
return self.train_size
|
||||
|
||||
def get_word2vec(self, word_counter):
|
||||
glove_path = os.path.join(self.glove_dir, "glove.6B.{}d.txt".format(self.glove_vec_size))
|
||||
print('----glove_path', glove_path)
|
||||
sizes = {'6B': int(4e5), '42B': int(1.9e6), '840B': int(2.2e6), '2B': int(1.2e6)}
|
||||
total = sizes['6B']
|
||||
word2vec_dict = {}
|
||||
with open(glove_path, 'r', encoding='utf-8') as fh:
|
||||
for line in tqdm(fh, total=total):
|
||||
array = line.lstrip().rstrip().split(" ")
|
||||
word = array[0]
|
||||
vector = list(map(float, array[1:]))
|
||||
if word in word_counter:
|
||||
word2vec_dict[word] = vector
|
||||
elif word.capitalize() in word_counter:
|
||||
word2vec_dict[word.capitalize()] = vector
|
||||
elif word.lower() in word_counter:
|
||||
word2vec_dict[word.lower()] = vector
|
||||
elif word.upper() in word_counter:
|
||||
word2vec_dict[word.upper()] = vector
|
||||
|
||||
print("{}/{} of word vocab have corresponding vectors in {}".format(len(word2vec_dict), len(word_counter),
|
||||
glove_path))
|
||||
return word2vec_dict
|
|
@ -0,0 +1,196 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# from layers.highway import Highway
|
||||
|
||||
class WordEmbedding(nn.Module):
|
||||
'''
|
||||
In : (N, sentence_len)
|
||||
Out: (N, sentence_len, embd_size)
|
||||
'''
|
||||
|
||||
def __init__(self, args, is_train_embd=False):
|
||||
super(WordEmbedding, self).__init__()
|
||||
self.embedding = nn.Embedding(args.vocab_size_w, args.word_embd_size)
|
||||
if args.pre_embd_w is not None:
|
||||
self.embedding.weight = nn.Parameter(args.pre_embd_w, requires_grad=is_train_embd)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embedding(x)
|
||||
|
||||
|
||||
class CmntModel(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(CmntModel, self).__init__()
|
||||
self.batch_size = args.batch_size
|
||||
self.embd_size = args.word_embd_size
|
||||
self.cmnt_length = args.max_cmnt_length
|
||||
self.d = self.embd_size # word_embedding
|
||||
self.no_action = args.no_action
|
||||
self.no_attention = args.no_attention
|
||||
self.no_hadamard = args.no_hadamard
|
||||
# self.d = self.embd_size # only word_embedding
|
||||
|
||||
# self.char_embd_net = CharEmbedding(args)
|
||||
self.word_embd_net = WordEmbedding(args)
|
||||
# self.highway_net = Highway(self.d)
|
||||
self.ctx_embd_layer = nn.GRU(self.d, self.d, bidirectional=True, dropout=0.2, batch_first=True)
|
||||
|
||||
self.W = nn.Linear(6 * self.d + 1, 1, bias=False)
|
||||
|
||||
# self.W2_noact = nn.Linear(2 * self.d, 1, bias=False)
|
||||
# self.W2 = nn.Linear(2 * self.d + 1, 1, bias=False)
|
||||
|
||||
# weights for attention layer
|
||||
if self.no_hadamard and self.no_action:
|
||||
# (1, 1)
|
||||
self.W2_nhna = nn.Linear(1, 1, bias=False)
|
||||
elif self.no_hadamard and not self.no_action:
|
||||
# (2, 1)
|
||||
self.W2_nha = nn.Linear(2, 1, bias=False)
|
||||
elif not self.no_hadamard and self.no_action:
|
||||
# (2d, 1)
|
||||
self.W2_hna = nn.Linear(2 * self.d, 1, bias=False)
|
||||
elif not self.no_hadamard and not self.no_action:
|
||||
# (2d+1, 1)
|
||||
self.W2_ha = nn.Linear(2 * self.d + 1, 1, bias=False)
|
||||
|
||||
self.modeling_layer = nn.GRU(8 * self.d, self.d, num_layers=2, bidirectional=True, dropout=0.2,
|
||||
batch_first=True)
|
||||
|
||||
# Linear function for comment ranking
|
||||
self.rank_linear = nn.Linear(self.cmnt_length * 2, 1, bias=True)
|
||||
self.rank_ctx_linear = nn.Linear(self.cmnt_length * 4, 1, bias=True)
|
||||
|
||||
# Linear function for edit anchoring
|
||||
self.anchor_linear = nn.Linear(self.cmnt_length * 2, 2, bias=True)
|
||||
|
||||
self.use_target_only = args.use_target_only
|
||||
self.ctx_mode = 1
|
||||
# self.p2_lstm_layer = nn.GRU(2*self.d, self.d, bidirectional=True, dropout=0.2, batch_first=True)
|
||||
|
||||
def build_contextual_embd(self, x_w):
|
||||
# 1. Word Embedding Layer
|
||||
embd = self.word_embd_net(x_w) # (N, seq_len, embd_size)
|
||||
|
||||
# 2. Highway Networks for 1.
|
||||
# embd = self.highway_net(word_embd) # (N, seq_len, d=embd_size)
|
||||
|
||||
# 3. Contextual Embedding Layer
|
||||
ctx_embd_out, _h = self.ctx_embd_layer(embd)
|
||||
return ctx_embd_out
|
||||
|
||||
def build_cmnt_sim(self, embd_context, embd_cmnt, embd_action, batch_size, T, J):
|
||||
|
||||
shape = (batch_size, T, J, 2 * self.d) # (N, T, J, 2d)
|
||||
embd_context_ex = embd_context.unsqueeze(2) # (N, T, 1, 2d)
|
||||
embd_cmnt_ex = embd_cmnt.unsqueeze(1) # (N, 1, J, 2d)
|
||||
|
||||
# action embedding
|
||||
embd_action_ex = embd_action.float().unsqueeze(2).unsqueeze(2)
|
||||
embd_action_ex = embd_action_ex.expand((batch_size, T, J, 1))
|
||||
|
||||
if self.no_hadamard:
|
||||
|
||||
if self.no_action:
|
||||
raise Exception('no hadamard cannot be used with -no_action simultaneously')
|
||||
# use inner product to replace the hadamard product
|
||||
# generate (N, T, J, 1)
|
||||
embd_cmnt_ex = embd_cmnt_ex.permute(0, 2, 3, 1) # (N, J, 2d, 1)
|
||||
# batch1 = torch.randn(10, 3, 4)
|
||||
# batch2 = torch.randn(10, 4, 5)
|
||||
# (N, T, 1, 2d) * (N, J, 2d, 1) => (N, T, J, 1)
|
||||
a_dotprod_mul_b = torch.einsum('ntid,njdi->ntji', [embd_context_ex, embd_cmnt_ex])
|
||||
|
||||
# no hadamard & action
|
||||
cat_data = torch.cat((a_dotprod_mul_b, embd_action_ex), 3) # (N, T, J, 2), [h◦u; a]
|
||||
S = self.W2_nha(cat_data).view(batch_size, T, J) # (N, T, J)
|
||||
else:
|
||||
embd_context_ex = embd_context_ex.expand(shape) # (N, T, J, 2d)
|
||||
|
||||
embd_cmnt_ex = embd_cmnt_ex.expand(shape) # (N, T, J, 2d)
|
||||
a_elmwise_mul_b = torch.mul(embd_context_ex, embd_cmnt_ex) # (N, T, J, 2d)
|
||||
|
||||
if self.no_action:
|
||||
# hadamard & no action
|
||||
S = self.W2_hna(a_elmwise_mul_b).view(batch_size, T, J) # (N, T, J)
|
||||
else:
|
||||
# hadamard & action
|
||||
cat_data = torch.cat((a_elmwise_mul_b, embd_action_ex), 3) # (N, T, J, 2d + 1), [h◦u; a]
|
||||
S = self.W2_ha(cat_data).view(batch_size, T, J) # (N, T, J)
|
||||
|
||||
if self.no_attention:
|
||||
# without using attention, simply use the mean of similarity matrix in edit dimension
|
||||
S_cmnt = torch.mean(S, 1)
|
||||
else:
|
||||
# attention implementation:
|
||||
# b: attention weights on the context
|
||||
b = F.softmax(torch.max(S, 2)[0], dim=-1) # (N, T)
|
||||
S_cmnt = torch.bmm(b.unsqueeze(1), S) # (N, 1, J) = bmm( (N, 1, T), (N, T, J) )
|
||||
S_cmnt = S_cmnt.squeeze(1) # (N, J)
|
||||
|
||||
# max implementation
|
||||
# S_cmnt = torch.max(S, 1)[0]
|
||||
|
||||
# c: attention weights on the comment
|
||||
# c = torch.max(S, 1)[0] # (N, J)
|
||||
# S_cmnt = c * S_cmnt # (N, J) = (N, J) * (N, J)
|
||||
|
||||
# c2q = torch.bmm(F.softmax(S, dim=-1), embd_cmnt) # (N, T, 2d) = bmm( (N, T, J), (N, J, 2d) )
|
||||
# c2q = torch.bmm(F.softmax(S, dim=-1), embd_cmnt) # (N, J, 1) = bmm( (N, J, T), (N, T, 1) )
|
||||
|
||||
return S_cmnt, S
|
||||
|
||||
# cmnt_words, neg_cmnt_words, src_diff_words, tgt_diff_words
|
||||
def forward(self, cmnt, src_token, src_action, tgt_token, tgt_action, cr_mode=True, cl_mode=False):
|
||||
|
||||
batch_size = cmnt.size(0)
|
||||
T = src_token.size(1) # sentence length = 100 (word level)
|
||||
# C = src_token.size(1) # context sentence length = 200 (word level)
|
||||
J = cmnt.size(1) # cmnt sentence length = 30 (word level)
|
||||
|
||||
# ####################################################################################
|
||||
# 1. Word Embedding Layer
|
||||
# 2. Contextual Embedding Layer (GRU)
|
||||
######################################################################################
|
||||
embd_src_diff = self.build_contextual_embd(src_token) # (N, T, 2d)
|
||||
embd_tgt_diff = self.build_contextual_embd(tgt_token) # (N, T, 2d)
|
||||
|
||||
if cl_mode:
|
||||
return embd_src_diff + embd_tgt_diff # (N, T, 2d)
|
||||
|
||||
embd_cmnt = self.build_contextual_embd(cmnt) # (N, J, 2d)
|
||||
|
||||
# if self.ctx_mode:
|
||||
# embd_src_ctx = self.build_contextual_embd(src_ctx) # (N, C, 2d)
|
||||
# embd_tgt_ctx = self.build_contextual_embd(tgt_ctx) # (N, C, 2d)
|
||||
|
||||
# ####################################################################################
|
||||
# 3. Similarity Layer
|
||||
######################################################################################
|
||||
|
||||
S_src_diff, _ = self.build_cmnt_sim(embd_src_diff, embd_cmnt, src_action, batch_size, T, J) # (N, J)
|
||||
S_tgt_diff, _ = self.build_cmnt_sim(embd_tgt_diff, embd_cmnt, tgt_action, batch_size, T, J) # (N, J)
|
||||
S_diff = torch.cat((S_src_diff, S_tgt_diff), 1) # (N, 2J)
|
||||
|
||||
# if self.ctx_mode:
|
||||
# S_src_ctx, _ = self.build_cmnt_sim(embd_src_ctx, embd_cmnt, batch_size, C, J)
|
||||
# S_tgt_ctx, _ = self.build_cmnt_sim(embd_tgt_ctx, embd_cmnt, batch_size, C, J)
|
||||
# S_ctx = torch.cat((S_src_ctx, S_tgt_ctx), 1) # (N, 2J)
|
||||
# score = self.rank_ctx_linear(torch.cat((S_diff, S_ctx), 1)) # (N, 2J) -> (N, 1)
|
||||
# else:
|
||||
if cr_mode:
|
||||
result = self.rank_linear(S_diff) # (N, 2J) -> (N, 1)
|
||||
else:
|
||||
result = self.anchor_linear(S_diff) # (N, 2J) -> (N, 2)
|
||||
|
||||
# if self.use_target_only:
|
||||
# S_diff = S_tgt_diff # (N, J)
|
||||
# else:
|
||||
# #S_diff = S_src_diff + S_tgt_diff # (N, J)
|
||||
# S_diff = torch.cat((S_src_diff, S_tgt_diff), 1)
|
||||
# #S = (torch.cat((S_src_diff, S_tgt_diff), 1)
|
||||
|
||||
return result, S_diff
|
Загрузка…
Ссылка в новой задаче