This commit is contained in:
Sujay Kumar Jauhar 2019-08-26 17:53:12 -07:00
Родитель aa3bffc42b
Коммит 5254314c7a
17 изменённых файлов: 3344 добавлений и 104 удалений

218
.gitignore поставляемый
Просмотреть файл

@ -1,104 +1,114 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
# Mac
.DS_Store
# model cache
*.pt
# dataset
rt-polaritydata
trees
*.tar
*.zip
# pycharm
.idea
.idea/
output/
*.7z
.DS_Store
data/
data/glove/
data/processed/enwiki*
data/raw/
data_processing/extractor.txt

165
README.md
Просмотреть файл

@ -1,3 +1,168 @@
Modeling the Relationship Between Comments and Edits in Document Revisions
======
This is a pytorch implementation of modeling the relationship between comments and edits for wikipedia revision data, as described in our EMNLP 2019 paper:
**Modeling the Relationship between User Comments and Edits in Document Revision**, Xuchao Zhang, Dheeraj Rajagopal, Michael Gamon, Sujay Kumar Jauhar and ChangTien Lu, 2019 Conference on Empirical Methods in Natural Language Processing (EMNLP), Hongkong, China, Nov 3-7, 2019.
Two distinct but related tasks are proposed in this work:
- **Comment Ranking**: ranking a list of comments based on their relevance to a specific edit
- **Edit Anchoring**: anchoring a comment to the most relevant sentences in a document
## Requirements
- python 3.6
- [tqdm](https://github.com/noamraph/tqdm)
- [pytorch 0.4.0](https://pytorch.org/)
- [numpy v1.13+](http://www.numpy.org/)
- [scipy](https://www.scipy.org/)
- [scikit-learn](http://scikit-learn.org/stable/)
- [spacy v2.0.11](https://spacy.io/)
- and some basic packages.
## Usage
### Data Preparation
- Download the raw data from Wikipedia and generate the preprocessed data file ```wikicmnt.json``` by the steps in our [Data Extractor](./data_processing/README.md). And leave the generated data file ```wikicmnt.json``` in the dataset folder such as ```./data/processed/```.
<!-- For demo purpose, we include a small dump file ```enwiki-20181001-pages-meta-history24.xml-p33948436p33952815.bz2``` (17MB) in the ```./dataset/``` folder. -->
- Download the glove embeddings from ```https://nlp.stanford.edu/projects/glove/```. And copy these files into the folder you specified in the parameter such as ```./dataset/glove/```.
### Training
In this section, we try to train the models for both comment ranking and edit anchoring tasks individually or jointly. Before training a model, you need to check whether the ```wikicmnt.json``` file existing in the ```--data_path```.
##### Comment Ranking Training
To train the model for comment ranking task, run the following command:
```
python3 main.py --cr_train --data_path="./data/processed/"
```
##### Edit Anchoring Training
To train the model for edit anchoring task, run the following command:
```
python3 main.py --ea_train --data_path="./data/processed/"
```
##### Jointly train on both Comment Ranking & Edit Anchoring tasks
Train a multi-task model with default parameters is simply combining the parameters of individual task together:
```
python3 main.py --cr_train --ea_train --data_path="./dataset/processed/"
```
The common parameters you can change:
```
--data_path="./dataset/"
--glove_path="./dataset/glove/"
--epoch=20
--batch_size=10
--word_embd_size=100
```
The best model is saved to default folder ```./checkpoint/```.
### Test
Test saved model. The following default metrics are presented: P@1, P@3, MRR and NDCG
```
python main.py --test -checkpoint="./checkpoint/saved_best_model.pt"
```
### Full Usage Options
A full options of our code are listed
```
usage: main.py [-h] [--data_path DATA_PATH]
[--checkpoint_path CHECKPOINT_PATH] [--glove_path GLOVE_PATH]
[--log_interval LOG_INTERVAL] [--test_interval TEST_INTERVAL]
[--save_best SAVE_BEST] [--lr LR] [--ngpu NGPU]
[--word_embd_size WORD_EMBD_SIZE]
[--max_ctx_length MAX_CTX_LENGTH]
[--max_diff_length MAX_DIFF_LENGTH]
[--max_cmnt_length MAX_CMNT_LENGTH] [--ctx_mode CTX_MODE]
[--rnn_model] [--epoch EPOCH] [--start_epoch START_EPOCH]
[--batch_size BATCH_SIZE] [--cr_train] [--ea_train]
[--no_action] [--no_attention] [--no_hadamard]
[--src_train SRC_TRAIN] [--train_ratio TRAIN_RATIO]
[--val_ratio VAL_RATIO] [--val_size VAL_SIZE]
[--manualSeed MANUALSEED] [--test] [--case_study]
[--resume PATH] [--seed SEED] [--gpu GPU]
[--checkpoint CHECKPOINT] [--rank_num RANK_NUM]
[--anchor_num ANCHOR_NUM] [--use_target_only] [--predict]
[--pred_cmnt PRED_CMNT] [--pred_ctx PRED_CTX]
optional arguments:
-h, --help show this help message and exit
--data_path DATA_PATH
the data directory
--checkpoint_path CHECKPOINT_PATH
the checkpoint directory
--glove_path GLOVE_PATH
the glove directory
--log_interval LOG_INTERVAL
how many steps to wait before logging training status
[default: 100]
--test_interval TEST_INTERVAL
how many steps to wait before testing [default: 1000]
--save_best SAVE_BEST
whether to save when get best performance
--lr LR learning rate, default=0.5
--ngpu NGPU number of GPUs to use
--word_embd_size WORD_EMBD_SIZE
word embedding size
--max_ctx_length MAX_CTX_LENGTH
the maximum words in the context [default: 300]
--max_diff_length MAX_DIFF_LENGTH
the maximum words in the revision difference [default:
200]
--max_cmnt_length MAX_CMNT_LENGTH
the maximum words in the comment [default: 30]
--ctx_mode CTX_MODE whether to use change context in training [default:
True]
--rnn_model use rnn baseline model
--epoch EPOCH number of epoch, default=10
--start_epoch START_EPOCH
resume epoch count, default=1
--batch_size BATCH_SIZE
input batch size
--cr_train whether to training the comment rank task
--ea_train whether to training the revision anchoring task
--no_action whether to use action encoding to train the model
--no_attention whether to use mutual attention to train the model
--no_hadamard whether to use hadamard product to train the model
--src_train SRC_TRAIN
whether to training the comment rank task without
before-editing version
--train_ratio TRAIN_RATIO
ratio of training data in the entire data [default:
0.7]
--val_ratio VAL_RATIO
ratio of validation data in the entire data [default:
0.1]
--val_size VAL_SIZE force the size of validation dataset, the parameter
will disgard the setting of val_ratio [default: -1]
--manualSeed MANUALSEED
manual seed
--test use test model
--case_study use case study mode
--resume PATH path saved params
--seed SEED random seed
--gpu GPU gpu to use for iterate data, -1 mean cpu [default: -1]
--checkpoint CHECKPOINT
filename of model checkpoint [default: None]
--rank_num RANK_NUM the number of ranking comments
--anchor_num ANCHOR_NUM
the number of ranking comments
--use_target_only use target context only in model
--predict predict the sentence given
--pred_cmnt PRED_CMNT
the comment of prediction
--pred_ctx PRED_CTX the context of prediction
```
## Author
If you have any troubles or questions, please contact [Xuchao Zhang]().
August, 2018
# Contributing

112
data_processing/.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,112 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
# Mac
.DS_Store
# model cache
*.pt
# dataset
rt-polaritydata
trees
*.tar
*.zip
# pycharm
.idea
data/
dataset/
output/
wikidata/
*.7z
*.bz2
sumdata_api.py

91
data_processing/README.md Normal file
Просмотреть файл

@ -0,0 +1,91 @@
Wikipedia Revision Data Extractor
======
This is a python implementation of the data extractor tool set for wikipedia revision data, as described in our EMNLP 2019 paper:
**Modeling the Relationship between User Comments and Edits in Document Revision**, Xuchao Zhang, Dheeraj Rajagopal, Michael Gamon, Sujay Kumar Jauhar and ChangTien Lu, 2019 Conference on Empirical Methods in Natural Language Processing (EMNLP), Hongkong, China, Nov 3-7, 2019.
We provide three tools to extract and preprocess the wikipedia revision history data from scratch:
- download entire enwiki revision dumps from wikipedia
- extract the revision data for comment modeling task from wiki dump files
- extract the summeration task dataset from wiki dump
Note: The collected wikipedia revision data can be used as the input for the proposed models in our EMNLP paper or used individually for other tasks.
## Requirements
- python 3.6
- [tqdm](https://github.com/noamraph/tqdm)
- [numpy v1.13+](http://www.numpy.org/)
- [scipy](https://www.scipy.org/)
- [scikit-learn](http://scikit-learn.org/stable/)
- [spacy v2.0.11](https://spacy.io/)
- and some basic packages.
## Usage
### Download Wiki dump files
First, choose a dump such as ```https://dumps.wikimedia.org/enwiki/20190801``` (the latest version of wiki dump when our code is released). You can check all the information related to dump files from this page such as the list of files generated in this dump. Then download a machine-readable dump status file ```dumpstatus.json``` from the Wikipedia dump page. Next copy the status file into the default data path, e.g., ```./data/```.
###### Important Note:
* Check the dump files must contain the complete page edit history: ```All pages with complete page edit history (.bz2)```. The edit history is sometimes skipped by some specific versions.
* Always choose the recent dumps since Wikipedia cleans the old dumps and make the old one deprecated.
Finally, run our wiki dump download script to download dump files as follows:
```
python wiki_dump_download.py --data_path="./data/raw/" --compress_type="bz2" --threads=3
```
You need to specify the data path and compress type (by default choose bz2 ). Since the download process will be extremely slow, you can use multiple threads to download the dump files. However, Wikipedia only allows three http connections to download simultaneously for each IP address. The maximum threads I recommend is three unless you can assign different IP address for each threads.
At the beginning of download process, all the files are listed with unique Ids as follows:
```
All files to download ...
1 enwiki-20190801-pages-meta-history1.xml-p1043p2036.bz2
2 enwiki-20190801-pages-meta-history1.xml-p10675p11499.bz2
3 enwiki-20190801-pages-meta-history1.xml-p10p1042.bz2
4 enwiki-20190801-pages-meta-history1.xml-p11500p12310.bz2
...
...
648 enwiki-20190801-pages-meta-history9.xml-p2326206p2336422.bz2
```
Usually, the entire download process takes one to two days to be done. You can download each file individually by specifying the ```--start``` and ```--end``` parameters. You can also use ```--verify``` parameter to verify the completeness of your dump files.
### Revision Data Preprocessing
For preprocessing the revision data, we provide both single-thread and multi-thread versions.
To preprocess a single dump file, we specify the file index of the dump file by the parameter ```--single_thread``` as follows:
```
python3 wikicmnt_extractor.py --single_thread=5
```
Here the number 5 in the example means the 5th dump file in the data folder of dump files.
To preprocess multiple dump files,
```
python3 wikicmnt_extractor.py --threads=10
```
You need to specify some common parameters:
```
--data_path="../data/raw/"
--output_path="../data/processed/"
--sample_ratio=0.1
--ctx_window=5
--min_page_tokens=50
--max_page_tokens=2000
--min_cmnt_length=8
```
Last, if you use the single thread mode to generate the processed files one by one, you need to merge the outputs of all the dump files together by running the following command:
```
python3 wikicmnt_extractor.py --merge_only
```
The output of the command is a processed file ```wikicmnt.json``` which includes all the processed data.
## Author
If you have any troubles or questions, please contact [Xuchao Zhang](xuczhang@gmail.com).
August, 2019

Просмотреть файл

@ -0,0 +1,186 @@
import argparse
import glob
import hashlib
import json
import logging
import os
import threading
import urllib.request
from datetime import datetime
parser = argparse.ArgumentParser(description='WikiDump Downloader')
parser.add_argument('--data-path', type=str, default="./data/", help='the data directory')
parser.add_argument('--compress-type', type=str, default='bz2',
help='the compressed file type to download: 7z or bz2 [default: bz2]')
parser.add_argument('--threads', type=int, default=3, help='number of threads [default: 3]')
parser.add_argument('--start', type=int, default=1, help='the first file to download [default: 0]')
parser.add_argument('--end', type=int, default=-1, help='the last file to download [default: -1]')
parser.add_argument('--verify', action='store_true', default=False, help='verify the dump files in the specific path')
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG,
format='(%(threadName)s) %(message)s',
)
def download(dump_status_file, data_path, compress_type, start, end, thread_num):
url_list = []
file_list = []
with open(dump_status_file) as json_data:
# Two dump types: compressed by 7z (metahistory7zdump) or bz2 (metahistorybz2dump)
history_dump = json.load(json_data)['jobs']['metahistory' + compress_type + 'dump']
dump_dict = history_dump['files']
dump_files = sorted(list(dump_dict.keys()))
if args.end > 0 and args.end <= len(dump_files):
dump_files = dump_files[start - 1:end]
else:
dump_files = dump_files[start - 1:]
# print all files to be downloaded.
print("All files to download ...")
for i, file in enumerate(dump_files):
print(i + args.start, file)
file_num = 0
for dump_file in dump_files:
file_name = data_path + dump_file
file_list.append(file_name)
# url example: https://dumps.wikimedia.org/enwiki/20180501/enwiki-20180501-pages-meta-history1.xml-p10p2123.7z
url = "https://dumps.wikimedia.org" + dump_dict[dump_file]['url']
url_list.append(url)
file_num += 1
print('Total file ', file_num, ' to be downloaded ...')
json_data.close()
task = WikiDumpTask(file_list, url_list)
threads = []
for i in range(thread_num):
t = threading.Thread(target=worker, args=(i, task))
threads.append(t)
t.start()
logging.debug('Waiting for worker threads')
main_thread = threading.currentThread()
for t in threading.enumerate():
if t is not main_thread:
t.join()
def existFile(data_path, cur_file):
exist_file_list = glob.glob(data_path + "*." + args.compress_type)
exist_file_names = [os.path.basename(i) for i in exist_file_list]
cur_file_name = os.path.basename(cur_file)
if cur_file_name in exist_file_names:
return True
return False
def md5(file):
hash_md5 = hashlib.md5()
with open(file, "rb") as f:
for chunk in iter(lambda: f.read(40960000), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def verify(dump_status_file, compress_type, data_path):
print("Verify the file in folder:", data_path)
pass_files, miss_files, crash_files = [], [], []
with open(dump_status_file) as json_data:
# Two dump types: compressed by 7z (metahistory7zdump) or bz2 (metahistorybz2dump)
history_dump = json.load(json_data)['jobs']['metahistory' + compress_type + 'dump']
dump_dict = history_dump['files']
for i, (file, value) in enumerate(dump_dict.items()):
gt_md5 = value['md5']
print("#", i, " ", file, ' ', value['md5'], sep='')
if existFile(data_path, file):
file_md5 = md5(data_path + file)
if file_md5 == gt_md5:
pass_files.append(file)
else:
crash_files.append(file)
else:
miss_files.append(file)
print(len(pass_files), "files passed, ", len(miss_files), "files missed, ", len(crash_files), "files crashed.")
if len(miss_files):
print("==== Missed Files ====")
print(miss_files)
if len(crash_files):
print("==== Crashed Files ====")
print(crash_files)
def main():
dump_status_file = args.data_path + "dumpstatus.json"
if args.verify:
verify(dump_status_file, args.compress_type, args.data_path)
else:
download(dump_status_file, args.data_path, args.compress_type, args.start, args.end, args.threads)
'''
WikiDumpTask class contains a list of dump files to be downloaded .
The assign_task function will be called by workers to grab a task.
'''
class WikiDumpTask(object):
def __init__(self, file_list, url_list):
self.lock = threading.Lock()
self.url_list = url_list
self.file_list = file_list
self.total_num = len(url_list)
def assign_task(self):
logging.debug('Assign tasks ... Waiting for lock')
self.lock.acquire()
url = None
file_name = None
cur_progress = None
try:
# logging.debug('Acquired lock')
if len(self.url_list) > 0:
url = self.url_list.pop(0)
file_name = self.file_list.pop(0)
cur_progress = self.total_num - len(self.url_list)
finally:
self.lock.release()
return url, file_name, cur_progress, self.total_num
'''
worker is main function for each thread.
'''
def worker(work_id, tasks):
logging.debug('Starting.')
# grab one task from task_list
while 1:
url, file_name, cur_progress, total_num = tasks.assign_task()
if not url:
break
logging.debug('Assigned task (' + str(cur_progress) + '/' + str(total_num) + '): ' + str(url))
if not existFile(args.data_path, file_name):
urllib.request.urlretrieve(url, file_name)
logging.debug("File Downloaded: " + url)
else:
logging.debug("File Exists, Skip: " + url)
logging.debug('Exiting.')
return
if __name__ == '__main__':
start_time = datetime.now()
main()
time_elapsed = datetime.now() - start_time
print('Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed))

Просмотреть файл

@ -0,0 +1,118 @@
#!/usr/bin/env python
import sys
sys.path.append('../')
from wiki_util import *
def extract_json(raw_file, ctx10_file, output_file, ctx_window=10, negative_edit_num=10):
# load tokenized comment and neg_comments from ctx10_file
print("Loading tokenized comment and neg_comments from ctx10_file")
cmnt_dict = {}
with open(ctx10_file, 'r', encoding='utf-8') as f:
for idx, json_line in enumerate(tqdm(f)):
article = json.loads(json_line.strip('\n'))
rev_id = article["revision_id"]
comment = article["comment"]
neg_cmnts = article["neg_cmnts"]
cmnt_dict[rev_id] = (comment, neg_cmnts)
json_file = open(output_file, "w", buffering=1, encoding='utf-8')
with open(raw_file, 'r', encoding='utf-8') as f:
for idx, json_line in enumerate(tqdm(f)):
article = json.loads(json_line.strip('\n'))
'''
Json file format:
=================
revision_id: The revision ID
parent_id: The parent revision ID
timestamp: Timestamp
diff_url: The wikipedia link to show the difference between previous and current version.
page_title: The title of page.
comment: Revision comment.
src_token: List of tokens in before-editing version
src_action: Action flags for each token in before-editing version. E.g., 0 represents no action; -1 represents removed token.
tgt_token: List of tokens in after-editing version
tgt_action: Action flags for each token in after-editing version. E.g., 0 represents no action; 1 represents added token.
neg_cmnts: Negative samples of user comments in the same page.
pos_edits: Edit sentences for comments.
neg_edits: Negative edit sentences for comments.
'''
rev_id = article["revision_id"]
parent_id = article["parent_id"]
timestamp = article["timestamp"]
diff_url = article["diff_url"]
page_title = article["page_title"]
# comment = word_tokenize(article["comment"])
# neg_comments = [word_tokenize(cmnt) for cmnt in article['neg_comments']]
# lookup comment and neg_comments from dictionary
comment, neg_cmnts = cmnt_dict[rev_id]
src_text = article["src_text"]
tgt_text = article["tgt_text"]
src_sents = article["src_sents"]
src_tokens = article["src_tokens"]
tgt_sents = article["tgt_sents"]
tgt_tokens = article["tgt_tokens"]
src_token_diff = article["src_token_diff"]
tgt_token_diff = article["tgt_token_diff"]
# src_sents, src_tokens = tokenizeText(src_text)
# tgt_sents, tgt_tokens = tokenizeText(tgt_text)
# extract the offset of the changed tokens in both src and tgt
# src_token_diff, tgt_token_diff = diffRevision(src_tokens, tgt_tokens)
src_ctx_tokens, src_action = extContext(src_tokens, src_token_diff, ctx_window)
tgt_ctx_tokens, tgt_action = extContext(tgt_tokens, tgt_token_diff, ctx_window)
# src_sent_diff = findSentDiff(src_sents, src_tokens, src_token_diff)
tgt_sent_diff = findSentDiff(tgt_sents, tgt_tokens, tgt_token_diff)
# generate the positive edits
pos_edits = [tgt_sents[i] for i in tgt_sent_diff]
# generate negative edits
neg_edits_idx = [i for i in range(len(tgt_sents)) if i not in tgt_sent_diff]
if negative_edit_num > len(neg_edits_idx):
sampled_neg_edits_idx = neg_edits_idx
else:
sampled_neg_edits_idx = random.sample(neg_edits_idx, negative_edit_num)
neg_edits = [tgt_sents[i] for i in sampled_neg_edits_idx]
if (len(src_token_diff) > 0 or len(tgt_token_diff) > 0):
json_dict = {"revision_id": rev_id, "parent_id": parent_id, "timestamp": timestamp, \
"diff_url": diff_url, "page_title": page_title, \
"comment": comment, "src_token": src_ctx_tokens, "src_action": src_action, \
"tgt_token": tgt_ctx_tokens, "tgt_action": tgt_action, \
"neg_cmnts": neg_cmnts, "neg_edits": neg_edits, "pos_edits": pos_edits
}
json_str = json.dumps(json_dict,
indent=None, sort_keys=False,
separators=(',', ': '), ensure_ascii=False)
json_file.write(json_str + '\n')
def main():
root_path = "../dataset/raw/"
ctx_window = int(sys.argv[1])
raw_file = root_path + "enwiki-sample_output_raw.json"
ctx10_file = root_path + "wikicmnt_ctx10.json"
output_file = root_path + "wikicmnt_ctx" + str(ctx_window) + ".json"
extract_json(raw_file, ctx10_file, output_file, ctx_window=ctx_window)
if __name__ == '__main__':
start_time = datetime.now()
main()
time_elapsed = datetime.now() - start_time

Просмотреть файл

@ -0,0 +1,155 @@
#!/usr/bin/env python
import argparse
import sys
sys.path.append('../')
from wiki_util import *
from wikicmnt_extractor_st import randSampleRev
# arguments
parser = argparse.ArgumentParser(description='Wiki Extractor')
parser.add_argument('--data_path', type=str, default="../data/raw/", help='the data directory')
parser.add_argument('--output_path', type=str, default="../data/processed/", help='the sample output path')
parser.add_argument('--min_page_tokens', type=int, default=50,
help='the minimum size of tokens in page to extract [default: 100]')
parser.add_argument('--max_page_tokens', type=int, default=2000,
help='the maximum size of tokens in page to extract [default: 1000]')
parser.add_argument('--min_cmnt_length', type=int, default=8,
help='the minimum words contained in the comments [default: 8]')
parser.add_argument('--ctx_window', type=int, default=5, help='the window size of context [default: 5]')
parser.add_argument('--sample_ratio', type=float, default=0.01, help='the ratio of sampling [default: 0.001]')
parser.add_argument('--threads', type=int, default=3, help='the number of sampling threads [default: 5]')
parser.add_argument('--single_thread', type=int, default=0,
help='the dump file index when using single thread mode. If the index equals to zero, '
'use multi-thread mode to proprocess all the dump files [default: 0]')
parser.add_argument('--user_stat', type=bool, default=False, help='whether to do user statistics')
parser.add_argument('--merge_only', action='store_true', default=False, help='merge the results only')
parser.add_argument('--neg_cmnt_num', type=int, default=10,
help='how many negative comments sampled for ranking problem [default: 10]')
parser.add_argument('--count_revision_only', type=bool, default=False,
help='count the revision only without sampling anything [default: False]')
args = parser.parse_args()
# create sample output folder if it doesn't exist
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
# logging configuration
logging.basicConfig(level=logging.DEBUG,
format='(%(threadName)s) %(message)s',
)
'''
WikiSampleTask class contains a list of dump files to be sampled.
The assign_task function will be called by workers to grab a task.
'''
class WikiSampleTask(object):
def __init__(self, dump_list):
self.lock = threading.Lock()
self.dump_list = dump_list
self.total_num = len(dump_list)
def assign_task(self):
logging.debug('Assign tasks ... Waiting for lock')
self.lock.acquire()
dump_name = None
cur_progress = None
try:
# logging.debug('Acquired lock')
if len(self.dump_list) > 0:
dump_name = self.dump_list.pop(0)
cur_progress = self.total_num - len(self.dump_list)
finally:
self.lock.release()
return dump_name, cur_progress, self.total_num
'''
worker is main function for each thread.
'''
def worker(work_id, tasks):
logging.debug('Starting.')
output_file = args.data_path + 'sample/enwiki_sample_' + str(work_id) + '.json'
# grab one task from task_list
while 1:
dump_file, cur_progress, total_num = tasks.assign_task()
if not dump_file:
break
logging.debug('Assigned task (' + str(cur_progress) + '/' + str(total_num) + '): ' + str(dump_file))
# start to sample the dump file
output_file = args.output_path + 'enwiki-sample-' + os.path.basename(dump_file)[27:-4] + '.json'
randSampleRev(work_id, dump_file, output_file, args.sample_ratio, args.min_cmnt_length, args.ctx_window,
args.neg_cmnt_num)
logging.debug('Exiting.')
def initLogger(file_idx):
logger = logging.getLogger()
logger.setLevel(logging.DEBUG) # or whatever
handler = logging.FileHandler('extractor.txt', 'a', 'utf-8') # or whatever
# handler = logging.StreamHandler(sys.stdout)
handler.setFormatter = logging.Formatter('%(asctime)s' + '[' + str(file_idx) + '] - %(message)s') # or whatever
logger.addHandler(handler)
return logger
def main():
# single file mode
if args.single_thread:
logger = initLogger(args.single_thread)
data_path = args.data_path
output_path = args.output_path
dump_list = sorted(glob.glob(data_path + "*.bz2"))
print(dump_list)
dump_file = dump_list[args.single_thread - 1]
output_file = output_path + 'enwiki-sample-' + os.path.basename(dump_file)[27:-4] + '.json'
# create sample output folder if it doesn't exist
if not os.path.exists(output_path):
os.makedirs(output_path)
print("start to preprocess dump file:", str(dump_file))
logger.info("[" + str(args.single_thread) + "] Start to sample dump file " + dump_file)
randSampleRev(args.single_thread, dump_file, output_file, args.sample_ratio, args.min_cmnt_length,
args.ctx_window, args.neg_cmnt_num)
return
if not args.merge_only:
dump_list = glob.glob(args.data_path + "*.bz2")
# # testing
# dump_list = dump_list[:5]
dump_num = len(dump_list)
logging.debug("Samping revisions from " + str(dump_num) + " dump files")
task = WikiSampleTask(dump_list)
threads = []
for i in range(args.threads):
t = threading.Thread(target=worker, args=(i, task))
threads.append(t)
t.start()
logging.debug('Waiting for worker threads')
main_thread = threading.currentThread()
for t in threading.enumerate():
if t is not main_thread:
t.join()
logging.debug('Merging the sample outputs from each dump file')
# merge the result
mergeOutputs(args.output_path)
if __name__ == '__main__':
start_time = datetime.now()
main()
time_elapsed = datetime.now() - start_time
logging.debug('Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed))

Просмотреть файл

@ -0,0 +1,182 @@
#!/usr/bin/env python
import sys
sys.path.append('../')
from wiki_util import *
def printSample(task_id, sample_count, revision_count, page_title, sect_title, comment, diff_url, parent_tokens,
target_tokens, origin_diff, target_diff, delimitor='^'):
# print the sampled revision in excel format
revision_info = '[' + str(len(parent_tokens)) + '|' + str(len(target_tokens)) + ']'
origin_diff_tokens = '[' + (
str(origin_diff[0]) if len(origin_diff) == 1 else ','.join([str(i) for i in origin_diff])) + ']'
target_diff_tokens = '[' + (
str(target_diff[0]) if len(target_diff) == 1 else ','.join([str(i) for i in target_diff])) + ']'
# print(sample_count, '/', revision_count, delimitor, page_title, delimitor, sect_title, delimitor, comment, delimitor,
# diff_url, delimitor, revision_info, delimitor, origin_diff_tokens, delimitor, target_diff_tokens, sep='')
logging.info("[" + str(task_id) + "] " + str(sample_count) + '/' + str(
revision_count) + delimitor + page_title + delimitor + sect_title +
delimitor + comment + delimitor + diff_url + delimitor + revision_info + delimitor + origin_diff_tokens + delimitor + target_diff_tokens)
def randSampleRev(task_id, dump_file, output_file, sample_ratio, min_cmnt_length, ctx_window, negative_cmnt_num=10,
negative_edit_num=10, count_revision_only=False, MIN_COMMENT_SIZE=20):
# if os.path.exists(output_file):
# print("Output file already exists. Please remove first so I don't destroy your stuff please")
# sys.exit(1)
json_file = open(output_file, "w", buffering=1, encoding='utf-8')
start_time = datetime.now()
wiki_file = bz2.open(dump_file, "rt", encoding='utf-8')
# template = '{"revision_id": %s, "diff_url": %s, "parent_id": %s, "timestamp": %s, "page_title": %s, "user_name": %s, \
# "user_id": %s, "user_ip": %s, "comment": %s, "parent_text": %s, "text": %s, "diff_tokens": %s}\n'
sample_count = 0
revision_count = 0
# time_elapsed = datetime.now() - start_time
# print("=== ", i, "/", file_num, ' === ', revision_count, " revisions extracted.", ' Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed), sep='')
sample_parent_id = None
sample_parent_text = None
page_comment_list = []
prev_page_title = ''
try:
for page_title, revision in split_records(wiki_file):
revision_count += 1
if count_revision_only:
if revision_count % 1000 == 0:
print("= revision", revision_count, "=")
continue
# fields
rev_id, parent_id, timestamp, username, userid, userip, comment, text = extract_data(revision)
# if the length of comment is less than MIN_COMMENT SIZE (default 20), skip the revision directly.
if len(comment) < MIN_COMMENT_SIZE:
continue
# clean the comment text
comment = cleanCmntText(comment)
# extract the section title and the comment without section info
sect_title, comment = extractSectionTitle(comment)
# skip the revision if no section title included or the length is too short.
if not sect_title or len(comment) < MIN_COMMENT_SIZE:
continue
# store the comments
if prev_page_title != page_title:
prev_page_title = page_title
page_comment_list.clear()
_, comment_tokens = tokenizeText(comment)
if checkComment(comment, comment_tokens, min_cmnt_length):
# page_comment_list.append(comment)
page_comment_list.append(comment)
# write the sampled revision to json file
# Skip the line if it is not satisfied with some criteria
if sample_parent_id == parent_id:
# do sample
diff_url = 'https://en.wikipedia.org/w/index.php?title=' + page_title.replace(" ",
'%20') + '&type=revision&diff=' + rev_id + '&oldid=' + parent_id
# print(diff_url)
# check whether the comment is appropriate by some criteria
# check_comment(comment, length_only=True)
try:
src_text = extractSectionText(sample_parent_text, sect_title)
tgt_text = extractSectionText(text, sect_title)
except:
print("ERROR-RegularExpression:", sample_parent_text, text, " Skip!!")
# skip the revision if any exception happens
continue
# clean the wiki text
src_text = cleanWikiText(src_text)
tgt_text = cleanWikiText(tgt_text)
if (src_text and tgt_text) and (len(src_text) < 1000000 and len(tgt_text) < 1000000):
# tokenization
src_sents, src_tokens = tokenizeText(src_text)
tgt_sents, tgt_tokens = tokenizeText(tgt_text)
# extract the offset of the changed tokens in both src and tgt
src_token_diff, tgt_token_diff = diffRevision(src_tokens, tgt_tokens)
if len(src_token_diff) == 0 and len(tgt_token_diff) == 0:
continue
if (len(src_token_diff) > 0 and src_token_diff[0] < 0) or (
len(tgt_token_diff) > 0 and tgt_token_diff[0] < 0):
continue
if src_sents == None or tgt_sents == None:
continue
src_ctx_tokens, src_action = extContext(src_tokens, src_token_diff, ctx_window)
tgt_ctx_tokens, tgt_action = extContext(tgt_tokens, tgt_token_diff, ctx_window)
# src_sent_diff = findSentDiff(src_sents, src_tokens, src_token_diff)
tgt_sent_diff = findSentDiff(tgt_sents, tgt_tokens, tgt_token_diff)
# randomly sample the negative comments
if negative_cmnt_num > len(page_comment_list) - 1:
neg_cmnt_idx = range(len(page_comment_list) - 1)
else:
neg_cmnt_idx = random.sample(range(len(page_comment_list) - 1), negative_cmnt_num)
neg_comments = [page_comment_list[i] for i in neg_cmnt_idx]
# generate the positive edits
pos_edits = [tgt_sents[i] for i in tgt_sent_diff]
# generate negative edits
neg_edits_idx = [i for i in range(len(tgt_sents)) if i not in tgt_sent_diff]
if negative_edit_num > len(neg_edits_idx):
sampled_neg_edits_idx = neg_edits_idx
else:
sampled_neg_edits_idx = random.sample(neg_edits_idx, negative_edit_num)
neg_edits = [tgt_sents[i] for i in sampled_neg_edits_idx]
if (len(src_token_diff) > 0 or len(tgt_token_diff) > 0):
json_dict = {"revision_id": rev_id, "parent_id": parent_id, "timestamp": timestamp, \
"diff_url": diff_url, "page_title": page_title, \
"comment": comment, "src_token": src_ctx_tokens, "src_action": src_action, \
"tgt_token": tgt_ctx_tokens, "tgt_action": tgt_action, \
"neg_cmnts": neg_comments, "neg_edits": neg_edits, "pos_edits": pos_edits
}
json_str = json.dumps(json_dict,
indent=None, sort_keys=False,
separators=(',', ': '), ensure_ascii=False)
json_file.write(json_str + '\n')
sample_count += 1
printSample(task_id, sample_count, revision_count, page_title, sect_title, comment, diff_url,
src_tokens, tgt_tokens, src_token_diff, tgt_token_diff)
# filterRevision(comment, diff_list):
# if sample_parent_id != parent_id:
# logging.debug("ALERT: Parent Id Missing" + str(sample_parent_id) + "/" + str(parent_id) + "--- DATA MIGHT BE CORRUPTED!!")
# decide to sample next
if sampleNext(sample_ratio):
sample_parent_id = rev_id
sample_parent_text = text
else:
sample_parent_id = None
sample_parent_text = None
# if revision_count % 1000 == 0:
# print("Finished ", str(revision_count))
finally:
time_elapsed = datetime.now() - start_time
logging.debug("=== " + str(sample_count) + " revisions sampled in total " + str(revision_count) + " revisions. " \
+ 'Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed) + ' ===')
json_file.close()

Просмотреть файл

@ -0,0 +1,120 @@
#!/usr/bin/env python
import json
import sys
sys.path.append('../')
from wiki_util import *
from process_data import word_tokenize
def extract_json(input_file, output_file, ctx_window=10, negative_edit_num=10):
json_file = open(output_file, "w", buffering=1, encoding='utf-8')
with open(input_file, 'r', encoding='utf-8') as f:
for idx, json_line in enumerate(tqdm(f)):
article = json.loads(json_line.strip('\n'))
'''
Json file format:
=================
revision_id: The revision ID
parent_id: The parent revision ID
timestamp: Timestamp
diff_url: The wikipedia link to show the difference between previous and current version.
page_title: The title of page.
comment: Revision comment.
src_token: List of tokens in before-editing version
src_action: Action flags for each token in before-editing version. E.g., 0 represents no action; -1 represents removed token.
tgt_token: List of tokens in after-editing version
tgt_action: Action flags for each token in after-editing version. E.g., 0 represents no action; 1 represents added token.
neg_cmnts: Negative samples of user comments in the same page.
pos_edits: Edit sentences for comments.
neg_edits: Negative edit sentences for comments.
'''
rev_id = article["revision_id"]
parent_id = article["parent_id"]
timestamp = article["timestamp"]
diff_url = article["diff_url"]
page_title = article["page_title"]
# # tokenize comment
# cmnt_tokens = word_tokenize(article['comment'])
# cmnt_list.append(cmnt_tokens)
#
# # negative comments
# neg_cmnts = [word_tokenize(cmnt) for cmnt in article['neg_comments']]
# neg_cmnts_list.append(neg_cmnts)
comment = word_tokenize(article["comment"])
neg_comments = [word_tokenize(cmnt) for cmnt in article['neg_comments']]
src_text = article["src_text"]
tgt_text = article["tgt_text"]
src_sents = article["src_sents"]
src_tokens = article["src_tokens"]
tgt_sents = article["tgt_sents"]
tgt_tokens = article["tgt_tokens"]
src_token_diff = article["src_token_diff"]
tgt_token_diff = article["tgt_token_diff"]
#src_sents, src_tokens = tokenizeText(src_text)
#tgt_sents, tgt_tokens = tokenizeText(tgt_text)
# extract the offset of the changed tokens in both src and tgt
#src_token_diff, tgt_token_diff = diffRevision(src_tokens, tgt_tokens)
src_ctx_tokens, src_action = extContext(src_tokens, src_token_diff, ctx_window)
tgt_ctx_tokens, tgt_action = extContext(tgt_tokens, tgt_token_diff, ctx_window)
# src_sent_diff = findSentDiff(src_sents, src_tokens, src_token_diff)
tgt_sent_diff = findSentDiff(tgt_sents, tgt_tokens, tgt_token_diff)
# generate the positive edits
pos_edits = [tgt_sents[i] for i in tgt_sent_diff]
# generate negative edits
neg_edits_idx = [i for i in range(len(tgt_sents)) if i not in tgt_sent_diff]
if negative_edit_num > len(neg_edits_idx):
sampled_neg_edits_idx = neg_edits_idx
else:
sampled_neg_edits_idx = random.sample(neg_edits_idx, negative_edit_num)
neg_edits = [tgt_sents[i] for i in sampled_neg_edits_idx]
if (len(src_token_diff) > 0 or len(tgt_token_diff) > 0):
json_dict = {"revision_id": rev_id, "parent_id": parent_id, "timestamp": timestamp, \
"diff_url": diff_url, "page_title": page_title, \
"comment": comment, "src_token": src_ctx_tokens, "src_action": src_action, \
"tgt_token": tgt_ctx_tokens, "tgt_action": tgt_action, \
"neg_cmnts": neg_comments, "neg_edits": neg_edits, "pos_edits": pos_edits
}
json_str = json.dumps(json_dict,
indent=None, sort_keys=False,
separators=(',', ': '), ensure_ascii=False)
json_file.write(json_str + '\n')
def main():
root_path = "../dataset/raw/"
data_path = root_path + "split_files/"
output_path = root_path + "output/"
file_idx = int(sys.argv[1])
dump_list = sorted(glob.glob(data_path + "*.json"))
dump_file = dump_list[file_idx - 1]
file_name = os.path.basename(dump_file)
output_file = output_path + file_name[:-5] + '_output.json'
# original way
# input_file = root_path + "enwiki-sample_output_raw.json"
# output_file = root_path + "wikicmnt.json"
extract_json(dump_file, output_file, ctx_window=10)
if __name__ == '__main__':
start_time = datetime.now()
main()
time_elapsed = datetime.now() - start_time

Просмотреть файл

@ -0,0 +1,75 @@
#!/usr/bin/env python
import os
import sys
import argparse
import bz2
import glob
from datetime import datetime
import random
import difflib
import nltk
import re
import threading
import logging
import operator
import html
import spacy
import en_core_web_sm
import json
from wiki_util import *
import random
from random import shuffle
def cal_total_line(input_file):
total_line = 0
with open(input_file, 'r', encoding='utf-8') as f:
for json_line in tqdm(f):
total_line += 1
return total_line
def rand_sample(input_file, output_file, sample_num, total_num):
sample_indices = random.sample(range(1, total_num), sample_num)
sample_file = open(output_file, "w", encoding='utf-8')
sample_list = []
with open(input_file, 'r', encoding='utf-8') as f:
for idx, json_line in enumerate(tqdm(f)):
# if idx in sample_indices:
# sample_list.append(json_line)
j = json.loads(json_line)
tgt_sent_len = len(j['tgt_sents'])
tgt_diff_sent_size = len(j['tgt_sent_diff'])
if tgt_sent_len - tgt_diff_sent_size >= 10:
sample_list.append(json_line)
# sample_list.append(json_line)
# if idx == sample_num - 1:
# break
print("Start to shuffle sample list ...")
shuffle(sample_list)
print("Start to write output ...")
for line in tqdm(sample_list):
sample_file.write(line)
def main():
data_folder = "../data/"
#input_file = data_folder + "wikicmnt.json"
input_file = data_folder + "wiki_comment_orig.json"
output_file = data_folder + "wiki_comment.json"
sample_num = 260000
#sample_num = 500000
#total_num = cal_total_line(input_file)
total_num = 786886
print("total line:", total_num)
rand_sample(input_file, output_file, sample_num, total_num)
if __name__ == '__main__':
start_time = datetime.now()
main()
time_elapsed = datetime.now() - start_time
logging.debug('Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed))

335
eval.py Normal file
Просмотреть файл

@ -0,0 +1,335 @@
import math
from statistics import mean
import sklearn
import sklearn.metrics
import torch
from process_data import _make_char_vector, _make_word_vector, gen_cmntrank_batches, gen_editanch_batches
from process_data import to_var, make_vector
from wiki_util import tokenizeText
## general evaluation
def predict(pred_cmnt, pred_ctx, w2i, c2i, model, max_ctx_length):
print("prediction on single sample ...")
model.eval()
_, pred_cmnt_words = tokenizeText(pred_cmnt)
pred_cmnt_chars = [list(i) for i in pred_cmnt_words]
_, pred_ctx_words = tokenizeText(pred_ctx)
pred_ctx_chars = [list(i) for i in pred_ctx_words]
cmnt_sent_len = len(pred_cmnt_words)
cmnt_word_len = int(mean([len(w) for w in pred_cmnt_chars]))
ctx_sent_len = max_ctx_length
ctx_word_len = int(mean([len(w) for w in pred_ctx_chars]))
cmnt_words, cmnt_chars, ctx_words, ctx_chars = [], [], [], []
# c, cc, q, cq, a in batch
cmnt_words.append(_make_word_vector(pred_cmnt_words, w2i, cmnt_sent_len))
cmnt_chars.append(_make_char_vector(pred_cmnt_chars, c2i, cmnt_sent_len, cmnt_word_len))
ctx_words.append(_make_word_vector(pred_ctx_words, w2i, ctx_sent_len))
ctx_chars.append(_make_char_vector(pred_ctx_chars, c2i, ctx_sent_len, ctx_word_len))
cmnt_words = to_var(torch.LongTensor(cmnt_words))
cmnt_chars = to_var(torch.stack(cmnt_chars, 0))
ctx_words = to_var(torch.LongTensor(ctx_words))
ctx_chars = to_var(torch.stack(ctx_chars, 0))
logit, _ = model(ctx_words, ctx_chars, cmnt_words, cmnt_chars)
a = torch.max(logit.cpu(), -1)
print(logit)
print(a)
y_pred = a[1].data[0]
y_prob = a[0].data[0]
y_pred_2 = [int(i) for i in (torch.max(logit, -1)[1].view(1).data).tolist()][0]
print(y_pred_2)
# y_pred = y_pred[0]
print(y_pred, y_prob)
return y_pred
def compute_rank_score(pos_score, neg_scores):
p1, p3, p5 = 0, 0, 0
pos_list = [0 if pos_score > neg_score else 1 for neg_score in neg_scores]
pos = sum(pos_list)
# precision @K
if pos == 0: p1 = 1
if pos < 3: p3 = 1
if pos < 5: p5 = 1
# MRR
mrr = 1 / (pos + 1)
# NDCG: DCG/IDCG (In our case, we set the rel=1 if relevent, otherwise rel=0; Then IDCG=1)
ndcg = 1 / math.log2(pos + 2)
return p1, p3, p5, mrr, ndcg
def get_rank(pos_score, neg_scores):
pos_list = [0 if pos_score > neg_score else 1 for neg_score in neg_scores]
pos = sum(pos_list)
return pos + 1
def isEditPredCorrect(pred, truth):
for i in len(pred):
if pred[i] != truth[i]:
return False
return True
def eval_rank(score_pos, score_neg, cand_num):
score_pos_list = score_pos.data.cpu().squeeze(1).numpy().tolist()
score_neg_list = score_neg.data.cpu().squeeze(1).numpy().tolist()
correct_p1, correct_p3, correct_p5, total_mrr, total_ndcg = 0, 0, 0, 0, 0
neg_num = cand_num - 1
batch_num = int(len(score_neg) / neg_num)
rank_list = []
for i in range(batch_num):
score_pos_i = score_pos_list[i * neg_num: (i + 1) * neg_num]
score_neg_i = score_neg_list[i * neg_num: (i + 1) * neg_num]
p1, p3, p5, mrr, ndcg = compute_rank_score(score_pos_i[0], score_neg_i)
rank = get_rank(score_pos_i[0], score_neg_i)
rank_list.append(rank)
correct_p1 += p1
correct_p3 += p3
correct_p5 += p5
total_mrr += mrr
total_ndcg += ndcg
return correct_p1, correct_p3, correct_p5, total_mrr, total_ndcg, rank_list
# def eval_rank_orig(score_pos, score_neg, batch_size):
# score_pos_list = score_pos.data.cpu().squeeze(1).numpy().tolist()
# score_neg_list = score_neg.data.cpu().squeeze(1).numpy().tolist()
#
# total_p1, total_p3, total_p5, total_mrr, total_ndcg = 0, 0, 0, 0, 0
# sample_num = int(len(score_neg) / batch_size)
# for i in range(batch_size):
# score_pos_i = score_pos_list[i * sample_num: (i+1) * sample_num]
# score_neg_i = score_neg_list[i * sample_num: (i+1) * sample_num]
# pos_list = [0 if score_pos_i[i] >= score_neg_i[i] else 1 for i in range(sample_num)]
# pos = sum(pos_list)
# sorted_neg = ["%.4f" % i for i in sorted(score_neg_i, reverse=True)]
# #print(pos, "%.4f" % score_pos_i[0], "\t".join(sorted_neg), sep='\t')
# if pos == 0:
# total_p1 += 1
#
# if pos < 3:
# total_p3 += 1
#
# if pos < 5:
# total_p5 += 1
#
# # MRR
# total_mrr += 1 / (pos + 1)
#
# # NDCG: DCG/IDCG (In our case, we set the rel=1 if relevent, otherwise rel=0; Then IDCG=1)
# total_ndcg += 1 / math.log2(pos + 2)
#
# return total_p1, total_p3, total_p5, total_mrr, total_ndcg
## general evaluation
def eval(dataset, val_df, w2i, model, args):
# print(" evaluation on", val_df.shape[0], " samples ...")
model.eval()
corrects, avg_loss = 0, 0
cmnt_rank_p1, cmnt_rank_p3, cmnt_rank_p5, cmnt_rank_mrr, cmnt_rank_ndcg = 0, 0, 0, 0, 0
ea_pred, ea_truth = [], []
cr_total, ea_total = 0, 0
pred_cr_list, pred_ea_list = [], []
for batch in dataset.iterate_minibatches(val_df, args.batch_size):
cmnt_sent_len = args.max_cmnt_length
ctx_sent_len = args.max_ctx_length
diff_sent_len = args.max_diff_length
###########################################################
# Comment Ranking Task
###########################################################
# generate positive and negative batches
pos_batch, neg_batch = gen_cmntrank_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
args.rank_num)
pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action = \
make_vector(pos_batch, w2i, cmnt_sent_len, ctx_sent_len)
neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
make_vector(neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
score_pos, _ = model(pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action, cr_mode=True)
score_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action, cr_mode=True)
cr_p1_corr, cr_p3_corr, cr_p5_corr, cr_mrr, cr_ndcg, pred_rank = eval_rank(score_pos, score_neg, args.rank_num)
cmnt_rank_p1 += cr_p1_corr
cmnt_rank_p3 += cr_p3_corr
cmnt_rank_p5 += cr_p5_corr
cmnt_rank_mrr += cr_mrr
cmnt_rank_ndcg += cr_ndcg
cr_total += int(len(score_pos) / (args.rank_num - 1))
pred_cr_list += pred_rank
###########################################################
# Edits Anchoring
###########################################################
# generate positive and negative batches
ea_batch, ea_truth_cur = gen_editanch_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
args.anchor_num)
if len(pos_batch[0]) > 0:
cmnt, src_token, src_action, tgt_token, tgt_action = \
make_vector(ea_batch, w2i, cmnt_sent_len, ctx_sent_len)
# neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
# make_vector(neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
logit, _ = model(cmnt, src_token, src_action, tgt_token, tgt_action, cr_mode=False)
# logit_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action, cr_mode=False)
ea_pred_cur = (torch.max(logit, 1)[1].view(logit.size(0)).data).tolist()
# ea_truth_cur = [1] * logit_pos.size(0) + [0] * logit_neg.size(0)
ea_pred += ea_pred_cur
ea_truth += ea_truth_cur
ea_total += int(len(score_pos) / (args.anchor_num - 1))
# # output the prediction results
# with open(args.checkpoint_path + 'test_out.txt', 'w') as f:
# for i in range(len(y_truth)):
# line = cmnt_readable_all[i] + '\t' + ctx_readable_all[i] + '\t' + str(y_pred[i]) + '\t' + str(y_truth[i])
# f.write(line + '\n')
# if args.test:
# print(total_rank)
# print("\t".join([str(i) for i in pred_cr_list]))
# print("\t".join([str(i) for i in ea_pred]))
cr_p1_acc = cmnt_rank_p1 / cr_total
cr_p3_acc = cmnt_rank_p3 / cr_total
cr_p5_acc = cmnt_rank_p5 / cr_total
cr_mrr = cmnt_rank_mrr / cr_total
cr_ndcg = cmnt_rank_ndcg / cr_total
ea_acc = (sklearn.metrics.accuracy_score(ea_truth, ea_pred))
ea_f1 = (sklearn.metrics.f1_score(ea_truth, ea_pred, pos_label=1))
ea_prec = (sklearn.metrics.precision_score(ea_truth, ea_pred, pos_label=1))
ea_recall = (sklearn.metrics.recall_score(ea_truth, ea_pred, pos_label=1))
print("\n*** Validation Results *** ")
# print("[Task-CR] P@1:", "%.3f" % cr_p1_acc, "% P@3:", "%.3f" % cr_p3_acc, "% P@5:", "%.3f" % cr_p5_acc,\
# '%', ' (', cmnt_rank_p1, '/', cr_total, ',', cmnt_rank_p3, '/', cr_total, ',', cmnt_rank_p5, '/', cr_total,')', sep='')
# print("[Task-RA] P@1:", "%.3f" % ea_p1_acc, "% P@3:", "%.3f" % ea_p3_acc, "% P@5:", "%.3f" % ea_p5_acc,\
# '%', ' (', edit_anch_p1, '/', ea_total, ',', edit_anch_p3, '/', ea_total, ',', edit_anch_p5, '/', ea_total,')', sep='')
print("[Task-CR] P@1:", "%.3f" % cr_p1_acc, " P@3:", "%.3f" % cr_p3_acc, " P@5:", "%.3f" % cr_p5_acc, " MRR:",
"%.3f" % cr_mrr, " NDCG:", "%.3f" % cr_ndcg, sep='')
print("[Task-EA] ACC:", "%.3f" % ea_acc, " F1:", "%.3f" % ea_f1, " Precision:", "%.3f" % ea_prec, " Recall:",
"%.3f" % ea_recall, sep='')
return cr_p1_acc, ea_f1
def dump_cmntrank_case(pos_batch, neg_batch, idx, rank_num, diff_url, rank, pos_score, neg_scores):
neg_num = rank_num - 1
pos_cmnt = pos_batch[0][idx * neg_num]
neg_cmnts = neg_batch[0][idx * neg_num: (idx + 1) * neg_num]
before_edit = pos_batch[1][idx * neg_num]
after_edit = pos_batch[3][idx * neg_num]
match = False
for token in pos_cmnt:
if token in before_edit + after_edit:
match = True
break
neg_match_words = []
neg_match = False
for neg_cmnt in neg_cmnts:
for token in neg_cmnt:
if len(token) <= 3:
continue
if token in before_edit + after_edit:
neg_match = True
neg_match_words.append(token)
if not match and neg_match:
print("\n ====== cmntrank case (Not Matched) ======")
print("Rank", rank)
print(diff_url)
print("pos_cmnt (", "{0:.3f}".format(pos_score), "): ", " ".join(pos_cmnt), sep='')
for i, neg_cmnt in enumerate(neg_cmnts):
print("neg_cmnt ", i, " (", "{0:.3f}".format(neg_scores[i]), "): ", " ".join(neg_cmnt), sep='')
pass
print("neg_match_words:", " ".join(neg_match_words))
def dump_editanch_case(comment, edit, pred, truth):
print("\n ====== editanch case ======")
print("pred/truth: ", pred, "/", truth)
print("comment:", " ".join(comment))
print("edit:", " ".join(edit))
def case_study(dataset, val_df, w2i, model, args):
model.eval()
print("Start the case study")
# for batch in dataset.iterate_minibatches(val_df[:500], args.batch_size):
for batch in dataset.iterate_minibatches(val_df, args.batch_size):
cmnt_sent_len = args.max_cmnt_length
ctx_sent_len = args.max_ctx_length
diff_sent_len = args.max_diff_length
###########################################################
# Comment Ranking Task
###########################################################
# generate positive and negative batches
if args.cr_train:
pos_batch, neg_batch = gen_cmntrank_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
args.rank_num)
pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action = \
make_vector(pos_batch, w2i, cmnt_sent_len, ctx_sent_len)
neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
make_vector(neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
score_pos, _ = model(pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action, cr_mode=True)
score_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action, cr_mode=True)
score_pos_list = score_pos.data.cpu().squeeze(1).numpy().tolist()
score_neg_list = score_neg.data.cpu().squeeze(1).numpy().tolist()
neg_num = args.rank_num - 1
batch_num = int(len(score_neg) / neg_num)
for i in range(batch_num):
score_pos_i = score_pos_list[i * neg_num: (i + 1) * neg_num]
score_neg_i = score_neg_list[i * neg_num: (i + 1) * neg_num]
rank = get_rank(score_pos_i[0], score_neg_i)
dump_cmntrank_case(pos_batch, neg_batch, i, args.rank_num, batch[8][i], rank, score_pos_i[0],
score_neg_i)
###########################################################
# Edits Anchoring
###########################################################
# generate positive and negative batches
if args.ea_train:
ea_batch, ea_truth_cur = gen_editanch_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
args.anchor_num)
cmnt, src_token, src_action, tgt_token, tgt_action = \
make_vector(ea_batch, w2i, cmnt_sent_len, ctx_sent_len)
# neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
# make_vector(neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
logit, _ = model(cmnt, src_token, src_action, tgt_token, tgt_action, cr_mode=False)
# logit_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action, cr_mode=False)
ea_pred_cur = (torch.max(logit, 1)[1].view(logit.size(0)).data).tolist()
for i in range(len(ea_truth_cur)):
# if ea_pred_cur[i] == ea_truth_cur[i]:
dump_editanch_case(ea_batch[0][i], ea_batch[3][i], ea_pred_cur[i], ea_truth_cur[i])
pass
# ea_truth_cur = [1] * logit_pos.size(0) + [0] * logit_neg.size(0)

185
main.py Normal file
Просмотреть файл

@ -0,0 +1,185 @@
import argparse
import time
from eval import case_study, predict
from process_data import load_glove_weights
from train import *
from wikicmnt_dataset import Wiki_DataSet
from wikicmnt_model import CmntModel
#############################################################################################
# ArgumentParser
#############################################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default='./data/processed/', help='the data directory')
parser.add_argument('--checkpoint_path', type=str, default='./checkpoint/', help='the checkpoint directory')
parser.add_argument('--glove_path', type=str, default='./data/glove/', help='the glove directory')
# model
parser.add_argument('--log_interval', type=int, default=100,
help='how many steps to wait before logging training status [default: 100]')
parser.add_argument('--test_interval', type=int, default=1,
help='how many steps to wait before testing [default: 1000]')
parser.add_argument('--save_best', type=bool, default=True, help='whether to save when get best performance')
parser.add_argument('--lr', type=float, default=0.5, help='learning rate, default=0.5')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--word_embd_size', type=int, default=100, help='word embedding size')
parser.add_argument('--max_ctx_length', type=int, default=300, help='the maximum words in the context [default: 300]')
parser.add_argument('--max_diff_length', type=int, default=300,
help='the maximum words in the revision difference [default: 200]')
parser.add_argument('--max_cmnt_length', type=int, default=30, help='the maximum words in the comment [default: 30]')
parser.add_argument('--ctx_mode', type=bool, default=True,
help='whether to use change context in training [default: True]')
# training
parser.add_argument('--epoch', type=int, default=10, help='number of epoch, default=10')
parser.add_argument('--start_epoch', type=int, default=1, help='resume epoch count, default=1')
parser.add_argument('--batch_size', type=int, default=10, help='input batch size')
parser.add_argument('--cr_train', action='store_true', default=False, help='whether to training the comment rank task')
parser.add_argument('--ea_train', action='store_true', default=False,
help='whether to training the revision anchoring task')
# ablation testing
parser.add_argument('--no_action', action='store_true', default=False,
help='whether to use action encoding to train the model')
parser.add_argument('--no_attention', action='store_true', default=False,
help='whether to use mutual attention to train the model')
parser.add_argument('--no_hadamard', action='store_true', default=False,
help='whether to use hadamard product to train the model')
parser.add_argument('--src_train', type=bool, default=False,
help='whether to training the comment rank task without before-editing version')
parser.add_argument('--train_ratio', type=int, default=0.7,
help='ratio of training data in the entire data [default: 0.7]')
parser.add_argument('--val_ratio', type=int, default=0.1,
help='ratio of validation data in the entire data [default: 0.1]')
parser.add_argument('--val_size', type=int, default=10000,
help='force the size of validation dataset, the parameter will disgard the setting of val_ratio [default: -1]')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--test', action='store_true', default=False, help='use test model')
parser.add_argument('--case_study', action='store_true', default=False, help='use case study mode')
parser.add_argument('--resume', default='./checkpoints/model_best.tar', type=str, metavar='PATH',
help='path saved params')
parser.add_argument('--seed', type=int, default=1111, help='random seed')
# device
parser.add_argument('--gpu', type=int, default=-1, help='gpu to use for iterate data, -1 mean cpu [default: -1]')
parser.add_argument('--checkpoint', type=str, default=None, help='filename of model checkpoint [default: None]')
parser.add_argument('--rank_num', type=int, default=5, help='the number of ranking comments')
parser.add_argument('--anchor_num', type=int, default=5, help='the number of ranking comments')
parser.add_argument('--use_target_only', action='store_true', default=False, help='use target context only in model')
# single case prediction
parser.add_argument('--predict', action='store_true', default=False, help='predict the sentence given')
parser.add_argument('--pred_cmnt', type=str, default=None, help='the comment of prediction')
parser.add_argument('--pred_ctx', type=str, default=None, help='the context of prediction')
args = parser.parse_args()
if args.gpu >= 0:
torch.cuda.set_device(args.gpu)
print("\nParameters:")
for attr, value in sorted(args.__dict__.items()):
print("\t{}={}".format(attr.upper(), value))
# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
print(os.name)
sys.stdout.flush()
# load data
dataset = Wiki_DataSet(args)
train_df, val_df, test_df, vocab_json = dataset.load_data(train_ratio=args.train_ratio, val_ratio=args.val_ratio)
w2i = vocab_json['word2idx']
print('----')
print('n_train', train_df.shape[0])
print('n_val', val_df.shape[0])
print('n_test', test_df.shape[0])
print('vocab_size:', len(w2i))
# load glove
glove_embd_w = torch.from_numpy(load_glove_weights(args.glove_path, args.word_embd_size, len(w2i), w2i)).type(
torch.FloatTensor)
# save_pickle(glove_embd_w, './pickle/glove_embd_w.pickle')
args.vocab_size_w = len(w2i)
args.pre_embd_w = glove_embd_w
args.filters = [[1, 5]]
args.out_chs = 100
# generate save directory
base_name = os.path.basename(os.path.normpath(args.data_path))
if args.cr_train and not args.ea_train:
task_str = "cr"
elif args.ea_train and not args.cr_train:
task_str = "ea"
elif args.cr_train and args.ea_train:
task_str = "mt"
folder_name = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + "_" + base_name + "_" + task_str
if args.no_action:
folder_name += "_noaction"
if args.word_embd_size == 300:
folder_name += "_d300"
args.save_dir = os.path.join(args.checkpoint_path, folder_name)
print("Save to ", args.save_dir)
sys.stdout.flush()
# initialize model
model = CmntModel(args)
if args.checkpoint is not None:
print('\nLoading model from {}...'.format(args.checkpoint))
model.load_state_dict(torch.load(args.checkpoint))
if torch.cuda.is_available() and os.name != 'nt':
print('use cuda')
model.cuda()
# model = torch.nn.DataParallel(model, device_ids=[0])
# optimizer = torch.optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()), lr=0.5)
# optimizer = torch.optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()))
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
# optimizer = torch.optim.Adamax(filter(lambda p: p.requires_grad, model.parameters()))
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
# best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
print(model)
print('parameters-----')
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.data.size())
if args.predict:
print('Prediction mode')
print('#Comment:', args.pred_cmnt)
print('#Context:', args.pred_ctx)
predict(args.pred_cmnt, args.pred_ctx, w2i, model, args.max_ctx_length)
elif args.test:
print('Test mode')
eval(dataset, test_df, w2i, model, args)
elif args.case_study:
start_time = time.time()
case_study(dataset, test_df, w2i, model, args)
else:
print('Train mode')
start_time = time.time()
train(args, model, dataset, train_df, val_df, optimizer, w2i, \
n_epoch=args.epoch, start_epoch=args.start_epoch, batch_size=args.batch_size)
print("Training duration: %s seconds" % (time.time() - start_time))
print('finish')

361
process_data.py Normal file
Просмотреть файл

@ -0,0 +1,361 @@
import json
import os
import pickle
import random
import numpy as np
import spacy
import torch
from torch.autograd import Variable
# TODO global
NULL = "-NULL-"
UNK = "-UNK-"
ENT = "-ENT-"
# initialize the spacy
nlp = spacy.load('en')
def word_tokenize(text):
doc = nlp(text)
tokens = [token.string.strip() for token in doc]
return tokens
def save_pickle(d, path):
print('save pickle to', path)
with open(path, mode='wb') as f:
pickle.dump(d, f)
def load_pickle(path):
print('load', path)
with open(path, mode='rb') as f:
return pickle.load(f)
def lower_list(str_list):
return [str_var.lower() for str_var in str_list]
def load_processed_json(fpath_data, fpath_shared):
data = json.load(open(fpath_data))
shared = json.load(open(fpath_shared))
return data, shared
def load_glove_weights(glove_dir, embd_dim, vocab_size, word_index):
embeddings_index = {}
if embd_dim < 300:
glove_version = 'glove.6B.'
else:
glove_version = 'glove.840B.'
with open(os.path.join(glove_dir, glove_version + str(embd_dim) + 'd.txt'), encoding='utf-8') as f:
for line in f:
try:
values = line.split()
word = values[0]
vector = np.array(values[1:], dtype='float32')
embeddings_index[word] = vector
except:
continue
print('Found %s word vectors in glove.' % len(embeddings_index))
embedding_matrix = np.zeros((vocab_size, embd_dim))
print('embed_matrix.shape', embedding_matrix.shape)
found_ct = 0
for word, i in word_index.items():
embedding_vector = embeddings_index.get(word)
# words not found in embedding index will be all-zeros.
if embedding_vector is not None:
embedding_matrix[i] = embedding_vector
found_ct += 1
print(found_ct, 'words are found in glove')
return embedding_matrix
def to_var(x):
# if torch.cuda.is_available():
# x = x.cuda()
# return Variable(x)
x = Variable(x)
if torch.cuda.is_available():
x = x.cuda()
return x
def to_np(x):
return x.data.cpu().numpy()
def _make_action_vector(actions, seq_len):
index_vec = [action for action in actions]
pad_len = max(0, seq_len - len(index_vec))
index_vec += [-1] * pad_len
index_vec = index_vec[:seq_len]
return index_vec
def _make_word_vector(sentence, w2i, seq_len):
index_vec = [w2i[w] if w in w2i else w2i[UNK] for w in sentence]
pad_len = max(0, seq_len - len(index_vec))
index_vec += [w2i[NULL]] * pad_len
index_vec = index_vec[:seq_len]
return index_vec
def _make_char_vector(data, c2i, sent_len, word_len):
tmp = torch.ones(sent_len, word_len).type(torch.LongTensor) # TODO use fills
for i, word in enumerate(data):
for j, ch in enumerate(word):
tmp[i][j] = c2i[ch] if ch in c2i else c2i[UNK]
if j == word_len - 1:
break
if i == sent_len - 1:
break
return tmp
def make_diff(diffs_raw):
return diffs_raw
def make_vector_one_sample(pred_cmnt, pred_ctx, w2i, c2i, ctx_sent_len, ctx_word_len, query_sent_len, query_word_len):
cmnt_words, cmnt_chars, ctx_words, ctx_chars, ans, diffs = [], [], [], [], [], []
# c, cc, q, cq, a in batch
cmnt_words.append(_make_word_vector(batch[0][i], w2i, ctx_sent_len))
cmnt_chars.append(_make_char_vector(batch[1][i], c2i, ctx_sent_len, ctx_word_len))
ctx_words.append(_make_word_vector(batch[2][i], w2i, query_sent_len))
ctx_chars.append(_make_char_vector(batch[3][i], c2i, query_sent_len, query_word_len))
ans.append(batch[4][i])
# append the diffs
diffs_raw = batch[5][i]
diffs_ex = [-1] * query_sent_len
for diff_idx in diffs_raw:
diffs_ex[diff_idx - 1] = 1
diffs.append(diffs_ex)
cmnt_words = to_var(torch.LongTensor(cmnt_words))
cmnt_chars = to_var(torch.stack(cmnt_chars, 0))
ctx_words = to_var(torch.LongTensor(ctx_words))
ctx_chars = to_var(torch.stack(ctx_chars, 0))
ans = to_var(torch.LongTensor(ans))
diffs = to_var(torch.FloatTensor(diffs))
return cmnt_words, cmnt_chars, ctx_words, ctx_chars
'''
Generate the word vector for each batch
'''
def make_vector(batch, w2i, cmnt_sent_len, ctx_sent_len):
cmnt_words, src_token, src_action, tgt_token, tgt_action = [], [], [], [], []
# batch_cmnt, batch_neg_cmnt, batch_origin, batch_target
for i in range(len(batch[0])):
cmnt_words.append(_make_word_vector(batch[0][i], w2i, cmnt_sent_len))
src_token.append(_make_word_vector(batch[1][i], w2i, ctx_sent_len))
src_action.append(_make_action_vector(batch[2][i], ctx_sent_len))
tgt_token.append(_make_word_vector(batch[3][i], w2i, ctx_sent_len))
tgt_action.append(_make_action_vector(batch[4][i], ctx_sent_len))
cmnt_words = to_var(torch.LongTensor(cmnt_words))
# neg_cmnt_words = to_var(torch.LongTensor(neg_cmnt_words))
src_token = to_var(torch.LongTensor(src_token))
src_action = to_var(torch.LongTensor(src_action))
tgt_token = to_var(torch.LongTensor(tgt_token))
tgt_action = to_var(torch.LongTensor(tgt_action))
return cmnt_words, src_token, src_action, tgt_token, tgt_action
'''
generate the batches for training and evaluation
type definition:
type 1: for the comment rank task
type 2: for the diff anchoring task
type 3: use the target diff only
'''
def gen_cmntrank_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len, rank_num):
'''
Batch Content:
0,1. batch comment, batch neg_cmnt
2,3. batch src_tokens, batch src_actions
4,5. batch tgt_tokens, batch tgt_actions
6,7. batch pos_edits, batch neg_edits
'''
pos_cmnts, pos_src_tokens, pos_src_actions, pos_tgt_tokens, pos_tgt_actions = [], [], [], [], []
neg_cmnts, neg_src_tokens, neg_src_actions, neg_tgt_tokens, neg_tgt_actions = [], [], [], [], []
sample_index_list = []
for i in range(len(batch[0])):
sample_size = 0
cmnt = batch[0][i]
neg_cmnt = batch[1][i]
src_tokens = batch[2][i]
src_actions = batch[3][i]
tgt_tokens = batch[4][i]
tgt_actions = batch[5][i]
if rank_num - 1 > len(neg_cmnt):
continue
neg_sample_indices = random.sample(range(len(neg_cmnt)), rank_num - 1)
for neg_idx in neg_sample_indices:
pos_cmnts.append(cmnt)
pos_src_tokens.append(src_tokens)
pos_src_actions.append(src_actions)
pos_tgt_tokens.append(tgt_tokens)
pos_tgt_actions.append(tgt_actions)
neg_cmnts.append(neg_cmnt[neg_idx])
neg_src_tokens.append(src_tokens)
neg_src_actions.append(src_actions)
neg_tgt_tokens.append(tgt_tokens)
neg_tgt_actions.append(tgt_actions)
sample_size += 1
sample_index_list.append(sample_size)
return (pos_cmnts, pos_src_tokens, pos_src_actions, pos_tgt_tokens, pos_tgt_actions), \
(neg_cmnts, neg_src_tokens, neg_src_actions, neg_tgt_tokens, neg_tgt_actions)
def gen_editanch_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len, anchor_num):
'''
Batch Content:
0,1. batch comment, batch neg_cmnt
2,3. batch src_tokens, batch src_actions
4,5. batch tgt_tokens, batch tgt_actions
6,7. batch pos_edits, batch neg_edits
'''
cmnts, src_tokens, src_actions, tgt_tokens, tgt_actions, ea_truth = [], [], [], [], [], []
for i in range(len(batch[0])):
cmnt = batch[0][i]
pos_edits = batch[6][i]
neg_edits = batch[7][i]
if len(pos_edits) > anchor_num:
pos_edits = pos_edits[:anchor_num]
if anchor_num - len(pos_edits) < 0:
neg_sample_indices = []
elif anchor_num - len(pos_edits) > len(neg_edits):
neg_sample_indices = range(len(neg_edits))
else:
neg_sample_indices = random.sample(range(len(neg_edits)), anchor_num - len(pos_edits))
for pos_edit in pos_edits:
cmnts.append(cmnt)
src_tokens.append([])
src_actions.append([])
tgt_tokens.append(pos_edit)
tgt_actions.append([1] * len(pos_edit))
ea_truth.append(1)
for neg_idx in neg_sample_indices:
cmnts.append(cmnt)
src_tokens.append([])
src_actions.append([])
tgt_tokens.append(neg_edits[neg_idx])
tgt_actions.append([1] * len(neg_edits[neg_idx]))
ea_truth.append(0)
return (cmnts, src_tokens, src_actions, tgt_tokens, tgt_actions), ea_truth
def find_cont_diffs(tokens, token_diff):
# split the token_diff into the consecutive parts
token_cont_list = []
if len(token_diff) == 0:
return token_cont_list
if len(token_diff) == 1:
token_cont_list.append(token_diff)
return token_cont_list
start_idx, cur_idx = 0, 1
while cur_idx < len(token_diff):
# if cur_idx == len(token_diff) - 1:
# token_list.append(list(range(start_idx, cur_idx + 1)))
# cur_idx += 1
if token_diff[cur_idx] != token_diff[cur_idx - 1] + 1:
token_cont_list.append(list(range(token_diff[start_idx], token_diff[cur_idx - 1] + 1)))
start_idx = cur_idx
cur_idx += 1
else:
cur_idx += 1
# handle the last list
token_cont_list.append(list(range(token_diff[start_idx], token_diff[cur_idx - 1] + 1)))
return token_cont_list
def find_diff_context(tokens, token_diff, context_length=50):
cont_difflist = find_cont_diffs(tokens, token_diff)
diff_context = set()
for cont_diff in cont_difflist:
# avoid the case when only one markup or space included in the context
if len(cont_diff) == 1 and len(tokens[cont_diff[0]]) <= 1:
continue
diff_context_cur = find_diff_context_int(tokens, cont_diff, context_length)
diff_context.update(diff_context_cur)
diff_context = [x for x in diff_context if x not in token_diff]
return sorted(list(diff_context))
'''
Find context difference
The function requires token_diff is consecutive.
'''
def find_diff_context_int(tokens, token_diff, context_length):
if len(token_diff) == 0:
return []
# if len(token_diff) == 1:
# diff_context = [token_diff[0]]
# else:
# diff_context = range(token_diff[0], token_diff[-1] + 1)
# diff_context = [x for x in diff_context if x not in token_diff]
start_idx = token_diff[0]
end_idx = token_diff[-1]
context_start = start_idx - context_length
context_start = context_start if context_start > 0 else 0
context_end = end_idx + context_length + 1
context_end = context_end if context_end < len(tokens) else len(tokens)
diff_context = list(range(context_start, start_idx)) + list(range(end_idx + 1, context_end))
diff_words = [tokens[i] for i in diff_context]
# if len(diff_context) > context_length:
# diff_context = diff_context[:int(context_length/2)] + diff_context[-int(context_length/2):]
# else:
# remain_length = context_length - len(diff_context)
return diff_context

148
train.py Normal file
Просмотреть файл

@ -0,0 +1,148 @@
import datetime
import glob
import os
import sys
import torch
import torch.nn.functional as F
from eval import eval
from process_data import to_var, make_vector, gen_cmntrank_batches, gen_editanch_batches
def rank_loss(score_pos, score_neg):
# normalize context attention
# ctx_att_norm = F.normalize(ctx_att, p=2, dim=1)
batch_size = len(score_pos)
# y = Variable(torch.FloatTensor([1] * batch_size))
margin = to_var(torch.FloatTensor(score_pos.size()).fill_(1))
# loss = F.margin_ranking_loss(score_pos, -score_neg, y, margin=10.0)
loss_list = margin - score_pos + score_neg
loss_list = loss_list.clamp(min=0)
loss = loss_list.sum()
return loss, loss_list
def cal_batch_loss(loss_list, batch_size, index_list):
loss_list = loss_list.data.squeeze(1).numpy().tolist()
loss_step = int(len(loss_list) / batch_size)
# return [sum(loss_list[i * loss_step: (i + 1) * loss_step]) / loss_step for i in range(batch_size)]
start = 0
batch_loss_list = []
for i in range(batch_size):
cur_bs = index_list[i]
end = start + cur_bs
if cur_bs == 0:
batch_loss_list.append(0)
else:
batch_loss_list.append(sum(loss_list[start:end]) / cur_bs)
start = end
return batch_loss_list
def train(args, model, dataset, train_df, val_df, optimizer, w2i, n_epoch, start_epoch, batch_size):
print('----Train---')
label = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
model.train()
cmnt_sent_len = args.max_cmnt_length
diff_sent_len = args.max_diff_length
ctx_sent_len = args.max_ctx_length
train_size = len(train_df)
# if args.use_cl:
# sample_size = int(train_size * 0.5)
# else:
# sample_size = -1
sample_size = int(train_size)
cr_best_acc, ea_best_acc = 0, 0
for epoch in range(start_epoch, n_epoch + 1):
print('============================== Epoch ', epoch, ' ==============================')
for step, batch in enumerate(dataset.iterate_minibatches(train_df, batch_size, epoch, n_epoch), start=1):
batch_sample_weights = None
total_loss = 0
if args.cr_train:
cr_pos_batch, cr_neg_batch = gen_cmntrank_batches(batch, w2i, cmnt_sent_len, diff_sent_len,
ctx_sent_len, args.rank_num)
if len(cr_pos_batch[0]) > 0:
# TODO: tuning the code for more effective way: combine the positive sample and negative samples
pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action = \
make_vector(cr_pos_batch, w2i, cmnt_sent_len, ctx_sent_len)
neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
make_vector(cr_neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
score_pos, _ = model(pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action,
cr_mode=True)
score_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action,
cr_mode=True)
loss, _ = rank_loss(score_pos, score_neg)
# batch_sample_weights = cal_batch_loss(loss_list, batch_size, index_list)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# train revision anchoring
if args.ea_train:
ea_batch, ea_truth = gen_editanch_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
args.anchor_num)
if len(ea_batch[0]) > 0:
cmnt, src_token, src_action, tgt_token, tgt_action = \
make_vector(ea_batch, w2i, cmnt_sent_len, ctx_sent_len)
logit, _ = model(cmnt, src_token, src_action, tgt_token, tgt_action, cr_mode=False)
target = to_var(torch.tensor(ea_truth))
loss = F.cross_entropy(logit, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update the loss for each data sample
# new_sample_weights += batch_sample_weights
if step % args.log_interval == 0:
# corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
# p1_corr, p3_corr, p5_corr, mrr, ndcg = eval_rank(score_pos, score_neg, args.batch_size)
# p1_acc = p1_corr / args.batch_size * 100
# p3_acc = p3_corr / args.batch_size * 100
# p5_acc = p5_corr / args.batch_size * 100
try:
sys.stdout.write(
'\rEpoch[{}] Batch[{}] - loss: {:.6f}\n'.format(epoch, step, loss.data.item()))
sys.stdout.flush()
except:
print("Unexpected error:", sys.exc_info()[0])
if step % args.test_interval == 0:
if args.val_size > 0:
val_df = val_df[:args.val_size]
cr_acc, ea_acc = eval(dataset, val_df, w2i, model, args)
model.train() # change model back to training mode
if args.cr_train and cr_acc > cr_best_acc:
cr_best_acc = cr_acc
if args.save_best:
save(model, args.save_dir, 'best_cr', epoch, step, cr_best_acc, args.no_action)
if args.ea_train and ea_acc > ea_best_acc:
ea_best_acc = ea_acc
if args.save_best:
save(model, args.save_dir, 'best_ea', epoch, step, ea_best_acc, args.no_action)
# sample_weights = new_sample_weights
def save(model, save_dir, save_prefix, epoch, steps, best_result, no_action=False):
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
save_prefix = os.path.join(save_dir, save_prefix)
# delete previously saved checkpoints
exist_files = sorted(glob.glob(save_prefix + '*'))
for file_name in exist_files:
if os.path.exists(file_name):
os.remove(file_name)
result_str = '%.3f' % best_result
save_path = '{}_steps_{:02}_{:06}_{}.pt'.format(save_prefix, epoch, steps, result_str)
print("Save best model", save_path)
torch.save(model.state_dict(), save_path)

492
wiki_util.py Normal file
Просмотреть файл

@ -0,0 +1,492 @@
import difflib
import glob
import html
import random
import re
import spacy
from tqdm import tqdm
# initialize the spacy
nlp = spacy.load('en')
'''
Extract the contents by delimitors
'''
def extract_with_delims(content, start_delim, end_delim, start_idx):
delims_start = content.find(start_delim, start_idx)
if delims_start == -1:
return '', start_idx
delims_end = content.find(end_delim, start_idx)
if delims_end == -1:
return '', start_idx
if delims_end <= delims_start:
return '', start_idx
delims_start += len(start_delim)
return content[delims_start:delims_end], delims_end
'''
Extract the contents of revisions, e.g., revision_id, parent_id, user_name, comment, text
'''
def extract_data(revision_part):
rev_id, next_idx = extract_with_delims(revision_part, "<id>", "</id>", 0)
parent_id, next_idx = extract_with_delims(revision_part, "<parentid>", "</parentid>", next_idx)
timestamp, next_idx = extract_with_delims(revision_part, "<timestamp>", "</timestamp>", next_idx)
username, next_idx = extract_with_delims(revision_part, "<username>", "</username>", next_idx)
userid, next_idx = extract_with_delims(revision_part, "<id>", "</id>", next_idx)
# For annoymous user, the ip address will be used instead of the user name and id
userip, next_idx = extract_with_delims(revision_part, "<ip>", "</ip>", next_idx)
comment, next_idx = extract_with_delims(revision_part, "<comment>", "</comment>", next_idx)
text, next_idx = extract_with_delims(revision_part, "<text xml:space=\"preserve\">", "</text>", next_idx)
return (rev_id, parent_id, timestamp, username, userid, userip, comment, text)
'''
Extract the revision text buffer, which has the format "<revision> ... </revision>".
'''
def split_records(wiki_file, chunk_size=150 * 1024):
text_buffer = ""
cur_index = 0
while True:
chunk = wiki_file.read(chunk_size)
if chunk:
text_buffer += chunk
cur_index = 0
REVISION_START = "<revision>"
REVISION_END = "</revision>"
PAGE_START = "<page>"
PAGE_TITLE_START = "<title>"
PAGE_TITLE_END = "</title>"
while True:
page_start_index = text_buffer.find(PAGE_START, cur_index)
if page_start_index != -1:
# update the current page title/ID
page_title, _ = extract_with_delims(text_buffer, PAGE_TITLE_START, PAGE_TITLE_END, 0)
if not page_title:
# no complete page title
break
#logging.debug("Error: page information is cut. FIX THIS ISSUE!!!")
# find the revision start position
revision_start_index = text_buffer.find(REVISION_START, cur_index)
# No revision in the buffer, continue loading data
if revision_start_index == -1:
break
# find the revision end position
revision_end_index = text_buffer.find(REVISION_END, revision_start_index)
# No complete page in buffer
if revision_end_index == -1:
break
yield page_title, text_buffer[revision_start_index:revision_end_index + len(REVISION_END)]
cur_index = revision_end_index + len(REVISION_END)
# No more data
if chunk == "":
break
if cur_index == -1:
text_buffer = ""
else:
text_buffer = text_buffer[cur_index:]
def sampleNext(sample_ratio):
return random.random() < sample_ratio
def cleanCmntText(comment):
filter_words = []
comment = comment.replace("(edited with [[User:ProveIt_GT|ProveIt]]", "")
#comment = re.sub("(edited with \[\[User\:ProveIt_GT\|ProveIt\]\]", "", comment)
return comment
def checkComment(comment, comment_tokens, min_comment_length):
if len(comment_tokens) < min_comment_length:
return False
filter_words = ["[[Project:AWB|AWB]]", "[[Project:AutoWikiBrowser|AWB]]", "Undid revision"]
if any(word in comment for word in filter_words):
return False
return True
'''
clean the wiki text
E.g. "[[link name]] a&quot; bds&quot; ''markup''" to "link name a bds markup"
'''
def cleanWikiText(wiki_text):
# replace link: [[link_name]] and quotes
wiki_text = re.sub("\[\[", "", wiki_text)
wiki_text = re.sub("\]\]", "", wiki_text)
wiki_text = re.sub("''", "", wiki_text)
# replace '<', '>', '&'
# wiki_text = re.sub("&quot;", "", wiki_text)
# wiki_text = re.sub("&lt;", "<", wiki_text)
# wiki_text = re.sub("&gt;", ">", wiki_text)
# wiki_text = re.sub("&amp;", "&", wiki_text)
# use html unescape to decode the html special characters
wiki_text = html.unescape(wiki_text)
return wiki_text
def tokenizeText(text):
doc = nlp(text)
sentences = [sent.string.strip() for sent in doc.sents]
sent_tokens = []
tokens = []
for sent in sentences:
sent_doc = nlp(sent)
token_one_sent = [token.string.strip() for token in sent_doc]
sent_tokens.append(token_one_sent)
tokens += token_one_sent
return sent_tokens, tokens
# return the difference indices starting at 0.
def diffRevision(parent_sent_list, sent_list):
# make diff
origin_start_idx, origin_start_idx = -1, -1
target_start_idx, target_end_idx = -1, -1
origin_diff_list, target_diff_list = [], []
for line in difflib.context_diff(parent_sent_list, sent_list, 'origin', 'target'):
#print(line)
# parse the origin diff line range: e.g., --- 56,62 ----
if line.startswith("*** ") and line.endswith(" ****\n"):
target_start_idx, target_end_idx = -1, -1 # reset the target indices
range_line = line[4:-6]
if ',' not in range_line:
origin_start_idx = int(range_line)
origin_end_idx = origin_start_idx
else:
origin_start_idx, origin_end_idx = [int(i) for i in range_line.split(',')]
origin_sent_idx = origin_start_idx
continue
# parse the diff line range: e.g., --- 56,62 ----
if line.startswith("--- ") and line.endswith(" ----\n"):
origin_start_idx, origin_end_idx = -1, -1 # reset the origin indices
range_line = line[4:-6]
if ',' not in range_line:
target_start_idx = int(range_line)
target_end_idx = target_start_idx
else:
target_start_idx, target_end_idx = [int(i) for i in range_line.split(',')]
target_sent_idx = target_start_idx
continue
if origin_start_idx >= 0:
if len(line.strip('\n')) == 0:
continue
elif line.startswith('-') or line.startswith('!'):
origin_diff_list.append(origin_sent_idx - 1) # adding the index starting at 0
origin_sent_idx += 1
else:
origin_sent_idx += 1
if target_start_idx >= 0:
if len(line.strip('\n')) == 0:
continue
elif line.startswith('+') or line.startswith('!'):
target_diff_list.append(target_sent_idx - 1) # adding the index starting at 0
target_sent_idx += 1
else:
target_sent_idx += 1
#print("Extracted Diff:", diff_sent_list)
return origin_diff_list, target_diff_list
def findSentDiff(sents, tokens, token_diff):
diff_idx, sent_idx, token_offset = 0, 0, 0
diff_sents = set()
if len(token_diff) == 0 or len(sents) == 0:
return list(diff_sents)
token_offset = len(sents[0])
while diff_idx < len(token_diff) and sent_idx < len(sents):
if token_offset >= token_diff[diff_idx]:
cur_token = tokens[diff_idx]
# avoid the case that one more sentence added because only one markup included.
if len(cur_token) > 1:
diff_sents.add(sent_idx)
diff_idx += 1
else:
sent_idx += 1
token_offset += len(sents[sent_idx])
return list(diff_sents)
def extContext(tokens, token_diff, ctx_window):
'''
Extend the context into the token difference
:param token_diff: a list of tokens which represents the difference between previous version and current version.
ctx_window: the size of context before and after the edits
:return:
For example: token_diff = [2, 3, 4, 11, 12, 13, 14, 16] and ctx_window = 2
The function will return ctx_tokens = [0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
action = [0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0]
'''
ctx_set = set(token_diff)
for idx in token_diff:
for i in range(idx - ctx_window, idx + ctx_window + 1):
if i < 0 or i >= len(tokens):
continue
ctx_set.add(i)
action = []
ctx_token = []
ctx_token_idx = sorted(list(ctx_set))
diff_set = set(token_diff)
for i in ctx_token_idx:
ctx_token.append(tokens[i])
if i in diff_set:
action.append(1)
else:
action.append(0)
return ctx_token, action
# def findDiffSents(sents, diff_list):
# token_idx = 0
# start_sent_idx, end_sent_idx = -1, -1
# for i, sent in enumerate(sents):
# token_idx += len(sent)
# if start_sent_idx < 0 and diff_list[0] < token_idx:
# start_sent_idx = i
# if end_sent_idx < 0 and diff_list[-1] < token_idx:
# end_sent_idx = i
# if start_sent_idx >= 0 and end_sent_idx >= 0:
# break
# return start_sent_idx, end_sent_idx
def extDiffSents(sents, start, end, max_tokens=200, max_sents_add=2):
token_size = 0
for i in range(start, end + 1):
token_size += len(sents[i])
context_sents = sents[start:end + 1]
sents_head_added, sents_tail_added = 0, 0
while token_size <= max_tokens and len(context_sents) < len(sents)\
and sents_head_added < max_sents_add and sents_tail_added < max_sents_add:
if start > 0:
start -= 1
insert_sent = sents[start]
context_sents.insert(0, insert_sent)
token_size += len(insert_sent)
sents_head_added += 1
if end < len(sents) - 1:
end += 1
insert_sent = sents[end]
context_sents.append(insert_sent)
token_size += len(insert_sent)
sents_tail_added += 1
diff_offset = sum([len(sents[i]) for i in range(start)])
return context_sents, diff_offset
def mapDiffContext(sents, start_sent, end_sent):
start_idx = -1
end_idx = -1
start_sent_str = " ".join(start_sent)
end_sent_str = " ".join(end_sent)
for i, sent in enumerate(sents):
sent_str = " ".join(sent)
if start_idx < 0 and start_sent_str in sent_str:
start_idx = i
if end_sent_str in sent_str:
end_idx = i
break
# if start_idx == -1:
# start_idx = 0
# if end_idx == -1:
# context_sents = sents[start_idx:]
# else:
# context_sents = sents[start_idx:end_idx + 1]
if start_idx == -1 or end_idx == -1:
return None, None
diff_offset = sum([len(sents[i]) for i in range(start_idx)])
return sents[start_idx:end_idx + 1], diff_offset
# def extDiffContextInt(target_context_sents, target_sents, target_start, target_end, token_size, max_tokens=200, max_sents_add=2):
# sents_head_added, sents_tail_added = 0, 0
# return target_context_sents
def calTokenSize(sents):
return sum([len(sent) for sent in sents])
def isSameSent(origin_sent, target_sent):
origin_size = len(origin_sent)
target_size = len(target_sent)
if origin_size != target_size:
return False
for i in range(origin_size):
if origin_sent[i] != target_sent[i]:
return False
return True
def stripContext(origin_sents, target_sents, max_token_size=200):
start_match = True
end_match = True
origin_token_size = calTokenSize(origin_sents)
target_token_size = calTokenSize(target_sents)
diff_offset = 0
while origin_token_size > max_token_size or target_token_size > max_token_size:
if len(origin_sents) == 0 or len(target_sents) == 0:
break
if start_match and isSameSent(origin_sents[0], target_sents[0]):
# remove the sentence from both origin and target
sent_size = len(origin_sents[0])
origin_sents = origin_sents[1:]
target_sents = target_sents[1:]
diff_offset += sent_size
origin_token_size -= sent_size
target_token_size -= sent_size
else:
start_match = False
if len(origin_sents) == 0 or len(target_sents) == 0:
break
if end_match and isSameSent(origin_sents[-1], target_sents[-1]):
sent_size = len(origin_sents[-1])
origin_sents = origin_sents[:-1]
target_sents = target_sents[:-1]
origin_token_size -= sent_size
target_token_size -= sent_size
else:
end_match = False
if not start_match and not end_match:
break
if origin_token_size > max_token_size or target_token_size > max_token_size:
origin_sents, target_sents = None, None
return origin_sents, target_sents, diff_offset
def extractDiffContext(origin_sents, target_sents, origin_diff, target_diff, max_tokens=200):
origin_context, target_context, diff_offset = stripContext(origin_sents, target_sents)
if origin_context != None and target_context != None:
origin_diff = [i - diff_offset for i in origin_diff]
target_diff = [i - diff_offset for i in target_diff]
# fix the issue in the case that the appended dot belonging to current sentence is marked as the previous sentence. See example:
# + .
# + Warren
# + ,
# ...
# + -
# + syndicalists
# .
if len(target_diff) > 0 and target_diff[0] == -1:
target_diff = target_diff[1:]
return origin_context, target_context, origin_diff, target_diff
def filterRevision(comment, diff_list, max_sent_length):
filter_pattern = "^.*\s*\W*(revert|undo|undid)(.*)$"
filter_regex = re.compile(filter_pattern, re.IGNORECASE)
if len(diff_list) == 0 or len(diff_list) > max_sent_length:
return True
comment = comment.strip()
if not comment.startswith('/*'):
return True
elif comment.startswith('/*') and comment.endswith('*/'):
return True
filter_match = filter_regex.match(comment)
if filter_match:
return True
return False
#return filter_match
comment_pattern = "^/\*(.+)\*/(.*)$"
comment_regex = re.compile(comment_pattern)
def extractSectionTitle(comment):
sect_title, sect_cmnt = '', comment
comment_match = comment_regex.match(comment)
if comment_match:
sect_title = comment_match.group(1).strip()
sect_cmnt = html.unescape(comment_match.group(2).strip()).strip()
return sect_title, sect_cmnt
def extractSectionText(text, sect_title):
sect_content = ''
text_match = re.search('(=+)\s*' + sect_title + '\s*(=+)', text)
if text_match:
sect_sign = text_match.group(1)
sect_sign_end = text_match.group(2)
if sect_sign != sect_sign_end:
print("ALERT: Section Data Corrupted!! Skip!!")
return sect_content
sect_start = text_match.regs[2][1]
remain_sect = text[sect_start:].strip().strip('\n')
# TODO: Fix the bug of comparing the ===Section Title=== after ==Section==
next_sect_match = re.search(sect_sign + '.*' + sect_sign, remain_sect)
if next_sect_match:
sect_end = next_sect_match.regs[0][0]
sect_content = remain_sect[:sect_end].strip().strip('\n')
else:
sect_content = remain_sect
return sect_content
'''
Merge the sample outputs of all the dump files
'''
def mergeOutputs(output_path):
# merge sample results
print("Merging the sampled outputs from each files ...")
sample_list = glob.glob(output_path + '*.json')
sample_file = open(output_path + 'wikicmnt.json', "w", encoding='utf-8')
for fn in tqdm(sample_list):
with open(fn, 'r', encoding='utf-8') as fi:
sample_file.write(fi.read())

309
wikicmnt_dataset.py Normal file
Просмотреть файл

@ -0,0 +1,309 @@
import itertools
import json
import os
import os.path
from collections import Counter
import numpy as np
# from torchtext import data
import pandas as pd
import torch
from nltk.stem import WordNetLemmatizer
from torch.autograd import Variable
from tqdm import tqdm
from wiki_util import tokenizeText
class Wiki_DataSet:
def __init__(self, args):
"""Create an Wiki dataset instance. """
# self.word_embed_file = self.data_folder + 'embedding/wiki.ar.vec'
# word_embed_file = data_folder + "embedding/Wiki-CBOW"
self.data_dir = args.data_path
self.data_file = self.data_dir + "wikicmnt.json"
self.vocab_file = self.data_dir + 'vocabulary.json'
self.train_df_file = self.data_dir + 'train_df.pkl'
self.val_df_file = self.data_dir + 'val_df.pkl'
self.test_df_file = self.data_dir + 'test_df.pkl'
self.tf_file = self.data_dir + 'term_freq.json'
self.weight_file = self.data_dir + 'train_weights.json'
self.glove_dir = args.glove_path
self.glove_vec_size = args.word_embd_size
self.lemmatizer = WordNetLemmatizer()
self.class_num = -1
self.rank_num = args.rank_num
self.anchor_num = args.anchor_num
self.max_ctx_length = int(args.max_ctx_length / 2)
pass
'''
Extract the data from raw json file
'''
def extract_data_from_json(self):
'''
Parse the json file and return the
:return:
'''
cmnt_list, neg_cmnts_list = [], []
src_token_list, src_action_list = [], []
tgt_token_list, tgt_action_list = [], []
pos_edits_list, neg_edits_list = [], []
diff_url_list = []
word_counter, lower_word_counter = Counter(), Counter()
print("Sample file:", self.data_file)
with open(self.data_file, 'r', encoding='utf-8') as f:
for idx, json_line in enumerate(tqdm(f)):
# if idx % 100 == 0:
# print("== processed ", idx)
article = json.loads(json_line.strip('\n'))
# print(article['diff_url'])
'''
Json file format:
=================
revision_id: The revision ID
parent_id: The parent revision ID
timestamp: Timestamp
diff_url: The wikipedia link to show the difference between previous and current version.
page_title: The title of page.
comment: Revision comment.
src_token: List of tokens in before-editing version
src_action: Action flags for each token in before-editing version. E.g., 0 represents no action; -1 represents removed token.
tgt_token: List of tokens in after-editing version
tgt_action: Action flags for each token in after-editing version. E.g., 0 represents no action; 1 represents added token.
neg_cmnts: Negative samples of user comments in the same page.
pos_edits: Edit sentences for comments.
neg_edits: Negative edit sentences for comments.
'''
try:
# comment
comment = article['comment']
_, cmnt_tokens = tokenizeText(comment)
# cmnt_tokens = article['comment']
cmnt_list.append(cmnt_tokens)
# negative comments
# neg_cmnts = article['neg_cmnts']
neg_cmnts = []
for neg_cmnt in article['neg_cmnts']:
_, tokens = tokenizeText(comment)
neg_cmnts.append(tokens)
neg_cmnts_list.append(neg_cmnts)
# source tokens and actions
src_token = article['src_token']
src_token_list.append(src_token)
src_action_list.append(article['src_action'])
# target tokens and actions
tgt_token = article['tgt_token']
tgt_token_list.append(tgt_token)
tgt_action_list.append(article['tgt_action'])
# positive and negative edits
pos_edits = article['pos_edits']
pos_edits_list.append(pos_edits)
neg_edits = article['neg_edits']
neg_edits_list.append(neg_edits)
# diff url
diff_url = article['diff_url']
diff_url_list.append(diff_url)
# for counters
for word in cmnt_tokens + src_token + tgt_token + \
list(itertools.chain.from_iterable(neg_cmnts + pos_edits + neg_edits)):
word_counter[word] += 1
lower_word_counter[word.lower()] += 1
except:
# ignore the index error
print("ERROR: Index Error", article['revision_id'])
continue
# if idx >= 100:
# break
return cmnt_list, neg_cmnts_list, src_token_list, src_action_list, tgt_token_list, tgt_action_list, \
pos_edits_list, neg_edits_list, word_counter, lower_word_counter, diff_url_list
'''
Create dataset objects for wiki revision data.
Arguments:
args: arguments
val_ratio: The ratio that will be used to get split validation dataset.
shuffle: Whether to shuffle the data before split.
'''
def load_data(self, train_ratio, val_ratio):
print("loading wiki data ...")
# check the existence of data files
if os.path.isfile(self.train_df_file) and os.path.isfile(self.test_df_file) and os.path.isfile(self.vocab_file):
print("dataframe file exists:", self.train_df_file)
train_df = pd.read_pickle(self.train_df_file)
val_df = pd.read_pickle(self.val_df_file)
test_df = pd.read_pickle(self.test_df_file)
vocab_json = json.load(open(self.vocab_file))
else:
cmnts, neg_cmnts, src_tokens, src_actions, \
tgt_tokens, tgt_actions, pos_edits, neg_edits, \
word_counter, lower_word_counter, diff_url = self.extract_data_from_json()
word2vec_dict = self.get_word2vec(word_counter)
lower_word2vec_dict = self.get_word2vec(lower_word_counter)
df = pd.DataFrame(
{
'cmnt_words': cmnts, "neg_cmnts": neg_cmnts,
'src_tokens': src_tokens, "src_actions": src_actions,
'tgt_tokens': tgt_tokens, "tgt_actions": tgt_actions,
'pos_edits': pos_edits, 'neg_edits': neg_edits, 'diff_url': diff_url
}
)
total_size = len(df)
self.train_size = int(total_size * train_ratio)
val_size = int(total_size * val_ratio)
# test_ratio = 0.2
# train_df, test_df = train_test_split(df,
# test_size=test_ratio, random_state=967898)
train_df = df[:self.train_size]
val_df = df[self.train_size:self.train_size + val_size]
test_df = df[self.train_size + val_size:]
print("saving data into pickle ...")
train_df.to_pickle(self.data_dir + 'train_df.pkl')
test_df.to_pickle(self.data_dir + 'test_df.pkl')
val_df.to_pickle(self.data_dir + 'val_df.pkl')
w2i = {w: i for i, w in enumerate(word_counter.keys(), 3)}
NULL = "-NULL-"
UNK = "-UNK-"
ENT = "-ENT-"
w2i[NULL] = 0
w2i[UNK] = 1
w2i[ENT] = 2
# save word2vec dictionary
vocab_json = {'word2idx': w2i, 'word2vec': word2vec_dict, 'lower_word2vec': lower_word2vec_dict}
json.dump(vocab_json, open(self.vocab_file, 'w', encoding='utf-8'))
return train_df, val_df, test_df, vocab_json
'''
batch padding
'''
def pad_batch(self, mini_batch, padding_size):
mini_batch_size = len(mini_batch)
# mean_sent_len = int(np.mean([len(x) for x in mini_batch]))
main_matrix = np.zeros((mini_batch_size, padding_size), dtype=np.long)
for i in range(main_matrix.shape[0]):
for j in range(main_matrix.shape[1]):
try:
main_matrix[i, j] = mini_batch[i][j]
except IndexError:
pass
# transfer the tensor to LongTensor for some compatibility issues
return Variable(torch.from_numpy(main_matrix).transpose(0, 1).type(torch.LongTensor))
'''
Generate minibatches from data frame
'''
def iterate_minibatches(self, df, batch_size, cur_epoch=-1, n_epoch=-1):
cmnt_words = df.cmnt_words.tolist()
neg_cmnts = df.neg_cmnts.tolist()
src_tokens = df.src_tokens.tolist()
src_actions = df.src_actions.tolist()
tgt_tokens = df.tgt_tokens.tolist()
tgt_actions = df.tgt_actions.tolist()
pos_edits = df.pos_edits.tolist()
neg_edits = df.neg_edits.tolist()
diff_urls = df.diff_url.tolist()
indices = np.arange(len(cmnt_words))
np.random.shuffle(indices)
cmnt_words = [cmnt_words[i] for i in indices]
neg_cmnts = [neg_cmnts[i] for i in indices]
src_tokens = [src_tokens[i] for i in indices]
src_actions = [src_actions[i] for i in indices]
tgt_tokens = [tgt_tokens[i] for i in indices]
tgt_actions = [tgt_actions[i] for i in indices]
pos_edits = [pos_edits[i] for i in indices]
neg_edits = [neg_edits[i] for i in indices]
diff_urls = [diff_urls[i] for i in indices]
for start_idx in range(0, len(cmnt_words) - batch_size + 1, batch_size):
# initialize batch variables
batch_cmnt, batch_neg_cmnt, batch_src_tokens, batch_src_actions, \
batch_tgt_tokens, batch_tgt_actions, batch_pos_edits, batch_neg_edits, batch_diffurls = [], [], [], [], [], [], [], [], []
for i in range(start_idx, start_idx + batch_size):
batch_cmnt.append(cmnt_words[i])
batch_neg_cmnt.append(neg_cmnts[i])
batch_src_tokens.append(src_tokens[i])
batch_src_actions.append(src_actions[i])
batch_tgt_tokens.append(tgt_tokens[i])
batch_tgt_actions.append(tgt_actions[i])
batch_pos_edits.append(pos_edits[i])
batch_neg_edits.append(neg_edits[i])
batch_diffurls.append(diff_urls[i])
yield batch_cmnt, batch_neg_cmnt, batch_src_tokens, batch_src_actions, batch_tgt_tokens, batch_tgt_actions, batch_pos_edits, batch_neg_edits, batch_diffurls
def get_datafile(self):
return self.data_file
def get_tffile(self):
return self.tf_file
def get_weight_file(self):
return self.weight_file
def get_train_size(self):
return self.train_size
def get_word2vec(self, word_counter):
glove_path = os.path.join(self.glove_dir, "glove.6B.{}d.txt".format(self.glove_vec_size))
print('----glove_path', glove_path)
sizes = {'6B': int(4e5), '42B': int(1.9e6), '840B': int(2.2e6), '2B': int(1.2e6)}
total = sizes['6B']
word2vec_dict = {}
with open(glove_path, 'r', encoding='utf-8') as fh:
for line in tqdm(fh, total=total):
array = line.lstrip().rstrip().split(" ")
word = array[0]
vector = list(map(float, array[1:]))
if word in word_counter:
word2vec_dict[word] = vector
elif word.capitalize() in word_counter:
word2vec_dict[word.capitalize()] = vector
elif word.lower() in word_counter:
word2vec_dict[word.lower()] = vector
elif word.upper() in word_counter:
word2vec_dict[word.upper()] = vector
print("{}/{} of word vocab have corresponding vectors in {}".format(len(word2vec_dict), len(word_counter),
glove_path))
return word2vec_dict

196
wikicmnt_model.py Normal file
Просмотреть файл

@ -0,0 +1,196 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# from layers.highway import Highway
class WordEmbedding(nn.Module):
'''
In : (N, sentence_len)
Out: (N, sentence_len, embd_size)
'''
def __init__(self, args, is_train_embd=False):
super(WordEmbedding, self).__init__()
self.embedding = nn.Embedding(args.vocab_size_w, args.word_embd_size)
if args.pre_embd_w is not None:
self.embedding.weight = nn.Parameter(args.pre_embd_w, requires_grad=is_train_embd)
def forward(self, x):
return self.embedding(x)
class CmntModel(nn.Module):
def __init__(self, args):
super(CmntModel, self).__init__()
self.batch_size = args.batch_size
self.embd_size = args.word_embd_size
self.cmnt_length = args.max_cmnt_length
self.d = self.embd_size # word_embedding
self.no_action = args.no_action
self.no_attention = args.no_attention
self.no_hadamard = args.no_hadamard
# self.d = self.embd_size # only word_embedding
# self.char_embd_net = CharEmbedding(args)
self.word_embd_net = WordEmbedding(args)
# self.highway_net = Highway(self.d)
self.ctx_embd_layer = nn.GRU(self.d, self.d, bidirectional=True, dropout=0.2, batch_first=True)
self.W = nn.Linear(6 * self.d + 1, 1, bias=False)
# self.W2_noact = nn.Linear(2 * self.d, 1, bias=False)
# self.W2 = nn.Linear(2 * self.d + 1, 1, bias=False)
# weights for attention layer
if self.no_hadamard and self.no_action:
# (1, 1)
self.W2_nhna = nn.Linear(1, 1, bias=False)
elif self.no_hadamard and not self.no_action:
# (2, 1)
self.W2_nha = nn.Linear(2, 1, bias=False)
elif not self.no_hadamard and self.no_action:
# (2d, 1)
self.W2_hna = nn.Linear(2 * self.d, 1, bias=False)
elif not self.no_hadamard and not self.no_action:
# (2d+1, 1)
self.W2_ha = nn.Linear(2 * self.d + 1, 1, bias=False)
self.modeling_layer = nn.GRU(8 * self.d, self.d, num_layers=2, bidirectional=True, dropout=0.2,
batch_first=True)
# Linear function for comment ranking
self.rank_linear = nn.Linear(self.cmnt_length * 2, 1, bias=True)
self.rank_ctx_linear = nn.Linear(self.cmnt_length * 4, 1, bias=True)
# Linear function for edit anchoring
self.anchor_linear = nn.Linear(self.cmnt_length * 2, 2, bias=True)
self.use_target_only = args.use_target_only
self.ctx_mode = 1
# self.p2_lstm_layer = nn.GRU(2*self.d, self.d, bidirectional=True, dropout=0.2, batch_first=True)
def build_contextual_embd(self, x_w):
# 1. Word Embedding Layer
embd = self.word_embd_net(x_w) # (N, seq_len, embd_size)
# 2. Highway Networks for 1.
# embd = self.highway_net(word_embd) # (N, seq_len, d=embd_size)
# 3. Contextual Embedding Layer
ctx_embd_out, _h = self.ctx_embd_layer(embd)
return ctx_embd_out
def build_cmnt_sim(self, embd_context, embd_cmnt, embd_action, batch_size, T, J):
shape = (batch_size, T, J, 2 * self.d) # (N, T, J, 2d)
embd_context_ex = embd_context.unsqueeze(2) # (N, T, 1, 2d)
embd_cmnt_ex = embd_cmnt.unsqueeze(1) # (N, 1, J, 2d)
# action embedding
embd_action_ex = embd_action.float().unsqueeze(2).unsqueeze(2)
embd_action_ex = embd_action_ex.expand((batch_size, T, J, 1))
if self.no_hadamard:
if self.no_action:
raise Exception('no hadamard cannot be used with -no_action simultaneously')
# use inner product to replace the hadamard product
# generate (N, T, J, 1)
embd_cmnt_ex = embd_cmnt_ex.permute(0, 2, 3, 1) # (N, J, 2d, 1)
# batch1 = torch.randn(10, 3, 4)
# batch2 = torch.randn(10, 4, 5)
# (N, T, 1, 2d) * (N, J, 2d, 1) => (N, T, J, 1)
a_dotprod_mul_b = torch.einsum('ntid,njdi->ntji', [embd_context_ex, embd_cmnt_ex])
# no hadamard & action
cat_data = torch.cat((a_dotprod_mul_b, embd_action_ex), 3) # (N, T, J, 2), [h◦u; a]
S = self.W2_nha(cat_data).view(batch_size, T, J) # (N, T, J)
else:
embd_context_ex = embd_context_ex.expand(shape) # (N, T, J, 2d)
embd_cmnt_ex = embd_cmnt_ex.expand(shape) # (N, T, J, 2d)
a_elmwise_mul_b = torch.mul(embd_context_ex, embd_cmnt_ex) # (N, T, J, 2d)
if self.no_action:
# hadamard & no action
S = self.W2_hna(a_elmwise_mul_b).view(batch_size, T, J) # (N, T, J)
else:
# hadamard & action
cat_data = torch.cat((a_elmwise_mul_b, embd_action_ex), 3) # (N, T, J, 2d + 1), [h◦u; a]
S = self.W2_ha(cat_data).view(batch_size, T, J) # (N, T, J)
if self.no_attention:
# without using attention, simply use the mean of similarity matrix in edit dimension
S_cmnt = torch.mean(S, 1)
else:
# attention implementation:
# b: attention weights on the context
b = F.softmax(torch.max(S, 2)[0], dim=-1) # (N, T)
S_cmnt = torch.bmm(b.unsqueeze(1), S) # (N, 1, J) = bmm( (N, 1, T), (N, T, J) )
S_cmnt = S_cmnt.squeeze(1) # (N, J)
# max implementation
# S_cmnt = torch.max(S, 1)[0]
# c: attention weights on the comment
# c = torch.max(S, 1)[0] # (N, J)
# S_cmnt = c * S_cmnt # (N, J) = (N, J) * (N, J)
# c2q = torch.bmm(F.softmax(S, dim=-1), embd_cmnt) # (N, T, 2d) = bmm( (N, T, J), (N, J, 2d) )
# c2q = torch.bmm(F.softmax(S, dim=-1), embd_cmnt) # (N, J, 1) = bmm( (N, J, T), (N, T, 1) )
return S_cmnt, S
# cmnt_words, neg_cmnt_words, src_diff_words, tgt_diff_words
def forward(self, cmnt, src_token, src_action, tgt_token, tgt_action, cr_mode=True, cl_mode=False):
batch_size = cmnt.size(0)
T = src_token.size(1) # sentence length = 100 (word level)
# C = src_token.size(1) # context sentence length = 200 (word level)
J = cmnt.size(1) # cmnt sentence length = 30 (word level)
# ####################################################################################
# 1. Word Embedding Layer
# 2. Contextual Embedding Layer (GRU)
######################################################################################
embd_src_diff = self.build_contextual_embd(src_token) # (N, T, 2d)
embd_tgt_diff = self.build_contextual_embd(tgt_token) # (N, T, 2d)
if cl_mode:
return embd_src_diff + embd_tgt_diff # (N, T, 2d)
embd_cmnt = self.build_contextual_embd(cmnt) # (N, J, 2d)
# if self.ctx_mode:
# embd_src_ctx = self.build_contextual_embd(src_ctx) # (N, C, 2d)
# embd_tgt_ctx = self.build_contextual_embd(tgt_ctx) # (N, C, 2d)
# ####################################################################################
# 3. Similarity Layer
######################################################################################
S_src_diff, _ = self.build_cmnt_sim(embd_src_diff, embd_cmnt, src_action, batch_size, T, J) # (N, J)
S_tgt_diff, _ = self.build_cmnt_sim(embd_tgt_diff, embd_cmnt, tgt_action, batch_size, T, J) # (N, J)
S_diff = torch.cat((S_src_diff, S_tgt_diff), 1) # (N, 2J)
# if self.ctx_mode:
# S_src_ctx, _ = self.build_cmnt_sim(embd_src_ctx, embd_cmnt, batch_size, C, J)
# S_tgt_ctx, _ = self.build_cmnt_sim(embd_tgt_ctx, embd_cmnt, batch_size, C, J)
# S_ctx = torch.cat((S_src_ctx, S_tgt_ctx), 1) # (N, 2J)
# score = self.rank_ctx_linear(torch.cat((S_diff, S_ctx), 1)) # (N, 2J) -> (N, 1)
# else:
if cr_mode:
result = self.rank_linear(S_diff) # (N, 2J) -> (N, 1)
else:
result = self.anchor_linear(S_diff) # (N, 2J) -> (N, 2)
# if self.use_target_only:
# S_diff = S_tgt_diff # (N, J)
# else:
# #S_diff = S_src_diff + S_tgt_diff # (N, J)
# S_diff = torch.cat((S_src_diff, S_tgt_diff), 1)
# #S = (torch.cat((S_src_diff, S_tgt_diff), 1)
return result, S_diff