ENH add option to delete intermediate data as dataset is generated
This commit is contained in:
Родитель
25deb3674c
Коммит
fc674fc013
|
@ -13,8 +13,9 @@ This package is based on [Luigi](https://luigi.readthedocs.io/en/stable/index.ht
|
|||
### Requirements
|
||||
|
||||
- Python 3.5+
|
||||
- **~210 GB** space for constructing the dialogues
|
||||
- **~210 GB** space for constructing the dialogues with default settings
|
||||
- Final zip is only **4.2 GB** though
|
||||
- [You can get away with less disk space, ~30GB](https://github.com/microsoft/dstc8-reddit-corpus/#i-dont-have-enough-disk-space)
|
||||
- An internet connection
|
||||
- 24-72 hours to generate the data
|
||||
- Depends on speed of internet connection, how many cores, how much RAM
|
||||
|
@ -155,6 +156,11 @@ The raw data takes up the most space (>144 GB) but also takes the longest time t
|
|||
|
||||
Filtering and building the dialogues discards a lot of the data, so only keeping things in the `dialogues*` directories is safe.
|
||||
|
||||
**If you just want the final dataset you can use the `--small` option to delete raw and intermediate data the dataset is generated, e.g.**
|
||||
|
||||
```bash
|
||||
python scripts/reddit.py generate --small
|
||||
```
|
||||
|
||||
#### Windows
|
||||
|
||||
|
|
|
@ -132,6 +132,7 @@ class RawConfig(BaseModel):
|
|||
'RC_2018-08.xz': 'b8939ecd280b48459c929c532eda923f3a2514db026175ed953a7956744c6003',
|
||||
'RC_2018-10.xz': 'cadb242a4b5f166071effdd9adbc1d7a78c978d3622bc01cd0f20d3a4c269bd0',
|
||||
}
|
||||
delete_intermediate_data: bool = False
|
||||
|
||||
|
||||
class RedditConfig:
|
||||
|
@ -139,12 +140,16 @@ class RedditConfig:
|
|||
_cfg = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, cfgyaml=None):
|
||||
def initialize(cls, cfgyaml=None, extra_config=None):
|
||||
kwargs = {}
|
||||
if cfgyaml:
|
||||
with open(cfgyaml, 'r', encoding='utf-8') as f:
|
||||
kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# Will override any preset and anything from the config yaml
|
||||
if extra_config:
|
||||
kwargs.update(extra_config)
|
||||
|
||||
kwargs['all_subreddits'] = set(kwargs.get('all_subreddits', []))
|
||||
kwargs['held_out_subreddits'] = set(kwargs.get('held_out_subreddits', []))
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import rapidjson as json
|
|||
|
||||
from dstc8_reddit.config import RedditConfig
|
||||
from dstc8_reddit.tasks.filtering import FilterRawSubmissions, FilterRawComments
|
||||
from dstc8_reddit.util import delete_requires
|
||||
|
||||
|
||||
class BuildDialogues(luigi.Task):
|
||||
|
@ -113,3 +114,7 @@ class BuildDialogues(luigi.Task):
|
|||
if dlgs_to_write:
|
||||
f.write(''.join(dlgs_to_write))
|
||||
f.close()
|
||||
|
||||
def on_success(self):
|
||||
if RedditConfig().delete_intermediate_data:
|
||||
delete_requires(self.requires())
|
||||
|
|
|
@ -4,7 +4,7 @@ import rapidjson as json
|
|||
from dstc8_reddit.config import RedditConfig
|
||||
from dstc8_reddit.constants import Patterns, SUBMISSION_ID_PREFIX
|
||||
from dstc8_reddit.tasks.download import DownloadRawFile
|
||||
from dstc8_reddit.util import process_file_linewise, SubmissionJsonOutputter, CommentJsonOutputter
|
||||
from dstc8_reddit.util import process_file_linewise, SubmissionJsonOutputter, CommentJsonOutputter, delete_requires
|
||||
|
||||
|
||||
class RawSubmissionFilterer:
|
||||
|
@ -155,6 +155,10 @@ class FilterRawSubmissions(luigi.Task):
|
|||
buffer_size=RedditConfig().dump_interval
|
||||
)
|
||||
|
||||
def on_success(self):
|
||||
if RedditConfig().delete_intermediate_data:
|
||||
delete_requires(self.requires())
|
||||
|
||||
|
||||
class FilterRawComments(luigi.Task):
|
||||
date = luigi.Parameter()
|
||||
|
@ -182,3 +186,7 @@ class FilterRawComments(luigi.Task):
|
|||
outputter=CommentJsonOutputter(),
|
||||
buffer_size=RedditConfig().dump_interval
|
||||
)
|
||||
|
||||
def on_success(self):
|
||||
if RedditConfig().delete_intermediate_data:
|
||||
delete_requires(self.requires())
|
||||
|
|
|
@ -14,10 +14,11 @@ from zipfile import ZipFile, ZIP_DEFLATED
|
|||
|
||||
from dstc8_reddit.config import RedditConfig, Subset
|
||||
from dstc8_reddit.tasks.sampling import SampleDialogues
|
||||
from dstc8_reddit.util import delete_requires
|
||||
|
||||
|
||||
class SplitDialogues(luigi.Task):
|
||||
""" This is a heavy task, but needed to do the splitting required for reaggregation """
|
||||
""" This is a heavy task, but needed to do the splitting required for reaggregation. """
|
||||
date = luigi.Parameter()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -92,9 +93,16 @@ class SplitDialogues(luigi.Task):
|
|||
if key in buffers and len(buffers[key]) > 0:
|
||||
outfile.write(''.join(buffers[key]))
|
||||
|
||||
def on_success(self):
|
||||
if RedditConfig().delete_intermediate_data:
|
||||
delete_requires(self.requires())
|
||||
|
||||
|
||||
class MergeDialoguesOverDates(luigi.Task):
|
||||
""" Decompresses because haven't tested thread safety in zip """
|
||||
"""
|
||||
Decompresses because haven't tested thread safety in zip.
|
||||
Also cannot call `delete_requires` in `on_success()` since there are cross-dependencies.
|
||||
"""
|
||||
split = luigi.IntParameter()
|
||||
subreddit = luigi.Parameter()
|
||||
|
||||
|
@ -155,3 +163,14 @@ class ZipDataset(luigi.Task):
|
|||
f.write('\n'.join([make_json_for_subreddit(t) for t in sorted(tasks)]) + '\n')
|
||||
|
||||
archive.close()
|
||||
|
||||
def on_success(self):
|
||||
if RedditConfig().delete_intermediate_data:
|
||||
delete_requires(self.requires())
|
||||
|
||||
parent_reqs = self.requires()
|
||||
if not isinstance(parent_reqs, list):
|
||||
parent_reqs = [parent_reqs]
|
||||
|
||||
for r in parent_reqs:
|
||||
delete_requires(r.requires())
|
||||
|
|
|
@ -10,6 +10,7 @@ from numpy import random
|
|||
from dstc8_reddit.config import RedditConfig
|
||||
from dstc8_reddit.constants import Patterns
|
||||
from dstc8_reddit.tasks.construction import BuildDialogues
|
||||
from dstc8_reddit.util import delete_requires
|
||||
from dstc8_reddit.validation import SessionItem
|
||||
|
||||
|
||||
|
@ -214,3 +215,7 @@ class SampleDialogues(luigi.Task):
|
|||
logging.debug(f" > [{self.date}] # DLGS: before sample={len(dlgs)}, after sample={len(sampled_dlgs)}")
|
||||
lens = [len(d) for d in sampled_dlgs]
|
||||
logging.debug(f" > [{self.date}] DLG LENGTHS: max={max(lens)}, min={min(lens)}, avg={sum(lens) / len(lens):2.2f}")
|
||||
|
||||
def on_success(self):
|
||||
if RedditConfig().delete_intermediate_data:
|
||||
delete_requires(self.requires())
|
|
@ -2,6 +2,7 @@ import bz2
|
|||
import gzip
|
||||
import logging
|
||||
import lzma
|
||||
import os
|
||||
import rapidjson as json
|
||||
|
||||
from dstc8_reddit.constants import Patterns, OUTPUT_FIELDS, SELF_BREAK_TOKEN, SUBMISSION_ID_PREFIX, COMMENT_ID_PREFIX
|
||||
|
@ -90,3 +91,19 @@ def process_file_linewise(
|
|||
if out_ids_filepath:
|
||||
with make_file_handle(out_ids_filepath, 'wt') as ids_outfile:
|
||||
ids_outfile.write('\n'.join(list(ids_set)) + '\n')
|
||||
|
||||
|
||||
def delete_requires(requires):
|
||||
if not isinstance(requires, list):
|
||||
requires = [requires]
|
||||
|
||||
for req in requires:
|
||||
outputs = req.output()
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
for out in outputs:
|
||||
fp = out.path
|
||||
if os.path.exists(fp):
|
||||
logging.info(f"[delete] Removed `{fp}`")
|
||||
os.remove(fp)
|
||||
|
|
|
@ -40,8 +40,18 @@ def download(workers, config, log_level):
|
|||
@click.option('-c', '--config', type=click.Path(dir_okay=False, file_okay=True, exists=True),
|
||||
default='configs/config.prod.yaml')
|
||||
@click.option('-l', '--log-level', default='ERROR')
|
||||
def generate(workers, config, log_level):
|
||||
RedditConfig.initialize(config)
|
||||
@click.option('--small', is_flag=True,
|
||||
help='If set, will use reduced storage by deleting intermediate data')
|
||||
def generate(workers, config, log_level, small):
|
||||
|
||||
extra_config = {}
|
||||
|
||||
if small:
|
||||
extra_config.update(dict(delete_intermediate_data=True,
|
||||
max_concurrent_downloads=1))
|
||||
|
||||
RedditConfig.initialize(config, extra_config)
|
||||
|
||||
print(RedditConfig())
|
||||
|
||||
luigi.configuration.get_config().set('resources', 'max_concurrent_downloads',
|
||||
|
|
Загрузка…
Ссылка в новой задаче