diff --git a/README.md b/README.md index 67f66b9..d7310e8 100644 --- a/README.md +++ b/README.md @@ -13,8 +13,9 @@ This package is based on [Luigi](https://luigi.readthedocs.io/en/stable/index.ht ### Requirements - Python 3.5+ -- **~210 GB** space for constructing the dialogues +- **~210 GB** space for constructing the dialogues with default settings - Final zip is only **4.2 GB** though + - [You can get away with less disk space, ~30GB](https://github.com/microsoft/dstc8-reddit-corpus/#i-dont-have-enough-disk-space) - An internet connection - 24-72 hours to generate the data - Depends on speed of internet connection, how many cores, how much RAM @@ -155,6 +156,11 @@ The raw data takes up the most space (>144 GB) but also takes the longest time t Filtering and building the dialogues discards a lot of the data, so only keeping things in the `dialogues*` directories is safe. +**If you just want the final dataset you can use the `--small` option to delete raw and intermediate data the dataset is generated, e.g.** + +```bash +python scripts/reddit.py generate --small +``` #### Windows diff --git a/dstc8_reddit/config.py b/dstc8_reddit/config.py index 49ac3f7..5cd938a 100644 --- a/dstc8_reddit/config.py +++ b/dstc8_reddit/config.py @@ -132,6 +132,7 @@ class RawConfig(BaseModel): 'RC_2018-08.xz': 'b8939ecd280b48459c929c532eda923f3a2514db026175ed953a7956744c6003', 'RC_2018-10.xz': 'cadb242a4b5f166071effdd9adbc1d7a78c978d3622bc01cd0f20d3a4c269bd0', } + delete_intermediate_data: bool = False class RedditConfig: @@ -139,12 +140,16 @@ class RedditConfig: _cfg = None @classmethod - def initialize(cls, cfgyaml=None): + def initialize(cls, cfgyaml=None, extra_config=None): kwargs = {} if cfgyaml: with open(cfgyaml, 'r', encoding='utf-8') as f: kwargs = yaml.load(f, Loader=yaml.FullLoader) + # Will override any preset and anything from the config yaml + if extra_config: + kwargs.update(extra_config) + kwargs['all_subreddits'] = set(kwargs.get('all_subreddits', [])) kwargs['held_out_subreddits'] = set(kwargs.get('held_out_subreddits', [])) diff --git a/dstc8_reddit/tasks/construction.py b/dstc8_reddit/tasks/construction.py index 6439ba6..18791c9 100644 --- a/dstc8_reddit/tasks/construction.py +++ b/dstc8_reddit/tasks/construction.py @@ -5,6 +5,7 @@ import rapidjson as json from dstc8_reddit.config import RedditConfig from dstc8_reddit.tasks.filtering import FilterRawSubmissions, FilterRawComments +from dstc8_reddit.util import delete_requires class BuildDialogues(luigi.Task): @@ -113,3 +114,7 @@ class BuildDialogues(luigi.Task): if dlgs_to_write: f.write(''.join(dlgs_to_write)) f.close() + + def on_success(self): + if RedditConfig().delete_intermediate_data: + delete_requires(self.requires()) diff --git a/dstc8_reddit/tasks/filtering.py b/dstc8_reddit/tasks/filtering.py index ec8b807..51c338d 100644 --- a/dstc8_reddit/tasks/filtering.py +++ b/dstc8_reddit/tasks/filtering.py @@ -4,7 +4,7 @@ import rapidjson as json from dstc8_reddit.config import RedditConfig from dstc8_reddit.constants import Patterns, SUBMISSION_ID_PREFIX from dstc8_reddit.tasks.download import DownloadRawFile -from dstc8_reddit.util import process_file_linewise, SubmissionJsonOutputter, CommentJsonOutputter +from dstc8_reddit.util import process_file_linewise, SubmissionJsonOutputter, CommentJsonOutputter, delete_requires class RawSubmissionFilterer: @@ -155,6 +155,10 @@ class FilterRawSubmissions(luigi.Task): buffer_size=RedditConfig().dump_interval ) + def on_success(self): + if RedditConfig().delete_intermediate_data: + delete_requires(self.requires()) + class FilterRawComments(luigi.Task): date = luigi.Parameter() @@ -182,3 +186,7 @@ class FilterRawComments(luigi.Task): outputter=CommentJsonOutputter(), buffer_size=RedditConfig().dump_interval ) + + def on_success(self): + if RedditConfig().delete_intermediate_data: + delete_requires(self.requires()) diff --git a/dstc8_reddit/tasks/packaging.py b/dstc8_reddit/tasks/packaging.py index b1ef8a8..f530f80 100644 --- a/dstc8_reddit/tasks/packaging.py +++ b/dstc8_reddit/tasks/packaging.py @@ -14,10 +14,11 @@ from zipfile import ZipFile, ZIP_DEFLATED from dstc8_reddit.config import RedditConfig, Subset from dstc8_reddit.tasks.sampling import SampleDialogues +from dstc8_reddit.util import delete_requires class SplitDialogues(luigi.Task): - """ This is a heavy task, but needed to do the splitting required for reaggregation """ + """ This is a heavy task, but needed to do the splitting required for reaggregation. """ date = luigi.Parameter() def __init__(self, *args, **kwargs): @@ -92,9 +93,16 @@ class SplitDialogues(luigi.Task): if key in buffers and len(buffers[key]) > 0: outfile.write(''.join(buffers[key])) + def on_success(self): + if RedditConfig().delete_intermediate_data: + delete_requires(self.requires()) + class MergeDialoguesOverDates(luigi.Task): - """ Decompresses because haven't tested thread safety in zip """ + """ + Decompresses because haven't tested thread safety in zip. + Also cannot call `delete_requires` in `on_success()` since there are cross-dependencies. + """ split = luigi.IntParameter() subreddit = luigi.Parameter() @@ -155,3 +163,14 @@ class ZipDataset(luigi.Task): f.write('\n'.join([make_json_for_subreddit(t) for t in sorted(tasks)]) + '\n') archive.close() + + def on_success(self): + if RedditConfig().delete_intermediate_data: + delete_requires(self.requires()) + + parent_reqs = self.requires() + if not isinstance(parent_reqs, list): + parent_reqs = [parent_reqs] + + for r in parent_reqs: + delete_requires(r.requires()) diff --git a/dstc8_reddit/tasks/sampling.py b/dstc8_reddit/tasks/sampling.py index 6f31de9..fb38567 100644 --- a/dstc8_reddit/tasks/sampling.py +++ b/dstc8_reddit/tasks/sampling.py @@ -10,6 +10,7 @@ from numpy import random from dstc8_reddit.config import RedditConfig from dstc8_reddit.constants import Patterns from dstc8_reddit.tasks.construction import BuildDialogues +from dstc8_reddit.util import delete_requires from dstc8_reddit.validation import SessionItem @@ -214,3 +215,7 @@ class SampleDialogues(luigi.Task): logging.debug(f" > [{self.date}] # DLGS: before sample={len(dlgs)}, after sample={len(sampled_dlgs)}") lens = [len(d) for d in sampled_dlgs] logging.debug(f" > [{self.date}] DLG LENGTHS: max={max(lens)}, min={min(lens)}, avg={sum(lens) / len(lens):2.2f}") + + def on_success(self): + if RedditConfig().delete_intermediate_data: + delete_requires(self.requires()) \ No newline at end of file diff --git a/dstc8_reddit/util.py b/dstc8_reddit/util.py index 9630fab..449e59a 100644 --- a/dstc8_reddit/util.py +++ b/dstc8_reddit/util.py @@ -2,6 +2,7 @@ import bz2 import gzip import logging import lzma +import os import rapidjson as json from dstc8_reddit.constants import Patterns, OUTPUT_FIELDS, SELF_BREAK_TOKEN, SUBMISSION_ID_PREFIX, COMMENT_ID_PREFIX @@ -90,3 +91,19 @@ def process_file_linewise( if out_ids_filepath: with make_file_handle(out_ids_filepath, 'wt') as ids_outfile: ids_outfile.write('\n'.join(list(ids_set)) + '\n') + + +def delete_requires(requires): + if not isinstance(requires, list): + requires = [requires] + + for req in requires: + outputs = req.output() + if not isinstance(outputs, list): + outputs = [outputs] + + for out in outputs: + fp = out.path + if os.path.exists(fp): + logging.info(f"[delete] Removed `{fp}`") + os.remove(fp) diff --git a/scripts/reddit.py b/scripts/reddit.py index d6793c6..a6ad4cc 100644 --- a/scripts/reddit.py +++ b/scripts/reddit.py @@ -5,7 +5,7 @@ import luigi from multiprocessing import cpu_count from dstc8_reddit.config import RedditConfig -from dstc8_reddit.tasks import DownloadRawFile, ZipDataset +from dstc8_reddit.tasks import DownloadRawFile, ZipDataset, BuildDialogues @click.group() @@ -40,8 +40,18 @@ def download(workers, config, log_level): @click.option('-c', '--config', type=click.Path(dir_okay=False, file_okay=True, exists=True), default='configs/config.prod.yaml') @click.option('-l', '--log-level', default='ERROR') -def generate(workers, config, log_level): - RedditConfig.initialize(config) +@click.option('--small', is_flag=True, + help='If set, will use reduced storage by deleting intermediate data') +def generate(workers, config, log_level, small): + + extra_config = {} + + if small: + extra_config.update(dict(delete_intermediate_data=True, + max_concurrent_downloads=2)) + + RedditConfig.initialize(config, extra_config) + print(RedditConfig()) luigi.configuration.get_config().set('resources', 'max_concurrent_downloads', @@ -51,6 +61,16 @@ def generate(workers, config, log_level): luigi.configuration.get_config().set('resources', 'max_concurrent_sample', str(RedditConfig().max_concurrent_sample)) + if small: + for d in RedditConfig().make_all_dates(): + luigi.interface.build( + [BuildDialogues(d)], + workers=workers, + local_scheduler=True, + log_level=log_level, + detailed_summary=True, + ) + result = luigi.interface.build( [ZipDataset()], workers=workers,