Коммит
a2f0180eff
|
@ -13,8 +13,9 @@ This package is based on [Luigi](https://luigi.readthedocs.io/en/stable/index.ht
|
||||||
### Requirements
|
### Requirements
|
||||||
|
|
||||||
- Python 3.5+
|
- Python 3.5+
|
||||||
- **~210 GB** space for constructing the dialogues
|
- **~210 GB** space for constructing the dialogues with default settings
|
||||||
- Final zip is only **4.2 GB** though
|
- Final zip is only **4.2 GB** though
|
||||||
|
- [You can get away with less disk space, ~30GB](https://github.com/microsoft/dstc8-reddit-corpus/#i-dont-have-enough-disk-space)
|
||||||
- An internet connection
|
- An internet connection
|
||||||
- 24-72 hours to generate the data
|
- 24-72 hours to generate the data
|
||||||
- Depends on speed of internet connection, how many cores, how much RAM
|
- Depends on speed of internet connection, how many cores, how much RAM
|
||||||
|
@ -155,6 +156,11 @@ The raw data takes up the most space (>144 GB) but also takes the longest time t
|
||||||
|
|
||||||
Filtering and building the dialogues discards a lot of the data, so only keeping things in the `dialogues*` directories is safe.
|
Filtering and building the dialogues discards a lot of the data, so only keeping things in the `dialogues*` directories is safe.
|
||||||
|
|
||||||
|
**If you just want the final dataset you can use the `--small` option to delete raw and intermediate data the dataset is generated, e.g.**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/reddit.py generate --small
|
||||||
|
```
|
||||||
|
|
||||||
#### Windows
|
#### Windows
|
||||||
|
|
||||||
|
|
|
@ -132,6 +132,7 @@ class RawConfig(BaseModel):
|
||||||
'RC_2018-08.xz': 'b8939ecd280b48459c929c532eda923f3a2514db026175ed953a7956744c6003',
|
'RC_2018-08.xz': 'b8939ecd280b48459c929c532eda923f3a2514db026175ed953a7956744c6003',
|
||||||
'RC_2018-10.xz': 'cadb242a4b5f166071effdd9adbc1d7a78c978d3622bc01cd0f20d3a4c269bd0',
|
'RC_2018-10.xz': 'cadb242a4b5f166071effdd9adbc1d7a78c978d3622bc01cd0f20d3a4c269bd0',
|
||||||
}
|
}
|
||||||
|
delete_intermediate_data: bool = False
|
||||||
|
|
||||||
|
|
||||||
class RedditConfig:
|
class RedditConfig:
|
||||||
|
@ -139,12 +140,16 @@ class RedditConfig:
|
||||||
_cfg = None
|
_cfg = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, cfgyaml=None):
|
def initialize(cls, cfgyaml=None, extra_config=None):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if cfgyaml:
|
if cfgyaml:
|
||||||
with open(cfgyaml, 'r', encoding='utf-8') as f:
|
with open(cfgyaml, 'r', encoding='utf-8') as f:
|
||||||
kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
# Will override any preset and anything from the config yaml
|
||||||
|
if extra_config:
|
||||||
|
kwargs.update(extra_config)
|
||||||
|
|
||||||
kwargs['all_subreddits'] = set(kwargs.get('all_subreddits', []))
|
kwargs['all_subreddits'] = set(kwargs.get('all_subreddits', []))
|
||||||
kwargs['held_out_subreddits'] = set(kwargs.get('held_out_subreddits', []))
|
kwargs['held_out_subreddits'] = set(kwargs.get('held_out_subreddits', []))
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import rapidjson as json
|
||||||
|
|
||||||
from dstc8_reddit.config import RedditConfig
|
from dstc8_reddit.config import RedditConfig
|
||||||
from dstc8_reddit.tasks.filtering import FilterRawSubmissions, FilterRawComments
|
from dstc8_reddit.tasks.filtering import FilterRawSubmissions, FilterRawComments
|
||||||
|
from dstc8_reddit.util import delete_requires
|
||||||
|
|
||||||
|
|
||||||
class BuildDialogues(luigi.Task):
|
class BuildDialogues(luigi.Task):
|
||||||
|
@ -113,3 +114,7 @@ class BuildDialogues(luigi.Task):
|
||||||
if dlgs_to_write:
|
if dlgs_to_write:
|
||||||
f.write(''.join(dlgs_to_write))
|
f.write(''.join(dlgs_to_write))
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
def on_success(self):
|
||||||
|
if RedditConfig().delete_intermediate_data:
|
||||||
|
delete_requires(self.requires())
|
||||||
|
|
|
@ -4,7 +4,7 @@ import rapidjson as json
|
||||||
from dstc8_reddit.config import RedditConfig
|
from dstc8_reddit.config import RedditConfig
|
||||||
from dstc8_reddit.constants import Patterns, SUBMISSION_ID_PREFIX
|
from dstc8_reddit.constants import Patterns, SUBMISSION_ID_PREFIX
|
||||||
from dstc8_reddit.tasks.download import DownloadRawFile
|
from dstc8_reddit.tasks.download import DownloadRawFile
|
||||||
from dstc8_reddit.util import process_file_linewise, SubmissionJsonOutputter, CommentJsonOutputter
|
from dstc8_reddit.util import process_file_linewise, SubmissionJsonOutputter, CommentJsonOutputter, delete_requires
|
||||||
|
|
||||||
|
|
||||||
class RawSubmissionFilterer:
|
class RawSubmissionFilterer:
|
||||||
|
@ -155,6 +155,10 @@ class FilterRawSubmissions(luigi.Task):
|
||||||
buffer_size=RedditConfig().dump_interval
|
buffer_size=RedditConfig().dump_interval
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_success(self):
|
||||||
|
if RedditConfig().delete_intermediate_data:
|
||||||
|
delete_requires(self.requires())
|
||||||
|
|
||||||
|
|
||||||
class FilterRawComments(luigi.Task):
|
class FilterRawComments(luigi.Task):
|
||||||
date = luigi.Parameter()
|
date = luigi.Parameter()
|
||||||
|
@ -182,3 +186,7 @@ class FilterRawComments(luigi.Task):
|
||||||
outputter=CommentJsonOutputter(),
|
outputter=CommentJsonOutputter(),
|
||||||
buffer_size=RedditConfig().dump_interval
|
buffer_size=RedditConfig().dump_interval
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_success(self):
|
||||||
|
if RedditConfig().delete_intermediate_data:
|
||||||
|
delete_requires(self.requires())
|
||||||
|
|
|
@ -14,10 +14,11 @@ from zipfile import ZipFile, ZIP_DEFLATED
|
||||||
|
|
||||||
from dstc8_reddit.config import RedditConfig, Subset
|
from dstc8_reddit.config import RedditConfig, Subset
|
||||||
from dstc8_reddit.tasks.sampling import SampleDialogues
|
from dstc8_reddit.tasks.sampling import SampleDialogues
|
||||||
|
from dstc8_reddit.util import delete_requires
|
||||||
|
|
||||||
|
|
||||||
class SplitDialogues(luigi.Task):
|
class SplitDialogues(luigi.Task):
|
||||||
""" This is a heavy task, but needed to do the splitting required for reaggregation """
|
""" This is a heavy task, but needed to do the splitting required for reaggregation. """
|
||||||
date = luigi.Parameter()
|
date = luigi.Parameter()
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -92,9 +93,16 @@ class SplitDialogues(luigi.Task):
|
||||||
if key in buffers and len(buffers[key]) > 0:
|
if key in buffers and len(buffers[key]) > 0:
|
||||||
outfile.write(''.join(buffers[key]))
|
outfile.write(''.join(buffers[key]))
|
||||||
|
|
||||||
|
def on_success(self):
|
||||||
|
if RedditConfig().delete_intermediate_data:
|
||||||
|
delete_requires(self.requires())
|
||||||
|
|
||||||
|
|
||||||
class MergeDialoguesOverDates(luigi.Task):
|
class MergeDialoguesOverDates(luigi.Task):
|
||||||
""" Decompresses because haven't tested thread safety in zip """
|
"""
|
||||||
|
Decompresses because haven't tested thread safety in zip.
|
||||||
|
Also cannot call `delete_requires` in `on_success()` since there are cross-dependencies.
|
||||||
|
"""
|
||||||
split = luigi.IntParameter()
|
split = luigi.IntParameter()
|
||||||
subreddit = luigi.Parameter()
|
subreddit = luigi.Parameter()
|
||||||
|
|
||||||
|
@ -155,3 +163,14 @@ class ZipDataset(luigi.Task):
|
||||||
f.write('\n'.join([make_json_for_subreddit(t) for t in sorted(tasks)]) + '\n')
|
f.write('\n'.join([make_json_for_subreddit(t) for t in sorted(tasks)]) + '\n')
|
||||||
|
|
||||||
archive.close()
|
archive.close()
|
||||||
|
|
||||||
|
def on_success(self):
|
||||||
|
if RedditConfig().delete_intermediate_data:
|
||||||
|
delete_requires(self.requires())
|
||||||
|
|
||||||
|
parent_reqs = self.requires()
|
||||||
|
if not isinstance(parent_reqs, list):
|
||||||
|
parent_reqs = [parent_reqs]
|
||||||
|
|
||||||
|
for r in parent_reqs:
|
||||||
|
delete_requires(r.requires())
|
||||||
|
|
|
@ -10,6 +10,7 @@ from numpy import random
|
||||||
from dstc8_reddit.config import RedditConfig
|
from dstc8_reddit.config import RedditConfig
|
||||||
from dstc8_reddit.constants import Patterns
|
from dstc8_reddit.constants import Patterns
|
||||||
from dstc8_reddit.tasks.construction import BuildDialogues
|
from dstc8_reddit.tasks.construction import BuildDialogues
|
||||||
|
from dstc8_reddit.util import delete_requires
|
||||||
from dstc8_reddit.validation import SessionItem
|
from dstc8_reddit.validation import SessionItem
|
||||||
|
|
||||||
|
|
||||||
|
@ -214,3 +215,7 @@ class SampleDialogues(luigi.Task):
|
||||||
logging.debug(f" > [{self.date}] # DLGS: before sample={len(dlgs)}, after sample={len(sampled_dlgs)}")
|
logging.debug(f" > [{self.date}] # DLGS: before sample={len(dlgs)}, after sample={len(sampled_dlgs)}")
|
||||||
lens = [len(d) for d in sampled_dlgs]
|
lens = [len(d) for d in sampled_dlgs]
|
||||||
logging.debug(f" > [{self.date}] DLG LENGTHS: max={max(lens)}, min={min(lens)}, avg={sum(lens) / len(lens):2.2f}")
|
logging.debug(f" > [{self.date}] DLG LENGTHS: max={max(lens)}, min={min(lens)}, avg={sum(lens) / len(lens):2.2f}")
|
||||||
|
|
||||||
|
def on_success(self):
|
||||||
|
if RedditConfig().delete_intermediate_data:
|
||||||
|
delete_requires(self.requires())
|
|
@ -2,6 +2,7 @@ import bz2
|
||||||
import gzip
|
import gzip
|
||||||
import logging
|
import logging
|
||||||
import lzma
|
import lzma
|
||||||
|
import os
|
||||||
import rapidjson as json
|
import rapidjson as json
|
||||||
|
|
||||||
from dstc8_reddit.constants import Patterns, OUTPUT_FIELDS, SELF_BREAK_TOKEN, SUBMISSION_ID_PREFIX, COMMENT_ID_PREFIX
|
from dstc8_reddit.constants import Patterns, OUTPUT_FIELDS, SELF_BREAK_TOKEN, SUBMISSION_ID_PREFIX, COMMENT_ID_PREFIX
|
||||||
|
@ -90,3 +91,19 @@ def process_file_linewise(
|
||||||
if out_ids_filepath:
|
if out_ids_filepath:
|
||||||
with make_file_handle(out_ids_filepath, 'wt') as ids_outfile:
|
with make_file_handle(out_ids_filepath, 'wt') as ids_outfile:
|
||||||
ids_outfile.write('\n'.join(list(ids_set)) + '\n')
|
ids_outfile.write('\n'.join(list(ids_set)) + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def delete_requires(requires):
|
||||||
|
if not isinstance(requires, list):
|
||||||
|
requires = [requires]
|
||||||
|
|
||||||
|
for req in requires:
|
||||||
|
outputs = req.output()
|
||||||
|
if not isinstance(outputs, list):
|
||||||
|
outputs = [outputs]
|
||||||
|
|
||||||
|
for out in outputs:
|
||||||
|
fp = out.path
|
||||||
|
if os.path.exists(fp):
|
||||||
|
logging.info(f"[delete] Removed `{fp}`")
|
||||||
|
os.remove(fp)
|
||||||
|
|
|
@ -5,7 +5,7 @@ import luigi
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
from dstc8_reddit.config import RedditConfig
|
from dstc8_reddit.config import RedditConfig
|
||||||
from dstc8_reddit.tasks import DownloadRawFile, ZipDataset
|
from dstc8_reddit.tasks import DownloadRawFile, ZipDataset, BuildDialogues
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
|
@ -40,8 +40,18 @@ def download(workers, config, log_level):
|
||||||
@click.option('-c', '--config', type=click.Path(dir_okay=False, file_okay=True, exists=True),
|
@click.option('-c', '--config', type=click.Path(dir_okay=False, file_okay=True, exists=True),
|
||||||
default='configs/config.prod.yaml')
|
default='configs/config.prod.yaml')
|
||||||
@click.option('-l', '--log-level', default='ERROR')
|
@click.option('-l', '--log-level', default='ERROR')
|
||||||
def generate(workers, config, log_level):
|
@click.option('--small', is_flag=True,
|
||||||
RedditConfig.initialize(config)
|
help='If set, will use reduced storage by deleting intermediate data')
|
||||||
|
def generate(workers, config, log_level, small):
|
||||||
|
|
||||||
|
extra_config = {}
|
||||||
|
|
||||||
|
if small:
|
||||||
|
extra_config.update(dict(delete_intermediate_data=True,
|
||||||
|
max_concurrent_downloads=2))
|
||||||
|
|
||||||
|
RedditConfig.initialize(config, extra_config)
|
||||||
|
|
||||||
print(RedditConfig())
|
print(RedditConfig())
|
||||||
|
|
||||||
luigi.configuration.get_config().set('resources', 'max_concurrent_downloads',
|
luigi.configuration.get_config().set('resources', 'max_concurrent_downloads',
|
||||||
|
@ -51,6 +61,16 @@ def generate(workers, config, log_level):
|
||||||
luigi.configuration.get_config().set('resources', 'max_concurrent_sample',
|
luigi.configuration.get_config().set('resources', 'max_concurrent_sample',
|
||||||
str(RedditConfig().max_concurrent_sample))
|
str(RedditConfig().max_concurrent_sample))
|
||||||
|
|
||||||
|
if small:
|
||||||
|
for d in RedditConfig().make_all_dates():
|
||||||
|
luigi.interface.build(
|
||||||
|
[BuildDialogues(d)],
|
||||||
|
workers=workers,
|
||||||
|
local_scheduler=True,
|
||||||
|
log_level=log_level,
|
||||||
|
detailed_summary=True,
|
||||||
|
)
|
||||||
|
|
||||||
result = luigi.interface.build(
|
result = luigi.interface.build(
|
||||||
[ZipDataset()],
|
[ZipDataset()],
|
||||||
workers=workers,
|
workers=workers,
|
||||||
|
|
Загрузка…
Ссылка в новой задаче