2020-04-11 23:31:42 +03:00
|
|
|
from pathlib import Path
|
2020-04-12 01:09:46 +03:00
|
|
|
from typing import Any, Dict, List, Set, Union
|
2020-04-11 23:31:42 +03:00
|
|
|
|
|
|
|
import srsly
|
2020-04-13 21:35:39 +03:00
|
|
|
from spacy.util import ensure_path
|
2020-04-11 23:31:42 +03:00
|
|
|
|
2020-04-08 04:19:04 +03:00
|
|
|
from .types import Example
|
|
|
|
|
|
|
|
|
|
|
|
class ExampleStore:
|
|
|
|
def __init__(self, examples: List[Example] = None):
|
|
|
|
self._map: Dict[int, Example] = {}
|
|
|
|
if examples is not None:
|
|
|
|
for e in examples:
|
|
|
|
self.add(e)
|
|
|
|
|
2020-04-12 01:09:46 +03:00
|
|
|
def __getitem__(self, example_hash: int) -> Example:
|
2020-04-08 04:19:04 +03:00
|
|
|
return self._map[example_hash]
|
2020-04-12 21:49:05 +03:00
|
|
|
|
2020-04-12 01:09:46 +03:00
|
|
|
def __len__(self) -> int:
|
2020-04-08 04:19:04 +03:00
|
|
|
"""The number of strings in the store.
|
2020-04-12 01:09:46 +03:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Number of examples in store
|
2020-04-08 04:19:04 +03:00
|
|
|
"""
|
|
|
|
return len(self._map)
|
|
|
|
|
2020-04-12 01:09:46 +03:00
|
|
|
def __contains__(self, example: Union[int, Example]) -> bool:
|
2020-04-08 04:19:04 +03:00
|
|
|
"""Check whether a string is in the store.
|
2020-04-12 01:09:46 +03:00
|
|
|
|
|
|
|
Args:
|
|
|
|
example (Union[int, Example]): The example to check
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Whether the store contains the example.
|
2020-04-08 04:19:04 +03:00
|
|
|
"""
|
2020-04-12 01:09:46 +03:00
|
|
|
example_hash = hash(example) if isinstance(example, Example) else example
|
|
|
|
return example_hash in self._map
|
2020-04-12 21:49:05 +03:00
|
|
|
|
2020-04-12 01:09:46 +03:00
|
|
|
def add(self, example: Example) -> None:
|
|
|
|
"""Add an Example to the store
|
|
|
|
|
|
|
|
Args:
|
|
|
|
example (Example): example to add
|
|
|
|
"""
|
|
|
|
example_hash = hash(example)
|
|
|
|
self._map[example_hash] = example
|
2020-04-11 23:31:42 +03:00
|
|
|
|
|
|
|
def from_disk(self, path: Path) -> "ExampleStore":
|
2020-04-12 01:09:46 +03:00
|
|
|
"""Load store from disk
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path (Path): Path to file to load from
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ExampleStore: Initialized ExampleStore
|
2020-04-12 21:49:05 +03:00
|
|
|
"""
|
2020-04-11 23:31:42 +03:00
|
|
|
path = ensure_path(path)
|
|
|
|
examples = srsly.read_jsonl(path)
|
|
|
|
for e in examples:
|
|
|
|
example_hash = e["example_hash"]
|
|
|
|
raw_example = e["example"]
|
|
|
|
example = Example(**raw_example)
|
|
|
|
assert hash(example) == example_hash
|
|
|
|
self.add(example)
|
|
|
|
|
|
|
|
return self
|
2020-04-12 21:49:05 +03:00
|
|
|
|
2020-04-11 23:31:42 +03:00
|
|
|
def to_disk(self, path: Path) -> None:
|
2020-04-12 01:09:46 +03:00
|
|
|
"""Save store to disk
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path (Path): Path to save store to
|
|
|
|
"""
|
2020-04-11 23:31:42 +03:00
|
|
|
path = ensure_path(path)
|
|
|
|
examples = []
|
|
|
|
for example_hash, example in self._map.items():
|
|
|
|
examples.append({"example_hash": example_hash, "example": example.dict()})
|
2020-04-12 21:49:05 +03:00
|
|
|
|
2020-04-11 23:31:42 +03:00
|
|
|
srsly.write_jsonl(path, examples)
|