зеркало из https://github.com/microsoft/genalog.git
Merge pull request #10 from microsoft/laserprec/use_linter
Use flake8 as linter and fix code format issues.
This commit is contained in:
Коммит
cda2c1e77d
|
@ -35,31 +35,20 @@ steps:
|
|||
displayName: 'Use Python $(python.version)'
|
||||
|
||||
- bash: |
|
||||
python -m venv .venv
|
||||
displayName: 'Create virtual environment'
|
||||
|
||||
- bash: |
|
||||
if [[ '$(Agent.OS)' == Windows* ]]
|
||||
then
|
||||
source .venv/Scripts/activate
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
pip install --upgrade pip
|
||||
pip install setuptools wheel
|
||||
pip install -r requirements.txt
|
||||
pip install pytest==5.3.5 pytest-cov==2.8.1
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install setuptools wheel
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install -r requirements-dev.txt
|
||||
workingDirectory: $(Build.SourcesDirectory)
|
||||
displayName: 'Install dependencies'
|
||||
|
||||
- bash: |
|
||||
if [[ '$(Agent.OS)' == Windows* ]]
|
||||
then
|
||||
source .venv/Scripts/activate
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
python -m pytest tests --cov=genalog --doctest-modules --junitxml=junit/test-results.xml --cov-report=xml --cov-report=html
|
||||
python -m flake8
|
||||
workingDirectory: $(Build.SourcesDirectory)
|
||||
displayName: 'Run Linter (flake8)'
|
||||
|
||||
- bash: |
|
||||
python -m pytest tests
|
||||
env:
|
||||
BLOB_KEY : $(BLOB_KEY)
|
||||
SEARCH_SERVICE_KEY: $(SEARCH_SERVICE_KEY)
|
||||
|
@ -86,12 +75,6 @@ steps:
|
|||
displayName: 'Publish test coverage'
|
||||
|
||||
- bash: |
|
||||
if [[ '$(Agent.OS)' == Windows* ]]
|
||||
then
|
||||
source .venv/Scripts/activate
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
python setup.py bdist_wheel --build-number $(Build.BuildNumber) --dist-dir dist
|
||||
workingDirectory: $(Build.SourcesDirectory)
|
||||
displayName: 'Building wheel package'
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
from genalog.degradation import effect
|
||||
from enum import Enum
|
||||
import copy
|
||||
import inspect
|
||||
from enum import Enum
|
||||
|
||||
from genalog.degradation import effect
|
||||
|
||||
DEFAULT_METHOD_PARAM_TO_INCLUDE = "src"
|
||||
|
||||
|
||||
class ImageState(Enum):
|
||||
ORIGINAL_STATE = "ORIGINAL_STATE"
|
||||
CURRENT_STATE = "CURRENT_STATE"
|
||||
|
||||
class Degrader():
|
||||
|
||||
class Degrader:
|
||||
""" An object for applying multiple degradation effects onto an image"""
|
||||
|
||||
def __init__(self, effects):
|
||||
"""Initialize a Degrader object
|
||||
|
||||
|
@ -73,14 +77,21 @@ class Degrader():
|
|||
# Try to find corresponding degradation method in the module
|
||||
method = getattr(effect, method_name)
|
||||
except AttributeError:
|
||||
raise ValueError(f"Method '{method_name}' is not defined in 'genalog.degradation.effect'")
|
||||
raise ValueError(
|
||||
f"Method '{method_name}' is not defined in 'genalog.degradation.effect'"
|
||||
)
|
||||
# Get the method signatures
|
||||
method_sign = inspect.signature(method)
|
||||
# Check if method parameters are valid
|
||||
for param_name in method_kwargs.keys(): # i.e. ["operation", "kernel_shape", ...]
|
||||
if not param_name in method_sign.parameters:
|
||||
for (
|
||||
param_name
|
||||
) in method_kwargs.keys(): # i.e. ["operation", "kernel_shape", ...]
|
||||
if param_name not in method_sign.parameters:
|
||||
method_args = [param for param in method_sign.parameters]
|
||||
raise ValueError(f"Invalid parameter name '{param_name}' for method 'genalog.degradation.effect.{method_name}()'. Method parameter names are: {method_args}")
|
||||
raise ValueError(
|
||||
f"Invalid parameter name '{param_name}' for method 'genalog.degradation.effect.{method_name}()'. " +
|
||||
f"Method parameter names are: {method_args}"
|
||||
)
|
||||
|
||||
def _add_default_method_param(self):
|
||||
"""All methods in "genalog.degradation.effect" module have a required
|
||||
|
@ -90,7 +101,9 @@ class Degrader():
|
|||
for effect_tuple in self.effects_to_apply:
|
||||
method_name, method_kwargs = effect_tuple
|
||||
if DEFAULT_METHOD_PARAM_TO_INCLUDE not in method_kwargs:
|
||||
method_kwargs[DEFAULT_METHOD_PARAM_TO_INCLUDE] = ImageState.CURRENT_STATE
|
||||
method_kwargs[
|
||||
DEFAULT_METHOD_PARAM_TO_INCLUDE
|
||||
] = ImageState.CURRENT_STATE
|
||||
|
||||
def apply_effects(self, src):
|
||||
"""Apply degradation effects in sequence
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from math import floor
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from math import floor
|
||||
|
||||
|
||||
def blur(src, radius=5):
|
||||
"""Wrapper function for cv2.GaussianBlur
|
||||
|
@ -17,6 +19,7 @@ def blur(src, radius=5):
|
|||
"""
|
||||
return cv2.GaussianBlur(src, (radius, radius), cv2.BORDER_DEFAULT)
|
||||
|
||||
|
||||
def overlay_weighted(src, background, alpha, beta, gamma=0):
|
||||
"""overlay two images together, pixels from each image is weighted as follow
|
||||
|
||||
|
@ -36,6 +39,7 @@ def overlay_weighted(src, background, alpha, beta, gamma=0):
|
|||
"""
|
||||
return cv2.addWeighted(src, alpha, background, beta, gamma).astype(np.uint8)
|
||||
|
||||
|
||||
def overlay(src, background):
|
||||
"""Overlay two images together via bitwise-and:
|
||||
|
||||
|
@ -50,6 +54,7 @@ def overlay(src, background):
|
|||
"""
|
||||
return cv2.bitwise_and(src, background).astype(np.uint8)
|
||||
|
||||
|
||||
def translation(src, offset_x, offset_y):
|
||||
"""Shift the image in x, y direction
|
||||
|
||||
|
@ -69,6 +74,7 @@ def translation(src, offset_x, offset_y):
|
|||
dst = cv2.warpAffine(src, trans_matrix, (cols, rows), borderValue=255)
|
||||
return dst.astype(np.uint8)
|
||||
|
||||
|
||||
def bleed_through(src, background=None, alpha=0.8, gamma=0, offset_x=0, offset_y=5):
|
||||
"""Apply bleed through effect, background is flipped horizontally.
|
||||
|
||||
|
@ -96,6 +102,7 @@ def bleed_through(src, background=None, alpha=0.8, gamma=0, offset_x=0, offset_y
|
|||
beta = 1 - alpha
|
||||
return overlay_weighted(src, background, alpha, beta, gamma)
|
||||
|
||||
|
||||
def pepper(src, amount=0.05):
|
||||
"""Randomly sprinkle dark pixels on src image.
|
||||
Wrapper function for skimage.util.noise.random_noise().
|
||||
|
@ -118,6 +125,7 @@ def pepper(src, amount=0.05):
|
|||
dst[noise < amount] = 0
|
||||
return dst.astype(np.uint8)
|
||||
|
||||
|
||||
def salt(src, amount=0.3):
|
||||
"""Randomly sprinkle white pixels on src image.
|
||||
Wrapper function for skimage.util.noise.random_noise().
|
||||
|
@ -140,6 +148,7 @@ def salt(src, amount=0.3):
|
|||
dst[noise < amount] = 255
|
||||
return dst.astype(np.uint8)
|
||||
|
||||
|
||||
def salt_then_pepper(src, salt_amount=0.1, pepper_amount=0.05):
|
||||
"""Randomly add salt then add pepper onto the image.
|
||||
|
||||
|
@ -159,6 +168,7 @@ def salt_then_pepper(src, salt_amount=0.1, pepper_amount=0.05):
|
|||
salted = salt(src, amount=salt_amount)
|
||||
return pepper(salted, amount=pepper_amount)
|
||||
|
||||
|
||||
def pepper_then_salt(src, pepper_amount=0.05, salt_amount=0.1):
|
||||
"""Randomly add pepper then salt onto the image.
|
||||
|
||||
|
@ -178,6 +188,7 @@ def pepper_then_salt(src, pepper_amount=0.05, salt_amount=0.1):
|
|||
peppered = pepper(src, amount=pepper_amount)
|
||||
return salt(peppered, amount=salt_amount)
|
||||
|
||||
|
||||
def create_2D_kernel(kernel_shape, kernel_type="ones"):
|
||||
"""Create 2D kernel for morphological operations.
|
||||
|
||||
|
@ -243,11 +254,21 @@ def create_2D_kernel(kernel_shape, kernel_type="ones"):
|
|||
elif kernel_type == "ellipse":
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_shape)
|
||||
else:
|
||||
valid_kernel_types = {"ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"}
|
||||
raise ValueError(f"Invalid kernel_type: {kernel_type}. Valid types are {valid_kernel_types}")
|
||||
valid_kernel_types = {
|
||||
"ones",
|
||||
"upper_triangle",
|
||||
"lower_triangle",
|
||||
"x",
|
||||
"plus",
|
||||
"ellipse",
|
||||
}
|
||||
raise ValueError(
|
||||
f"Invalid kernel_type: {kernel_type}. Valid types are {valid_kernel_types}"
|
||||
)
|
||||
|
||||
return kernel.astype(np.uint8)
|
||||
|
||||
|
||||
def morphology(src, operation="open", kernel_shape=(3, 3), kernel_type="ones"):
|
||||
"""Dynamic calls different morphological operations
|
||||
("open", "close", "dilate" and "erode") with the given parameters
|
||||
|
@ -280,7 +301,10 @@ def morphology(src, operation="open", kernel_shape=(3,3), kernel_type="ones"):
|
|||
return erode(src, kernel)
|
||||
else:
|
||||
valid_operations = ["open", "close", "dilate", "erode"]
|
||||
raise ValueError(f"Invalid morphology operation '{operation}'. Valid morphological operations are {valid_operations}")
|
||||
raise ValueError(
|
||||
f"Invalid morphology operation '{operation}'. Valid morphological operations are {valid_operations}"
|
||||
)
|
||||
|
||||
|
||||
def open(src, kernel):
|
||||
""" "open" morphological operation. Like morphological "erosion", it removes
|
||||
|
@ -299,6 +323,7 @@ def open(src, kernel):
|
|||
"""
|
||||
return cv2.morphologyEx(src, cv2.MORPH_OPEN, kernel)
|
||||
|
||||
|
||||
def close(src, kernel):
|
||||
""" "close" morphological operation. Like morphological "dilation", it grows the
|
||||
boundary of the foreground (white pixels), however, it is less destructive than
|
||||
|
@ -317,6 +342,7 @@ def close(src, kernel):
|
|||
"""
|
||||
return cv2.morphologyEx(src, cv2.MORPH_CLOSE, kernel)
|
||||
|
||||
|
||||
def erode(src, kernel):
|
||||
""" "erode" morphological operation. Erodes foreground pixels (white pixels).
|
||||
For more information see:
|
||||
|
@ -332,6 +358,7 @@ def erode(src, kernel):
|
|||
"""
|
||||
return cv2.erode(src, kernel)
|
||||
|
||||
|
||||
def dilate(src, kernel):
|
||||
""" "dilate" morphological operation. Grows foreground pixels (white pixels).
|
||||
For more information see:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from enum import Enum, auto
|
||||
from enum import auto, Enum
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
PARAGRAPH = auto()
|
||||
|
@ -6,14 +7,17 @@ class ContentType(Enum):
|
|||
IMAGE = auto()
|
||||
COMPOSITE = auto()
|
||||
|
||||
class Content():
|
||||
|
||||
class Content:
|
||||
def __init__(self):
|
||||
self.iterable = True
|
||||
self._content = None
|
||||
|
||||
def set_content_type(self, content_type):
|
||||
if type(content_type) != ContentType:
|
||||
raise TypeError(f"Invalid content type: {content_type}, valid types are {list(ContentType)}")
|
||||
raise TypeError(
|
||||
f"Invalid content type: {content_type}, valid types are {list(ContentType)}"
|
||||
)
|
||||
self.content_type = content_type
|
||||
|
||||
def validate_content(self):
|
||||
|
@ -28,6 +32,7 @@ class Content():
|
|||
def __getitem__(self, key):
|
||||
return self._content.__getitem__(key)
|
||||
|
||||
|
||||
class Paragraph(Content):
|
||||
def __init__(self, content):
|
||||
self.set_content_type(ContentType.PARAGRAPH)
|
||||
|
@ -38,6 +43,7 @@ class Paragraph(Content):
|
|||
if not isinstance(content, str):
|
||||
raise TypeError(f"Expect a str, but got {type(content)}")
|
||||
|
||||
|
||||
class Title(Content):
|
||||
def __init__(self, content):
|
||||
self.set_content_type(ContentType.TITLE)
|
||||
|
@ -48,6 +54,7 @@ class Title(Content):
|
|||
if not isinstance(content, str):
|
||||
raise TypeError(f"Expect a str, but got {type(content)}")
|
||||
|
||||
|
||||
class CompositeContent(Content):
|
||||
def __init__(self, content_list, content_type_list):
|
||||
self.set_content_type(ContentType.COMPOSITE)
|
||||
|
@ -82,5 +89,5 @@ class CompositeContent(Content):
|
|||
"""get a string transparent of the nested object types"""
|
||||
transparent_str = "["
|
||||
for content in self._content:
|
||||
transparent_str += '"' + content.__str__() + "\", "
|
||||
transparent_str += '"' + content.__str__() + '", '
|
||||
return transparent_str + "]"
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
from jinja2 import PackageLoader, FileSystemLoader
|
||||
from jinja2 import Environment, select_autoescape
|
||||
from weasyprint import HTML
|
||||
from cairocffi import FORMAT_ARGB32
|
||||
import numpy as np
|
||||
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import cv2
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from cairocffi import FORMAT_ARGB32
|
||||
from jinja2 import Environment, select_autoescape
|
||||
from jinja2 import FileSystemLoader, PackageLoader
|
||||
from weasyprint import HTML
|
||||
|
||||
DEFAULT_DOCUMENT_STYLE = {
|
||||
"language": "en_US",
|
||||
|
@ -27,8 +25,10 @@ DEFAULT_STYLE_COMBINATION = {
|
|||
"hyphenate": [False],
|
||||
}
|
||||
|
||||
|
||||
class Document(object):
|
||||
""" A composite object that represents a document """
|
||||
|
||||
def __init__(self, content, template, **styles):
|
||||
"""Initialize a Document object with source template and content
|
||||
|
||||
|
@ -97,21 +97,26 @@ class Document(object):
|
|||
Arguments:
|
||||
target -- a filename, file-like object, or None
|
||||
split_pages {bool} -- true if save each document page as a separate file.
|
||||
resolution {int} -- the output resolution in PNG pixels per CSS inch. At 300 dpi (the default), PNG pixels match the CSS px unit.
|
||||
resolution {int} -- the output resolution in PNG pixels per CSS inch. At 300 dpi (the default),
|
||||
PNG pixels match the CSS px unit.
|
||||
|
||||
Returns:
|
||||
The image as bytes if target is not provided or None, otherwise None (the PDF is written to target)
|
||||
"""
|
||||
if target != None and split_pages:
|
||||
if target is not None and split_pages:
|
||||
# get destination filename and extension
|
||||
filename, ext = os.path.splitext(target)
|
||||
for page_num, page in enumerate(self._document.pages):
|
||||
page_name = filename + f"_pg_{page_num}" + ext
|
||||
self._document.copy([page]).write_png(target=page_name, resolution=resolution)
|
||||
self._document.copy([page]).write_png(
|
||||
target=page_name, resolution=resolution
|
||||
)
|
||||
return None
|
||||
elif target == None:
|
||||
elif target is None:
|
||||
# return image bytes string if no target is specified
|
||||
png_bytes, png_width, png_height = self._document.write_png(target=target, resolution=resolution)
|
||||
png_bytes, png_width, png_height = self._document.write_png(
|
||||
target=target, resolution=resolution
|
||||
)
|
||||
return png_bytes
|
||||
else:
|
||||
return self._document.write_png(target=target, resolution=resolution)
|
||||
|
@ -131,18 +136,23 @@ class Document(object):
|
|||
"""
|
||||
# Method below returns a cairocffi.ImageSurface object
|
||||
# https://cairocffi.readthedocs.io/en/latest/api.html#cairocffi.ImageSurface
|
||||
surface, width, height = self._document.write_image_surface(resolution=resolution)
|
||||
surface, width, height = self._document.write_image_surface(
|
||||
resolution=resolution
|
||||
)
|
||||
img_format = surface.get_format()
|
||||
|
||||
# This is BGRA channel in little endian (reverse)
|
||||
if img_format != FORMAT_ARGB32:
|
||||
raise RuntimeError(f"Expect surface format to be 'cairocffi.FORMAT_ARGB32', but got {img_format}. Please check the underlining implementation of 'weasyprint.document.Document.write_image_surface()'")
|
||||
raise RuntimeError(
|
||||
f"Expect surface format to be 'cairocffi.FORMAT_ARGB32', but got {img_format}." +
|
||||
"Please check the underlining implementation of 'weasyprint.document.Document.write_image_surface()'"
|
||||
)
|
||||
|
||||
img_buffer = surface.get_data()
|
||||
# Returns image array in "BGRA" channel
|
||||
img_array = np.ndarray(shape=(height, width, 4),
|
||||
dtype=np.uint8,
|
||||
buffer=img_buffer)
|
||||
img_array = np.ndarray(
|
||||
shape=(height, width, 4), dtype=np.uint8, buffer=img_buffer
|
||||
)
|
||||
if channel == "GRAYSCALE":
|
||||
return cv2.cvtColor(img_array, cv2.COLOR_BGRA2GRAY)
|
||||
elif channel == "RGBA":
|
||||
|
@ -155,7 +165,9 @@ class Document(object):
|
|||
return cv2.cvtColor(img_array, cv2.COLOR_BGRA2BGR)
|
||||
else:
|
||||
valid_channels = ["GRAYSCALE", "RGB", "RGBA", "BGR", "BGRA"]
|
||||
raise ValueError(f"Invalid channel code {channel}. Valid values are: {valid_channels}.")
|
||||
raise ValueError(
|
||||
f"Invalid channel code {channel}. Valid values are: {valid_channels}."
|
||||
)
|
||||
|
||||
def update_style(self, **style):
|
||||
"""Update template variables that controls the document style and re-compile the document to reflect the style change.
|
||||
|
@ -175,11 +187,14 @@ class Document(object):
|
|||
self.styles.update(style)
|
||||
# Recompile the html template and the document obj
|
||||
self.compiled_html = self.render_html()
|
||||
self._document = HTML(string=self.compiled_html).render() # weasyprinter.document.Document object
|
||||
self._document = HTML(
|
||||
string=self.compiled_html
|
||||
).render() # weasyprinter.document.Document object
|
||||
|
||||
|
||||
class DocumentGenerator():
|
||||
class DocumentGenerator:
|
||||
""" Document generator class """
|
||||
|
||||
def __init__(self, template_path=None):
|
||||
"""Initialize a DocumentGenerator class
|
||||
|
||||
|
@ -191,17 +206,19 @@ class DocumentGenerator():
|
|||
if template_path:
|
||||
self.template_env = Environment(
|
||||
loader=FileSystemLoader(template_path),
|
||||
autoescape=select_autoescape(['html', 'xml'])
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
)
|
||||
self.template_list = self.template_env.list_templates()
|
||||
else:
|
||||
# Loading built-in templates from the genalog package
|
||||
self.template_env = Environment(
|
||||
loader=PackageLoader("genalog.generation", "templates"),
|
||||
autoescape=select_autoescape(['html', 'xml'])
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
)
|
||||
# Remove macros and css templates from rendering
|
||||
self.template_list = self.template_env.list_templates(filter_func=DocumentGenerator._keep_template)
|
||||
self.template_list = self.template_env.list_templates(
|
||||
filter_func=DocumentGenerator._keep_template
|
||||
)
|
||||
|
||||
self.set_styles_to_generate(DEFAULT_STYLE_COMBINATION)
|
||||
|
||||
|
@ -251,7 +268,9 @@ class DocumentGenerator():
|
|||
If this parameter is not provided, generator will use default document
|
||||
styles: DEFAULT_STYLE_COMBINATION
|
||||
"""
|
||||
self.styles_to_generate = DocumentGenerator.expand_style_combinations(style_combinations)
|
||||
self.styles_to_generate = DocumentGenerator.expand_style_combinations(
|
||||
style_combinations
|
||||
)
|
||||
|
||||
def create_generator(self, content, templates_to_render):
|
||||
"""Create a Document generator
|
||||
|
@ -266,7 +285,9 @@ class DocumentGenerator():
|
|||
"""
|
||||
for template_name in templates_to_render:
|
||||
if template_name not in self.template_list:
|
||||
raise FileNotFoundError(f"File '{template_name}' not found. Available templates are {self.template_list}")
|
||||
raise FileNotFoundError(
|
||||
f"File '{template_name}' not found. Available templates are {self.template_list}"
|
||||
)
|
||||
template = self.template_env.get_template(template_name)
|
||||
for style in self.styles_to_generate:
|
||||
yield Document(content, template, **style)
|
||||
|
@ -308,8 +329,12 @@ class DocumentGenerator():
|
|||
if not styles:
|
||||
return []
|
||||
# Python 2.x+ guarantees that the order in keys() and values() is preserved
|
||||
style_properties = styles.keys() # ex) ["font_family", "font_size", "hyphenate"]
|
||||
property_values = styles.values() # ex) [["Calibri", "Times"], ["10px", "12px"], [True]]
|
||||
style_properties = (
|
||||
styles.keys()
|
||||
) # ex) ["font_family", "font_size", "hyphenate"]
|
||||
property_values = (
|
||||
styles.values()
|
||||
) # ex) [["Calibri", "Times"], ["10px", "12px"], [True]]
|
||||
|
||||
# Generate all possible combinations:
|
||||
# [("Calibri", "10px", True), ("Calibri", "12px", True), ...]
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
"""Uses the python sdk to make operation on Azure Blob storage.
|
||||
see: https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from multiprocessing import Pool
|
||||
from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient
|
||||
from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient
|
||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
||||
from tqdm import tqdm
|
||||
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
|
||||
import base64
|
||||
|
||||
import aiofiles
|
||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient
|
||||
from tqdm import tqdm
|
||||
|
||||
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
|
||||
|
||||
# maximum number of simultaneous requests
|
||||
REQUEST_SEMAPHORE = asyncio.Semaphore(50)
|
||||
|
@ -24,11 +24,19 @@ FILE_SEMAPHORE = asyncio.Semaphore(500)
|
|||
|
||||
MAX_RETRIES = 5
|
||||
|
||||
|
||||
class GrokBlobClient:
|
||||
"""This class is a client that is used to upload and delete files from Azure Blob storage
|
||||
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python
|
||||
"""
|
||||
def __init__(self, datasource_container_name, blob_account_name, blob_key, projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasource_container_name,
|
||||
blob_account_name,
|
||||
blob_key,
|
||||
projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME,
|
||||
):
|
||||
"""Creates the blob storage client given the key and storage account name
|
||||
|
||||
Args:
|
||||
|
@ -42,8 +50,10 @@ class GrokBlobClient:
|
|||
self.PROJECTIONS_CONTAINER_NAME = projections_container_name
|
||||
self.BLOB_NAME = blob_account_name
|
||||
self.BLOB_KEY = blob_key
|
||||
self.BLOB_CONNECTION_STRING = f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" \
|
||||
self.BLOB_CONNECTION_STRING = (
|
||||
f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};"
|
||||
f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_from_env_var():
|
||||
|
@ -55,11 +65,24 @@ class GrokBlobClient:
|
|||
DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"]
|
||||
BLOB_NAME = os.environ["BLOB_NAME"]
|
||||
BLOB_KEY = os.environ["BLOB_KEY"]
|
||||
PROJECTIONS_CONTAINER_NAME = os.environ.get("PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME)
|
||||
client = GrokBlobClient(DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY, projections_container_name=PROJECTIONS_CONTAINER_NAME)
|
||||
PROJECTIONS_CONTAINER_NAME = os.environ.get(
|
||||
"PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME
|
||||
)
|
||||
client = GrokBlobClient(
|
||||
DATASOURCE_CONTAINER_NAME,
|
||||
BLOB_NAME,
|
||||
BLOB_KEY,
|
||||
projections_container_name=PROJECTIONS_CONTAINER_NAME,
|
||||
)
|
||||
return client
|
||||
|
||||
def upload_images_to_blob(self, src_folder_path, dest_folder_name=None,check_existing_cache=True, use_async=True):
|
||||
def upload_images_to_blob(
|
||||
self,
|
||||
src_folder_path,
|
||||
dest_folder_name=None,
|
||||
check_existing_cache=True,
|
||||
use_async=True,
|
||||
):
|
||||
"""Uploads images from the src_folder_path to blob storage at the destination folder.
|
||||
The destination folder is created if it doesn't exist. If a destination folder is not
|
||||
given a folder is created named by the md5 hash of the files.
|
||||
|
@ -73,7 +96,8 @@ class GrokBlobClient:
|
|||
"""
|
||||
self._create_container()
|
||||
blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
|
||||
if dest_folder_name is None:
|
||||
dest_folder_name = self.get_folder_hash(src_folder_path)
|
||||
|
@ -96,28 +120,48 @@ class GrokBlobClient:
|
|||
if check_existing_cache:
|
||||
existing_blobs, _ = self.list_blobs(dest_folder_name or "")
|
||||
existing_blobs = list(map(lambda blob: blob["name"], existing_blobs))
|
||||
file_blob_names = filter(lambda file_blob_names: not file_blob_names[1] in existing_blobs, zip(files_to_upload, blob_names))
|
||||
job_args = [get_job_args(file_path, blob_name) for file_path, blob_name in file_blob_names ]
|
||||
file_blob_names = filter(
|
||||
lambda file_blob_names: not file_blob_names[1] in existing_blobs,
|
||||
zip(files_to_upload, blob_names),
|
||||
)
|
||||
job_args = [
|
||||
get_job_args(file_path, blob_name)
|
||||
for file_path, blob_name in file_blob_names
|
||||
]
|
||||
else:
|
||||
job_args = [get_job_args(file_path, blob_name) for file_path, blob_name in zip(files_to_upload, blob_names)]
|
||||
job_args = [
|
||||
get_job_args(file_path, blob_name)
|
||||
for file_path, blob_name in zip(files_to_upload, blob_names)
|
||||
]
|
||||
|
||||
print("uploading ", len(job_args), "files")
|
||||
if not use_async:
|
||||
blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
blob_container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
blob_container_client = blob_service_client.get_container_client(
|
||||
self.DATASOURCE_CONTAINER_NAME
|
||||
)
|
||||
jobs = [(blob_container_client,) + x for x in job_args]
|
||||
for _ in tqdm(map(_upload_worker_sync, jobs), total=len(jobs)):
|
||||
pass
|
||||
else:
|
||||
async_blob_service_client = asyncBlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
|
||||
async def async_upload():
|
||||
async with async_blob_service_client:
|
||||
async_blob_container_client = async_blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
|
||||
async_blob_container_client = (
|
||||
async_blob_service_client.get_container_client(
|
||||
self.DATASOURCE_CONTAINER_NAME
|
||||
)
|
||||
)
|
||||
jobs = [(async_blob_container_client,) + x for x in job_args]
|
||||
for f in tqdm(asyncio.as_completed(map(_upload_worker_async,jobs)), total=len(jobs)):
|
||||
for f in tqdm(
|
||||
asyncio.as_completed(map(_upload_worker_async, jobs)),
|
||||
total=len(jobs),
|
||||
):
|
||||
await f
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -158,53 +202,83 @@ class GrokBlobClient:
|
|||
blobs_list, blob_service_client = self.list_blobs(folder_name)
|
||||
for blob in blobs_list:
|
||||
blob_client = blob_service_client.get_blob_client(
|
||||
container=self.DATASOURCE_CONTAINER_NAME, blob=blob)
|
||||
container=self.DATASOURCE_CONTAINER_NAME, blob=blob
|
||||
)
|
||||
blob_client.delete_blob()
|
||||
|
||||
def list_blobs(self, folder_name):
|
||||
blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
|
||||
return container_client.list_blobs(name_starts_with=folder_name), blob_service_client
|
||||
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
container_client = blob_service_client.get_container_client(
|
||||
self.DATASOURCE_CONTAINER_NAME
|
||||
)
|
||||
return (
|
||||
container_client.list_blobs(name_starts_with=folder_name),
|
||||
blob_service_client,
|
||||
)
|
||||
|
||||
def _create_container(self):
|
||||
"""Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist.
|
||||
"""
|
||||
"""Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist."""
|
||||
# Create the BlobServiceClient object which will be used to create a container client
|
||||
blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
|
||||
try:
|
||||
blob_service_client.create_container(
|
||||
self.DATASOURCE_CONTAINER_NAME)
|
||||
blob_service_client.create_container(self.DATASOURCE_CONTAINER_NAME)
|
||||
except ResourceExistsError:
|
||||
print("container already exists:", self.DATASOURCE_CONTAINER_NAME)
|
||||
|
||||
# create the container for storing ocr projections
|
||||
try:
|
||||
print("creating projections storage container:", self.PROJECTIONS_CONTAINER_NAME)
|
||||
blob_service_client.create_container(
|
||||
self.PROJECTIONS_CONTAINER_NAME)
|
||||
print(
|
||||
"creating projections storage container:",
|
||||
self.PROJECTIONS_CONTAINER_NAME,
|
||||
)
|
||||
blob_service_client.create_container(self.PROJECTIONS_CONTAINER_NAME)
|
||||
except ResourceExistsError:
|
||||
print("container already exists:", self.PROJECTIONS_CONTAINER_NAME)
|
||||
|
||||
def get_ocr_json(self, remote_path, output_folder, use_async=True):
|
||||
blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
container_client = blob_service_client.get_container_client(
|
||||
self.DATASOURCE_CONTAINER_NAME
|
||||
)
|
||||
blobs_list = list(container_client.list_blobs(name_starts_with=remote_path))
|
||||
container_uri = f"https://{self.BLOB_NAME}.blob.core.windows.net/{self.DATASOURCE_CONTAINER_NAME}"
|
||||
|
||||
if use_async:
|
||||
async_blob_service_client = asyncBlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
|
||||
async def async_download():
|
||||
async with async_blob_service_client:
|
||||
async_projection_container_client = async_blob_service_client.get_container_client(self.PROJECTIONS_CONTAINER_NAME)
|
||||
jobs = list(map(lambda blob :(blob, async_projection_container_client, container_uri, output_folder), blobs_list ))
|
||||
for f in tqdm(asyncio.as_completed(map(_download_worker_async,jobs)), total=len(jobs)):
|
||||
async_projection_container_client = (
|
||||
async_blob_service_client.get_container_client(
|
||||
self.PROJECTIONS_CONTAINER_NAME
|
||||
)
|
||||
)
|
||||
jobs = list(
|
||||
map(
|
||||
lambda blob: (
|
||||
blob,
|
||||
async_projection_container_client,
|
||||
container_uri,
|
||||
output_folder,
|
||||
),
|
||||
blobs_list,
|
||||
)
|
||||
)
|
||||
for f in tqdm(
|
||||
asyncio.as_completed(map(_download_worker_async, jobs)),
|
||||
total=len(jobs),
|
||||
):
|
||||
await f
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
result = loop.create_task(async_download())
|
||||
|
@ -212,12 +286,25 @@ class GrokBlobClient:
|
|||
result = loop.run_until_complete(async_download())
|
||||
return result
|
||||
else:
|
||||
projection_container_client = blob_service_client.get_container_client(self.PROJECTIONS_CONTAINER_NAME)
|
||||
jobs = list(map(lambda blob : (blob, projection_container_client, container_uri, output_folder), blobs_list))
|
||||
projection_container_client = blob_service_client.get_container_client(
|
||||
self.PROJECTIONS_CONTAINER_NAME
|
||||
)
|
||||
jobs = list(
|
||||
map(
|
||||
lambda blob: (
|
||||
blob,
|
||||
projection_container_client,
|
||||
container_uri,
|
||||
output_folder,
|
||||
),
|
||||
blobs_list,
|
||||
)
|
||||
)
|
||||
print("downloading", len(jobs), "files")
|
||||
for _ in tqdm(map(_download_worker_sync, jobs), total=len(jobs)):
|
||||
pass
|
||||
|
||||
|
||||
def _get_projection_path(container_uri, blob):
|
||||
blob_uri = f"{container_uri}/{blob.name}"
|
||||
|
||||
|
@ -229,10 +316,13 @@ def _get_projection_path(container_uri, blob):
|
|||
projection_path = projection_path.replace("=", "") + str(projection_path.count("="))
|
||||
return projection_path
|
||||
|
||||
|
||||
def _download_worker_sync(args):
|
||||
blob, projection_container_client, container_uri, output_folder = args
|
||||
projection_path = _get_projection_path(container_uri, blob)
|
||||
blob_client = projection_container_client.get_blob_client(blob=f"{projection_path}/document.json")
|
||||
blob_client = projection_container_client.get_blob_client(
|
||||
blob=f"{projection_path}/document.json"
|
||||
)
|
||||
doc = json.loads(blob_client.download_blob().readall())
|
||||
file_name = os.path.basename(blob.name)
|
||||
base_name, ext = os.path.splitext(file_name)
|
||||
|
@ -242,10 +332,13 @@ def _download_worker_sync(args):
|
|||
json.dump(text, open(output_file, "w", encoding="utf-8"), ensure_ascii=False)
|
||||
return output_file
|
||||
|
||||
|
||||
async def _download_worker_async(args):
|
||||
blob, async_projection_container_client, container_uri, output_folder = args
|
||||
projection_path = _get_projection_path(container_uri, blob)
|
||||
async_blob_client = async_projection_container_client.get_blob_client( blob=f"{projection_path}/document.json")
|
||||
async_blob_client = async_projection_container_client.get_blob_client(
|
||||
blob=f"{projection_path}/document.json"
|
||||
)
|
||||
file_name = os.path.basename(blob.name)
|
||||
base_name, ext = os.path.splitext(file_name)
|
||||
for retry in range(MAX_RETRIES):
|
||||
|
@ -260,7 +353,7 @@ async def _download_worker_async(args):
|
|||
json.dump(text, open(output_file, "w"))
|
||||
return output_file
|
||||
except ResourceNotFoundError:
|
||||
print(f"blob doesn't exist in OCR projection. try rerunning OCR", blob.name)
|
||||
print(f"Blob '{blob.name}'' doesn't exist in OCR projection. try rerunning OCR")
|
||||
return
|
||||
except Exception as e:
|
||||
print("error getting blob OCR projection", blob.name, e)
|
||||
|
@ -268,6 +361,7 @@ async def _download_worker_async(args):
|
|||
# sleep for a bit then retry
|
||||
asyncio.sleep(2 * random.random())
|
||||
|
||||
|
||||
async def _upload_worker_async(args):
|
||||
async_blob_container_client, upload_file_path, blob_name = args
|
||||
async with FILE_SEMAPHORE:
|
||||
|
@ -276,22 +370,31 @@ async def _upload_worker_async(args):
|
|||
for retry in range(MAX_RETRIES):
|
||||
async with REQUEST_SEMAPHORE:
|
||||
try:
|
||||
await async_blob_container_client.upload_blob(name=blob_name, max_concurrency=8, data=data)
|
||||
await async_blob_container_client.upload_blob(
|
||||
name=blob_name, max_concurrency=8, data=data
|
||||
)
|
||||
return blob_name
|
||||
except ResourceExistsError:
|
||||
print("blob already exists:", blob_name)
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"blob upload error. retry count: {retry}/{MAX_RETRIES} :", blob_name, e)
|
||||
print(
|
||||
f"blob upload error. retry count: {retry}/{MAX_RETRIES} :",
|
||||
blob_name,
|
||||
e,
|
||||
)
|
||||
# sleep for a bit then retry
|
||||
asyncio.sleep(2 * random.random())
|
||||
return blob_name
|
||||
|
||||
|
||||
def _upload_worker_sync(args):
|
||||
blob_container_client, upload_file_path, blob_name = args
|
||||
with open(upload_file_path, "rb") as data:
|
||||
try:
|
||||
blob_container_client.upload_blob(name=blob_name, max_concurrency=8, data=data)
|
||||
blob_container_client.upload_blob(
|
||||
name=blob_name, max_concurrency=8, data=data
|
||||
)
|
||||
except ResourceExistsError:
|
||||
print("blob already exists:", blob_name)
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from .rest_client import GrokRestClient
|
||||
from .blob_client import GrokBlobClient
|
||||
import time
|
||||
|
||||
from .blob_client import GrokBlobClient
|
||||
from .rest_client import GrokRestClient
|
||||
|
||||
|
||||
class Grok:
|
||||
|
||||
@staticmethod
|
||||
def create_from_env_var():
|
||||
"""Initializes Grok based on keys in the environment variables.
|
||||
|
@ -16,11 +16,20 @@ class Grok:
|
|||
grok_blob_client = GrokBlobClient.create_from_env_var()
|
||||
return Grok(grok_rest_client, grok_blob_client)
|
||||
|
||||
def __init__(self, grok_rest_client: GrokRestClient, grok_blob_client: GrokBlobClient):
|
||||
def __init__(
|
||||
self, grok_rest_client: GrokRestClient, grok_blob_client: GrokBlobClient
|
||||
):
|
||||
self.grok_rest_client = grok_rest_client
|
||||
self.grok_blob_client = grok_blob_client
|
||||
|
||||
def run_grok(self, src_folder_path, dest_folder_path, blob_dest_folder=None,cleanup=False,use_async=True):
|
||||
def run_grok(
|
||||
self,
|
||||
src_folder_path,
|
||||
dest_folder_path,
|
||||
blob_dest_folder=None,
|
||||
cleanup=False,
|
||||
use_async=True,
|
||||
):
|
||||
"""Uploads images in the source folder to blob, sets up an indexing pipeline to run
|
||||
GROK OCR on this blob storage as a source, then dowloads the OCR output json to the destination
|
||||
folder. There resulting json files are of the same name as the original images except prefixed
|
||||
|
@ -40,7 +49,8 @@ class Grok:
|
|||
"""
|
||||
print("uploading images to blob")
|
||||
blob_folder_name, _ = self.grok_blob_client.upload_images_to_blob(
|
||||
src_folder_path, dest_folder_name=blob_dest_folder, use_async=use_async)
|
||||
src_folder_path, dest_folder_name=blob_dest_folder, use_async=use_async
|
||||
)
|
||||
print(f"images upload under folder {blob_folder_name}")
|
||||
try:
|
||||
print("creating and running indexer")
|
||||
|
@ -53,7 +63,10 @@ class Grok:
|
|||
|
||||
# if not already running start the indexer
|
||||
print("indexer_status", indexer_status)
|
||||
if indexer_status["lastResult"] == None or indexer_status["lastResult"]["status"] != "inProgress":
|
||||
if (
|
||||
indexer_status["lastResult"] is None
|
||||
or indexer_status["lastResult"]["status"] != "inProgress"
|
||||
):
|
||||
self.grok_rest_client.run_indexer()
|
||||
|
||||
time.sleep(1)
|
||||
|
@ -62,9 +75,13 @@ class Grok:
|
|||
if indexer_status["lastResult"]["status"] == "success":
|
||||
time.sleep(30)
|
||||
print("fetching ocr json results.")
|
||||
self.grok_blob_client.get_ocr_json(blob_folder_name, dest_folder_path, use_async=use_async)
|
||||
self.grok_blob_client.get_ocr_json(
|
||||
blob_folder_name, dest_folder_path, use_async=use_async
|
||||
)
|
||||
print(f"indexer status {indexer_status}")
|
||||
print(f"finished running indexer. json files saved to {dest_folder_path}")
|
||||
print(
|
||||
f"finished running indexer. json files saved to {dest_folder_path}"
|
||||
)
|
||||
else:
|
||||
print("GROK failed", indexer_status["status"])
|
||||
raise RuntimeError("GROK failed", indexer_status["status"])
|
||||
|
|
|
@ -6,7 +6,8 @@ OCR Metrics
|
|||
Accuracy = Correct Words/Total Words (in target strings)
|
||||
|
||||
2. Count of edit distance ops:
|
||||
insert, delete, substitutions; like in the paper "Deep Statistical Analysis of OCR Errors for Effective Post-OCR Processing". This is based on Levenshtein edit distance.
|
||||
insert, delete, substitutions; like in the paper "Deep Statistical Analysis of OCR Errors for Effective Post-OCR Processing".
|
||||
This is based on Levenshtein edit distance.
|
||||
|
||||
3. By looking at the gaps in alignment we also generate substitution dicts:
|
||||
e.g: if we have text "a worn coat" and ocr is "a wom coat" , "rn" -> "m" will be captured as a substitution
|
||||
|
@ -14,29 +15,33 @@ since the rest of the segments align.The assumption here is that we do not expec
|
|||
hence collecting and counting these substitutions will be managable.
|
||||
|
||||
"""
|
||||
import string
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
from genalog.text.ner_label import _find_gap_char_candidates
|
||||
from genalog.text.anchor import align_w_anchor
|
||||
from genalog.text.ner_label import _find_gap_char_candidates
|
||||
|
||||
LOG_LEVEL = 0
|
||||
WORKERS_PER_CPU = 2
|
||||
|
||||
|
||||
def _log(*args, **kwargs):
|
||||
if LOG_LEVEL:
|
||||
print(args)
|
||||
|
||||
|
||||
def _trim_whitespace(src_string):
|
||||
return re.sub(r"\s+", " ", src_string.strip())
|
||||
|
||||
|
||||
def _update_align_stats(src, target, align_stats, substitution_dict, gap_char):
|
||||
"""Given two string that differ and have no alignment at all,
|
||||
update the alignment dict and fill in substitution if replacements are found.
|
||||
|
@ -70,13 +75,23 @@ def _update_align_stats(src, target, align_stats, substitution_dict, gap_char):
|
|||
else:
|
||||
align_stats["replace"] += 1
|
||||
_log("replacing", source_substr, target_substr)
|
||||
substitution_dict[source_substr, target_substr] = substitution_dict.get(
|
||||
(source_substr, target_substr), 0) + 1
|
||||
substitution_dict[source_substr, target_substr] = (
|
||||
substitution_dict.get((source_substr, target_substr), 0) + 1
|
||||
)
|
||||
_log("spacing count", spacing_count)
|
||||
align_stats["spacing"] += spacing_count
|
||||
|
||||
def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matching_chars_count, \
|
||||
matching_words_count, matching_alnum_words_count):
|
||||
|
||||
def _update_word_stats(
|
||||
aligned_src,
|
||||
aligned_target,
|
||||
gap_char,
|
||||
start,
|
||||
end,
|
||||
matching_chars_count,
|
||||
matching_words_count,
|
||||
matching_alnum_words_count,
|
||||
):
|
||||
"""Given two string segments that align. update the counts of matching words and characters
|
||||
|
||||
Args:
|
||||
|
@ -93,7 +108,7 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
|
|||
tuple(int,int,int): the updated matching_chars_count, matching_words_count, matching_alnum_words_count
|
||||
"""
|
||||
aligned_part = aligned_src[start:end]
|
||||
matching_chars_count += (end-start)
|
||||
matching_chars_count += end - start
|
||||
# aligned_part = seq.strip()
|
||||
_log("aligned", aligned_part, start, end)
|
||||
if len(aligned_src) != len(aligned_target):
|
||||
|
@ -112,10 +127,18 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
|
|||
# to be compared with the full string to see if they have space before or after
|
||||
|
||||
if i == 0:
|
||||
if start != 0 and (aligned_target[start] != " " or aligned_src[start] != " " ):
|
||||
if start != 0 and (
|
||||
aligned_target[start] != " " or aligned_src[start] != " "
|
||||
):
|
||||
# if this was the start of the string in the target or source
|
||||
if not(aligned_src[:start].replace(gap_char,"").replace(" ","") == "" and aligned_target[start-1] == " ") and \
|
||||
not(aligned_target[:start].replace(gap_char,"").replace(" ","") == "" and aligned_src[start-1] == " "):
|
||||
if not (
|
||||
aligned_src[:start].replace(gap_char, "").replace(" ", "") == ""
|
||||
and aligned_target[start - 1] == " "
|
||||
) and not (
|
||||
aligned_target[:start].replace(gap_char, "").replace(" ", "")
|
||||
== ""
|
||||
and aligned_src[start - 1] == " "
|
||||
):
|
||||
# beginning word not matching completely
|
||||
_log("removing first match word from count", word, aligned_part)
|
||||
matching_words_count -= 1
|
||||
|
@ -124,10 +147,18 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
|
|||
continue
|
||||
|
||||
if i == len(words) - 1:
|
||||
if end != len(aligned_target) and (aligned_target[end] != " " or aligned_src[end] != " " ):
|
||||
if end != len(aligned_target) and (
|
||||
aligned_target[end] != " " or aligned_src[end] != " "
|
||||
):
|
||||
# this was not the end of the string in the src and not end of string in target
|
||||
if not(aligned_src[end:].replace(gap_char,"").replace(" ","") == "" and aligned_target[end] == " ") and \
|
||||
not(aligned_target[end:].replace(gap_char,"").replace(" ","") == ""and aligned_src[end] == " "):
|
||||
if not (
|
||||
aligned_src[end:].replace(gap_char, "").replace(" ", "") == ""
|
||||
and aligned_target[end] == " "
|
||||
) and not (
|
||||
aligned_target[end:].replace(gap_char, "").replace(" ", "")
|
||||
== ""
|
||||
and aligned_src[end] == " "
|
||||
):
|
||||
# last word not matching completely
|
||||
_log("removing last match word from count", word, aligned_part)
|
||||
matching_words_count -= 1
|
||||
|
@ -138,6 +169,7 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
|
|||
_log("matched alnum count", matching_alnum_words_count)
|
||||
return matching_chars_count, matching_words_count, matching_alnum_words_count
|
||||
|
||||
|
||||
def _get_align_stats(alignment, src_string, target, gap_char):
|
||||
"""Given an alignment, this function get the align stats and substitution mapping to
|
||||
transform the source string to the target string
|
||||
|
@ -174,8 +206,15 @@ def _get_align_stats(alignment, src_string, target, gap_char):
|
|||
matching_words_count = 0
|
||||
matching_alnum_words_count = 0
|
||||
|
||||
align_stats = {"insert": 0, "delete": 0, "replace": 0, "spacing": 0,
|
||||
"total_chars": char_count, "total_words": word_count, "total_alnum_words": alnum_words_count}
|
||||
align_stats = {
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": char_count,
|
||||
"total_words": word_count,
|
||||
"total_alnum_words": alnum_words_count,
|
||||
}
|
||||
start = 0
|
||||
|
||||
_log("######### Alignment ############")
|
||||
|
@ -190,35 +229,83 @@ def _get_align_stats(alignment, src_string, target, gap_char):
|
|||
# since this substring aligns, simple count the number of matching words and chars in and update
|
||||
# the word stats
|
||||
end = i
|
||||
_log("sequences", aligned_src[start:end], aligned_target[start:end], start, end)
|
||||
_log(
|
||||
"sequences",
|
||||
aligned_src[start:end],
|
||||
aligned_target[start:end],
|
||||
start,
|
||||
end,
|
||||
)
|
||||
assert aligned_src[start:end] == aligned_target[start:end]
|
||||
matching_chars_count, matching_words_count, matching_alnum_words_count = _update_word_stats(aligned_src,
|
||||
aligned_target, gap_char, start, end, matching_chars_count,matching_words_count, matching_alnum_words_count)
|
||||
(
|
||||
matching_chars_count,
|
||||
matching_words_count,
|
||||
matching_alnum_words_count,
|
||||
) = _update_word_stats(
|
||||
aligned_src,
|
||||
aligned_target,
|
||||
gap_char,
|
||||
start,
|
||||
end,
|
||||
matching_chars_count,
|
||||
matching_words_count,
|
||||
matching_alnum_words_count,
|
||||
)
|
||||
start = end + 1
|
||||
if gap_start is None:
|
||||
gap_start = end
|
||||
else:
|
||||
gap_end = i
|
||||
if not gap_start is None:
|
||||
if gap_start is not None:
|
||||
# since characters now match gap_start:i contains a substring of the characters that didnt align before
|
||||
# handle this gap alignment by calling _update_align_stats
|
||||
_log("gap", aligned_src[gap_start:gap_end], aligned_target[gap_start:gap_end], gap_start, gap_end)
|
||||
_update_align_stats(aligned_src[gap_start:gap_end], aligned_target[gap_start:gap_end], align_stats, substitution_dict, gap_char)
|
||||
_log(
|
||||
"gap",
|
||||
aligned_src[gap_start:gap_end],
|
||||
aligned_target[gap_start:gap_end],
|
||||
gap_start,
|
||||
gap_end,
|
||||
)
|
||||
_update_align_stats(
|
||||
aligned_src[gap_start:gap_end],
|
||||
aligned_target[gap_start:gap_end],
|
||||
align_stats,
|
||||
substitution_dict,
|
||||
gap_char,
|
||||
)
|
||||
gap_start = None
|
||||
|
||||
# Now compare any left overs string segments from the for loop
|
||||
if gap_start is not None:
|
||||
# handle last alignment gap
|
||||
_log("last gap", aligned_src[gap_start:], aligned_target[gap_start:])
|
||||
_update_align_stats(aligned_src[gap_start:], aligned_target[gap_start:], align_stats, substitution_dict, gap_char)
|
||||
_update_align_stats(
|
||||
aligned_src[gap_start:],
|
||||
aligned_target[gap_start:],
|
||||
align_stats,
|
||||
substitution_dict,
|
||||
gap_char,
|
||||
)
|
||||
else:
|
||||
# handle last aligned substring
|
||||
seq = aligned_src[start:]
|
||||
aligned_part = seq.strip()
|
||||
end = len(aligned_src)
|
||||
_log("last aligned", aligned_part)
|
||||
matching_chars_count, matching_words_count, matching_alnum_words_count = _update_word_stats(aligned_src,
|
||||
aligned_target, gap_char, start, end, matching_chars_count,matching_words_count, matching_alnum_words_count)
|
||||
(
|
||||
matching_chars_count,
|
||||
matching_words_count,
|
||||
matching_alnum_words_count,
|
||||
) = _update_word_stats(
|
||||
aligned_src,
|
||||
aligned_target,
|
||||
gap_char,
|
||||
start,
|
||||
end,
|
||||
matching_chars_count,
|
||||
matching_words_count,
|
||||
matching_alnum_words_count,
|
||||
)
|
||||
|
||||
align_stats["matching_chars"] = matching_chars_count
|
||||
align_stats["matching_alnum_words"] = matching_alnum_words_count
|
||||
|
@ -248,11 +335,17 @@ def get_editops_stats(alignment, gap_char):
|
|||
aligned_src, aligned_target = alignment
|
||||
if aligned_src == "" or aligned_target == "":
|
||||
raise ValueError("one of the input strings is empty")
|
||||
stats = {"edit_insert": 0, "edit_delete": 0, "edit_replace": 0,
|
||||
"edit_insert_spacing": 0, "edit_delete_spacing": 0}
|
||||
stats = {
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 0,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
}
|
||||
actions = {}
|
||||
for i, (char_1, char_2) in enumerate(zip(aligned_src, aligned_target)):
|
||||
if LOG_LEVEL > 1: _log(char_1, char_2)
|
||||
if LOG_LEVEL > 1:
|
||||
_log(char_1, char_2)
|
||||
if char_1 == gap_char:
|
||||
# insert
|
||||
if char_2 == " ":
|
||||
|
@ -266,12 +359,13 @@ def get_editops_stats(alignment, gap_char):
|
|||
stats["edit_delete_spacing"] += 1
|
||||
else:
|
||||
stats["edit_delete"] += 1
|
||||
actions[i] = ("D")
|
||||
actions[i] = "D"
|
||||
elif char_2 != char_1:
|
||||
stats["edit_replace"] += 1
|
||||
actions[i] = ("R", char_2)
|
||||
return stats, actions
|
||||
|
||||
|
||||
def get_align_stats(alignment, src_string, target, gap_char):
|
||||
"""Get alignment stats
|
||||
|
||||
|
@ -293,9 +387,11 @@ def get_align_stats(alignment, src_string, target, gap_char):
|
|||
_log("alignment results")
|
||||
_log(alignment)
|
||||
align_stats, substitution_dict = _get_align_stats(
|
||||
alignment, src_string, target, gap_char)
|
||||
alignment, src_string, target, gap_char
|
||||
)
|
||||
return align_stats, substitution_dict
|
||||
|
||||
|
||||
def get_stats(target, src_string):
|
||||
"""Get align stats, edit stats, and substitution mappings for transforming the
|
||||
source string to the target string. Edit stats refers to character level edit operation
|
||||
|
@ -311,16 +407,24 @@ def get_stats(target, src_string):
|
|||
Returns:
|
||||
tuple(str, str): One dict containing the edit and align stats, another dict containing the substitutions
|
||||
"""
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
|
||||
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(
|
||||
[src_string], [target]
|
||||
)
|
||||
gap_char = (
|
||||
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
)
|
||||
alignment = align_w_anchor(src_string, target, gap_char=gap_char)
|
||||
align_stats, substitution_dict = get_align_stats(alignment,src_string, target, gap_char)
|
||||
align_stats, substitution_dict = get_align_stats(
|
||||
alignment, src_string, target, gap_char
|
||||
)
|
||||
edit_stats, actions = get_editops_stats(alignment, gap_char)
|
||||
_log("alignment", align_stats)
|
||||
return {**edit_stats, **align_stats}, substitution_dict, actions
|
||||
|
||||
|
||||
def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocessing=True):
|
||||
def get_metrics(
|
||||
src_text_path, ocr_json_path, folder_hash=None, use_multiprocessing=True
|
||||
):
|
||||
"""Given a path to the folder containing the source text and a folder containing
|
||||
the output OCR json, this generates the metrics for all files in the source folder.
|
||||
This assumes that the files json folder are of the same name the text files except they
|
||||
|
@ -338,7 +442,6 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess
|
|||
filename and the values are dicts of the substition mappings for that file.
|
||||
"""
|
||||
|
||||
|
||||
rows = []
|
||||
substitutions = {}
|
||||
actions_map = {}
|
||||
|
@ -347,16 +450,25 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess
|
|||
cpu_count = multiprocessing.cpu_count()
|
||||
n_workers = WORKERS_PER_CPU * cpu_count
|
||||
|
||||
job_args = list(map(lambda f: (f, src_text_path, ocr_json_path, folder_hash) , os.listdir(src_text_path)))
|
||||
job_args = list(
|
||||
map(
|
||||
lambda f: (f, src_text_path, ocr_json_path, folder_hash),
|
||||
os.listdir(src_text_path),
|
||||
)
|
||||
)
|
||||
|
||||
if use_multiprocessing:
|
||||
with Pool(n_workers) as pool:
|
||||
for f, stats, actions, subs in tqdm(pool.imap_unordered(_worker, job_args), total=len(job_args)):
|
||||
for f, stats, actions, subs in tqdm(
|
||||
pool.imap_unordered(_worker, job_args), total=len(job_args)
|
||||
):
|
||||
substitutions[f] = subs
|
||||
actions_map[f] = actions
|
||||
rows.append(stats)
|
||||
else:
|
||||
for f, stats, actions, subs in tqdm(map(_worker, job_args), total=len(job_args)):
|
||||
for f, stats, actions, subs in tqdm(
|
||||
map(_worker, job_args), total=len(job_args)
|
||||
):
|
||||
substitutions[f] = subs
|
||||
actions_map[f] = actions
|
||||
rows.append(stats)
|
||||
|
@ -364,16 +476,17 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess
|
|||
df = pd.DataFrame(rows)
|
||||
return df, substitutions, actions_map
|
||||
|
||||
|
||||
def get_file_metrics(f, src_text_path, ocr_json_path, folder_hash):
|
||||
src_filename = os.path.join(src_text_path, f)
|
||||
if folder_hash:
|
||||
ocr_filename = os.path.join(
|
||||
ocr_json_path, f"{folder_hash}_{f.split('txt')[0] + 'json'}")
|
||||
ocr_json_path, f"{folder_hash}_{f.split('txt')[0] + 'json'}"
|
||||
)
|
||||
else:
|
||||
ocr_filename = os.path.join(
|
||||
ocr_json_path, f"{f.split('txt')[0] + 'json'}")
|
||||
ocr_filename = os.path.join(ocr_json_path, f"{f.split('txt')[0] + 'json'}")
|
||||
try:
|
||||
src_string = open(src_filename, "r", errors='ignore', encoding="utf8").read()
|
||||
src_string = open(src_filename, "r", errors="ignore", encoding="utf8").read()
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {src_filename}, skipping this file.")
|
||||
return f, {}, {}, {}
|
||||
|
@ -395,10 +508,12 @@ def get_file_metrics(f, src_text_path, ocr_json_path, folder_hash):
|
|||
stats["filename"] = f
|
||||
return f, stats, actions, subs
|
||||
|
||||
|
||||
def _worker(args):
|
||||
(f, src_text_path, ocr_json_path, folder_hash) = args
|
||||
return get_file_metrics(f, src_text_path, ocr_json_path, folder_hash)
|
||||
|
||||
|
||||
def _get_sorted_text(ocr_json):
|
||||
if "lines" in ocr_json[0]:
|
||||
lines = ocr_json[0]["lines"]
|
||||
|
@ -407,22 +522,27 @@ def _get_sorted_text(ocr_json):
|
|||
else:
|
||||
return ocr_json[0]["text"]
|
||||
|
||||
|
||||
def substitution_dict_to_json(substitution_dict):
|
||||
"""Converts substitution dict to list of tuples of (source_substring, target_substring, count)
|
||||
|
||||
Args:
|
||||
substitution_dict ([type]): [description]
|
||||
"""
|
||||
to_tuple = lambda x: [(k +(x[k],)) for k in x]
|
||||
to_tuple = lambda x: [(k + (x[k],)) for k in x] # noqa: E731
|
||||
out = {}
|
||||
for filename in substitution_dict:
|
||||
out[filename] = to_tuple(substitution_dict[filename])
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("src", help="path to folder with text files.")
|
||||
parser.add_argument("ocr", help="folder with ocr json. the filename must match the text filename prefixed by ocr_prefix.")
|
||||
parser.add_argument(
|
||||
"ocr",
|
||||
help="folder with ocr json. the filename must match the text filename prefixed by ocr_prefix.",
|
||||
)
|
||||
parser.add_argument("--ocr_prefix", help="the prefix of the ocr files")
|
||||
parser.add_argument("--output", help="output names of metrics files")
|
||||
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
"""Uses the REST api to perform operations on the search service.
|
||||
see: https://docs.microsoft.com/en-us/rest/api/searchservice/
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
import pkgutil
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
import time
|
||||
import sys
|
||||
import time
|
||||
from itertools import cycle
|
||||
|
||||
import requests
|
||||
|
||||
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
|
||||
|
||||
API_VERSION = '?api-version=2019-05-06-Preview'
|
||||
API_VERSION = "?api-version=2019-05-06-Preview"
|
||||
|
||||
# 15 min schedule
|
||||
SCHEDULE_INTERVAL = "PT15M"
|
||||
|
@ -25,9 +26,20 @@ class GrokRestClient:
|
|||
ongoing indexers. The indexing pipeline can allow you to run batch OCR enrichment of documents.
|
||||
"""
|
||||
|
||||
def __init__(self, cognitive_service_key, search_service_key, search_service_name, skillset_name,
|
||||
index_name, indexer_name, datasource_name, datasource_container_name, blob_account_name, blob_key,
|
||||
projections_container_name = DEFAULT_PROJECTIONS_CONTAINER_NAME):
|
||||
def __init__(
|
||||
self,
|
||||
cognitive_service_key,
|
||||
search_service_key,
|
||||
search_service_name,
|
||||
skillset_name,
|
||||
index_name,
|
||||
indexer_name,
|
||||
datasource_name,
|
||||
datasource_container_name,
|
||||
blob_account_name,
|
||||
blob_key,
|
||||
projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME,
|
||||
):
|
||||
"""Creates the REST client
|
||||
|
||||
Args:
|
||||
|
@ -70,8 +82,10 @@ class GrokRestClient:
|
|||
|
||||
self.API_VERSION = API_VERSION
|
||||
|
||||
self.BLOB_CONNECTION_STRING = f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" \
|
||||
self.BLOB_CONNECTION_STRING = (
|
||||
f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};"
|
||||
f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_from_env_var():
|
||||
|
@ -85,51 +99,68 @@ class GrokRestClient:
|
|||
DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"]
|
||||
BLOB_NAME = os.environ["BLOB_NAME"]
|
||||
BLOB_KEY = os.environ["BLOB_KEY"]
|
||||
PROJECTIONS_CONTAINER_NAME = os.environ.get("PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME)
|
||||
PROJECTIONS_CONTAINER_NAME = os.environ.get(
|
||||
"PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME
|
||||
)
|
||||
|
||||
client = GrokRestClient(COGNITIVE_SERVICE_KEY, SEARCH_SERVICE_KEY, SEARCH_SERVICE_NAME, SKILLSET_NAME, INDEX_NAME,
|
||||
INDEXER_NAME, DATASOURCE_NAME, DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY,projections_container_name=PROJECTIONS_CONTAINER_NAME)
|
||||
client = GrokRestClient(
|
||||
COGNITIVE_SERVICE_KEY,
|
||||
SEARCH_SERVICE_KEY,
|
||||
SEARCH_SERVICE_NAME,
|
||||
SKILLSET_NAME,
|
||||
INDEX_NAME,
|
||||
INDEXER_NAME,
|
||||
DATASOURCE_NAME,
|
||||
DATASOURCE_CONTAINER_NAME,
|
||||
BLOB_NAME,
|
||||
BLOB_KEY,
|
||||
projections_container_name=PROJECTIONS_CONTAINER_NAME,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def create_skillset(self):
|
||||
"""Adds a skillset that performs OCR on images
|
||||
"""
|
||||
"""Adds a skillset that performs OCR on images"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
skillset_json = json.loads(pkgutil.get_data(
|
||||
__name__, "templates/skillset.json"))
|
||||
skillset_json = json.loads(
|
||||
pkgutil.get_data(__name__, "templates/skillset.json")
|
||||
)
|
||||
|
||||
skillset_json["name"] = self.SKILLSET_NAME
|
||||
skillset_json["cognitiveServices"]["key"] = self.COGNITIVE_SERVICE_KEY
|
||||
|
||||
knowledge_store_json = json.loads(pkgutil.get_data(
|
||||
__name__, "templates/knowledge_store.json"))
|
||||
knowledge_store_json = json.loads(
|
||||
pkgutil.get_data(__name__, "templates/knowledge_store.json")
|
||||
)
|
||||
knowledge_store_json["storageConnectionString"] = self.BLOB_CONNECTION_STRING
|
||||
knowledge_store_json["projections"][0]["objects"][0]["storageContainer"] = self.PROJECTIONS_CONTAINER_NAME
|
||||
knowledge_store_json["projections"][0]["objects"][0][
|
||||
"storageContainer"
|
||||
] = self.PROJECTIONS_CONTAINER_NAME
|
||||
skillset_json["knowledgeStore"] = knowledge_store_json
|
||||
print(skillset_json)
|
||||
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}"
|
||||
|
||||
r = requests.put(endpoint + self.API_VERSION,
|
||||
json.dumps(skillset_json), headers=headers)
|
||||
r = requests.put(
|
||||
endpoint + self.API_VERSION, json.dumps(skillset_json), headers=headers
|
||||
)
|
||||
print("skillset response", r.text)
|
||||
r.raise_for_status()
|
||||
print("added skillset", self.SKILLSET_NAME, r)
|
||||
|
||||
def create_datasource(self):
|
||||
"""Attaches the blob data store to the search service as a source for image documents
|
||||
"""
|
||||
"""Attaches the blob data store to the search service as a source for image documents"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
|
||||
datasource_json = json.loads(pkgutil.get_data(
|
||||
__name__, "templates/datasource.json"))
|
||||
datasource_json = json.loads(
|
||||
pkgutil.get_data(__name__, "templates/datasource.json")
|
||||
)
|
||||
datasource_json["name"] = self.DATASOURCE_NAME
|
||||
datasource_json["credentials"]["connectionString"] = self.BLOB_CONNECTION_STRING
|
||||
datasource_json["type"] = "azureblob"
|
||||
|
@ -137,27 +168,27 @@ class GrokRestClient:
|
|||
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/datasources/{self.DATASOURCE_NAME}"
|
||||
|
||||
r = requests.put(endpoint + self.API_VERSION,
|
||||
json.dumps(datasource_json), headers=headers)
|
||||
r = requests.put(
|
||||
endpoint + self.API_VERSION, json.dumps(datasource_json), headers=headers
|
||||
)
|
||||
print("datasource response", r.text)
|
||||
r.raise_for_status()
|
||||
print("added datasource", self.DATASOURCE_NAME, r)
|
||||
|
||||
def create_index(self):
|
||||
"""Create an index with the layoutText column to store OCR output from the enrichment
|
||||
"""
|
||||
"""Create an index with the layoutText column to store OCR output from the enrichment"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
index_json = json.loads(pkgutil.get_data(
|
||||
__name__, "templates/index.json"))
|
||||
index_json = json.loads(pkgutil.get_data(__name__, "templates/index.json"))
|
||||
index_json["name"] = self.INDEX_NAME
|
||||
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexes/{self.INDEX_NAME}"
|
||||
|
||||
r = requests.put(endpoint + self.API_VERSION,
|
||||
json.dumps(index_json), headers=headers)
|
||||
r = requests.put(
|
||||
endpoint + self.API_VERSION, json.dumps(index_json), headers=headers
|
||||
)
|
||||
print("index response", r.text)
|
||||
r.raise_for_status()
|
||||
print("created index", self.INDEX_NAME, r)
|
||||
|
@ -167,24 +198,26 @@ class GrokRestClient:
|
|||
The enriched results are pushed to the index.
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
|
||||
indexer_json = json.loads(pkgutil.get_data(
|
||||
__name__, "templates/indexer.json"))
|
||||
indexer_json = json.loads(pkgutil.get_data(__name__, "templates/indexer.json"))
|
||||
|
||||
indexer_json["name"] = self.INDEXER_NAME
|
||||
indexer_json["skillsetName"] = self.SKILLSET_NAME
|
||||
indexer_json["targetIndexName"] = self.INDEX_NAME
|
||||
indexer_json["dataSourceName"] = self.DATASOURCE_NAME
|
||||
indexer_json["schedule"] = {"interval": SCHEDULE_INTERVAL}
|
||||
indexer_json["parameters"]["configuration"]["excludedFileNameExtensions"] = extension_to_exclude
|
||||
indexer_json["parameters"]["configuration"][
|
||||
"excludedFileNameExtensions"
|
||||
] = extension_to_exclude
|
||||
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}"
|
||||
|
||||
r = requests.put(endpoint + self.API_VERSION,
|
||||
json.dumps(indexer_json), headers=headers)
|
||||
r = requests.put(
|
||||
endpoint + self.API_VERSION, json.dumps(indexer_json), headers=headers
|
||||
)
|
||||
print("indexer response", r.text)
|
||||
r.raise_for_status()
|
||||
print("created indexer", self.INDEXER_NAME, r)
|
||||
|
@ -203,14 +236,14 @@ class GrokRestClient:
|
|||
created
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
endpoints = [
|
||||
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}",
|
||||
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexes/{self.INDEX_NAME}",
|
||||
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/datasources/{self.DATASOURCE_NAME}",
|
||||
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}"
|
||||
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}",
|
||||
]
|
||||
|
||||
for endpoint in endpoints:
|
||||
|
@ -220,8 +253,8 @@ class GrokRestClient:
|
|||
|
||||
def run_indexer(self):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}/run"
|
||||
|
@ -239,7 +272,10 @@ class GrokRestClient:
|
|||
request_json = self.get_indexer_status()
|
||||
if request_json["status"] == "error":
|
||||
raise RuntimeError("Indexer failed")
|
||||
if request_json["lastResult"] and not request_json["lastResult"]["status"] == "inProgress":
|
||||
if (
|
||||
request_json["lastResult"]
|
||||
and not request_json["lastResult"]["status"] == "inProgress"
|
||||
):
|
||||
print(request_json["lastResult"]["status"], self.INDEXER_NAME)
|
||||
return request_json
|
||||
|
||||
|
@ -250,8 +286,8 @@ class GrokRestClient:
|
|||
|
||||
def get_indexer_status(self):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}/status"
|
||||
response = requests.get(endpoint + self.API_VERSION, headers=headers)
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.generation.document import DEFAULT_STYLE_COMBINATION
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.degradation.degrader import Degrader, ImageState
|
||||
from json import JSONEncoder
|
||||
from tqdm import tqdm
|
||||
|
||||
import concurrent.futures
|
||||
import timeit
|
||||
import cv2
|
||||
import os
|
||||
import timeit
|
||||
from json import JSONEncoder
|
||||
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
|
||||
from genalog.degradation.degrader import Degrader, ImageState
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.generation.document import DEFAULT_STYLE_COMBINATION
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
|
||||
|
||||
class ImageStateEncoder(JSONEncoder):
|
||||
def default(self, obj):
|
||||
|
@ -16,8 +19,15 @@ class ImageStateEncoder(JSONEncoder):
|
|||
return obj.value
|
||||
return JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
class AnalogDocumentGeneration(object):
|
||||
def __init__(self, template_path=None, styles=DEFAULT_STYLE_COMBINATION, degradations=[], resolution=300):
|
||||
def __init__(
|
||||
self,
|
||||
template_path=None,
|
||||
styles=DEFAULT_STYLE_COMBINATION,
|
||||
degradations=[],
|
||||
resolution=300,
|
||||
):
|
||||
self.doc_generator = DocumentGenerator(template_path=template_path)
|
||||
self.doc_generator.set_styles_to_generate(styles)
|
||||
self.degrader = Degrader(degradations)
|
||||
|
@ -67,45 +77,71 @@ class AnalogDocumentGeneration(object):
|
|||
cv2.imwrite(img_dst_path, dst)
|
||||
return
|
||||
|
||||
|
||||
def _divide_batches(a, batch_size):
|
||||
for i in range(0, len(a), batch_size):
|
||||
yield a[i: i + batch_size]
|
||||
|
||||
|
||||
def _setup_folder(output_folder):
|
||||
os.makedirs(os.path.join(output_folder, "img"), exist_ok=True)
|
||||
|
||||
|
||||
def batch_img_generate(args):
|
||||
input_files, output_folder, styles, degradations, template, resolution = args
|
||||
generator = AnalogDocumentGeneration(styles=styles, degradations=degradations, resolution=resolution)
|
||||
generator = AnalogDocumentGeneration(
|
||||
styles=styles, degradations=degradations, resolution=resolution
|
||||
)
|
||||
for file in input_files:
|
||||
generator.generate_img(file, template, target_folder=output_folder)
|
||||
|
||||
def _set_batch_generate_args(file_batches, output_folder, styles, degradations, template, resolution):
|
||||
return list(map(
|
||||
lambda batch:
|
||||
(batch, output_folder, styles, degradations, template, resolution),
|
||||
file_batches
|
||||
))
|
||||
|
||||
def _set_batch_generate_args(
|
||||
file_batches, output_folder, styles, degradations, template, resolution
|
||||
):
|
||||
return list(
|
||||
map(
|
||||
lambda batch: (
|
||||
batch,
|
||||
output_folder,
|
||||
styles,
|
||||
degradations,
|
||||
template,
|
||||
resolution,
|
||||
),
|
||||
file_batches,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def generate_dataset_multiprocess(
|
||||
input_text_files, output_folder, styles, degradations, template,
|
||||
resolution=300, batch_size=25
|
||||
input_text_files,
|
||||
output_folder,
|
||||
styles,
|
||||
degradations,
|
||||
template,
|
||||
resolution=300,
|
||||
batch_size=25,
|
||||
):
|
||||
_setup_folder(output_folder)
|
||||
print(f"Storing generated images in {output_folder}")
|
||||
|
||||
batches = list(_divide_batches(input_text_files, batch_size))
|
||||
print(f"Splitting {len(input_text_files)} documents into {len(batches)} batches with size {batch_size}")
|
||||
print(
|
||||
f"Splitting {len(input_text_files)} documents into {len(batches)} batches with size {batch_size}"
|
||||
)
|
||||
|
||||
batch_img_generate_args = _set_batch_generate_args(batches, output_folder, styles, degradations, template, resolution)
|
||||
batch_img_generate_args = _set_batch_generate_args(
|
||||
batches, output_folder, styles, degradations, template, resolution
|
||||
)
|
||||
|
||||
# Default to the number of processors on the machine
|
||||
start_time = timeit.default_timer()
|
||||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||
batch_iterator = executor.map(batch_img_generate, batch_img_generate_args)
|
||||
for _ in tqdm(batch_iterator, total=len(batch_img_generate_args)): # wrapping tqdm for progress report
|
||||
for _ in tqdm(
|
||||
batch_iterator, total=len(batch_img_generate_args)
|
||||
): # wrapping tqdm for progress report
|
||||
pass
|
||||
elapsed = timeit.default_timer() - start_time
|
||||
print(f"Time to generate {len(input_text_files)} documents: {elapsed:.3f} sec")
|
||||
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
|
|
@ -1,25 +1,36 @@
|
|||
from genalog.text.preprocess import _is_spacing, tokenize
|
||||
from Bio import pairwise2
|
||||
import re
|
||||
|
||||
from Bio import pairwise2
|
||||
|
||||
from genalog.text.preprocess import _is_spacing, tokenize
|
||||
|
||||
# Configuration params for global sequence alignment algorithm (Needleman-Wunsch)
|
||||
MATCH_REWARD = 1
|
||||
GAP_PENALTY = -0.5
|
||||
GAP_EXT_PENALTY = -0.5
|
||||
MISMATCH_PENALTY = -0.5
|
||||
GAP_CHAR = '@'
|
||||
GAP_CHAR = "@"
|
||||
ONE_ALIGNMENT_ONLY = False
|
||||
SPACE_MISMATCH_PENALTY = .1
|
||||
SPACE_MISMATCH_PENALTY = 0.1
|
||||
|
||||
|
||||
def _join_char_list(alignment_tuple):
|
||||
""" Post-process alignment results for unicode support """
|
||||
gt_char_list, noise_char_list, score, start, end = alignment_tuple
|
||||
return "".join(gt_char_list), "".join(noise_char_list), score, start, end
|
||||
|
||||
def _align_seg(gt, noise,
|
||||
match_reward=MATCH_REWARD, mismatch_pen=MISMATCH_PENALTY,
|
||||
gap_pen=GAP_PENALTY, gap_ext_pen=GAP_EXT_PENALTY, space_mismatch_penalty=SPACE_MISMATCH_PENALTY,
|
||||
gap_char=GAP_CHAR, one_alignment_only=ONE_ALIGNMENT_ONLY):
|
||||
|
||||
def _align_seg(
|
||||
gt,
|
||||
noise,
|
||||
match_reward=MATCH_REWARD,
|
||||
mismatch_pen=MISMATCH_PENALTY,
|
||||
gap_pen=GAP_PENALTY,
|
||||
gap_ext_pen=GAP_EXT_PENALTY,
|
||||
space_mismatch_penalty=SPACE_MISMATCH_PENALTY,
|
||||
gap_char=GAP_CHAR,
|
||||
one_alignment_only=ONE_ALIGNMENT_ONLY,
|
||||
):
|
||||
"""Wrapper function for Bio.pairwise2.align.globalms(), which
|
||||
calls the sequence alignment algorithm (Needleman-Wunsch)
|
||||
|
||||
|
@ -47,6 +58,7 @@ def _align_seg(gt, noise,
|
|||
...
|
||||
]
|
||||
"""
|
||||
|
||||
def match_reward_fn(x, y):
|
||||
if x == y:
|
||||
return match_reward
|
||||
|
@ -55,12 +67,21 @@ def _align_seg(gt, noise,
|
|||
return mismatch_pen - space_mismatch_penalty
|
||||
else:
|
||||
return mismatch_pen
|
||||
|
||||
# NOTE: Work-around to enable support full Unicode character set - passing string as a list of characters
|
||||
alignments = pairwise2.align.globalcs(list(gt), list(noise), match_reward_fn,
|
||||
gap_pen, gap_ext_pen, gap_char=[gap_char], one_alignment_only=ONE_ALIGNMENT_ONLY)
|
||||
alignments = pairwise2.align.globalcs(
|
||||
list(gt),
|
||||
list(noise),
|
||||
match_reward_fn,
|
||||
gap_pen,
|
||||
gap_ext_pen,
|
||||
gap_char=[gap_char],
|
||||
one_alignment_only=ONE_ALIGNMENT_ONLY,
|
||||
)
|
||||
# Alignment result is a list of char instead of string because of the work-around
|
||||
return list(map(_join_char_list, alignments))
|
||||
|
||||
|
||||
def _select_alignment_candidates(alignments, target_num_gt_tokens):
|
||||
"""Return an alignment that contains the desired number
|
||||
of ground truth tokens from a list of possible alignments
|
||||
|
@ -111,11 +132,16 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens):
|
|||
if num_aligned_gt_tokens == target_num_gt_tokens:
|
||||
# Invariant 1
|
||||
if len(aligned_gt) != len(aligned_noise):
|
||||
raise ValueError(f"Aligned strings are not equal in length: \naligned_gt: '{aligned_gt}'\naligned_noise '{aligned_noise}'\n")
|
||||
raise ValueError(
|
||||
f"Aligned strings are not equal in length: \naligned_gt: '{aligned_gt}'\naligned_noise '{aligned_noise}'\n"
|
||||
)
|
||||
# Returns the FIRST candidate that satisfies the invariant
|
||||
return alignment
|
||||
|
||||
raise ValueError(f"No alignment candidates with {target_num_gt_tokens} tokens. Total candidates: {len(alignments)}")
|
||||
raise ValueError(
|
||||
f"No alignment candidates with {target_num_gt_tokens} tokens. Total candidates: {len(alignments)}"
|
||||
)
|
||||
|
||||
|
||||
def align(gt, noise, gap_char=GAP_CHAR):
|
||||
"""Align two text segments via sequence alignment algorithm
|
||||
|
@ -144,7 +170,7 @@ def align(gt, noise, gap_char=GAP_CHAR):
|
|||
|
||||
"""
|
||||
if not gt and not noise: # Both inputs are empty string
|
||||
return '', ''
|
||||
return "", ""
|
||||
elif not gt: # Either is empty
|
||||
return gap_char * len(noise), noise
|
||||
elif not noise:
|
||||
|
@ -153,11 +179,16 @@ def align(gt, noise, gap_char=GAP_CHAR):
|
|||
num_gt_tokens = len(tokenize(gt))
|
||||
alignments = _align_seg(gt, noise, gap_char=gap_char)
|
||||
try:
|
||||
aligned_gt, aligned_noise, _, _, _ = _select_alignment_candidates(alignments, num_gt_tokens)
|
||||
aligned_gt, aligned_noise, _, _, _ = _select_alignment_candidates(
|
||||
alignments, num_gt_tokens
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error with input strings '{gt}' and '{noise}': \n{str(e)}")
|
||||
raise ValueError(
|
||||
f"Error with input strings '{gt}' and '{noise}': \n{str(e)}"
|
||||
)
|
||||
return aligned_gt, aligned_noise
|
||||
|
||||
|
||||
def _format_alignment(align1, align2):
|
||||
"""Wrapper function for Bio.pairwise2.format_alignment()
|
||||
|
||||
|
@ -178,11 +209,14 @@ def _format_alignment(align1, align2):
|
|||
New Yerk@is big.
|
||||
"
|
||||
"""
|
||||
formatted_str = pairwise2.format_alignment(align1, align2, 0, 0, len(align1), full_sequences=True)
|
||||
formatted_str = pairwise2.format_alignment(
|
||||
align1, align2, 0, 0, len(align1), full_sequences=True
|
||||
)
|
||||
# Remove the "Score=0" from the str
|
||||
formatted_str_no_score = formatted_str.replace("\n Score=0", "")
|
||||
return formatted_str_no_score
|
||||
|
||||
|
||||
def _find_token_start(s, index):
|
||||
"""Find the position of the start of token
|
||||
|
||||
|
@ -198,13 +232,16 @@ def _find_token_start(s, index):
|
|||
IndexError: if is out-of-bound index
|
||||
"""
|
||||
max_index = len(s) - 1
|
||||
if len(s) == 0: raise ValueError("Cannot search in an empty string")
|
||||
if index > max_index: raise IndexError(f"Out-of-bound index: {index} in string: {s}")
|
||||
if len(s) == 0:
|
||||
raise ValueError("Cannot search in an empty string")
|
||||
if index > max_index:
|
||||
raise IndexError(f"Out-of-bound index: {index} in string: {s}")
|
||||
|
||||
while index < max_index and _is_spacing(s[index]):
|
||||
index += 1
|
||||
return index
|
||||
|
||||
|
||||
def _find_token_end(s, index):
|
||||
"""Find the position of the end of a token
|
||||
|
||||
|
@ -224,13 +261,16 @@ def _find_token_end(s, index):
|
|||
IndexError: if is out-of-bound index
|
||||
"""
|
||||
max_index = len(s) - 1
|
||||
if len(s) == 0: raise ValueError("Cannot search in an empty string")
|
||||
if index > max_index: raise IndexError(f"Out-of-bound index: {index} in string: {s}")
|
||||
if len(s) == 0:
|
||||
raise ValueError("Cannot search in an empty string")
|
||||
if index > max_index:
|
||||
raise IndexError(f"Out-of-bound index: {index} in string: {s}")
|
||||
|
||||
while index < max_index and not _is_spacing(s[index]):
|
||||
index += 1
|
||||
return index
|
||||
|
||||
|
||||
def _find_next_token(s, start):
|
||||
"""Return the start and end index of a token in a string
|
||||
|
||||
|
@ -250,6 +290,7 @@ def _find_next_token(s, start):
|
|||
token_end = _find_token_end(s, token_start)
|
||||
return token_start, token_end
|
||||
|
||||
|
||||
def _is_valid_token(token, gap_char=GAP_CHAR):
|
||||
"""Returns true if token is valid (i.e. compose of non-gap characters)
|
||||
Invalid tokens are
|
||||
|
@ -269,9 +310,12 @@ def _is_valid_token(token, gap_char=GAP_CHAR):
|
|||
bool-- True if is a valid token, false otherwise
|
||||
"""
|
||||
# Matches multiples of 'gap_char' that are padded with whitespace characters on either end
|
||||
INVALID_TOKEN_REGEX = rf'^\s*{re.escape(gap_char)}*\s*$' # Escape special regex chars
|
||||
INVALID_TOKEN_REGEX = (
|
||||
rf"^\s*{re.escape(gap_char)}*\s*$" # Escape special regex chars
|
||||
)
|
||||
return not re.match(INVALID_TOKEN_REGEX, token)
|
||||
|
||||
|
||||
def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
||||
"""Parse alignment to pair ground truth tokens with noise tokens
|
||||
|
||||
|
@ -352,8 +396,8 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
total_noise_tokens = len(tokenize(aligned_noise))
|
||||
|
||||
# Initialization
|
||||
aligned_gt += ' ' # add whitespace padding to prevent ptr overflow
|
||||
aligned_noise += ' ' # add whitespace padding to prevent ptr overflow
|
||||
aligned_gt += " " # add whitespace padding to prevent ptr overflow
|
||||
aligned_noise += " " # add whitespace padding to prevent ptr overflow
|
||||
tk_index_gt = tk_index_noise = 0
|
||||
tk_start_gt, tk_end_gt = _find_next_token(aligned_gt, 0)
|
||||
tk_start_noise, tk_end_noise = _find_next_token(aligned_noise, 0)
|
||||
|
@ -364,8 +408,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
# If both tokens are aligned (one-to-one case)
|
||||
if tk_end_gt == tk_end_noise:
|
||||
# if both gt_token and noise_token are valid (missing token case)
|
||||
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
|
||||
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
|
||||
if _is_valid_token(
|
||||
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
|
||||
) and _is_valid_token(
|
||||
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
|
||||
):
|
||||
# register the index of these tokens in the gt_to_noise_mapping
|
||||
index_row = gt_to_noise_mapping[tk_index_gt]
|
||||
index_row.append(tk_index_noise)
|
||||
|
@ -381,8 +428,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
elif tk_end_gt < tk_end_noise:
|
||||
while tk_end_gt < tk_end_noise:
|
||||
# if both gt_token and noise_token are valid (missing token case)
|
||||
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
|
||||
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
|
||||
if _is_valid_token(
|
||||
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
|
||||
) and _is_valid_token(
|
||||
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
|
||||
):
|
||||
# register the index of these tokens in the gt_to_noise_mapping
|
||||
index_row = gt_to_noise_mapping[tk_index_gt]
|
||||
index_row.append(tk_index_noise)
|
||||
|
@ -397,8 +447,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
else:
|
||||
while tk_end_gt > tk_end_noise:
|
||||
# if both gt_token and noise_token are valid (missing token case)
|
||||
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
|
||||
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
|
||||
if _is_valid_token(
|
||||
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
|
||||
) and _is_valid_token(
|
||||
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
|
||||
):
|
||||
# register the index of these token in the gt_to_noise mapping
|
||||
index_row = gt_to_noise_mapping[tk_index_gt]
|
||||
index_row.append(tk_index_noise)
|
||||
|
@ -406,7 +459,9 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
index_row = noise_to_gt_mapping[tk_index_noise]
|
||||
index_row.append(tk_index_gt)
|
||||
# Find the next gt_token
|
||||
tk_start_noise, tk_end_noise = _find_next_token(aligned_noise, tk_end_noise)
|
||||
tk_start_noise, tk_end_noise = _find_next_token(
|
||||
aligned_noise, tk_end_noise
|
||||
)
|
||||
# Increment index
|
||||
tk_index_noise += 1
|
||||
|
||||
|
|
|
@ -14,15 +14,17 @@
|
|||
"""
|
||||
import itertools
|
||||
from collections import Counter
|
||||
from genalog.text import preprocess, alignment
|
||||
from genalog.text.lcs import LCS
|
||||
|
||||
from genalog.text import alignment, preprocess
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
from genalog.text.lcs import LCS
|
||||
|
||||
# The recursively portion of the algorithm will run on
|
||||
# segments longer than this value to find anchor points in
|
||||
# the longer segment (to break it up further).
|
||||
MAX_ALIGN_SEGMENT_LENGTH = 100 # in characters length
|
||||
|
||||
|
||||
def get_unique_words(tokens, case_sensitive=False):
|
||||
"""Get a set of unique words from a Counter dictionary of word occurrences
|
||||
|
||||
|
@ -44,6 +46,7 @@ def get_unique_words(tokens, case_sensitive=False):
|
|||
word_count = Counter(tokens_lowercase)
|
||||
return {tk for tk in tokens if word_count[tk.lower()] < 2}
|
||||
|
||||
|
||||
def segment_len(tokens):
|
||||
"""Get length of the segment
|
||||
|
||||
|
@ -54,6 +57,7 @@ def segment_len(tokens):
|
|||
"""
|
||||
return sum(map(len, tokens))
|
||||
|
||||
|
||||
def get_word_map(unique_words, src_tokens):
|
||||
"""Arrange the set of unique words by the order they original appear in the text
|
||||
|
||||
|
@ -73,6 +77,7 @@ def get_word_map(unique_words, src_tokens):
|
|||
word_map.sort(key=lambda x: x[1]) # Re-arrange order by the index
|
||||
return word_map
|
||||
|
||||
|
||||
def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
|
||||
"""Find the location of anchor words in both the gt and ocr text.
|
||||
Anchor words are location where we can split both the source gt
|
||||
|
@ -129,18 +134,29 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
|
|||
anchor_words = lcs_words.intersection(unique_words_common)
|
||||
|
||||
# 6. Filter the unique words to keep the anchor words ONLY
|
||||
anchor_map_gt = list(filter(
|
||||
anchor_map_gt = list(
|
||||
filter(
|
||||
# This is a list of (unique_word, unique_word_index)
|
||||
lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_gt
|
||||
))
|
||||
anchor_map_ocr = list(filter(
|
||||
lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_ocr
|
||||
))
|
||||
lambda word_coordinate: word_coordinate[0] in anchor_words,
|
||||
unique_word_map_gt,
|
||||
)
|
||||
)
|
||||
anchor_map_ocr = list(
|
||||
filter(
|
||||
lambda word_coordinate: word_coordinate[0] in anchor_words,
|
||||
unique_word_map_ocr,
|
||||
)
|
||||
)
|
||||
return anchor_map_gt, anchor_map_ocr
|
||||
|
||||
def find_anchor_recur(gt_tokens, ocr_tokens,
|
||||
start_pos_gt=0, start_pos_ocr=0,
|
||||
max_seg_length=MAX_ALIGN_SEGMENT_LENGTH):
|
||||
|
||||
def find_anchor_recur(
|
||||
gt_tokens,
|
||||
ocr_tokens,
|
||||
start_pos_gt=0,
|
||||
start_pos_ocr=0,
|
||||
max_seg_length=MAX_ALIGN_SEGMENT_LENGTH,
|
||||
):
|
||||
"""Recursively find anchor positions in the gt and ocr text
|
||||
|
||||
Arguments:
|
||||
|
@ -186,13 +202,22 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
|
|||
gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt]
|
||||
ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr]
|
||||
# 4. Loop through each segment
|
||||
for gt_seg, ocr_seg, gt_start, ocr_start in zip(gt_segments, ocr_segments, seg_start_gt, seg_start_ocr):
|
||||
if segment_len(gt_seg) > max_seg_length or segment_len(ocr_seg) > max_seg_length:
|
||||
for gt_seg, ocr_seg, gt_start, ocr_start in zip(
|
||||
gt_segments, ocr_segments, seg_start_gt, seg_start_ocr
|
||||
):
|
||||
if (
|
||||
segment_len(gt_seg) > max_seg_length
|
||||
or segment_len(ocr_seg) > max_seg_length
|
||||
):
|
||||
# recur on the segment in between the two anchors.
|
||||
# We assume the first token in the segment is an anchor word
|
||||
gt_anchors, ocr_anchors = find_anchor_recur(gt_seg[1:], ocr_seg[1:],
|
||||
start_pos_gt=gt_start + 1, start_pos_ocr=ocr_start + 1,
|
||||
max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = find_anchor_recur(
|
||||
gt_seg[1:],
|
||||
ocr_seg[1:],
|
||||
start_pos_gt=gt_start + 1,
|
||||
start_pos_ocr=ocr_start + 1,
|
||||
max_seg_length=max_seg_length,
|
||||
)
|
||||
# shift the token indices
|
||||
# (these are indices of a subsequence and does not reflect true position in the source sequence)
|
||||
gt_anchors = set(map(lambda x: x + start_pos_gt, gt_anchors))
|
||||
|
@ -203,6 +228,7 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
|
|||
|
||||
return sorted(output_gt_anchors), sorted(output_ocr_anchors)
|
||||
|
||||
|
||||
def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_LENGTH):
|
||||
"""A faster alignment scheme of two text segments. This method first
|
||||
breaks the strings into smaller segments with anchor words.
|
||||
|
@ -248,11 +274,17 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
|
|||
ocr_tokens = preprocess.tokenize(ocr)
|
||||
|
||||
# 1. Find anchor positions
|
||||
gt_anchors, ocr_anchors = find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = find_anchor_recur(
|
||||
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
|
||||
)
|
||||
|
||||
# 2. Split into segments
|
||||
start_n_end_gt = zip(itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None]))
|
||||
start_n_end_ocr = zip(itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None]))
|
||||
start_n_end_gt = zip(
|
||||
itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None])
|
||||
)
|
||||
start_n_end_ocr = zip(
|
||||
itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None])
|
||||
)
|
||||
gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt]
|
||||
ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr]
|
||||
|
||||
|
@ -263,13 +295,15 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
|
|||
gt_segment = preprocess.join_tokens(gt_segment)
|
||||
noisy_segment = preprocess.join_tokens(noisy_segment)
|
||||
# Run alignment algorithm
|
||||
aligned_seg_gt, aligned_seg_ocr = alignment.align(gt_segment, noisy_segment, gap_char=gap_char)
|
||||
aligned_seg_gt, aligned_seg_ocr = alignment.align(
|
||||
gt_segment, noisy_segment, gap_char=gap_char
|
||||
)
|
||||
if aligned_seg_gt and aligned_seg_ocr: # if not empty string ""
|
||||
aligned_segments_gt.append(aligned_seg_gt)
|
||||
aligned_segments_ocr.append(aligned_seg_ocr)
|
||||
|
||||
# Stitch all segments together
|
||||
aligned_gt = ' '.join(aligned_segments_gt)
|
||||
aligned_noise = ' '.join(aligned_segments_ocr)
|
||||
aligned_gt = " ".join(aligned_segments_gt)
|
||||
aligned_noise = " ".join(aligned_segments_ocr)
|
||||
|
||||
return aligned_gt, aligned_noise
|
|
@ -34,21 +34,22 @@ example usage
|
|||
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
|
||||
--train_subset
|
||||
"""
|
||||
import itertools
|
||||
import difflib
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import difflib
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import timeit
|
||||
import concurrent.futures
|
||||
|
||||
from tqdm import tqdm
|
||||
from genalog.text import ner_label, ner_label, alignment
|
||||
|
||||
from genalog.text import alignment, ner_label
|
||||
|
||||
EMPTY_SENTENCE_SENTINEL = "<<<<EMPTY_OCR_SENTENCE>>>>"
|
||||
EMPTY_SENTENCE_SENTINEL_NER_LABEL = "O"
|
||||
|
||||
|
||||
def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens):
|
||||
"""
|
||||
propagate_labels_sentences propagates clean labels for clean tokens to ocr tokens and splits ocr tokens into sentences
|
||||
|
@ -73,9 +74,18 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
|
|||
# Ensure equal number of tokens in both clean_tokens and clean_sentences
|
||||
merged_sentences = list(itertools.chain(*clean_sentences))
|
||||
if merged_sentences != clean_tokens:
|
||||
delta = "\n".join(difflib.unified_diff(merged_sentences, clean_tokens, fromfile='merged_clean_sentences', tofile="clean_tokens"))
|
||||
raise ValueError(f"Inconsistent tokens. " +
|
||||
f"Delta between clean_text and clean_labels:\n{delta}")
|
||||
delta = "\n".join(
|
||||
difflib.unified_diff(
|
||||
merged_sentences,
|
||||
clean_tokens,
|
||||
fromfile="merged_clean_sentences",
|
||||
tofile="clean_tokens",
|
||||
)
|
||||
)
|
||||
raise ValueError(
|
||||
"Inconsistent tokens. "
|
||||
+ f"Delta between clean_text and clean_labels:\n{delta}"
|
||||
)
|
||||
# Ensure that there's OCR result
|
||||
if len(ocr_tokens) == 0:
|
||||
raise ValueError("Empty OCR tokens.")
|
||||
|
@ -83,14 +93,18 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
|
|||
raise ValueError("Empty clean tokens.")
|
||||
|
||||
# 1. Propagate labels + alig
|
||||
ocr_labels, aligned_clean, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(clean_labels, clean_tokens, ocr_tokens)
|
||||
ocr_labels, aligned_clean, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(
|
||||
clean_labels, clean_tokens, ocr_tokens
|
||||
)
|
||||
|
||||
# 2. Parse alignment to get mapping
|
||||
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(aligned_clean, aligned_ocr, gap_char=gap_char)
|
||||
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(
|
||||
aligned_clean, aligned_ocr, gap_char=gap_char
|
||||
)
|
||||
|
||||
# 3. Find sentence breaks in clean text sentences
|
||||
gt_to_ocr_mapping_is_empty = [len(mapping) == 0 for mapping in gt_to_ocr_mapping]
|
||||
gt_to_ocr_mapping_is_empty_reverse = gt_to_ocr_mapping_is_empty[::-1]
|
||||
|
||||
sentence_index = []
|
||||
sentence_token_counts = 0
|
||||
for sentence in clean_sentences:
|
||||
|
@ -135,11 +149,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
|
|||
ocr_labels_sentences.append(ocr_sentence_labels)
|
||||
return ocr_text_sentences, ocr_labels_sentences
|
||||
|
||||
|
||||
def get_sentences_from_iob_format(iob_format_str):
|
||||
sentences = []
|
||||
sentence = []
|
||||
for line in iob_format_str:
|
||||
if line.strip() == '': # if line is empty (sentence separator)
|
||||
if line.strip() == "": # if line is empty (sentence separator)
|
||||
sentences.append(sentence)
|
||||
sentence = []
|
||||
else:
|
||||
|
@ -149,44 +164,78 @@ def get_sentences_from_iob_format(iob_format_str):
|
|||
# filter any empty sentences
|
||||
return list(filter(lambda sentence: len(sentence) > 0, sentences))
|
||||
|
||||
def propagate_labels_sentence_single_file(arg):
|
||||
clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext, input_filename = arg
|
||||
|
||||
clean_labels_file = os.path.join(clean_labels_dir, input_filename).replace(clean_label_ext, ".txt")
|
||||
def propagate_labels_sentence_single_file(arg):
|
||||
(
|
||||
clean_labels_dir,
|
||||
output_text_dir,
|
||||
output_labels_dir,
|
||||
clean_label_ext,
|
||||
input_filename,
|
||||
) = arg
|
||||
|
||||
clean_labels_file = os.path.join(clean_labels_dir, input_filename).replace(
|
||||
clean_label_ext, ".txt"
|
||||
)
|
||||
ocr_text_file = os.path.join(output_text_dir, input_filename)
|
||||
ocr_labels_file = os.path.join(output_labels_dir, input_filename)
|
||||
if not os.path.exists(clean_labels_file):
|
||||
print(f"Warning: missing clean label file '{clean_labels_file}'. Please check file corruption. Skipping this file index...")
|
||||
print(
|
||||
f"Warning: missing clean label file '{clean_labels_file}'. Please check file corruption. Skipping this file index..."
|
||||
)
|
||||
return
|
||||
elif not os.path.exists(ocr_text_file):
|
||||
print(f"Warning: missing ocr text file '{ocr_text_file}'. Please check file corruption. Skipping this file index...")
|
||||
print(
|
||||
f"Warning: missing ocr text file '{ocr_text_file}'. Please check file corruption. Skipping this file index..."
|
||||
)
|
||||
return
|
||||
else:
|
||||
with open(clean_labels_file, 'r', encoding='utf-8') as clf:
|
||||
with open(clean_labels_file, "r", encoding="utf-8") as clf:
|
||||
tokens_labels_str = clf.readlines()
|
||||
clean_tokens = [line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2]
|
||||
clean_labels = [line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2]
|
||||
clean_tokens = [
|
||||
line.split()[0].strip()
|
||||
for line in tokens_labels_str
|
||||
if len(line.split()) == 2
|
||||
]
|
||||
clean_labels = [
|
||||
line.split()[1].strip()
|
||||
for line in tokens_labels_str
|
||||
if len(line.split()) == 2
|
||||
]
|
||||
clean_sentences = get_sentences_from_iob_format(tokens_labels_str)
|
||||
# read ocr tokens
|
||||
with open(ocr_text_file, 'r', encoding='utf-8') as otf:
|
||||
ocr_text_str = ' '.join(otf.readlines())
|
||||
ocr_tokens = [token.strip() for token in ocr_text_str.split()] # already tokenized in data
|
||||
with open(ocr_text_file, "r", encoding="utf-8") as otf:
|
||||
ocr_text_str = " ".join(otf.readlines())
|
||||
ocr_tokens = [
|
||||
token.strip() for token in ocr_text_str.split()
|
||||
] # already tokenized in data
|
||||
try:
|
||||
ocr_tokens_sentences, ocr_labels_sentences = propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
ocr_tokens_sentences, ocr_labels_sentences = propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\nWarning: error processing '{input_filename}': {str(e)}.\nSkipping this file...")
|
||||
print(
|
||||
f"\nWarning: error processing '{input_filename}': {str(e)}.\nSkipping this file..."
|
||||
)
|
||||
return
|
||||
# Write result to file
|
||||
with open(ocr_labels_file, 'w', encoding="utf-8") as olf:
|
||||
for ocr_tokens, ocr_labels in zip(ocr_tokens_sentences, ocr_labels_sentences):
|
||||
with open(ocr_labels_file, "w", encoding="utf-8") as olf:
|
||||
for ocr_tokens, ocr_labels in zip(
|
||||
ocr_tokens_sentences, ocr_labels_sentences
|
||||
):
|
||||
if len(ocr_tokens) == 0: # if empty OCR sentences
|
||||
olf.write(f'{EMPTY_SENTENCE_SENTINEL}\t{EMPTY_SENTENCE_SENTINEL_NER_LABEL}\n')
|
||||
olf.write(
|
||||
f"{EMPTY_SENTENCE_SENTINEL}\t{EMPTY_SENTENCE_SENTINEL_NER_LABEL}\n"
|
||||
)
|
||||
else:
|
||||
for token, label in zip(ocr_tokens, ocr_labels):
|
||||
olf.write(f"{token}\t{label}\n")
|
||||
olf.write('\n')
|
||||
olf.write("\n")
|
||||
|
||||
def propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext):
|
||||
|
||||
def propagate_labels_sentences_multiprocess(
|
||||
clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext
|
||||
):
|
||||
"""
|
||||
propagate_labels_sentences_all_files propagates labels and sentences for all files in dataset
|
||||
|
||||
|
@ -204,16 +253,24 @@ def propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, o
|
|||
file extension of the clean_labels
|
||||
"""
|
||||
clean_label_files = os.listdir(clean_labels_dir)
|
||||
args = list(map(
|
||||
lambda clean_label_filename:
|
||||
(clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext, clean_label_filename),
|
||||
clean_label_files
|
||||
))
|
||||
args = list(
|
||||
map(
|
||||
lambda clean_label_filename: (
|
||||
clean_labels_dir,
|
||||
output_text_dir,
|
||||
output_labels_dir,
|
||||
clean_label_ext,
|
||||
clean_label_filename,
|
||||
),
|
||||
clean_label_files,
|
||||
)
|
||||
)
|
||||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||
iterator = executor.map(propagate_labels_sentence_single_file, args)
|
||||
for _ in tqdm(iterator, total=len(args)): # wrapping tqdm for progress report
|
||||
pass
|
||||
|
||||
|
||||
def extract_ocr_text(input_file, output_file):
|
||||
"""
|
||||
extract_ocr_text from GROK json
|
||||
|
@ -227,16 +284,17 @@ def extract_ocr_text(input_file, output_file):
|
|||
"""
|
||||
out_dir = os.path.dirname(output_file)
|
||||
in_file_name = os.path.basename(input_file)
|
||||
file_pre = in_file_name.split('_')[-1].split('.')[0]
|
||||
output_file_name = '{}.txt'.format(file_pre)
|
||||
file_pre = in_file_name.split("_")[-1].split(".")[0]
|
||||
output_file_name = "{}.txt".format(file_pre)
|
||||
output_file = os.path.join(out_dir, output_file_name)
|
||||
with open(input_file, 'r', encoding='utf-8') as fin:
|
||||
with open(input_file, "r", encoding="utf-8") as fin:
|
||||
json_data = json.load(fin)
|
||||
json_dict = json_data[0]
|
||||
text = json_dict['text']
|
||||
with open(output_file, 'wb') as fout:
|
||||
text = json_dict["text"]
|
||||
with open(output_file, "wb") as fout:
|
||||
fout.write(text.encode("utf-8"))
|
||||
|
||||
|
||||
def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
|
||||
"""
|
||||
check_n_sentences prints file name if number of sentences is different in clean and OCR files
|
||||
|
@ -253,21 +311,23 @@ def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
|
|||
text_files = os.listdir(output_labels_dir)
|
||||
skip_files = []
|
||||
for text_filename in tqdm(text_files):
|
||||
clean_labels_file = os.path.join(clean_labels_dir, text_filename).replace(".txt", clean_label_ext)
|
||||
clean_labels_file = os.path.join(clean_labels_dir, text_filename).replace(
|
||||
".txt", clean_label_ext
|
||||
)
|
||||
ocr_labels_file = os.path.join(output_labels_dir, text_filename)
|
||||
remove_first_line(clean_labels_file, clean_labels_file)
|
||||
remove_first_line(ocr_labels_file, ocr_labels_file)
|
||||
remove_last_line(clean_labels_file, clean_labels_file)
|
||||
remove_last_line(ocr_labels_file, ocr_labels_file)
|
||||
with open(clean_labels_file, 'r', encoding='utf-8') as lf:
|
||||
with open(clean_labels_file, "r", encoding="utf-8") as lf:
|
||||
clean_tokens_labels = lf.readlines()
|
||||
with open(ocr_labels_file, 'r', encoding='utf-8') as of:
|
||||
with open(ocr_labels_file, "r", encoding="utf-8") as of:
|
||||
ocr_tokens_labels = of.readlines()
|
||||
error = False
|
||||
n_clean_sentences = 0
|
||||
nl = False
|
||||
for line in clean_tokens_labels:
|
||||
if line == '\n':
|
||||
if line == "\n":
|
||||
if nl is True:
|
||||
error = True
|
||||
else:
|
||||
|
@ -278,7 +338,7 @@ def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
|
|||
n_ocr_sentences = 0
|
||||
nl = False
|
||||
for line in ocr_tokens_labels:
|
||||
if line == '\n':
|
||||
if line == "\n":
|
||||
if nl is True:
|
||||
error = True
|
||||
else:
|
||||
|
@ -287,11 +347,14 @@ def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
|
|||
else:
|
||||
nl = False
|
||||
if error or n_ocr_sentences != n_clean_sentences:
|
||||
print(f"Warning: Inconsistent numbers of sentences in '{text_filename}''." +
|
||||
f"clean_sentences to ocr_sentences: {n_clean_sentences}:{n_ocr_sentences}")
|
||||
print(
|
||||
f"Warning: Inconsistent numbers of sentences in '{text_filename}''."
|
||||
+ f"clean_sentences to ocr_sentences: {n_clean_sentences}:{n_ocr_sentences}"
|
||||
)
|
||||
skip_files.append(text_filename)
|
||||
return skip_files
|
||||
|
||||
|
||||
def remove_first_line(input_file, output_file):
|
||||
"""
|
||||
remove_first_line from files (some clean CoNLL files have an empty first line)
|
||||
|
@ -303,13 +366,14 @@ def remove_first_line(input_file, output_file):
|
|||
output_file : str
|
||||
output file path
|
||||
"""
|
||||
with open(input_file, 'r', encoding='utf-8') as in_f:
|
||||
with open(input_file, "r", encoding="utf-8") as in_f:
|
||||
lines = in_f.readlines()
|
||||
if len(lines) > 1 and lines[0].strip() == '':
|
||||
if len(lines) > 1 and lines[0].strip() == "":
|
||||
# the clean CoNLL formatted files had a newline as the first line
|
||||
with open(output_file, 'w', encoding='utf-8') as out_f:
|
||||
with open(output_file, "w", encoding="utf-8") as out_f:
|
||||
out_f.writelines(lines[1:])
|
||||
|
||||
|
||||
def remove_last_line(input_file, output_file):
|
||||
"""
|
||||
remove_last_line from files (some clean CoNLL files have an empty last line)
|
||||
|
@ -322,12 +386,13 @@ def remove_last_line(input_file, output_file):
|
|||
output file path
|
||||
"""
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as in_f:
|
||||
with open(input_file, "r", encoding="utf-8") as in_f:
|
||||
lines = in_f.readlines()
|
||||
if len(lines) > 1 and lines[-1].strip() == '':
|
||||
with open(output_file, 'w', encoding='utf-8') as out_f:
|
||||
if len(lines) > 1 and lines[-1].strip() == "":
|
||||
with open(output_file, "w", encoding="utf-8") as out_f:
|
||||
out_f.writelines(lines[:-1])
|
||||
|
||||
|
||||
def for_all_files(input_dir, output_dir, func):
|
||||
"""
|
||||
for_all_files will apply function to every file in a director
|
||||
|
@ -347,25 +412,34 @@ def for_all_files(input_dir, output_dir, func):
|
|||
output_file = os.path.join(output_dir, text_filename)
|
||||
func(input_file, output_file)
|
||||
|
||||
|
||||
def main(args):
|
||||
if not args.train_subset and not args.test_subset:
|
||||
subsets = ['train', 'test']
|
||||
subsets = ["train", "test"]
|
||||
else:
|
||||
subsets = []
|
||||
if args.train_subset:
|
||||
subsets.append('train')
|
||||
subsets.append("train")
|
||||
|
||||
if args.test_subset:
|
||||
subsets.append('test')
|
||||
subsets.append("test")
|
||||
|
||||
for subset in subsets:
|
||||
print("Processing {} subset...".format(subset))
|
||||
|
||||
clean_labels_dir = os.path.join(args.base_folder, args.gt_folder, subset,'clean_labels')
|
||||
ocr_json_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr')
|
||||
clean_labels_dir = os.path.join(
|
||||
args.base_folder, args.gt_folder, subset, "clean_labels"
|
||||
)
|
||||
ocr_json_dir = os.path.join(
|
||||
args.base_folder, args.degraded_folder, subset, "ocr"
|
||||
)
|
||||
|
||||
output_text_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr_text')
|
||||
output_labels_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr_labels')
|
||||
output_text_dir = os.path.join(
|
||||
args.base_folder, args.degraded_folder, subset, "ocr_text"
|
||||
)
|
||||
output_labels_dir = os.path.join(
|
||||
args.base_folder, args.degraded_folder, subset, "ocr_labels"
|
||||
)
|
||||
|
||||
# remove first empty line of labels file, if exists
|
||||
for_all_files(clean_labels_dir, clean_labels_dir, remove_first_line)
|
||||
|
@ -380,21 +454,50 @@ def main(args):
|
|||
os.mkdir(output_labels_dir)
|
||||
|
||||
# make ocr labels files by propagating clean labels to ocr_text and creating files in ocr_labels
|
||||
propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, output_labels_dir, args.clean_label_ext)
|
||||
propagate_labels_sentences_multiprocess(
|
||||
clean_labels_dir, output_text_dir, output_labels_dir, args.clean_label_ext
|
||||
)
|
||||
print("Validating number of sentences in gt and ocr labels")
|
||||
check_n_sentences(clean_labels_dir, output_labels_dir, args.clean_label_ext) # check number of sentences and make sure same; print anomaly files
|
||||
check_n_sentences(
|
||||
clean_labels_dir, output_labels_dir, args.clean_label_ext
|
||||
) # check number of sentences and make sure same; print anomaly files
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("base_folder", help="base directory containing the collection of dataset")
|
||||
parser.add_argument("degraded_folder", help="directory containing train and test subset for degradation")
|
||||
parser.add_argument("--gt_folder", type=str, default="shared", help="directory containing the ground truth")
|
||||
parser.add_argument("--clean_label_ext", type=str, default=".txt", help="file extension of the clean_labels files")
|
||||
parser.add_argument('--train_subset', help="include if only train folder should be processed", action='store_true')
|
||||
parser.add_argument('--test_subset', help="include if only test folder should be processed", action='store_true')
|
||||
parser.add_argument(
|
||||
"base_folder", help="base directory containing the collection of dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"degraded_folder",
|
||||
help="directory containing train and test subset for degradation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gt_folder",
|
||||
type=str,
|
||||
default="shared",
|
||||
help="directory containing the ground truth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean_label_ext",
|
||||
type=str,
|
||||
default=".txt",
|
||||
help="file extension of the clean_labels files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_subset",
|
||||
help="include if only train folder should be processed",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_subset",
|
||||
help="include if only test folder should be processed",
|
||||
action="store_true",
|
||||
)
|
||||
return parser
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
start = timeit.default_timer()
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from genalog.text import preprocess
|
||||
|
||||
class LCS():
|
||||
class LCS:
|
||||
""" Compute the Longest Common Subsequence (LCS) of two given string."""
|
||||
|
||||
def __init__(self, str_m, str_n):
|
||||
self.str_m_len = len(str_m)
|
||||
self.str_n_len = len(str_n)
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from genalog.text import alignment, anchor
|
||||
from genalog.text import preprocess
|
||||
import itertools
|
||||
import re
|
||||
import string
|
||||
import itertools
|
||||
|
||||
from genalog.text import alignment, anchor
|
||||
from genalog.text import preprocess
|
||||
|
||||
# Both regex below has the following behavior:
|
||||
# 1. whitespace-tolerant at both ends of the string
|
||||
|
@ -10,34 +11,42 @@ import itertools
|
|||
# For example, given a label 'B-PLACE'
|
||||
# Group 1 (denoted by \1): Label Indicator (B-)
|
||||
# Group 2 (denoted by \2): Label Name (PLACE)
|
||||
MULTI_TOKEN_BEGIN_LABEL_REGEX = r'^\s*(B-)([a-z|A-Z]+)\s*$'
|
||||
MULTI_TOKEN_INSIDE_LABEL_REGEX = r'^\s*(I-)([a-z|A-Z]+)\s*$'
|
||||
MULTI_TOKEN_LABEL_REGEX = r'^\s*([B|I]-)([a-z|A-Z]+)\s*'
|
||||
MULTI_TOKEN_BEGIN_LABEL_REGEX = r"^\s*(B-)([a-z|A-Z]+)\s*$"
|
||||
MULTI_TOKEN_INSIDE_LABEL_REGEX = r"^\s*(I-)([a-z|A-Z]+)\s*$"
|
||||
MULTI_TOKEN_LABEL_REGEX = r"^\s*([B|I]-)([a-z|A-Z]+)\s*"
|
||||
|
||||
# To avoid confusion in the Python interpreter,
|
||||
# gap char should not be any of the following special characters
|
||||
SPECIAL_CHAR = set(" \t\n'\x0b''\x0c''\r'") # Notice space characters (' ', '\t', '\n') are in this set.
|
||||
SPECIAL_CHAR = set(
|
||||
" \t\n'\x0b''\x0c''\r'"
|
||||
) # Notice space characters (' ', '\t', '\n') are in this set.
|
||||
GAP_CHAR_SET = set(string.printable).difference(SPECIAL_CHAR)
|
||||
# GAP_CHAR_SET = '!"#$%&()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~'
|
||||
|
||||
|
||||
class GapCharError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _is_begin_label(label):
|
||||
""" Return true if the NER label is a begin label (eg. B-PLACE) """
|
||||
return re.match(MULTI_TOKEN_BEGIN_LABEL_REGEX, label) != None
|
||||
return re.match(MULTI_TOKEN_BEGIN_LABEL_REGEX, label) is not None
|
||||
|
||||
|
||||
def _is_inside_label(label):
|
||||
""" Return true if the NER label is an inside label (eg. I-PLACE) """
|
||||
return re.match(MULTI_TOKEN_INSIDE_LABEL_REGEX, label) != None
|
||||
return re.match(MULTI_TOKEN_INSIDE_LABEL_REGEX, label) is not None
|
||||
|
||||
|
||||
def _is_multi_token_label(label):
|
||||
""" Return true if the NER label is a multi token label (eg. B-PLACE, I-PLACE) """
|
||||
return re.match(MULTI_TOKEN_LABEL_REGEX, label) != None
|
||||
return re.match(MULTI_TOKEN_LABEL_REGEX, label) is not None
|
||||
|
||||
|
||||
def _clean_multi_token_label(label):
|
||||
""" Rid the multi-token-labels of whitespaces"""
|
||||
return re.sub(MULTI_TOKEN_LABEL_REGEX, r'\1\2', label)
|
||||
return re.sub(MULTI_TOKEN_LABEL_REGEX, r"\1\2", label)
|
||||
|
||||
|
||||
def _convert_to_begin_label(label):
|
||||
"""Convert an inside label, or I-label, (ex. I-PLACE) to a begin label, or B-Label, (ex. B-PLACE)
|
||||
|
@ -50,9 +59,10 @@ def _convert_to_begin_label(label):
|
|||
"""
|
||||
if _is_inside_label(label):
|
||||
# Replace the Label Indicator to 'B-'(\1) and keep the Label Name (\2)
|
||||
return re.sub(MULTI_TOKEN_INSIDE_LABEL_REGEX, r'B-\2', label)
|
||||
return re.sub(MULTI_TOKEN_INSIDE_LABEL_REGEX, r"B-\2", label)
|
||||
return label
|
||||
|
||||
|
||||
def _convert_to_inside_label(label):
|
||||
"""Convert a begin label, or B-label, (ex. B-PLACE) to an inside label, or I-Label, (ex. B-PLACE)
|
||||
|
||||
|
@ -64,9 +74,10 @@ def _convert_to_inside_label(label):
|
|||
"""
|
||||
if _is_begin_label(label):
|
||||
# Replace the Label Indicator to 'I-'(\1) and keep the Label Name (\2)
|
||||
return re.sub(MULTI_TOKEN_BEGIN_LABEL_REGEX, r'I-\2', label)
|
||||
return re.sub(MULTI_TOKEN_BEGIN_LABEL_REGEX, r"I-\2", label)
|
||||
return label
|
||||
|
||||
|
||||
def _is_missing_begin_label(begin_label, inside_label):
|
||||
"""Validate a inside label given an begin label
|
||||
|
||||
|
@ -93,6 +104,7 @@ def _is_missing_begin_label(begin_label, inside_label):
|
|||
else:
|
||||
return True
|
||||
|
||||
|
||||
def correct_ner_labels(labels):
|
||||
"""Correct the given list of labels for the following case:
|
||||
|
||||
|
@ -119,6 +131,7 @@ def correct_ner_labels(labels):
|
|||
cur_begin_label = ""
|
||||
return labels
|
||||
|
||||
|
||||
def _select_from_multiple_ner_labels(label_indices):
|
||||
"""Private method to select a NER label from a list of candidate
|
||||
|
||||
|
@ -147,6 +160,7 @@ def _select_from_multiple_ner_labels(label_indices):
|
|||
# TODO: may need a more sophisticated way to select from multiple NER labels
|
||||
return label_indices[0]
|
||||
|
||||
|
||||
def _find_gap_char_candidates(gt_tokens, ocr_tokens):
|
||||
"""Find a set of suitable GAP_CHARs based not in the set of input characters
|
||||
|
||||
|
@ -159,12 +173,15 @@ def _find_gap_char_candidates(gt_tokens, ocr_tokens):
|
|||
1. the set of suitable GAP_CHARs
|
||||
2. the set of input characters
|
||||
"""
|
||||
input_char_set = set(''.join(itertools.chain(gt_tokens, ocr_tokens))) # The set of input characters
|
||||
input_char_set = set(
|
||||
"".join(itertools.chain(gt_tokens, ocr_tokens))
|
||||
) # The set of input characters
|
||||
gap_char_set = GAP_CHAR_SET # The set of possible GAP_CHARs
|
||||
# Find a set of gap_char that is NOT in the set of input characters
|
||||
gap_char_candidates = gap_char_set.difference(input_char_set)
|
||||
return gap_char_candidates, input_char_set
|
||||
|
||||
|
||||
def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
|
||||
"""Propagate NER label for ground truth tokens to to ocr tokens.
|
||||
|
||||
|
@ -199,20 +216,29 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
|
|||
`gap_char` is the char used to alignment for inserting gaps
|
||||
"""
|
||||
# Find a set of suitable GAP_CHAR based not in the set of input characters
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(gt_tokens, ocr_tokens)
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(
|
||||
gt_tokens, ocr_tokens
|
||||
)
|
||||
if len(gap_char_candidates) == 0:
|
||||
raise GapCharError("Exhausted all possible GAP_CHAR candidates for alignment." +
|
||||
" Consider reducing cardinality of the input character set.\n" +
|
||||
f"The set of possible GAP_CHAR candidates is: '{''.join(sorted(GAP_CHAR_SET))}'\n" +
|
||||
f"The set of input character is: '{''.join(sorted(input_char_set))}'")
|
||||
raise GapCharError(
|
||||
"Exhausted all possible GAP_CHAR candidates for alignment."
|
||||
+ " Consider reducing cardinality of the input character set.\n"
|
||||
+ f"The set of possible GAP_CHAR candidates is: '{''.join(sorted(GAP_CHAR_SET))}'\n"
|
||||
+ f"The set of input character is: '{''.join(sorted(input_char_set))}'"
|
||||
)
|
||||
else:
|
||||
if alignment.GAP_CHAR in gap_char_candidates:
|
||||
gap_char = alignment.GAP_CHAR # prefer to use default GAP_CHAR
|
||||
else:
|
||||
gap_char = gap_char_candidates.pop()
|
||||
return _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char, use_anchor=use_anchor)
|
||||
return _propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char, use_anchor=use_anchor
|
||||
)
|
||||
|
||||
def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment.GAP_CHAR, use_anchor=True):
|
||||
|
||||
def _propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens, gap_char=alignment.GAP_CHAR, use_anchor=True
|
||||
):
|
||||
"""Propagate NER label for ground truth tokens to to ocr tokens. Low level implementation
|
||||
|
||||
NOTE: that `gt_tokens` and `ocr_tokens` MUST NOT contain invalid tokens.
|
||||
|
@ -342,31 +368,43 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
|
|||
|
||||
# Sanity check:
|
||||
if len(gt_tokens) != len(gt_labels):
|
||||
raise ValueError(f"Unequal number of gt_tokens ({len(gt_tokens)})" +
|
||||
f"to that of gt_labels ({len(gt_labels)})")
|
||||
raise ValueError(
|
||||
f"Unequal number of gt_tokens ({len(gt_tokens)})"
|
||||
+ f"to that of gt_labels ({len(gt_labels)})"
|
||||
)
|
||||
|
||||
for tk in (gt_tokens + ocr_tokens):
|
||||
for tk in gt_tokens + ocr_tokens:
|
||||
if len(preprocess.tokenize(tk)) > 1:
|
||||
raise ValueError(f"Invalid token '{tk}'. Tokens must be atomic.")
|
||||
if not alignment._is_valid_token(tk, gap_char=gap_char):
|
||||
if re.search(rf'{re.escape(gap_char)}+', tk): # Escape special regex chars
|
||||
raise GapCharError(f"Invalid token '{tk}'. Tokens cannot be a chain repetition of the GAP_CHAR '{gap_char}'")
|
||||
if re.search(rf"{re.escape(gap_char)}+", tk): # Escape special regex chars
|
||||
raise GapCharError(
|
||||
f"Invalid token '{tk}'. Tokens cannot be a chain repetition of the GAP_CHAR '{gap_char}'"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid token '{tk}'. Tokens cannot be an empty string or a mix of space characters (spaces, tabs, newlines)")
|
||||
raise ValueError(
|
||||
f"Invalid token '{tk}'. Tokens cannot be an empty string or a mix of space characters (spaces, tabs, newlines)"
|
||||
)
|
||||
|
||||
# Stitch tokens together into one string for alignment
|
||||
gt_txt = preprocess.join_tokens(gt_tokens)
|
||||
ocr_txt = preprocess.join_tokens(ocr_tokens)
|
||||
# Align the ground truth and ocr text first
|
||||
if use_anchor:
|
||||
aligned_gt, aligned_ocr = anchor.align_w_anchor(gt_txt, ocr_txt, gap_char=gap_char)
|
||||
aligned_gt, aligned_ocr = anchor.align_w_anchor(
|
||||
gt_txt, ocr_txt, gap_char=gap_char
|
||||
)
|
||||
else:
|
||||
aligned_gt, aligned_ocr = alignment.align(gt_txt, ocr_txt, gap_char=gap_char)
|
||||
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(aligned_gt, aligned_ocr, gap_char=gap_char)
|
||||
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(
|
||||
aligned_gt, aligned_ocr, gap_char=gap_char
|
||||
)
|
||||
# Check invariant
|
||||
if len(gt_to_ocr_mapping) != len(gt_tokens):
|
||||
raise ValueError(f"Alignment modified number of gt_tokens. aligned_gt_tokens to gt_tokens: " +
|
||||
f"{len(gt_to_ocr_mapping)}:{len(gt_tokens)}. \nCheck alignment.parse_alignment().")
|
||||
raise ValueError(
|
||||
"Alignment modified number of gt_tokens. aligned_gt_tokens to gt_tokens: "
|
||||
+ f"{len(gt_to_ocr_mapping)}:{len(gt_tokens)}. \nCheck alignment.parse_alignment()."
|
||||
)
|
||||
|
||||
ocr_labels = []
|
||||
# STEP 1: naively propagate NER label based on text-alignment
|
||||
|
@ -374,7 +412,9 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
|
|||
# if is not mapping to missing a token (Case 4)
|
||||
if ocr_to_gt_token_relationship:
|
||||
# Find the corresponding gt_token it is aligned to
|
||||
ner_label_index = _select_from_multiple_ner_labels(ocr_to_gt_token_relationship)
|
||||
ner_label_index = _select_from_multiple_ner_labels(
|
||||
ocr_to_gt_token_relationship
|
||||
)
|
||||
# Get the NER label for that particular gt_token
|
||||
ocr_labels.append(gt_labels[ner_label_index])
|
||||
|
||||
|
@ -429,18 +469,26 @@ def format_labels(tokens, labels, label_top=True):
|
|||
len_diff = abs(len(label) - len(token))
|
||||
# Add padding spaces for whichever is shorter
|
||||
if len(label) > len(token):
|
||||
formatted_labels += label + ' '
|
||||
formatted_tokens += token + ' '*len_diff + ' '
|
||||
formatted_labels += label + " "
|
||||
formatted_tokens += token + " " * len_diff + " "
|
||||
else:
|
||||
formatted_labels += label + ' '*len_diff + ' '
|
||||
formatted_tokens += token + ' '
|
||||
formatted_labels += label + " " * len_diff + " "
|
||||
formatted_tokens += token + " "
|
||||
if label_top:
|
||||
return formatted_labels + '\n' + formatted_tokens + '\n'
|
||||
return formatted_labels + "\n" + formatted_tokens + "\n"
|
||||
else:
|
||||
return formatted_tokens + '\n' + formatted_labels + '\n'
|
||||
return formatted_tokens + "\n" + formatted_labels + "\n"
|
||||
|
||||
def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \
|
||||
aligned_gt, aligned_ocr, show_alignment=True):
|
||||
|
||||
def format_label_propagation(
|
||||
gt_tokens,
|
||||
gt_labels,
|
||||
ocr_tokens,
|
||||
ocr_labels,
|
||||
aligned_gt,
|
||||
aligned_ocr,
|
||||
show_alignment=True,
|
||||
):
|
||||
"""Format label propagation for display
|
||||
|
||||
Arguments:
|
||||
|
@ -483,4 +531,3 @@ def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \
|
|||
return gt_label_str + alignment_str + label_str
|
||||
else:
|
||||
return gt_label_str + label_str
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import re
|
||||
|
||||
END_OF_TOKEN = {' ', '\t', '\n'}
|
||||
END_OF_TOKEN = {" ", "\t", "\n"}
|
||||
NON_ASCII_REPLACEMENT = "_"
|
||||
|
||||
|
||||
def remove_non_ascii(token, replacement=NON_ASCII_REPLACEMENT):
|
||||
"""Remove non ascii characters in a token
|
||||
|
||||
|
@ -15,12 +16,13 @@ def remove_non_ascii(token, replacement=NON_ASCII_REPLACEMENT):
|
|||
str -- a word token with non-ASCII characters removed
|
||||
"""
|
||||
# Remove non-ASCII characters in the token
|
||||
ascii_token = str(token.encode('utf-8').decode('ascii', 'ignore'))
|
||||
ascii_token = str(token.encode("utf-8").decode("ascii", "ignore"))
|
||||
# If token becomes an empty string as a result
|
||||
if len(ascii_token) == 0 and len(token) != 0:
|
||||
ascii_token = replacement # replace with a default character
|
||||
return ascii_token
|
||||
|
||||
|
||||
def tokenize(s):
|
||||
"""Tokenize string
|
||||
|
||||
|
@ -33,6 +35,7 @@ def tokenize(s):
|
|||
# split alignment tokens by spaces, tabs and newline (and excluding them in the tokens)
|
||||
return s.split()
|
||||
|
||||
|
||||
def join_tokens(tokens):
|
||||
"""Join a list of tokens into a string
|
||||
|
||||
|
@ -44,14 +47,17 @@ def join_tokens(tokens):
|
|||
"""
|
||||
return " ".join(tokens)
|
||||
|
||||
|
||||
def _is_spacing(c):
|
||||
""" Determine if the character is ignorable """
|
||||
return True if c in END_OF_TOKEN else False
|
||||
|
||||
|
||||
def split_sentences(text, delimiter="\n"):
|
||||
""" Split a text into sentences with a delimiter"""
|
||||
return re.sub(r'(( /?[.!?])+ )', rf'\1{delimiter}', text)
|
||||
return re.sub(r"(( /?[.!?])+ )", rf"\1{delimiter}", text)
|
||||
|
||||
|
||||
def is_sentence_separator(token):
|
||||
""" Returns true if the token is a sentence splitter """
|
||||
return re.match(r'^/?[.!?]$', token) != None
|
||||
return re.match(r"^/?[.!?]$", token) is not None
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""This is a utility tool to split CoNLL formated files. It has the capability to pack sentences into generated
|
||||
pages more tightly.
|
||||
"""This is a utility tool to split CoNLL formated files.
|
||||
It has the capability to pack sentences into generated pages more tightly.
|
||||
|
||||
usage: splitter.py [-h] [--doc_sep DOC_SEP] [--line_sep LINE_SEP]
|
||||
[--force_doc_sep]
|
||||
|
@ -22,16 +22,17 @@ example usage:
|
|||
python -m genalog.text.splitter CoNLL-2012_train.txt conll2012_train
|
||||
|
||||
"""
|
||||
import re
|
||||
import os
|
||||
import multiprocessing
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
from genalog.text import preprocess
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
import multiprocessing
|
||||
import os
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.text import preprocess
|
||||
|
||||
# default buffer. Preferebly set this to something large
|
||||
# It holds the lines read from the CoNLL file
|
||||
BUFFER_SIZE = 50000
|
||||
|
@ -48,6 +49,7 @@ WORKERS_PER_CPU = 2
|
|||
|
||||
default_generator = DocumentGenerator()
|
||||
|
||||
|
||||
def unwrap(size, accumulator):
|
||||
words = []
|
||||
labels = []
|
||||
|
@ -58,7 +60,10 @@ def unwrap(size, accumulator):
|
|||
labels.append((word, tok))
|
||||
return words, labels
|
||||
|
||||
def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='text_block.html.jinja'):
|
||||
|
||||
def find_split_position(
|
||||
accumulator, start_pos, iters=SPLIT_ITERS, template_name="text_block.html.jinja"
|
||||
):
|
||||
"""Run a few iterations of binary search to find the best split point
|
||||
from the start to pack in sentences into a page without overflow.
|
||||
|
||||
|
@ -75,7 +80,10 @@ def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='
|
|||
best = None
|
||||
count = 0
|
||||
while start <= end:
|
||||
if count==0 and (STARTING_SPLIT_GUESS+start_pos > start and STARTING_SPLIT_GUESS + start_pos < end):
|
||||
if count == 0 and (
|
||||
STARTING_SPLIT_GUESS + start_pos > start
|
||||
and STARTING_SPLIT_GUESS + start_pos < end
|
||||
):
|
||||
split_point = STARTING_SPLIT_GUESS
|
||||
else:
|
||||
split_point = (start + end) // 2
|
||||
|
@ -101,8 +109,15 @@ def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='
|
|||
return best
|
||||
|
||||
|
||||
def generate_splits(input_file, output_folder, sentence_seperator="", doc_seperator=None, pool=None,
|
||||
force_doc_sep=False, ext="txt"):
|
||||
def generate_splits(
|
||||
input_file,
|
||||
output_folder,
|
||||
sentence_seperator="",
|
||||
doc_seperator=None,
|
||||
pool=None,
|
||||
force_doc_sep=False,
|
||||
ext="txt",
|
||||
):
|
||||
"""Processes the file line by line and add sentences to the buffer for processing.
|
||||
|
||||
Args:
|
||||
|
@ -132,7 +147,9 @@ def generate_splits(input_file, output_folder, sentence_seperator="", doc_sepera
|
|||
continue
|
||||
start_pos = 0
|
||||
while start_pos < len(accumulator):
|
||||
start_pos = next_doc(accumulator,doc_id, start_pos, output_folder,pool)
|
||||
start_pos = next_doc(
|
||||
accumulator, doc_id, start_pos, output_folder, pool
|
||||
)
|
||||
doc_id += 1
|
||||
progress_bar.update(1)
|
||||
accumulator = []
|
||||
|
@ -145,17 +162,20 @@ def generate_splits(input_file, output_folder, sentence_seperator="", doc_sepera
|
|||
|
||||
# process any left over lines
|
||||
start_pos = 0
|
||||
if len(sentence) > 0 : accumulator.append(sentence)
|
||||
if len(sentence) > 0:
|
||||
accumulator.append(sentence)
|
||||
while start_pos < len(accumulator):
|
||||
start_pos = next_doc(accumulator, doc_id, start_pos, output_folder, pool)
|
||||
doc_id += 1
|
||||
progress_bar.update(1)
|
||||
|
||||
|
||||
def next_doc(accumulator, doc_id, start_pos, output_folder, pool, ext="txt"):
|
||||
split_pos, doc, labels, text = find_split_position(accumulator, start_pos)
|
||||
handle_doc(doc, labels, doc_id, text, output_folder, pool, ext)
|
||||
return split_pos
|
||||
|
||||
|
||||
def write_doc(doc, doc_id, labels, text, output_folder, ext="txt", write_png=False):
|
||||
|
||||
if write_png:
|
||||
|
@ -165,45 +185,66 @@ def write_doc(doc, doc_id, labels, text, output_folder, ext="txt", write_png=Fal
|
|||
text += " " # adding a space at EOF
|
||||
text = preprocess.split_sentences(text)
|
||||
|
||||
with open(f"{output_folder}/clean_labels/{doc_id}.{ext}", "w") as l:
|
||||
with open(f"{output_folder}/clean_labels/{doc_id}.{ext}", "w") as fp:
|
||||
for idx, (token, label) in enumerate(labels):
|
||||
l.write(token + "\t" + label)
|
||||
fp.write(token + "\t" + label)
|
||||
next_token, _ = labels[(idx + 1) % len(labels)]
|
||||
if preprocess.is_sentence_separator(token) and not \
|
||||
preprocess.is_sentence_separator(next_token):
|
||||
l.write("\n")
|
||||
if preprocess.is_sentence_separator(
|
||||
token
|
||||
) and not preprocess.is_sentence_separator(next_token):
|
||||
fp.write("\n")
|
||||
if idx == len(labels): # Reach the end of the document
|
||||
l.write("\n")
|
||||
fp.write("\n")
|
||||
|
||||
with open(f"{output_folder}/clean_text/{doc_id}.txt", "w") as text_file:
|
||||
text_file.write(text)
|
||||
return f"wrote: doc id: {doc_id}"
|
||||
|
||||
|
||||
def _error_callback(err):
|
||||
raise RuntimeError(err)
|
||||
|
||||
|
||||
def handle_doc(doc, labels, doc_id, text, output_folder, pool, ext="txt"):
|
||||
if pool:
|
||||
pool.apply_async(write_doc, args=(doc, doc_id, labels, text, output_folder,ext), error_callback=_error_callback)
|
||||
pool.apply_async(
|
||||
write_doc,
|
||||
args=(doc, doc_id, labels, text, output_folder, ext),
|
||||
error_callback=_error_callback,
|
||||
)
|
||||
else:
|
||||
write_doc(doc, doc_id, labels, text, output_folder)
|
||||
|
||||
|
||||
def setup_folder(output_folder):
|
||||
os.makedirs(os.path.join(output_folder, "clean_text"), exist_ok=True)
|
||||
os.makedirs(os.path.join(output_folder, "clean_labels"), exist_ok=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input_file", default="CoNLL-2012_train.txt", help="path to input CoNLL formated file.")
|
||||
parser.add_argument("output_folder", default="conll2012_train", help="folder to write results to.")
|
||||
parser.add_argument(
|
||||
"input_file",
|
||||
default="CoNLL-2012_train.txt",
|
||||
help="path to input CoNLL formated file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_folder", default="conll2012_train", help="folder to write results to."
|
||||
)
|
||||
parser.add_argument("--doc_sep", help="CoNLL doc seperator")
|
||||
parser.add_argument("--ext", help="file extension", default="txt")
|
||||
parser.add_argument("--line_sep", default=CONLL2012_DOC_SEPERATOR, help="CoNLL line seperator")
|
||||
parser.add_argument("--force_doc_sep", default=False, action="store_true",
|
||||
help="If set, documents are forced to be split by the doc seperator (recommended to turn this off)")
|
||||
parser.add_argument(
|
||||
"--line_sep", default=CONLL2012_DOC_SEPERATOR, help="CoNLL line seperator"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_doc_sep",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="If set, documents are forced to be split by the doc seperator (recommended to turn this off)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
unescape = lambda s: s.encode('utf-8').decode('unicode_escape') if s else None
|
||||
unescape = lambda s: s.encode("utf-8").decode("unicode_escape") if s else None # noqa: E731
|
||||
|
||||
input_file = args.input_file
|
||||
output_folder = args.output_folder
|
||||
|
@ -215,6 +256,14 @@ if __name__ == "__main__":
|
|||
|
||||
n_workers = WORKERS_PER_CPU * multiprocessing.cpu_count()
|
||||
with ThreadPool(processes=n_workers) as pool:
|
||||
generate_splits(input_file, output_folder, line_sep, doc_seperator=doc_sep, pool=pool, force_doc_sep=False, ext=args.ext)
|
||||
generate_splits(
|
||||
input_file,
|
||||
output_folder,
|
||||
line_sep,
|
||||
doc_seperator=doc_sep,
|
||||
pool=pool,
|
||||
force_doc_sep=False,
|
||||
ext=args.ext,
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
|
@ -1,2 +0,0 @@
|
|||
[pytest]
|
||||
junit_family=xunit1
|
|
@ -0,0 +1,4 @@
|
|||
pytest
|
||||
pytest-cov
|
||||
flake8
|
||||
flake8-import-order
|
3
setup.py
3
setup.py
|
@ -1,6 +1,7 @@
|
|||
import setuptools
|
||||
import os
|
||||
|
||||
import setuptools
|
||||
|
||||
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'VERSION.txt')) as version_file:
|
||||
BUILD_VERSION = version_file.read().strip()
|
||||
|
||||
|
|
|
@ -114,4 +114,6 @@ ns_tokens = [preprocess.tokenize(txt) for txt in ns_txt]
|
|||
|
||||
# test function expect params in tuple of
|
||||
# (gt_label, gt_tokens, ocr_tokens, desired_ocr_labels)
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES = list(zip(ner_labels, gt_tokens, ns_tokens, desired_ocr_labels))
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES = list(
|
||||
zip(ner_labels, gt_tokens, ns_tokens, desired_ocr_labels)
|
||||
)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
# Initializing test cases
|
||||
# For extensibility, all parameters in a case are append to following arrays
|
||||
gt_txt = []
|
||||
|
@ -23,17 +22,21 @@ gt_to_noise_maps.append(
|
|||
[0, 1],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
[3],
|
||||
]
|
||||
)
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
|
||||
noise_to_gt_maps.append(
|
||||
[
|
||||
[0],
|
||||
# Similarly, the following shows that the second token in noise "ewYork" maps to the
|
||||
# first ("New") and second ("York") token in gt
|
||||
[0, 1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
[3],
|
||||
]
|
||||
)
|
||||
|
||||
##############################################################################################
|
||||
|
||||
|
@ -52,18 +55,9 @@ gt_txt.append("Boston is great")
|
|||
aligned_ns.append("B oston@@@ grea t")
|
||||
aligned_gt.append("B@oston is grea@t")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0,1],
|
||||
[1], # 'is' is also mapped to 'oston'
|
||||
[2,3]
|
||||
])
|
||||
gt_to_noise_maps.append([[0, 1], [1], [2, 3]]) # 'is' is also mapped to 'oston'
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[0,1], # 'oston' is to 'Boston' and 'is'
|
||||
[2],
|
||||
[2]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [0, 1], [2], [2]]) # 'oston' is to 'Boston' and 'is'
|
||||
############################################################################################
|
||||
|
||||
# Empty Cases:
|
||||
|
@ -95,18 +89,9 @@ ns_txt.append("B oston bi g")
|
|||
aligned_gt.append("B@oston is bi@g")
|
||||
aligned_ns.append("B oston@@@ bi g")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0,1],
|
||||
[1],
|
||||
[2,3]
|
||||
])
|
||||
gt_to_noise_maps.append([[0, 1], [1], [2, 3]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[0,1],
|
||||
[2],
|
||||
[2]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [0, 1], [2], [2]])
|
||||
|
||||
############################################################################################
|
||||
gt_txt.append("New York is big.")
|
||||
|
@ -115,17 +100,9 @@ ns_txt.append("NewYork big")
|
|||
aligned_gt.append("New York is big.")
|
||||
aligned_ns.append("New@York @@@big@")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0],
|
||||
[0],
|
||||
[1],
|
||||
[1]
|
||||
])
|
||||
gt_to_noise_maps.append([[0], [0], [1], [1]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0, 1],
|
||||
[2, 3]
|
||||
])
|
||||
noise_to_gt_maps.append([[0, 1], [2, 3]])
|
||||
|
||||
#############################################################################################
|
||||
gt_txt.append("politicians who lag superfluous on the")
|
||||
|
@ -134,23 +111,9 @@ ns_txt.append("politicians who kg superfluous on the")
|
|||
aligned_gt.append("politicians who lag superfluous on the")
|
||||
aligned_ns.append("politicians who @kg superfluous on the")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5]
|
||||
])
|
||||
gt_to_noise_maps.append([[0], [1], [2], [3], [4], [5]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [1], [2], [3], [4], [5]])
|
||||
|
||||
############################################################################################
|
||||
|
||||
|
@ -160,20 +123,9 @@ ns_txt.append("faithei uifoimtdon the subject")
|
|||
aligned_gt.append("farther @informed on the subject.")
|
||||
aligned_ns.append("faithei ui@foimtd@on the subject@")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
gt_to_noise_maps.append([[0], [1], [1], [2], [3]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[1,2],
|
||||
[3],
|
||||
[4]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [1, 2], [3], [4]])
|
||||
|
||||
############################################################################################
|
||||
|
||||
|
@ -183,20 +135,9 @@ ns_txt.append("New Yorkis big .")
|
|||
aligned_gt.append("New York is big .")
|
||||
aligned_ns.append("New York@is big .")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
gt_to_noise_maps.append([[0], [1], [1], [2], [3]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[1,2],
|
||||
[3],
|
||||
[4]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [1, 2], [3], [4]])
|
||||
|
||||
############################################################################################
|
||||
|
||||
|
@ -206,23 +147,14 @@ ns_txt.append("New Yo rk is big.")
|
|||
aligned_gt.append("New Yo@rk is big.")
|
||||
aligned_ns.append("New Yo rk is big.")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0],
|
||||
[1,2],
|
||||
[3],
|
||||
[4]
|
||||
])
|
||||
gt_to_noise_maps.append([[0], [1, 2], [3], [4]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [1], [1], [2], [3]])
|
||||
|
||||
# Format tests for pytest
|
||||
# Each test expect in the following format
|
||||
# (aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps)
|
||||
PARSE_ALIGNMENT_REGRESSION_TEST_CASES = zip(aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps)
|
||||
PARSE_ALIGNMENT_REGRESSION_TEST_CASES = zip(
|
||||
aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps
|
||||
)
|
||||
ALIGNMENT_REGRESSION_TEST_CASES = list(zip(gt_txt, ns_txt, aligned_gt, aligned_ns))
|
|
@ -1,20 +1,24 @@
|
|||
from genalog.degradation.degrader import Degrader, ImageState
|
||||
from genalog.degradation.degrader import DEFAULT_METHOD_PARAM_TO_INCLUDE
|
||||
import copy
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import copy
|
||||
|
||||
from genalog.degradation.degrader import DEFAULT_METHOD_PARAM_TO_INCLUDE
|
||||
from genalog.degradation.degrader import Degrader, ImageState
|
||||
|
||||
MOCK_IMAGE_SHAPE = (4, 3)
|
||||
MOCK_IMAGE = np.arange(12, dtype=np.uint8).reshape(MOCK_IMAGE_SHAPE)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_degrader():
|
||||
effects = []
|
||||
return Degrader(effects)
|
||||
|
||||
@pytest.fixture(params=[
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
[("blur", {"radius": 5})],
|
||||
[("blur", {"src": ImageState.ORIGINAL_STATE, "radius": 5})],
|
||||
[("blur", {"src": ImageState.CURRENT_STATE, "radius": 5})],
|
||||
|
@ -26,34 +30,45 @@ def empty_degrader():
|
|||
],
|
||||
[
|
||||
("blur", {"radius": 5}),
|
||||
("bleed_through", {
|
||||
(
|
||||
"bleed_through",
|
||||
{
|
||||
"src": ImageState.CURRENT_STATE,
|
||||
"alpha": 0.7,
|
||||
"background": ImageState.ORIGINAL_STATE,
|
||||
}),
|
||||
("morphology", {
|
||||
"operation": "open",
|
||||
"kernel_shape": (3,3),
|
||||
"kernel_type": "ones"
|
||||
}),
|
||||
},
|
||||
),
|
||||
(
|
||||
"morphology",
|
||||
{"operation": "open", "kernel_shape": (3, 3), "kernel_type": "ones"},
|
||||
),
|
||||
],
|
||||
]
|
||||
])
|
||||
)
|
||||
def degrader(request):
|
||||
effects = request.param
|
||||
return Degrader(effects)
|
||||
|
||||
def test_degrader_init(empty_degrader):
|
||||
|
||||
def test_empty_degrader_init(empty_degrader):
|
||||
assert empty_degrader.effects_to_apply == []
|
||||
|
||||
|
||||
def test_degrader_init(degrader):
|
||||
assert degrader.effects_to_apply is not []
|
||||
for effect_tuple in degrader.effects_to_apply:
|
||||
method_name, method_kwargs = effect_tuple
|
||||
assert DEFAULT_METHOD_PARAM_TO_INCLUDE in method_kwargs
|
||||
param_value = method_kwargs[DEFAULT_METHOD_PARAM_TO_INCLUDE]
|
||||
assert param_value is ImageState.ORIGINAL_STATE or param_value is ImageState.CURRENT_STATE
|
||||
assert (
|
||||
param_value is ImageState.ORIGINAL_STATE
|
||||
or param_value is ImageState.CURRENT_STATE
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("effects, error_thrown", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"effects, error_thrown",
|
||||
[
|
||||
([], None), # Empty effect
|
||||
(None, TypeError),
|
||||
([("blur", {"radius": 5})], None), # Validate input
|
||||
|
@ -64,17 +79,20 @@ def test_degrader_init(degrader):
|
|||
[
|
||||
("blur", {"radius": 5}),
|
||||
("bleed_through", {"alpha": "0.8"}),
|
||||
("morphology", {"operation": "open"})
|
||||
], None
|
||||
("morphology", {"operation": "open"}),
|
||||
],
|
||||
None,
|
||||
), # Multiple effects
|
||||
(
|
||||
[
|
||||
("blur", {"radius": 5}),
|
||||
("bleed_through", {"not_argument": "0.8"}),
|
||||
("morphology", {"missing value"})
|
||||
], ValueError
|
||||
("morphology", {"missing value"}),
|
||||
],
|
||||
ValueError,
|
||||
), # Multiple effects
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_degrader_validate_effects(effects, error_thrown):
|
||||
if error_thrown:
|
||||
with pytest.raises(error_thrown):
|
||||
|
@ -82,23 +100,26 @@ def test_degrader_validate_effects(effects, error_thrown):
|
|||
else:
|
||||
Degrader.validate_effects(effects)
|
||||
|
||||
|
||||
def test_degrader_apply_effects(degrader):
|
||||
method_names = [effect[0] for effect in degrader.effects_to_apply]
|
||||
with patch("genalog.degradation.effect") as mock_effect:
|
||||
degraded = degrader.apply_effects(MOCK_IMAGE)
|
||||
degrader.apply_effects(MOCK_IMAGE)
|
||||
for method in method_names:
|
||||
assert mock_effect[method].is_called()
|
||||
# assert degraded.shape == MOCK_IMAGE_SHAPE
|
||||
|
||||
|
||||
def test_degrader_apply_effects_e2e(degrader):
|
||||
degraded = degrader.apply_effects(MOCK_IMAGE)
|
||||
assert degraded.shape == MOCK_IMAGE_SHAPE
|
||||
assert degraded.dtype == np.uint8
|
||||
|
||||
|
||||
def test_degrader_instructions(degrader):
|
||||
original_instruction = copy.deepcopy(degrader.effects_to_apply)
|
||||
degraded1 = degrader.apply_effects(MOCK_IMAGE)
|
||||
degraded2 = degrader.apply_effects(MOCK_IMAGE)
|
||||
degrader.apply_effects(MOCK_IMAGE)
|
||||
degrader.apply_effects(MOCK_IMAGE)
|
||||
# Make sure the degradation instructions are not altered
|
||||
assert len(original_instruction) == len(degrader.effects_to_apply)
|
||||
for i in range(len(original_instruction)):
|
||||
|
@ -107,5 +128,5 @@ def test_degrader_instructions(degrader):
|
|||
assert org_method_name == method_name
|
||||
assert len(org_method_arg) == len(method_arg)
|
||||
for key in org_method_arg.keys():
|
||||
assert type(org_method_arg[key]) == type(method_arg[key])
|
||||
assert isinstance(org_method_arg[key], type(method_arg[key]))
|
||||
assert org_method_arg[key] == method_arg[key]
|
|
@ -1,18 +1,21 @@
|
|||
from genalog.degradation import effect
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from genalog.degradation import effect
|
||||
|
||||
NEW_IMG_SHAPE = (100, 100)
|
||||
MOCK_IMG_SHAPE = (100, 120)
|
||||
MOCK_IMG = np.ones(MOCK_IMG_SHAPE, dtype=np.uint8)
|
||||
|
||||
|
||||
def test_blur():
|
||||
dst = effect.blur(MOCK_IMG, radius=3)
|
||||
assert dst.dtype == np.uint8 # preverse dtype
|
||||
assert dst.shape == MOCK_IMG_SHAPE # preverse image size
|
||||
|
||||
|
||||
def test_translation():
|
||||
offset_x = offset_y = 1
|
||||
# Test that border pixels are not white (<255)
|
||||
|
@ -25,6 +28,7 @@ def test_translation():
|
|||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_overlay_weighted():
|
||||
src = MOCK_IMG.copy()
|
||||
src[0][0] = 10
|
||||
|
@ -34,6 +38,7 @@ def test_overlay_weighted():
|
|||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
assert dst[0][0] == src[0][0] * alpha + src[0][0] * beta
|
||||
|
||||
|
||||
def test_overlay():
|
||||
src1 = MOCK_IMG.copy()
|
||||
src2 = MOCK_IMG.copy()
|
||||
|
@ -45,6 +50,7 @@ def test_overlay():
|
|||
assert dst[0][0] == 0
|
||||
assert dst[0][1] == 1
|
||||
|
||||
|
||||
@patch("genalog.degradation.effect.translation")
|
||||
def test_bleed_through_default(mock_translation):
|
||||
mock_translation.return_value = MOCK_IMG
|
||||
|
@ -53,11 +59,15 @@ def test_bleed_through_default(mock_translation):
|
|||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
@pytest.mark.parametrize("foreground, background, error_thrown", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"foreground, background, error_thrown",
|
||||
[
|
||||
(MOCK_IMG, MOCK_IMG, None),
|
||||
# Test unmatched shape
|
||||
(MOCK_IMG, MOCK_IMG[:, :-1], Exception),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_bleed_through_kwargs(foreground, background, error_thrown):
|
||||
if error_thrown:
|
||||
assert foreground.shape != background.shape
|
||||
|
@ -68,36 +78,43 @@ def test_bleed_through_kwargs(foreground, background, error_thrown):
|
|||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_pepper():
|
||||
dst = effect.pepper(MOCK_IMG, amount=0.1)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_salt():
|
||||
dst = effect.salt(MOCK_IMG, amount=0.1)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_salt_then_pepper():
|
||||
dst = effect.salt_then_pepper(MOCK_IMG, 0.5, 0.001)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_pepper_then_salt():
|
||||
dst = effect.pepper_then_salt(MOCK_IMG, 0.001, 0.5)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
@pytest.mark.parametrize("kernel_shape, kernel_type", [
|
||||
((3,3), "NOT_VALID_TYPE"),
|
||||
(1, "ones"),
|
||||
((1,2,3), "ones")
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_shape, kernel_type",
|
||||
[((3, 3), "NOT_VALID_TYPE"), (1, "ones"), ((1, 2, 3), "ones")],
|
||||
)
|
||||
def test_create_2D_kernel_error(kernel_shape, kernel_type):
|
||||
with pytest.raises(Exception):
|
||||
effect.create_2D_kernel(kernel_shape, kernel_type)
|
||||
|
||||
@pytest.mark.parametrize("kernel_shape, kernel_type, expected_kernel", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_shape, kernel_type, expected_kernel",
|
||||
[
|
||||
((2, 2), "ones", np.array([[1, 1], [1, 1]])), # sq kernel
|
||||
((1, 2), "ones", np.array([[1, 1]])), # horizontal
|
||||
((2, 1), "ones", np.array([[1], [1]])), # vertical
|
||||
|
@ -108,56 +125,77 @@ def test_create_2D_kernel_error(kernel_shape, kernel_type):
|
|||
((2, 2), "plus", np.array([[0, 1], [1, 1]])),
|
||||
((3, 3), "plus", np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])),
|
||||
((3, 3), "ellipse", np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])),
|
||||
((5,5), "ellipse",
|
||||
np.array([
|
||||
(
|
||||
(5, 5),
|
||||
"ellipse",
|
||||
np.array(
|
||||
[
|
||||
[0, 0, 1, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[0, 0, 1, 0, 0]
|
||||
])),
|
||||
])
|
||||
[0, 0, 1, 0, 0],
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_create_2D_kernel(kernel_shape, kernel_type, expected_kernel):
|
||||
kernel = effect.create_2D_kernel(kernel_shape, kernel_type)
|
||||
assert np.array_equal(kernel, expected_kernel)
|
||||
|
||||
|
||||
def test_morphology_with_error():
|
||||
INVALID_OPERATION = "NOT_A_OPERATION"
|
||||
with pytest.raises(ValueError):
|
||||
effect.morphology(MOCK_IMG, operation=INVALID_OPERATION)
|
||||
|
||||
@pytest.mark.parametrize("operation, kernel_shape, kernel_type", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation, kernel_shape, kernel_type",
|
||||
[
|
||||
("open", (3, 3), "ones"),
|
||||
("close", (3, 3), "ones"),
|
||||
("dilate", (3, 3), "ones"),
|
||||
("erode", (3, 3), "ones"),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_morphology(operation, kernel_shape, kernel_type):
|
||||
dst = effect.morphology(MOCK_IMG,
|
||||
operation=operation, kernel_shape=kernel_shape,
|
||||
kernel_type=kernel_type)
|
||||
dst = effect.morphology(
|
||||
MOCK_IMG,
|
||||
operation=operation,
|
||||
kernel_shape=kernel_shape,
|
||||
kernel_type=kernel_type,
|
||||
)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
@pytest.fixture(params=["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"])
|
||||
|
||||
@pytest.fixture(
|
||||
params=["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"]
|
||||
)
|
||||
def kernel(request):
|
||||
return effect.create_2D_kernel((5, 5), request.param)
|
||||
|
||||
|
||||
def test_open(kernel):
|
||||
dst = effect.open(MOCK_IMG, kernel)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_close(kernel):
|
||||
dst = effect.close(MOCK_IMG, kernel)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_erode(kernel):
|
||||
dst = effect.erode(MOCK_IMG, kernel)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_dilate(kernel):
|
||||
dst = effect.dilate(MOCK_IMG, kernel)
|
||||
assert dst.dtype == np.uint8
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
from genalog.text import alignment, anchor, preprocess
|
||||
|
||||
import glob
|
||||
import pytest
|
||||
import difflib
|
||||
import glob
|
||||
import warnings
|
||||
|
||||
@pytest.mark.parametrize("gt_file, ocr_file",
|
||||
import pytest
|
||||
|
||||
from genalog.text import alignment, anchor, preprocess
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_file, ocr_file",
|
||||
zip(
|
||||
sorted(glob.glob("tests/text/data/gt_*.txt")),
|
||||
sorted(glob.glob("tests/text/data/ocr_*.txt"))
|
||||
)
|
||||
sorted(glob.glob("tests/text/data/ocr_*.txt")),
|
||||
),
|
||||
)
|
||||
def test_align_w_anchor_and_align(gt_file, ocr_file):
|
||||
gt_text = open(gt_file, "r").read()
|
||||
|
@ -21,17 +24,22 @@ def test_align_w_anchor_and_align(gt_file, ocr_file):
|
|||
aligned_anchor_gt = aligned_anchor_gt.split(".")
|
||||
aligned_gt = aligned_gt.split(".")
|
||||
str_diff = "\n".join(difflib.unified_diff(aligned_gt, aligned_anchor_gt))
|
||||
warnings.warn(UserWarning(
|
||||
"\n"+ f"{str_diff}" +
|
||||
f"\n\n**** Inconsistent Alignment Results between align() and " +
|
||||
f"align_w_anchor(). Ignore this if the delta is not significant. ****\n"))
|
||||
warnings.warn(
|
||||
UserWarning(
|
||||
"\n"
|
||||
+ f"{str_diff}"
|
||||
+ "\n\n**** Inconsistent Alignment Results between align() and "
|
||||
+ "align_w_anchor(). Ignore this if the delta is not significant. ****\n"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gt_file, ocr_file",
|
||||
@pytest.mark.parametrize(
|
||||
"gt_file, ocr_file",
|
||||
zip(
|
||||
sorted(glob.glob("tests/text/data/gt_*.txt")),
|
||||
sorted(glob.glob("tests/text/data/ocr_*.txt"))
|
||||
)
|
||||
sorted(glob.glob("tests/text/data/ocr_*.txt")),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("max_seg_length", [25, 50, 75, 100, 150])
|
||||
def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
|
||||
|
@ -39,7 +47,9 @@ def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
|
|||
ocr_text = open(ocr_file, "r").read()
|
||||
gt_tokens = preprocess.tokenize(gt_text)
|
||||
ocr_tokens = preprocess.tokenize(ocr_text)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
|
||||
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
|
||||
)
|
||||
for gt_anchor, ocr_anchor in zip(gt_anchors, ocr_anchors):
|
||||
# Ensure that each anchor word is the same word in both text
|
||||
assert gt_tokens[gt_anchor] == ocr_tokens[ocr_anchor]
|
||||
|
|
|
@ -1,47 +1,62 @@
|
|||
from genalog.text import conll_format
|
||||
from unittest import mock
|
||||
|
||||
import argparse
|
||||
import pytest
|
||||
import glob
|
||||
import itertools
|
||||
|
||||
@pytest.mark.parametrize("required_args", [
|
||||
(["tests/e2e/data/synthetic_dataset", "test_version"])
|
||||
])
|
||||
@pytest.mark.parametrize("optional_args", [
|
||||
import pytest
|
||||
|
||||
from genalog.text import conll_format
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"required_args", [(["tests/e2e/data/synthetic_dataset", "test_version"])]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"optional_args",
|
||||
[
|
||||
(["--train_subset"]),
|
||||
(["--test_subset"]),
|
||||
(["--gt_folder", "shared"]),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_conll_format(required_args, optional_args):
|
||||
parser = conll_format.create_parser()
|
||||
arg_list = required_args + optional_args
|
||||
args = parser.parse_args(args=arg_list)
|
||||
conll_format.main(args)
|
||||
|
||||
|
||||
basepath = "tests/e2e/data/conll_formatter/"
|
||||
|
||||
@pytest.mark.parametrize("clean_label_filename, ocr_text_filename",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clean_label_filename, ocr_text_filename",
|
||||
zip(
|
||||
sorted(glob.glob("tests/e2e/data/conll_formatter/clean_labels/*.txt")),
|
||||
sorted(glob.glob("tests/e2e/data/conll_formatter/ocr_text/*.txt"))
|
||||
)
|
||||
sorted(glob.glob("tests/e2e/data/conll_formatter/ocr_text/*.txt")),
|
||||
),
|
||||
)
|
||||
def test_propagate_labels_sentence_single_file(clean_label_filename, ocr_text_filename):
|
||||
with open(clean_label_filename, 'r', encoding='utf-8') as clf:
|
||||
with open(clean_label_filename, "r", encoding="utf-8") as clf:
|
||||
tokens_labels_str = clf.readlines()
|
||||
clean_tokens = [line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2]
|
||||
clean_labels = [line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2]
|
||||
clean_tokens = [
|
||||
line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2
|
||||
]
|
||||
clean_labels = [
|
||||
line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2
|
||||
]
|
||||
clean_sentences = conll_format.get_sentences_from_iob_format(tokens_labels_str)
|
||||
# read ocr tokens
|
||||
with open(ocr_text_filename, 'r', encoding='utf-8') as otf:
|
||||
ocr_text_str = ' '.join(otf.readlines())
|
||||
ocr_tokens = [token.strip() for token in ocr_text_str.split()] # already tokenized in data
|
||||
with open(ocr_text_filename, "r", encoding="utf-8") as otf:
|
||||
ocr_text_str = " ".join(otf.readlines())
|
||||
ocr_tokens = [
|
||||
token.strip() for token in ocr_text_str.split()
|
||||
] # already tokenized in data
|
||||
|
||||
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
|
||||
assert len(ocr_text_sentences) == len(clean_sentences)
|
||||
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
|
||||
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
|
||||
|
||||
assert len(ocr_sentences_flatten) == len(
|
||||
ocr_tokens
|
||||
) # ensure aligned ocr tokens == ocr tokens
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
|
||||
import pytest
|
||||
import os
|
||||
|
||||
CONTENT = CompositeContent(["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH])
|
||||
import pytest
|
||||
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
|
||||
CONTENT = CompositeContent(
|
||||
["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH]
|
||||
)
|
||||
UNSUPPORTED_CONTENT_FORMAT = ["foo bar"]
|
||||
UNSUPPORTED_CONTENT_TYPE = CompositeContent(["foo"], [ContentType.TITLE])
|
||||
|
||||
|
@ -21,9 +24,10 @@ CUSTOM_STYLE = {
|
|||
"font_family": ["Calibri", "Times"],
|
||||
"font_size": ["10px"],
|
||||
"text_align": ["right"],
|
||||
"hyphenate": [True, False]
|
||||
"hyphenate": [True, False],
|
||||
}
|
||||
|
||||
|
||||
def test_default_template_generation():
|
||||
doc_gen = DocumentGenerator()
|
||||
generator = doc_gen.create_generator(CONTENT, doc_gen.template_list)
|
||||
|
@ -32,20 +36,27 @@ def test_default_template_generation():
|
|||
assert "Unsupported Content Type:" not in html_str
|
||||
assert "No content loaded" not in html_str
|
||||
|
||||
|
||||
def test_default_template_generation_w_unsupported_content_format():
|
||||
doc_gen = DocumentGenerator()
|
||||
generator = doc_gen.create_generator(UNSUPPORTED_CONTENT_FORMAT, doc_gen.template_list)
|
||||
generator = doc_gen.create_generator(
|
||||
UNSUPPORTED_CONTENT_FORMAT, doc_gen.template_list
|
||||
)
|
||||
for doc in generator:
|
||||
html_str = doc.render_html()
|
||||
assert "No content loaded" in html_str
|
||||
|
||||
|
||||
def test_default_template_generation_w_unsupported_content_type():
|
||||
doc_gen = DocumentGenerator()
|
||||
generator = doc_gen.create_generator(UNSUPPORTED_CONTENT_TYPE, ["text_block.html.jinja"])
|
||||
generator = doc_gen.create_generator(
|
||||
UNSUPPORTED_CONTENT_TYPE, ["text_block.html.jinja"]
|
||||
)
|
||||
for doc in generator:
|
||||
html_str = doc.render_html()
|
||||
assert "Unsupported Content Type: ContentType.TITLE" in html_str
|
||||
|
||||
|
||||
def test_custom_template_generation():
|
||||
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
|
||||
generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME])
|
||||
|
@ -53,6 +64,7 @@ def test_custom_template_generation():
|
|||
result = doc.render_html()
|
||||
assert result == str(CONTENT)
|
||||
|
||||
|
||||
def test_undefined_template_generation():
|
||||
doc_gen = DocumentGenerator()
|
||||
assert UNDEFINED_TEMPLATE_NAME not in doc_gen.template_list
|
||||
|
@ -60,6 +72,7 @@ def test_undefined_template_generation():
|
|||
with pytest.raises(FileNotFoundError):
|
||||
next(generator)
|
||||
|
||||
|
||||
def test_custom_style_template_generation():
|
||||
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
|
||||
assert len(doc_gen.styles_to_generate) == 1
|
||||
|
@ -70,14 +83,16 @@ def test_custom_style_template_generation():
|
|||
result = doc.render_html()
|
||||
assert doc.styles["font_family"] == result
|
||||
|
||||
|
||||
def test_render_pdf_and_png():
|
||||
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
|
||||
generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME])
|
||||
for doc in generator:
|
||||
pdf_bytes = doc.render_pdf()
|
||||
png_bytes = doc.render_png()
|
||||
assert pdf_bytes != None
|
||||
assert png_bytes != None
|
||||
assert pdf_bytes is not None
|
||||
assert png_bytes is not None
|
||||
|
||||
|
||||
def test_save_document_as_png():
|
||||
if not os.path.exists(TEST_OUTPUT_DIR):
|
||||
|
@ -89,6 +104,7 @@ def test_save_document_as_png():
|
|||
# Check if the document is saved in filepath
|
||||
assert os.path.exists(FILE_DESTINATION)
|
||||
|
||||
|
||||
def test_save_document_as_separate_png():
|
||||
if not os.path.exists(TEST_OUTPUT_DIR):
|
||||
os.mkdir(TEST_OUTPUT_DIR)
|
||||
|
@ -102,6 +118,7 @@ def test_save_document_as_separate_png():
|
|||
printed_doc_name = FILE_DESTINATION.replace(".png", f"_pg_{page_num}.png")
|
||||
assert os.path.exists(printed_doc_name)
|
||||
|
||||
|
||||
def test_overwriting_style():
|
||||
new_font = "NewFontFamily"
|
||||
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
|
||||
|
|
|
@ -1,23 +1,25 @@
|
|||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.degradation.degrader import Degrader
|
||||
from genalog.degradation import effect
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
TEST_OUTPUT_DIR = "test_out/"
|
||||
SAMPLE_TXT = "Everton 's Duncan Ferguson , who scored twice against Manchester United on Wednesday , was picked on Thursday for the Scottish squad after a 20-month exile ."
|
||||
SAMPLE_TXT = """Everton 's Duncan Ferguson , who scored twice against Manchester United on Wednesday ,
|
||||
was picked on Thursday for the Scottish squad after a 20-month exile ."""
|
||||
DEFAULT_TEMPLATE = "text_block.html.jinja"
|
||||
DEGRADATION_EFFECTS = [
|
||||
("blur", {"radius": 5}),
|
||||
("bleed_through", {"alpha": 0.8}),
|
||||
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "plus"}),
|
||||
(
|
||||
"morphology",
|
||||
{"operation": "open", "kernel_shape": (3, 3), "kernel_type": "plus"},
|
||||
),
|
||||
("morphology", {"operation": "close"}),
|
||||
("morphology", {"operation": "dilate"}),
|
||||
("morphology", {"operation": "erode"})
|
||||
("morphology", {"operation": "erode"}),
|
||||
]
|
||||
|
||||
|
||||
def test_generation_and_degradation():
|
||||
# Initiate content
|
||||
content = CompositeContent([SAMPLE_TXT], [ContentType.PARAGRAPH])
|
||||
|
@ -32,6 +34,4 @@ def test_generation_and_degradation():
|
|||
# get the image in bytes in RGBA channels
|
||||
src = doc.render_array(resolution=100, channel="GRAYSCALE")
|
||||
# run each degradation effect
|
||||
dst = degrader.apply_effects(src)
|
||||
|
||||
|
||||
degrader.apply_effects(src)
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import cv2
|
||||
import pytest
|
||||
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
|
||||
TEMPLATE_PATH = "tests/e2e/templates"
|
||||
TEST_OUT_FOLDER = "test_out/"
|
||||
SAMPLE_TXT = "foo"
|
||||
CONTENT = CompositeContent([SAMPLE_TXT], [ContentType.PARAGRAPH])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def doc_generator():
|
||||
return DocumentGenerator(template_path=TEMPLATE_PATH)
|
||||
|
||||
|
||||
def test_red_channel(doc_generator):
|
||||
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
|
||||
for doc in generator:
|
||||
|
@ -23,6 +24,7 @@ def test_red_channel(doc_generator):
|
|||
assert tuple(img_array[0][0]) == (0, 0, 255, 255)
|
||||
cv2.imwrite(TEST_OUT_FOLDER + "red.png", img_array)
|
||||
|
||||
|
||||
def test_green_channel(doc_generator):
|
||||
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
|
||||
for doc in generator:
|
||||
|
@ -32,6 +34,7 @@ def test_green_channel(doc_generator):
|
|||
assert tuple(img_array[0][0]) == (0, 128, 0, 255)
|
||||
cv2.imwrite(TEST_OUT_FOLDER + "green.png", img_array)
|
||||
|
||||
|
||||
def test_blue_channel(doc_generator):
|
||||
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
|
||||
for doc in generator:
|
||||
|
|
|
@ -1,30 +1,36 @@
|
|||
from genalog.ocr.rest_client import GrokRestClient
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from genalog.ocr.blob_client import GrokBlobClient
|
||||
from genalog.ocr.grok import Grok
|
||||
import requests
|
||||
import pytest
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv("tests/ocr/.env")
|
||||
|
||||
|
||||
class TestBlobClient:
|
||||
@pytest.mark.parametrize("use_async", [True, False])
|
||||
def test_upload_images(self, use_async):
|
||||
blob_client = GrokBlobClient.create_from_env_var()
|
||||
subfolder = "tests/ocr/data/img"
|
||||
file_prefix = subfolder.replace("/", "_")
|
||||
dst_folder, _ = blob_client.upload_images_to_blob(subfolder, use_async=use_async)
|
||||
subfolder.replace("/", "_")
|
||||
dst_folder, _ = blob_client.upload_images_to_blob(
|
||||
subfolder, use_async=use_async
|
||||
)
|
||||
uploaded_items, _ = blob_client.list_blobs(dst_folder)
|
||||
uploaded_items = sorted(list(uploaded_items), key=lambda x: x.name)
|
||||
assert uploaded_items[0].name == f"{dst_folder}/0.png"
|
||||
assert uploaded_items[1].name == f"{dst_folder}/1.png"
|
||||
assert uploaded_items[2].name == f"{dst_folder}/11.png"
|
||||
blob_client.delete_blobs_folder(dst_folder)
|
||||
assert len(list(blob_client.list_blobs(dst_folder)[0])) == 0, f"folder {dst_folder} was not deleted"
|
||||
assert (
|
||||
len(list(blob_client.list_blobs(dst_folder)[0])) == 0
|
||||
), f"folder {dst_folder} was not deleted"
|
||||
|
||||
dst_folder, _ = blob_client.upload_images_to_blob(subfolder, "test_images", use_async=use_async)
|
||||
dst_folder, _ = blob_client.upload_images_to_blob(
|
||||
subfolder, "test_images", use_async=use_async
|
||||
)
|
||||
assert dst_folder == "test_images"
|
||||
uploaded_items, _ = blob_client.list_blobs(dst_folder)
|
||||
uploaded_items = sorted(list(uploaded_items), key=lambda x: x.name)
|
||||
|
@ -32,17 +38,23 @@ class TestBlobClient:
|
|||
assert uploaded_items[1].name == f"{dst_folder}/1.png"
|
||||
assert uploaded_items[2].name == f"{dst_folder}/11.png"
|
||||
blob_client.delete_blobs_folder(dst_folder)
|
||||
assert len(list(blob_client.list_blobs(dst_folder)[0])) == 0, f"folder {dst_folder} was not deleted"
|
||||
assert (
|
||||
len(list(blob_client.list_blobs(dst_folder)[0])) == 0
|
||||
), f"folder {dst_folder} was not deleted"
|
||||
|
||||
|
||||
class TestGROKe2e:
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
def test_grok_e2e(self, tmpdir, use_async):
|
||||
grok = Grok.create_from_env_var()
|
||||
src_folder = "tests/ocr/data/img"
|
||||
grok.run_grok(src_folder, tmpdir, blob_dest_folder="testimages", use_async=use_async, cleanup=True)
|
||||
json_folder = "tests/ocr/data/json"
|
||||
json_hash = "521c38122f783673598856cd81d91c21"
|
||||
grok.run_grok(
|
||||
src_folder,
|
||||
tmpdir,
|
||||
blob_dest_folder="testimages",
|
||||
use_async=use_async,
|
||||
cleanup=True,
|
||||
)
|
||||
assert json.load(open(f"{tmpdir}/0.json", "r"))[0]["text"]
|
||||
assert json.load(open(f"{tmpdir}/1.json", "r"))[0]["text"]
|
||||
assert json.load(open(f"{tmpdir}/11.json", "r"))[0]["text"]
|
||||
|
||||
|
|
|
@ -1,32 +1,43 @@
|
|||
from genalog import pipeline
|
||||
|
||||
import pytest
|
||||
import glob
|
||||
|
||||
import pytest
|
||||
|
||||
from genalog import pipeline
|
||||
|
||||
EXAMPLE_TEXT_FILE = "tests/text/data/gt_1.txt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_analog_generator():
|
||||
return pipeline.AnalogDocumentGeneration()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_analog_generator():
|
||||
custom_styles = {"font_size": ["5px"]}
|
||||
custom_degradation = [("blur", {"radius": 3})]
|
||||
return pipeline.AnalogDocumentGeneration(
|
||||
styles=custom_styles,
|
||||
degradations=custom_degradation,
|
||||
resolution=300)
|
||||
styles=custom_styles, degradations=custom_degradation, resolution=300
|
||||
)
|
||||
|
||||
|
||||
def test_default_generate_img(default_analog_generator):
|
||||
example_template = default_analog_generator.list_templates()[0]
|
||||
img_array = default_analog_generator.generate_img(EXAMPLE_TEXT_FILE, example_template, target_folder=None)
|
||||
default_analog_generator.generate_img(
|
||||
EXAMPLE_TEXT_FILE, example_template, target_folder=None
|
||||
)
|
||||
|
||||
|
||||
def test_custom_generate_img(custom_analog_generator):
|
||||
example_template = custom_analog_generator.list_templates()[0]
|
||||
img_array = custom_analog_generator.generate_img(EXAMPLE_TEXT_FILE, example_template, target_folder=None)
|
||||
custom_analog_generator.generate_img(
|
||||
EXAMPLE_TEXT_FILE, example_template, target_folder=None
|
||||
)
|
||||
|
||||
|
||||
def test_generate_dataset_multiprocess():
|
||||
INPUT_TEXT_FILENAMES = glob.glob("tests/text/data/gt_*.txt")
|
||||
with pytest.deprecated_call():
|
||||
pipeline.generate_dataset_multiprocess(INPUT_TEXT_FILENAMES, "test_out", {}, [], "text_block.html.jinja")
|
||||
pipeline.generate_dataset_multiprocess(
|
||||
INPUT_TEXT_FILENAMES, "test_out", {}, [], "text_block.html.jinja"
|
||||
)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import os
|
||||
import difflib
|
||||
import os
|
||||
|
||||
from genalog.text.splitter import CONLL2003_DOC_SEPERATOR, generate_splits
|
||||
|
||||
from genalog.text.splitter import generate_splits, CONLL2003_DOC_SEPERATOR
|
||||
|
||||
def _compare_content(file1, file2):
|
||||
txt1 = open(file1, "r").read()
|
||||
|
@ -12,15 +13,32 @@ def _compare_content(file1, file2):
|
|||
str_diff = "\n".join(difflib.unified_diff(sentences_txt1, sentences_txt2))
|
||||
assert False, f"Delta between outputs: \n {str_diff}"
|
||||
|
||||
|
||||
def test_splitter(tmpdir):
|
||||
# tmpdir = "test_out"
|
||||
os.makedirs(f"{tmpdir}/clean_labels")
|
||||
os.makedirs(f"{tmpdir}/clean_text")
|
||||
|
||||
generate_splits("tests/e2e/data/splitter/example_conll2012.txt", tmpdir,
|
||||
doc_seperator=CONLL2003_DOC_SEPERATOR, sentence_seperator="")
|
||||
generate_splits(
|
||||
"tests/e2e/data/splitter/example_conll2012.txt",
|
||||
tmpdir,
|
||||
doc_seperator=CONLL2003_DOC_SEPERATOR,
|
||||
sentence_seperator="",
|
||||
)
|
||||
|
||||
_compare_content("tests/e2e/data/splitter/example_splits/clean_text/0.txt", f"{tmpdir}/clean_text/0.txt")
|
||||
_compare_content("tests/e2e/data/splitter/example_splits/clean_text/1.txt", f"{tmpdir}/clean_text/1.txt")
|
||||
_compare_content("tests/e2e/data/splitter/example_splits/clean_labels/0.txt", f"{tmpdir}/clean_labels/0.txt")
|
||||
_compare_content("tests/e2e/data/splitter/example_splits/clean_labels/1.txt", f"{tmpdir}/clean_labels/1.txt")
|
||||
_compare_content(
|
||||
"tests/e2e/data/splitter/example_splits/clean_text/0.txt",
|
||||
f"{tmpdir}/clean_text/0.txt",
|
||||
)
|
||||
_compare_content(
|
||||
"tests/e2e/data/splitter/example_splits/clean_text/1.txt",
|
||||
f"{tmpdir}/clean_text/1.txt",
|
||||
)
|
||||
_compare_content(
|
||||
"tests/e2e/data/splitter/example_splits/clean_labels/0.txt",
|
||||
f"{tmpdir}/clean_labels/0.txt",
|
||||
)
|
||||
_compare_content(
|
||||
"tests/e2e/data/splitter/example_splits/clean_labels/1.txt",
|
||||
f"{tmpdir}/clean_labels/1.txt",
|
||||
)
|
||||
|
|
|
@ -1,62 +1,76 @@
|
|||
from genalog.generation.content import *
|
||||
|
||||
import pytest
|
||||
|
||||
from genalog.generation.content import CompositeContent, Content, ContentType
|
||||
from genalog.generation.content import Paragraph, Title
|
||||
|
||||
CONTENT_LIST = ["foo", "bar"]
|
||||
COMPOSITE_CONTENT_TYPE = [ContentType.TITLE, ContentType.PARAGRAPH]
|
||||
TEXT = "foo bar"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def content_base_class():
|
||||
return Content()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def paragraph():
|
||||
return Paragraph(TEXT)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def title():
|
||||
return Title(TEXT)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def section():
|
||||
return CompositeContent(CONTENT_LIST, COMPOSITE_CONTENT_TYPE)
|
||||
|
||||
|
||||
def test_content_set_content_type(content_base_class):
|
||||
with pytest.raises(TypeError):
|
||||
content_base_class.set_content_type("NOT VALID CONTENT TYPE")
|
||||
content_base_class.set_content_type(ContentType.PARAGRAPH)
|
||||
|
||||
|
||||
def test_paragraph_init(paragraph):
|
||||
with pytest.raises(TypeError):
|
||||
Paragraph([])
|
||||
assert paragraph.content_type == ContentType.PARAGRAPH
|
||||
|
||||
|
||||
def test_paragraph_print(paragraph):
|
||||
assert paragraph.__str__()
|
||||
|
||||
|
||||
def test_paragraph_iterable_indexable(paragraph):
|
||||
for index, character in enumerate(paragraph):
|
||||
assert character == paragraph[index]
|
||||
|
||||
|
||||
def test_title_init(title):
|
||||
with pytest.raises(TypeError):
|
||||
Title([])
|
||||
assert title.content_type == ContentType.TITLE
|
||||
|
||||
|
||||
def test_title_iterable_indexable(title):
|
||||
for index, character in enumerate(title):
|
||||
assert character == title[index]
|
||||
|
||||
|
||||
def test_composite_content_init(section):
|
||||
with pytest.raises(TypeError):
|
||||
CompositeContent((), [])
|
||||
assert section.content_type == ContentType.COMPOSITE
|
||||
|
||||
|
||||
def test_composite_content_iterable(section):
|
||||
for index, content in enumerate(section):
|
||||
assert content.content_type == COMPOSITE_CONTENT_TYPE[index]
|
||||
|
||||
|
||||
def test_composite_content_print(section):
|
||||
assert "foo" in section.__str__()
|
||||
assert "bar" in section.__str__()
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from genalog.generation.document import Document, DocumentGenerator
|
||||
from genalog.generation.document import DEFAULT_DOCUMENT_STYLE
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from genalog.generation.document import DEFAULT_DOCUMENT_STYLE
|
||||
from genalog.generation.document import Document, DocumentGenerator
|
||||
|
||||
|
||||
FRENCH = "fr"
|
||||
CONTENT = ["some text"]
|
||||
|
@ -21,53 +23,72 @@ DEFAULT_TEMPLATE_NAME = "text_block.html.jinja"
|
|||
DEFAULT_PACKAGE_NAME = "genalog.generation"
|
||||
DEFAULT_TEMPLATE_FOLDER = "templates"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_document():
|
||||
mock_jinja_template = MagicMock()
|
||||
mock_jinja_template.render.return_value = MOCK_COMPILED_DOCUMENT
|
||||
return Document(CONTENT, mock_jinja_template)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def french_document():
|
||||
mock_jinja_template = MagicMock()
|
||||
mock_jinja_template.render.return_value = MOCK_COMPILED_DOCUMENT
|
||||
return Document(CONTENT, mock_jinja_template, language=FRENCH)
|
||||
|
||||
|
||||
def test_document_init(default_document):
|
||||
assert default_document.styles == DEFAULT_DOCUMENT_STYLE
|
||||
assert default_document._document != None
|
||||
assert default_document.compiled_html != None
|
||||
assert default_document._document is not None
|
||||
assert default_document.compiled_html is not None
|
||||
|
||||
|
||||
def test_document_init_with_kwargs(french_document):
|
||||
assert french_document.styles["language"] == FRENCH
|
||||
assert french_document._document != None
|
||||
assert french_document.compiled_html != None
|
||||
assert french_document._document is not None
|
||||
assert french_document.compiled_html is not None
|
||||
|
||||
|
||||
def test_document_render_html(french_document):
|
||||
compiled_document = french_document.render_html()
|
||||
assert compiled_document == MOCK_COMPILED_DOCUMENT
|
||||
french_document.template.render.assert_called_with(content=CONTENT, **french_document.styles)
|
||||
french_document.template.render.assert_called_with(
|
||||
content=CONTENT, **french_document.styles
|
||||
)
|
||||
|
||||
|
||||
def test_document_render_pdf(default_document):
|
||||
default_document._document = MagicMock()
|
||||
# run tested function
|
||||
default_document.render_pdf(target=FILE_DESTINATION_PDF, zoom=2)
|
||||
default_document._document.write_pdf.assert_called_with(target=FILE_DESTINATION_PDF, zoom=2)
|
||||
default_document._document.write_pdf.assert_called_with(
|
||||
target=FILE_DESTINATION_PDF, zoom=2
|
||||
)
|
||||
|
||||
|
||||
def test_document_render_png(default_document):
|
||||
default_document._document = MagicMock()
|
||||
# run tested function
|
||||
default_document.render_png(target=FILE_DESTINATION_PNG, resolution=100)
|
||||
default_document._document.write_png.assert_called_with(target=FILE_DESTINATION_PNG, resolution=100)
|
||||
default_document._document.write_png.assert_called_with(
|
||||
target=FILE_DESTINATION_PNG, resolution=100
|
||||
)
|
||||
|
||||
|
||||
def test_document_render_png_split_pages(default_document):
|
||||
default_document._document.copy = MagicMock()
|
||||
# run tested function
|
||||
default_document.render_png(target=FILE_DESTINATION_PNG, split_pages=True, resolution=100)
|
||||
default_document.render_png(
|
||||
target=FILE_DESTINATION_PNG, split_pages=True, resolution=100
|
||||
)
|
||||
result_destination = FILE_DESTINATION_PNG.replace(".png", "_pg_0.png")
|
||||
# assertion
|
||||
document_copy = default_document._document.copy.return_value
|
||||
document_copy.write_png.assert_called_with(target=result_destination, resolution=100)
|
||||
document_copy.write_png.assert_called_with(
|
||||
target=result_destination, resolution=100
|
||||
)
|
||||
|
||||
|
||||
def test_document_render_array_valid_args(default_document):
|
||||
# setup mock
|
||||
|
@ -84,11 +105,13 @@ def test_document_render_array_valid_args(default_document):
|
|||
img_array = default_document.render_array(resolution=100, channel=channel_type)
|
||||
assert img_array.shape == expected_img_shape
|
||||
|
||||
|
||||
def test_document_render_array_invalid_args(default_document):
|
||||
invalid_channel_types = "INVALID"
|
||||
with pytest.raises(ValueError):
|
||||
default_document.render_array(resolution=100, channel=invalid_channel_types)
|
||||
|
||||
|
||||
def test_document_render_array_invalid_format(default_document):
|
||||
# setup mock
|
||||
mock_surface = MagicMock()
|
||||
|
@ -99,6 +122,7 @@ def test_document_render_array_invalid_format(default_document):
|
|||
with pytest.raises(RuntimeError):
|
||||
default_document.render_array(resolution=100)
|
||||
|
||||
|
||||
def test_document_update_style(default_document):
|
||||
new_style = {"language": FRENCH, "new_property": "some value"}
|
||||
# Ensure that a new property is not already defined
|
||||
|
@ -111,11 +135,13 @@ def test_document_update_style(default_document):
|
|||
# Ensure that a new property is added
|
||||
assert default_document.styles["new_property"] == new_style["new_property"]
|
||||
|
||||
|
||||
@patch("genalog.generation.document.Environment")
|
||||
@patch("genalog.generation.document.PackageLoader")
|
||||
@patch("genalog.generation.document.FileSystemLoader")
|
||||
def test_document_generator_init_default_setting(mock_file_system_loader,
|
||||
mock_package_loader, mock_environment):
|
||||
def test_document_generator_init_default_setting(
|
||||
mock_file_system_loader, mock_package_loader, mock_environment
|
||||
):
|
||||
# setup mock template environment
|
||||
mock_environment_instance = mock_environment.return_value
|
||||
mock_environment_instance.list_templates.return_value = [DEFAULT_TEMPLATE_NAME]
|
||||
|
@ -123,15 +149,19 @@ def test_document_generator_init_default_setting(mock_file_system_loader,
|
|||
document_generator = DocumentGenerator()
|
||||
# Ensure the right loader is called
|
||||
mock_file_system_loader.assert_not_called()
|
||||
mock_package_loader.assert_called_with(DEFAULT_PACKAGE_NAME, DEFAULT_TEMPLATE_FOLDER)
|
||||
mock_package_loader.assert_called_with(
|
||||
DEFAULT_PACKAGE_NAME, DEFAULT_TEMPLATE_FOLDER
|
||||
)
|
||||
# Ensure that the default template in the package is loaded
|
||||
assert DEFAULT_TEMPLATE_NAME in document_generator.template_list
|
||||
|
||||
|
||||
@patch("genalog.generation.document.Environment")
|
||||
@patch("genalog.generation.document.PackageLoader")
|
||||
@patch("genalog.generation.document.FileSystemLoader")
|
||||
def test_document_generator_init_custom_template(mock_file_system_loader,
|
||||
mock_package_loader, mock_environment):
|
||||
def test_document_generator_init_custom_template(
|
||||
mock_file_system_loader, mock_package_loader, mock_environment
|
||||
):
|
||||
# setup mock template environment
|
||||
mock_environment_instance = mock_environment.return_value
|
||||
mock_environment_instance.list_templates.return_value = [CUSTOM_TEMPLATE_NAME]
|
||||
|
@ -143,67 +173,79 @@ def test_document_generator_init_custom_template(mock_file_system_loader,
|
|||
# Ensure that the expected template is registered
|
||||
assert CUSTOM_TEMPLATE_NAME in document_generator.template_list
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_document_generator():
|
||||
with patch("genalog.generation.document.Environment") as MockEnvironment:
|
||||
template_environment_instance = MockEnvironment.return_value
|
||||
template_environment_instance.list_templates.return_value = [DEFAULT_TEMPLATE_NAME]
|
||||
template_environment_instance.list_templates.return_value = [
|
||||
DEFAULT_TEMPLATE_NAME
|
||||
]
|
||||
template_environment_instance.get_template.return_value = MOCK_TEMPLATE
|
||||
doc_gen = DocumentGenerator()
|
||||
return doc_gen
|
||||
|
||||
|
||||
def test_document_generator_create_generator(default_document_generator):
|
||||
available_templates = default_document_generator.template_list
|
||||
assert len(available_templates) < 2
|
||||
generator = default_document_generator.create_generator(CONTENT, available_templates)
|
||||
doc = next(generator)
|
||||
generator = default_document_generator.create_generator(
|
||||
CONTENT, available_templates
|
||||
)
|
||||
next(generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
|
||||
|
||||
def test_document_generator_create_generator_(default_document_generator):
|
||||
# setup test case
|
||||
available_templates = default_document_generator.template_list
|
||||
undefined_template = "NOT A VALID TEMPLATE"
|
||||
assert undefined_template not in available_templates
|
||||
|
||||
generator = default_document_generator.create_generator(CONTENT, [undefined_template])
|
||||
generator = default_document_generator.create_generator(
|
||||
CONTENT, [undefined_template]
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
doc = next(generator)
|
||||
next(generator)
|
||||
|
||||
@pytest.mark.parametrize("template_name, expected_output", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"template_name, expected_output",
|
||||
[
|
||||
("base.html.jinja", False),
|
||||
("text_block.html.jinja", True),
|
||||
("text_block.css.jinja", False),
|
||||
("macro/dimension.css.jinja", False)
|
||||
])
|
||||
("macro/dimension.css.jinja", False),
|
||||
],
|
||||
)
|
||||
def test__keep_templates(template_name, expected_output):
|
||||
output = DocumentGenerator._keep_template(template_name)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_set_styles_to_generate(default_document_generator):
|
||||
assert len(default_document_generator.styles_to_generate) == 1
|
||||
default_document_generator.set_styles_to_generate({"foo": ["bar", "bar"]})
|
||||
assert len(default_document_generator.styles_to_generate) == 2
|
||||
|
||||
@pytest.mark.parametrize("styles, expected_output", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"styles, expected_output",
|
||||
[
|
||||
({}, []), # empty case
|
||||
({"size": ["10px"], "color": [] }, []), #empty value will result in null combinations
|
||||
(
|
||||
{"size": ["10px"], "color": ["red"] },
|
||||
[{"size":"10px", "color":"red"}]
|
||||
),
|
||||
(
|
||||
{"size": ["5px", "10px"]},
|
||||
[{"size": "5px"}, {"size": "10px"}]
|
||||
),
|
||||
{"size": ["10px"], "color": []},
|
||||
[],
|
||||
), # empty value will result in null combinations
|
||||
({"size": ["10px"], "color": ["red"]}, [{"size": "10px", "color": "red"}]),
|
||||
({"size": ["5px", "10px"]}, [{"size": "5px"}, {"size": "10px"}]),
|
||||
(
|
||||
{"size": ["10px", "15px"], "color": ["blue"]},
|
||||
[
|
||||
{"size":"10px", "color": "blue"},
|
||||
{"size":"15px", "color": "blue"}
|
||||
]
|
||||
[{"size": "10px", "color": "blue"}, {"size": "15px", "color": "blue"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
])
|
||||
def test_document_generator_expand_style_combinations(styles, expected_output):
|
||||
output = DocumentGenerator.expand_style_combinations(styles)
|
||||
assert output == expected_output
|
|
@ -1,209 +1,608 @@
|
|||
from genalog.ocr.metrics import get_align_stats, get_editops_stats, get_metrics, get_stats
|
||||
from genalog.text.anchor import align_w_anchor
|
||||
from genalog.text.alignment import GAP_CHAR, align
|
||||
from genalog.text.ner_label import _find_gap_char_candidates
|
||||
from pandas._testing import assert_frame_equal
|
||||
|
||||
import pytest
|
||||
|
||||
import genalog.ocr.metrics
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import json
|
||||
import pickle
|
||||
import os
|
||||
from genalog.ocr.metrics import get_align_stats, get_editops_stats, get_stats
|
||||
from genalog.text.alignment import align, GAP_CHAR
|
||||
from genalog.text.ner_label import _find_gap_char_candidates
|
||||
|
||||
|
||||
genalog.ocr.metrics.LOG_LEVEL = 0
|
||||
|
||||
@pytest.mark.parametrize("src_string, target, expected_stats",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src_string, target, expected_stats",
|
||||
[
|
||||
("a worn coat", "a wom coat",
|
||||
{'edit_insert': 1, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
(" ", "a",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("a", " ",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("a", "a",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 0, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("ab", "ac",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("ac", "ab",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("New York is big.", "N ewYork kis big.",
|
||||
{'edit_insert': 0, 'edit_delete': 1, 'edit_replace': 0, 'edit_insert_spacing': 1, 'edit_delete_spacing': 1}),
|
||||
("B oston grea t", "Boston is great",
|
||||
{'edit_insert': 0, 'edit_delete': 2, 'edit_replace': 0, 'edit_insert_spacing': 2, 'edit_delete_spacing': 1}),
|
||||
("New York is big.", "N ewyork kis big",
|
||||
{'edit_insert': 1, 'edit_delete': 1, 'edit_replace': 1, 'edit_insert_spacing': 1, 'edit_delete_spacing': 1}),
|
||||
("dog", "d@g", # Test against default gap_char "@"
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("some@one.com", "some@one.com",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 0, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0})
|
||||
])
|
||||
(
|
||||
"a worn coat",
|
||||
"a wom coat",
|
||||
{
|
||||
"edit_insert": 1,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
" ",
|
||||
"a",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"a",
|
||||
" ",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"a",
|
||||
"a",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 0,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"ab",
|
||||
"ac",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"ac",
|
||||
"ab",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"New York is big.",
|
||||
"N ewYork kis big.",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 1,
|
||||
"edit_replace": 0,
|
||||
"edit_insert_spacing": 1,
|
||||
"edit_delete_spacing": 1,
|
||||
},
|
||||
),
|
||||
(
|
||||
"B oston grea t",
|
||||
"Boston is great",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 2,
|
||||
"edit_replace": 0,
|
||||
"edit_insert_spacing": 2,
|
||||
"edit_delete_spacing": 1,
|
||||
},
|
||||
),
|
||||
(
|
||||
"New York is big.",
|
||||
"N ewyork kis big",
|
||||
{
|
||||
"edit_insert": 1,
|
||||
"edit_delete": 1,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 1,
|
||||
"edit_delete_spacing": 1,
|
||||
},
|
||||
),
|
||||
(
|
||||
"dog",
|
||||
"d@g", # Test against default gap_char "@"
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"some@one.com",
|
||||
"some@one.com",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 0,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_editops_stats(src_string, target, expected_stats):
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
|
||||
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(
|
||||
[src_string], [target]
|
||||
)
|
||||
gap_char = (
|
||||
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
)
|
||||
alignment = align(target, src_string)
|
||||
stats, actions = get_editops_stats(alignment, gap_char)
|
||||
for k in expected_stats:
|
||||
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
|
||||
|
||||
@pytest.mark.parametrize("src_string, target, expected_stats, expected_substitutions",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src_string, target, expected_stats, expected_substitutions",
|
||||
[
|
||||
(
|
||||
"a worn coat", "a wom coat",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 11, 'total_words': 3,
|
||||
'matching_chars': 9, 'matching_words': 2,"matching_alnum_words" : 2, 'word_accuracy': 2/3, 'char_accuracy': 9/11},
|
||||
{('rn', 'm'): 1}
|
||||
"a worn coat",
|
||||
"a wom coat",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 11,
|
||||
"total_words": 3,
|
||||
"matching_chars": 9,
|
||||
"matching_words": 2,
|
||||
"matching_alnum_words": 2,
|
||||
"word_accuracy": 2 / 3,
|
||||
"char_accuracy": 9 / 11,
|
||||
},
|
||||
{("rn", "m"): 1},
|
||||
),
|
||||
(
|
||||
"a c", "def",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1 , 'spacing': 0, 'total_chars': 3, 'total_words': 1,
|
||||
'matching_chars': 0, 'matching_words': 0,"matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 0},
|
||||
{('a c', 'def'): 1}
|
||||
"a c",
|
||||
"def",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"matching_chars": 0,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0,
|
||||
"char_accuracy": 0,
|
||||
},
|
||||
{("a c", "def"): 1},
|
||||
),
|
||||
(
|
||||
"a", "a b",
|
||||
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 2,
|
||||
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 0.5, 'char_accuracy': 1/3},
|
||||
{}
|
||||
"a",
|
||||
"a b",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 1,
|
||||
"total_chars": 3,
|
||||
"total_words": 2,
|
||||
"matching_chars": 1,
|
||||
"matching_words": 1,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 0.5,
|
||||
"char_accuracy": 1 / 3,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"a b", "b",
|
||||
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 1,
|
||||
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1, 'char_accuracy': 1/3},
|
||||
{}
|
||||
"a b",
|
||||
"b",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 1,
|
||||
"replace": 0,
|
||||
"spacing": 1,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_words": 1,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 1,
|
||||
"char_accuracy": 1 / 3,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"a b", "a",
|
||||
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 1,
|
||||
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1, 'char_accuracy': 1/3},
|
||||
{}
|
||||
"a b",
|
||||
"a",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 1,
|
||||
"replace": 0,
|
||||
"spacing": 1,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_words": 1,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 1,
|
||||
"char_accuracy": 1 / 3,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"b ..", "a b ..",
|
||||
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 1, 'total_chars': 6, 'total_words': 3, 'total_alnum_words': 2,
|
||||
'matching_chars': 4, 'matching_words': 2, "matching_alnum_words" : 1, 'word_accuracy': 2/3, 'char_accuracy': 4/6},
|
||||
{}
|
||||
"b ..",
|
||||
"a b ..",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 1,
|
||||
"total_chars": 6,
|
||||
"total_words": 3,
|
||||
"total_alnum_words": 2,
|
||||
"matching_chars": 4,
|
||||
"matching_words": 2,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 2 / 3,
|
||||
"char_accuracy": 4 / 6,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"taxi cab", "taxl c b",
|
||||
{'insert': 0, 'delete': 1, 'replace': 1, 'spacing': 1, 'total_chars': 9, 'total_words': 3,
|
||||
'matching_chars': 6, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 6/9},
|
||||
{('i','l'):1}
|
||||
"taxi cab",
|
||||
"taxl c b",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 1,
|
||||
"replace": 1,
|
||||
"spacing": 1,
|
||||
"total_chars": 9,
|
||||
"total_words": 3,
|
||||
"matching_chars": 6,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0,
|
||||
"char_accuracy": 6 / 9,
|
||||
},
|
||||
{("i", "l"): 1},
|
||||
),
|
||||
(
|
||||
"taxl c b ri de", "taxi cab ride",
|
||||
{'insert': 1, 'delete': 0, 'replace': 1, 'spacing': 6, 'total_chars': 18, 'total_words': 3,
|
||||
'matching_chars': 11, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 11/18},
|
||||
{('l','i'):1}
|
||||
"taxl c b ri de",
|
||||
"taxi cab ride",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 6,
|
||||
"total_chars": 18,
|
||||
"total_words": 3,
|
||||
"matching_chars": 11,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0,
|
||||
"char_accuracy": 11 / 18,
|
||||
},
|
||||
{("l", "i"): 1},
|
||||
),
|
||||
(
|
||||
"ab", "ac",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 2, 'total_words': 1,
|
||||
'matching_chars': 1, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0.0, 'char_accuracy': 0.5},
|
||||
{}
|
||||
"ab",
|
||||
"ac",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 2,
|
||||
"total_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 0.5,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"a", "a",
|
||||
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 1, 'total_words': 1,
|
||||
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1.0, 'char_accuracy': 1.0},
|
||||
{}
|
||||
"a",
|
||||
"a",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": 1,
|
||||
"total_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_words": 1,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 1.0,
|
||||
"char_accuracy": 1.0,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"New York is big.", "N ewYork kis big.",
|
||||
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 2, 'total_chars': 17, 'total_words': 4,
|
||||
'matching_chars': 15, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1/4, 'char_accuracy': 15/17},
|
||||
{}
|
||||
"New York is big.",
|
||||
"N ewYork kis big.",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 2,
|
||||
"total_chars": 17,
|
||||
"total_words": 4,
|
||||
"matching_chars": 15,
|
||||
"matching_words": 1,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 1 / 4,
|
||||
"char_accuracy": 15 / 17,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"B oston grea t", "Boston is great",
|
||||
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 3, 'total_chars': 15, 'total_words': 3,
|
||||
'matching_chars': 12, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0.0, 'char_accuracy': 0.8},
|
||||
{}
|
||||
"B oston grea t",
|
||||
"Boston is great",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 3,
|
||||
"total_chars": 15,
|
||||
"total_words": 3,
|
||||
"matching_chars": 12,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 0.8,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"New York is big.", "N ewyork kis big",
|
||||
{'insert': 1, 'delete': 1, 'replace': 1, 'spacing': 2, 'total_chars': 16, 'total_words': 4,
|
||||
'matching_chars': 13, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 13/16},
|
||||
{('Y', 'y'): 1}
|
||||
"New York is big.",
|
||||
"N ewyork kis big",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 1,
|
||||
"replace": 1,
|
||||
"spacing": 2,
|
||||
"total_chars": 16,
|
||||
"total_words": 4,
|
||||
"matching_chars": 13,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0,
|
||||
"char_accuracy": 13 / 16,
|
||||
},
|
||||
{("Y", "y"): 1},
|
||||
),
|
||||
(
|
||||
"dog", "d@g",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
|
||||
{('o', '@'): 1}
|
||||
"dog",
|
||||
"d@g",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 2,
|
||||
"matching_alnum_words": 0,
|
||||
"matching_words": 0,
|
||||
"alnum_word_accuracy": 0.0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 2 / 3,
|
||||
},
|
||||
{("o", "@"): 1},
|
||||
),
|
||||
(
|
||||
"some@one.com", "some@one.com",
|
||||
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 12, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 12, 'matching_alnum_words': 1, 'matching_words': 1, 'alnum_word_accuracy': 1.0, 'word_accuracy': 1.0, 'char_accuracy': 1.0}, {}
|
||||
"some@one.com",
|
||||
"some@one.com",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": 12,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 12,
|
||||
"matching_alnum_words": 1,
|
||||
"matching_words": 1,
|
||||
"alnum_word_accuracy": 1.0,
|
||||
"word_accuracy": 1.0,
|
||||
"char_accuracy": 1.0,
|
||||
},
|
||||
{},
|
||||
),
|
||||
],
|
||||
)
|
||||
])
|
||||
def test_align_stats(src_string, target, expected_stats, expected_substitutions):
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
|
||||
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(
|
||||
[src_string], [target]
|
||||
)
|
||||
gap_char = (
|
||||
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
)
|
||||
alignment = align(src_string, target, gap_char=gap_char)
|
||||
stats, substitution_dict = get_align_stats(alignment, src_string, target, gap_char)
|
||||
for k in expected_stats:
|
||||
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
|
||||
for k in expected_substitutions:
|
||||
assert substitution_dict[k] == expected_substitutions[k], (substitution_dict, expected_substitutions)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("src_string, target, expected_stats, expected_substitutions, expected_actions", [
|
||||
(
|
||||
"ab", "a",
|
||||
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 0, 'total_chars': 2, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 1, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 1/2},
|
||||
{},
|
||||
{1: 'D'}
|
||||
),
|
||||
(
|
||||
"ab", "abb",
|
||||
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
|
||||
{},
|
||||
{2: ('I', 'b')}
|
||||
),
|
||||
(
|
||||
"ab", "ac",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 2, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 1, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 1/2},
|
||||
{('b', 'c'): 1},
|
||||
{1: ('R', 'c')}
|
||||
),
|
||||
(
|
||||
"New York is big.", "N ewyork kis big",
|
||||
{'insert': 1, 'delete': 1, 'replace': 1, 'spacing': 2, 'total_chars': 16, 'total_words': 4,
|
||||
'matching_chars': 13, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 13/16},
|
||||
{('Y', 'y'): 1},
|
||||
{1: ('I', ' '), 4: 'D', 5: ('R', 'y'), 10: ('I', 'k'), 17: 'D'}
|
||||
),
|
||||
(
|
||||
"dog", "d@g",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
|
||||
{('o', '@'): 1},
|
||||
{1: ('R', '@')}
|
||||
),
|
||||
(
|
||||
"some@one.com", "some@one.com",
|
||||
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 12, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 12, 'matching_alnum_words': 1, 'matching_words': 1, 'alnum_word_accuracy': 1.0, 'word_accuracy': 1.0, 'char_accuracy': 1.0},
|
||||
{},
|
||||
{}
|
||||
assert substitution_dict[k] == expected_substitutions[k], (
|
||||
substitution_dict,
|
||||
expected_substitutions,
|
||||
)
|
||||
])
|
||||
def test_get_stats(src_string, target, expected_stats, expected_substitutions, expected_actions ):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src_string, target, expected_stats, expected_substitutions, expected_actions",
|
||||
[
|
||||
(
|
||||
"ab",
|
||||
"a",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 1,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": 2,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_alnum_words": 0,
|
||||
"matching_words": 0,
|
||||
"alnum_word_accuracy": 0.0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 1 / 2,
|
||||
},
|
||||
{},
|
||||
{1: "D"},
|
||||
),
|
||||
(
|
||||
"ab",
|
||||
"abb",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 2,
|
||||
"matching_alnum_words": 0,
|
||||
"matching_words": 0,
|
||||
"alnum_word_accuracy": 0.0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 2 / 3,
|
||||
},
|
||||
{},
|
||||
{2: ("I", "b")},
|
||||
),
|
||||
(
|
||||
"ab",
|
||||
"ac",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 2,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_alnum_words": 0,
|
||||
"matching_words": 0,
|
||||
"alnum_word_accuracy": 0.0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 1 / 2,
|
||||
},
|
||||
{("b", "c"): 1},
|
||||
{1: ("R", "c")},
|
||||
),
|
||||
(
|
||||
"New York is big.",
|
||||
"N ewyork kis big",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 1,
|
||||
"replace": 1,
|
||||
"spacing": 2,
|
||||
"total_chars": 16,
|
||||
"total_words": 4,
|
||||
"matching_chars": 13,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0,
|
||||
"char_accuracy": 13 / 16,
|
||||
},
|
||||
{("Y", "y"): 1},
|
||||
{1: ("I", " "), 4: "D", 5: ("R", "y"), 10: ("I", "k"), 17: "D"},
|
||||
),
|
||||
(
|
||||
"dog",
|
||||
"d@g",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 2,
|
||||
"matching_alnum_words": 0,
|
||||
"matching_words": 0,
|
||||
"alnum_word_accuracy": 0.0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 2 / 3,
|
||||
},
|
||||
{("o", "@"): 1},
|
||||
{1: ("R", "@")},
|
||||
),
|
||||
(
|
||||
"some@one.com",
|
||||
"some@one.com",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": 12,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 12,
|
||||
"matching_alnum_words": 1,
|
||||
"matching_words": 1,
|
||||
"alnum_word_accuracy": 1.0,
|
||||
"word_accuracy": 1.0,
|
||||
"char_accuracy": 1.0,
|
||||
},
|
||||
{},
|
||||
{},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_stats(
|
||||
src_string, target, expected_stats, expected_substitutions, expected_actions
|
||||
):
|
||||
stats, substitution_dict, actions = get_stats(target, src_string)
|
||||
for k in expected_stats:
|
||||
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
|
||||
for k in expected_substitutions:
|
||||
assert substitution_dict[k] == expected_substitutions[k], (substitution_dict, expected_substitutions)
|
||||
assert substitution_dict[k] == expected_substitutions[k], (
|
||||
substitution_dict,
|
||||
expected_substitutions,
|
||||
)
|
||||
for k in expected_actions:
|
||||
assert actions[k] == expected_actions[k], (k, actions[k], expected_actions[k])
|
||||
|
||||
@pytest.mark.parametrize("src_string, target, expected_actions",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src_string, target, expected_actions",
|
||||
[
|
||||
("dog and cat", "g and at",
|
||||
{0: ('I', 'd'), 1: ('I', 'o'), 8:('I', 'c')}),
|
||||
])
|
||||
("dog and cat", "g and at", {0: ("I", "d"), 1: ("I", "o"), 8: ("I", "c")}),
|
||||
],
|
||||
)
|
||||
def test_actions_stats(src_string, target, expected_actions):
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
|
||||
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(
|
||||
[src_string], [target]
|
||||
)
|
||||
gap_char = (
|
||||
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
)
|
||||
alignment = align(target, src_string, gap_char=gap_char)
|
||||
_, actions = get_editops_stats(alignment, gap_char)
|
||||
print(actions)
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
from genalog.ocr.rest_client import GrokRestClient
|
||||
from genalog.ocr.blob_client import GrokBlobClient
|
||||
from genalog.ocr.grok import Grok
|
||||
import requests
|
||||
import pytest
|
||||
import time
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from genalog.ocr.rest_client import GrokRestClient
|
||||
|
||||
|
||||
load_dotenv("tests/ocr/.env")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_monkeypatch(monkeypatch):
|
||||
def mock_http(*args, **kwargs):
|
||||
|
@ -21,7 +23,6 @@ def setup_monkeypatch(monkeypatch):
|
|||
|
||||
|
||||
class MockedResponse:
|
||||
|
||||
def __init__(self, args, kwargs):
|
||||
self.url = args[0]
|
||||
self.text = "response"
|
||||
|
@ -34,10 +35,7 @@ class MockedResponse:
|
|||
|
||||
if "search.windows.net/indexers/" in self.url:
|
||||
if "status" in self.url:
|
||||
return {
|
||||
"lastResult": {"status": "success"},
|
||||
"status": "finished"
|
||||
}
|
||||
return {"lastResult": {"status": "success"}, "status": "finished"}
|
||||
return {}
|
||||
|
||||
if "search.windows.net/indexes/" in self.url:
|
||||
|
@ -46,15 +44,30 @@ class MockedResponse:
|
|||
"value": [
|
||||
{
|
||||
"metadata_storage_name": "521c38122f783673598856cd81d91c21_0.png",
|
||||
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_0.png.json", "r"))
|
||||
"layoutText": json.load(
|
||||
open(
|
||||
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_0.png.json",
|
||||
"r",
|
||||
)
|
||||
),
|
||||
},
|
||||
{
|
||||
"metadata_storage_name": "521c38122f783673598856cd81d91c21_1.png",
|
||||
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_1.png.json", "r"))
|
||||
"layoutText": json.load(
|
||||
open(
|
||||
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_1.png.json",
|
||||
"r",
|
||||
)
|
||||
),
|
||||
},
|
||||
{
|
||||
"metadata_storage_name": "521c38122f783673598856cd81d91c21_11.png",
|
||||
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_11.png.json", "r"))
|
||||
"layoutText": json.load(
|
||||
open(
|
||||
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_11.png.json",
|
||||
"r",
|
||||
)
|
||||
),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
@ -69,7 +82,6 @@ class MockedResponse:
|
|||
|
||||
|
||||
class TestGROK:
|
||||
|
||||
def test_creating_indexing_pipeline(self):
|
||||
grok_rest_client = GrokRestClient.create_from_env_var()
|
||||
grok_rest_client.create_indexing_pipeline()
|
||||
|
@ -91,4 +103,3 @@ class TestGROK:
|
|||
indexer_status = grok_rest_client.poll_indexer_till_complete()
|
||||
assert indexer_status["lastResult"]["status"] == "success"
|
||||
grok_rest_client.delete_indexer_pipeline()
|
||||
|
||||
|
|
|
@ -1,43 +1,59 @@
|
|||
from genalog.text import alignment
|
||||
from genalog.text.alignment import MATCH_REWARD, MISMATCH_PENALTY, GAP_PENALTY, GAP_EXT_PENALTY
|
||||
from tests.cases.text_alignment import PARSE_ALIGNMENT_REGRESSION_TEST_CASES, ALIGNMENT_REGRESSION_TEST_CASES
|
||||
|
||||
import warnings
|
||||
from random import randint
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
from genalog.text import alignment
|
||||
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
|
||||
from tests.cases.text_alignment import PARSE_ALIGNMENT_REGRESSION_TEST_CASES
|
||||
|
||||
RANDOM_INT = randint(1, 100)
|
||||
MOCK_ALIGNMENT_RESULT = [("X", "X", 0, 0, 1)]
|
||||
|
||||
|
||||
# Settup mock for third party library call
|
||||
@pytest.fixture
|
||||
def mock_pairwise2_align(monkeypatch):
|
||||
mock = MagicMock()
|
||||
|
||||
def mock_globalcs(*args, **kwargs):
|
||||
mock.globalcs(*args, **kwargs)
|
||||
return MOCK_ALIGNMENT_RESULT
|
||||
|
||||
# replace target method reference with the mock method
|
||||
monkeypatch.setattr("Bio.pairwise2.align.globalcs", mock_globalcs)
|
||||
return mock
|
||||
|
||||
|
||||
def test__align_seg(mock_pairwise2_align):
|
||||
# setup method input
|
||||
required_arg = ("A", "B")
|
||||
optional_arg = (alignment.MATCH_REWARD, alignment.MISMATCH_PENALTY, alignment.GAP_PENALTY, alignment.GAP_EXT_PENALTY)
|
||||
optional_kwarg = {"gap_char": alignment.GAP_CHAR, "one_alignment_only": alignment.ONE_ALIGNMENT_ONLY}
|
||||
optional_arg = (
|
||||
alignment.MATCH_REWARD,
|
||||
alignment.MISMATCH_PENALTY,
|
||||
alignment.GAP_PENALTY,
|
||||
alignment.GAP_EXT_PENALTY,
|
||||
)
|
||||
optional_kwarg = {
|
||||
"gap_char": alignment.GAP_CHAR,
|
||||
"one_alignment_only": alignment.ONE_ALIGNMENT_ONLY,
|
||||
}
|
||||
# test method
|
||||
result = alignment._align_seg(*required_arg + optional_arg, **optional_kwarg)
|
||||
# assertion
|
||||
mock_pairwise2_align.globalcs.assert_called()
|
||||
assert result == MOCK_ALIGNMENT_RESULT
|
||||
|
||||
@pytest.mark.parametrize("alignments, target_num_tokens, raised_exception",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"alignments, target_num_tokens, raised_exception",
|
||||
[
|
||||
(MOCK_ALIGNMENT_RESULT, 1, None),
|
||||
(MOCK_ALIGNMENT_RESULT, 2, ValueError),
|
||||
([("X", "XY", 0, 0, 1)], 1, ValueError)
|
||||
])
|
||||
([("X", "XY", 0, 0, 1)], 1, ValueError),
|
||||
],
|
||||
)
|
||||
def test__select_alignment_candidates(alignments, target_num_tokens, raised_exception):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
|
@ -46,7 +62,9 @@ def test__select_alignment_candidates(alignments, target_num_tokens, raised_exce
|
|||
result = alignment._select_alignment_candidates(alignments, target_num_tokens)
|
||||
assert result == MOCK_ALIGNMENT_RESULT[0]
|
||||
|
||||
@pytest.mark.parametrize("s, index, desired_output, raised_exception",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s, index, desired_output, raised_exception",
|
||||
[
|
||||
# Test exceptions
|
||||
("s", 2, None, IndexError),
|
||||
|
@ -63,7 +81,8 @@ def test__select_alignment_candidates(alignments, target_num_tokens, raised_exce
|
|||
("t1 \t \n t2", 3, 7, None),
|
||||
# Gap char
|
||||
(" @", 0, 1, None),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test__find_token_start(s, index, desired_output, raised_exception):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
|
@ -72,7 +91,9 @@ def test__find_token_start(s, index, desired_output, raised_exception):
|
|||
output = alignment._find_token_start(s, index)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("s, index, desired_output, raised_exception",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s, index, desired_output, raised_exception",
|
||||
[
|
||||
# Test exceptions
|
||||
("s", 2, None, IndexError),
|
||||
|
@ -90,7 +111,8 @@ def test__find_token_start(s, index, desired_output, raised_exception):
|
|||
(".", 0, 0, None),
|
||||
# Gap char
|
||||
("@@ @", 0, 2, None),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test__find_token_end(s, index, desired_output, raised_exception):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
|
@ -99,7 +121,9 @@ def test__find_token_end(s, index, desired_output, raised_exception):
|
|||
output = alignment._find_token_end(s, index)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("s, start, desired_output",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s, start, desired_output",
|
||||
[
|
||||
("token", 0, (0, 4)),
|
||||
("token\t", 0, (0, 5)),
|
||||
|
@ -111,13 +135,16 @@ def test__find_token_end(s, index, desired_output, raised_exception):
|
|||
# single character string
|
||||
("s", 0, (0, 0)),
|
||||
# punctuation
|
||||
(" !,.: ", 0, (2,6))
|
||||
])
|
||||
(" !,.: ", 0, (2, 6)),
|
||||
],
|
||||
)
|
||||
def test__find_next_token(s, start, desired_output):
|
||||
output = alignment._find_next_token(s, start)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("token, desired_output",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"token, desired_output",
|
||||
[
|
||||
# Valid tokens
|
||||
("\n\t token.!,:\n\t ", True),
|
||||
|
@ -134,26 +161,41 @@ def test__find_next_token(s, start, desired_output):
|
|||
("\t\n@", False),
|
||||
(alignment.GAP_CHAR * 1, False),
|
||||
(alignment.GAP_CHAR * RANDOM_INT, False),
|
||||
(f"\n\t {alignment.GAP_CHAR*RANDOM_INT} \n\t", False)
|
||||
])
|
||||
(f"\n\t {alignment.GAP_CHAR*RANDOM_INT} \n\t", False),
|
||||
],
|
||||
)
|
||||
def test__is_valid_token(token, desired_output):
|
||||
result = alignment._is_valid_token(token)
|
||||
assert result == desired_output
|
||||
|
||||
@pytest.mark.parametrize("aligned_gt, aligned_noise," +
|
||||
"expected_gt_to_noise_map, expected_noise_to_gt_map",
|
||||
PARSE_ALIGNMENT_REGRESSION_TEST_CASES)
|
||||
def test_parse_alignment(aligned_gt, aligned_noise, expected_gt_to_noise_map, expected_noise_to_gt_map):
|
||||
gt_to_noise_map, noise_to_gt_map = alignment.parse_alignment(aligned_gt, aligned_noise)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"aligned_gt, aligned_noise," + "expected_gt_to_noise_map, expected_noise_to_gt_map",
|
||||
PARSE_ALIGNMENT_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_parse_alignment(
|
||||
aligned_gt, aligned_noise, expected_gt_to_noise_map, expected_noise_to_gt_map
|
||||
):
|
||||
gt_to_noise_map, noise_to_gt_map = alignment.parse_alignment(
|
||||
aligned_gt, aligned_noise
|
||||
)
|
||||
assert gt_to_noise_map == expected_gt_to_noise_map
|
||||
assert noise_to_gt_map == expected_noise_to_gt_map
|
||||
|
||||
@pytest.mark.parametrize("gt_txt, noisy_txt," +
|
||||
"expected_aligned_gt, expected_aligned_noise",
|
||||
ALIGNMENT_REGRESSION_TEST_CASES)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_txt, noisy_txt," + "expected_aligned_gt, expected_aligned_noise",
|
||||
ALIGNMENT_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_align(gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
|
||||
aligned_gt, aligned_noise = alignment.align(gt_txt, noisy_txt)
|
||||
if aligned_gt != expected_aligned_gt:
|
||||
expected_alignment = alignment._format_alignment(expected_aligned_gt, expected_aligned_noise)
|
||||
expected_alignment = alignment._format_alignment(
|
||||
expected_aligned_gt, expected_aligned_noise
|
||||
)
|
||||
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,33 +1,40 @@
|
|||
import glob
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from genalog.text import alignment, anchor, preprocess
|
||||
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
|
||||
|
||||
import glob
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
@pytest.mark.parametrize("tokens, case_sensitive, desired_output", [
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, case_sensitive, desired_output",
|
||||
[
|
||||
([], True, set()),
|
||||
([], False, set()),
|
||||
(["a", "A"], True, set(["a", "A"])),
|
||||
(["a", "A"], False, set()),
|
||||
(["An", "an", "ab"], True, set(["An", "an", "ab"])),
|
||||
(["An", "an", "ab"], False, set(["ab"])),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_get_unique_words(tokens, case_sensitive, desired_output):
|
||||
output = anchor.get_unique_words(tokens, case_sensitive=case_sensitive)
|
||||
assert desired_output == output
|
||||
|
||||
@pytest.mark.parametrize("tokens, desired_output", [
|
||||
([], 0),
|
||||
([""], 0),
|
||||
(["a", "b"], 2),
|
||||
(["abc.", "def!"], 8)
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, desired_output",
|
||||
[([], 0), ([""], 0), (["a", "b"], 2), (["abc.", "def!"], 8)],
|
||||
)
|
||||
def test_segment_len(tokens, desired_output):
|
||||
output = anchor.segment_len(tokens)
|
||||
assert desired_output == output
|
||||
|
||||
@pytest.mark.parametrize("unique_words, src_tokens, desired_output, raised_exception", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"unique_words, src_tokens, desired_output, raised_exception",
|
||||
[
|
||||
(set(), [], [], None),
|
||||
(set(), ["a"], [], None),
|
||||
(set("a"), [], [], ValueError), # unique word not in src_tokens
|
||||
|
@ -36,8 +43,14 @@ def test_segment_len(tokens, desired_output):
|
|||
(set("a"), ["an", "na", " a "], [], ValueError), # substring
|
||||
(set("a"), ["a"], [("a", 0)], None), # valid input
|
||||
(set("a"), ["c", "b", "a"], [("a", 2)], None), # multiple src_tokens
|
||||
(set("ab"), ["c", "b", "a"], [("b", 1), ("a", 2)], None), # multiple matches ordered by index
|
||||
])
|
||||
(
|
||||
set("ab"),
|
||||
["c", "b", "a"],
|
||||
[("b", 1), ("a", 2)],
|
||||
None,
|
||||
), # multiple matches ordered by index
|
||||
],
|
||||
)
|
||||
def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
|
@ -46,7 +59,10 @@ def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception
|
|||
output = anchor.get_word_map(unique_words, src_tokens)
|
||||
assert desired_output == output
|
||||
|
||||
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_output", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_tokens, ocr_tokens, desired_output",
|
||||
[
|
||||
([], [], ([], [])), # empty
|
||||
([""], [""], ([], [])),
|
||||
(["a"], ["b"], ([], [])), # no common unique words
|
||||
|
@ -54,77 +70,139 @@ def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception
|
|||
(["a"], ["a", "a"], ([], [])),
|
||||
(["a"], ["a"], ([("a", 0)], [("a", 0)])), # common unique word exist
|
||||
(["a"], ["b", "a"], ([("a", 0)], [("a", 1)])),
|
||||
(["a", "b", "c"], ["a", "b", "c"], # common unique words
|
||||
([("a", 0), ("b", 1), ("c", 2)], [("a", 0), ("b", 1), ("c", 2)])),
|
||||
(["a", "b", "c"], ["c", "b", "a"], # common unique words but not in same order
|
||||
([("b", 1)], [("b", 1)])),
|
||||
(["b", "a", "c"], ["c", "b", "a"], # LCS has multiple results
|
||||
([("b", 0), ("a", 1)], [("b", 1), ("a", 2)])),
|
||||
(["c", "a", "b"], ["c", "b", "a"],
|
||||
([("c", 0), ("b", 2)], [("c", 0), ("b", 1)])),
|
||||
(["c", "a", "b"], ["a", "c", "b"], # LCS has multiple results
|
||||
([("a", 1), ("b", 2)], [("a", 0), ("b", 2)])),
|
||||
])
|
||||
(
|
||||
["a", "b", "c"],
|
||||
["a", "b", "c"], # common unique words
|
||||
([("a", 0), ("b", 1), ("c", 2)], [("a", 0), ("b", 1), ("c", 2)]),
|
||||
),
|
||||
(
|
||||
["a", "b", "c"],
|
||||
["c", "b", "a"], # common unique words but not in same order
|
||||
([("b", 1)], [("b", 1)]),
|
||||
),
|
||||
(
|
||||
["b", "a", "c"],
|
||||
["c", "b", "a"], # LCS has multiple results
|
||||
([("b", 0), ("a", 1)], [("b", 1), ("a", 2)]),
|
||||
),
|
||||
(
|
||||
["c", "a", "b"],
|
||||
["c", "b", "a"],
|
||||
([("c", 0), ("b", 2)], [("c", 0), ("b", 1)]),
|
||||
),
|
||||
(
|
||||
["c", "a", "b"],
|
||||
["a", "c", "b"], # LCS has multiple results
|
||||
([("a", 1), ("b", 2)], [("a", 0), ("b", 2)]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_anchor_map(gt_tokens, ocr_tokens, desired_output):
|
||||
desired_gt_map, desired_ocr_map = desired_output
|
||||
gt_map, ocr_map = anchor.get_anchor_map(gt_tokens, ocr_tokens)
|
||||
assert desired_gt_map == gt_map
|
||||
assert desired_ocr_map == ocr_map
|
||||
|
||||
|
||||
# max_seg_length does not change the following output
|
||||
@pytest.mark.parametrize("max_seg_length", [0, 1, 2, 3, 5, 4, 6])
|
||||
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_output", [
|
||||
@pytest.mark.parametrize(
|
||||
"gt_tokens, ocr_tokens, desired_output",
|
||||
[
|
||||
([], [], ([], [])), # empty
|
||||
([""], [""], ([], [])),
|
||||
(["a"], ["b"], ([], [])), # no anchors
|
||||
(["a", "a"], ["a"], ([], [])),
|
||||
(["a"], ["a", "a"], ([], [])),
|
||||
(["a"], ["a"], ([0], [0])), # anchors exist
|
||||
("a1 w w w".split(), "a1 w w w".split(), # no anchors in the subsequence [w w w]
|
||||
([0], [0])),
|
||||
("a1 w w w a2".split(), "a1 w w w a2".split(),
|
||||
([0, 4], [0, 4])),
|
||||
("a1 w w w2 a2".split(), "a1 w w w3 a2".split(),
|
||||
([0, 4], [0, 4])),
|
||||
("a1 a2 a3".split(), "a1 a2 a3".split(), # all words are anchors
|
||||
([0, 1, 2], [0, 1, 2])),
|
||||
("a1 a2 a3".split(), "A1 A2 A3".split(), # anchor words must be in the same casing
|
||||
([], [])),
|
||||
("a1 w w a2".split(), "a1 w W a2".split(), # unique words are case insensitive
|
||||
([0, 3], [0, 3])),
|
||||
("a1 w w a2".split(), "A1 w W A2".split(), # unique words are case insensitive, but anchor are case sensitive
|
||||
([], [])),
|
||||
])
|
||||
def test_find_anchor_recur_various_seg_len(max_seg_length, gt_tokens, ocr_tokens, desired_output):
|
||||
(
|
||||
"a1 w w w".split(),
|
||||
"a1 w w w".split(), # no anchors in the subsequence [w w w]
|
||||
([0], [0]),
|
||||
),
|
||||
("a1 w w w a2".split(), "a1 w w w a2".split(), ([0, 4], [0, 4])),
|
||||
("a1 w w w2 a2".split(), "a1 w w w3 a2".split(), ([0, 4], [0, 4])),
|
||||
(
|
||||
"a1 a2 a3".split(),
|
||||
"a1 a2 a3".split(), # all words are anchors
|
||||
([0, 1, 2], [0, 1, 2]),
|
||||
),
|
||||
(
|
||||
"a1 a2 a3".split(),
|
||||
"A1 A2 A3".split(), # anchor words must be in the same casing
|
||||
([], []),
|
||||
),
|
||||
(
|
||||
"a1 w w a2".split(),
|
||||
"a1 w W a2".split(), # unique words are case insensitive
|
||||
([0, 3], [0, 3]),
|
||||
),
|
||||
(
|
||||
"a1 w w a2".split(),
|
||||
"A1 w W A2".split(), # unique words are case insensitive, but anchor are case sensitive
|
||||
([], []),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_find_anchor_recur_various_seg_len(
|
||||
max_seg_length, gt_tokens, ocr_tokens, desired_output
|
||||
):
|
||||
desired_gt_anchors, desired_ocr_anchors = desired_output
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
|
||||
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
|
||||
)
|
||||
assert desired_gt_anchors == gt_anchors
|
||||
assert desired_ocr_anchors == ocr_anchors
|
||||
|
||||
|
||||
# Test the recursion bahavior
|
||||
@pytest.mark.parametrize("gt_tokens, ocr_tokens, max_seg_length, desired_output", [
|
||||
("a1 w_ w_ a3".split(), "a1 w_ w_ a3".split(), 6,
|
||||
([0, 3], [0, 3])),
|
||||
("a1 w_ w_ a2 a3 a2".split(), "a1 w_ w_ a2 a3 a2".split(), 4, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
|
||||
([0, 3, 4], [0, 3, 4])),
|
||||
("a1 w_ w_ a2 a3 a2".split(), "a1 w_ w_ a2 a3 a2".split(), 2, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
|
||||
([0, 2, 3, 4, 5], [0, 2, 3, 4, 5])),
|
||||
("a1 w_ w_ a2 w_ w_ a3".split(), "a1 w_ a2 w_ a3".split(), 2, # missing ocr token
|
||||
([0, 3, 6], [0, 2, 4])),
|
||||
("a1 w_ w_ a2 w_ w_ a3".split(), "a1 w_ a2 W_ A3".split(), 2, # changing cases
|
||||
([0, 3], [0, 2])),
|
||||
])
|
||||
def test_find_anchor_recur_fixed_seg_len(gt_tokens, ocr_tokens, max_seg_length, desired_output):
|
||||
@pytest.mark.parametrize(
|
||||
"gt_tokens, ocr_tokens, max_seg_length, desired_output",
|
||||
[
|
||||
("a1 w_ w_ a3".split(), "a1 w_ w_ a3".split(), 6, ([0, 3], [0, 3])),
|
||||
(
|
||||
"a1 w_ w_ a2 a3 a2".split(),
|
||||
"a1 w_ w_ a2 a3 a2".split(),
|
||||
4, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
|
||||
([0, 3, 4], [0, 3, 4]),
|
||||
),
|
||||
(
|
||||
"a1 w_ w_ a2 a3 a2".split(),
|
||||
"a1 w_ w_ a2 a3 a2".split(),
|
||||
2, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
|
||||
([0, 2, 3, 4, 5], [0, 2, 3, 4, 5]),
|
||||
),
|
||||
(
|
||||
"a1 w_ w_ a2 w_ w_ a3".split(),
|
||||
"a1 w_ a2 w_ a3".split(),
|
||||
2, # missing ocr token
|
||||
([0, 3, 6], [0, 2, 4]),
|
||||
),
|
||||
(
|
||||
"a1 w_ w_ a2 w_ w_ a3".split(),
|
||||
"a1 w_ a2 W_ A3".split(),
|
||||
2, # changing cases
|
||||
([0, 3], [0, 2]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_find_anchor_recur_fixed_seg_len(
|
||||
gt_tokens, ocr_tokens, max_seg_length, desired_output
|
||||
):
|
||||
desired_gt_anchors, desired_ocr_anchors = desired_output
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
|
||||
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
|
||||
)
|
||||
assert desired_gt_anchors == gt_anchors
|
||||
assert desired_ocr_anchors == ocr_anchors
|
||||
|
||||
@pytest.mark.parametrize("gt_file, ocr_file",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_file, ocr_file",
|
||||
zip(
|
||||
sorted(glob.glob("tests/text/data/gt_1.txt")),
|
||||
sorted(glob.glob("tests/text/data/ocr_1.txt"))
|
||||
)
|
||||
sorted(glob.glob("tests/text/data/ocr_1.txt")),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("max_seg_length", [75])
|
||||
def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
|
||||
|
@ -132,15 +210,27 @@ def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
|
|||
ocr_text = open(ocr_file, "r").read()
|
||||
gt_tokens = preprocess.tokenize(gt_text)
|
||||
ocr_tokens = preprocess.tokenize(ocr_text)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
|
||||
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
|
||||
)
|
||||
for gt_anchor, ocr_anchor in zip(gt_anchors, ocr_anchors):
|
||||
# Ensure that each anchor word is the same word in both text
|
||||
assert gt_tokens[gt_anchor] == ocr_tokens[ocr_anchor]
|
||||
|
||||
@pytest.mark.parametrize("gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", ALIGNMENT_REGRESSION_TEST_CASES)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise",
|
||||
ALIGNMENT_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_align_w_anchor(gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
|
||||
aligned_gt, aligned_noise = anchor.align_w_anchor(gt_txt, noisy_txt)
|
||||
if aligned_gt != expected_aligned_gt:
|
||||
expected_alignment = alignment._format_alignment(expected_aligned_gt, expected_aligned_noise)
|
||||
expected_alignment = alignment._format_alignment(
|
||||
expected_aligned_gt, expected_aligned_noise
|
||||
)
|
||||
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,34 +1,75 @@
|
|||
from genalog.text import conll_format
|
||||
import itertools
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import itertools
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception", [
|
||||
from genalog.text import conll_format
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception",
|
||||
[
|
||||
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], None),
|
||||
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], [], ValueError), # No alignment
|
||||
(["w1", "w3"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], ValueError), # Unequal tokens
|
||||
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w3"]], ["w", "w"], ValueError), # Unequal tokens
|
||||
(["w1", "w3"], ["l1", "l2"], [["w1"]],["w", "w"], ValueError), # Unequal length
|
||||
(["w1"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], ValueError), # Unequal length
|
||||
])
|
||||
def test_propagate_labels_sentences_error(clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception):
|
||||
(
|
||||
["w1", "w3"],
|
||||
["l1", "l2"],
|
||||
[["w1"], ["w2"]],
|
||||
["w", "w"],
|
||||
ValueError,
|
||||
), # Unequal tokens
|
||||
(
|
||||
["w1", "w2"],
|
||||
["l1", "l2"],
|
||||
[["w1"], ["w3"]],
|
||||
["w", "w"],
|
||||
ValueError,
|
||||
), # Unequal tokens
|
||||
(
|
||||
["w1", "w3"],
|
||||
["l1", "l2"],
|
||||
[["w1"]],
|
||||
["w", "w"],
|
||||
ValueError,
|
||||
), # Unequal length
|
||||
(
|
||||
["w1"],
|
||||
["l1", "l2"],
|
||||
[["w1"], ["w2"]],
|
||||
["w", "w"],
|
||||
ValueError,
|
||||
), # Unequal length
|
||||
],
|
||||
)
|
||||
def test_propagate_labels_sentences_error(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception
|
||||
):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
conll_format.propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
else:
|
||||
conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
conll_format.propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels",
|
||||
[
|
||||
(
|
||||
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
|
||||
"a1 b1 a2 b2".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]], # clean sentences
|
||||
["a1", "b1", "a2", "b2"], # ocr token
|
||||
[["a1", "b1"], ["a2","b2"]], [["l1", "l2"], ["l3", "l4"]] # desired output
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
[["l1", "l2"], ["l3", "l4"]], # desired output
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
|
||||
"a1 b1 a2 b2".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]], # clean sentences
|
||||
["a1", "b1"], # Missing sentence 2
|
||||
# Ideally we would expect [["a1", "b1"], []]
|
||||
|
@ -39,119 +80,207 @@ def test_propagate_labels_sentences_error(clean_tokens, clean_labels, clean_sent
|
|||
# when all tokens "b1" "a2" "b2" are aligned to "b1@@@@@@"
|
||||
# NOTE: this is a improper behavior but the best
|
||||
# solution to this corner case by preserving the number of OCR tokens.
|
||||
[["a1"], ["b1"]], [["l1"], ["l2"]]
|
||||
[["a1"], ["b1"]],
|
||||
[["l1"], ["l2"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
|
||||
"a1 b1 a2 b2".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
["a", "a2", "b2"], # ocr token (missing b1 token at sentence boundary)
|
||||
[["a"], ["a2", "b2"]], [["l1"], ["l3", "l4"]]
|
||||
[["a"], ["a2", "b2"]],
|
||||
[["l1"], ["l3", "l4"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
|
||||
"a1 b1 a2 b2".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
["a1", "b1", "a2"], # ocr token (missing b2 token at sentence boundary)
|
||||
[["a1", "b1"], ["a2"]], [["l1", "l2"], ["l3"]]
|
||||
[["a1", "b1"], ["a2"]],
|
||||
[["l1", "l2"], ["l3"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
|
||||
"a1 b1 a2 b2".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
["b1", "a2", "b2"], # ocr token (missing a1 token at sentence start)
|
||||
[["b1", "a2"], ["b2"]], [["l2", "l3"], ["l4"]]
|
||||
[["b1", "a2"], ["b2"]],
|
||||
[["l2", "l3"], ["l4"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
|
||||
"a1 b1 c1 a2 b2".split(),
|
||||
"l1 l2 l3 l4 l5".split(),
|
||||
[["a1"], ["b1", "c1", "a2"], ["b2"]],
|
||||
["a1", "b1", "a2", "b2"], # ocr token (missing c1 token at middle of sentence)
|
||||
[["a1"], ["b1", "a2"], ["b2"]], [["l1"], ["l2", "l4"], ["l5"]]
|
||||
[
|
||||
"a1",
|
||||
"b1",
|
||||
"a2",
|
||||
"b2",
|
||||
], # ocr token (missing c1 token at middle of sentence)
|
||||
[["a1"], ["b1", "a2"], ["b2"]],
|
||||
[["l1"], ["l2", "l4"], ["l5"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
|
||||
"a1 b1 c1 a2 b2".split(),
|
||||
"l1 l2 l3 l4 l5".split(),
|
||||
[["a1", "b1"], ["c1", "a2", "b2"]],
|
||||
["a1", "b1", "b2"], # ocr token (missing c1 a2 tokens)
|
||||
[["a1"], ["b1", "b2"]], [["l1"], ["l2", "l5"]]
|
||||
[["a1"], ["b1", "b2"]],
|
||||
[["l1"], ["l2", "l5"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
|
||||
"a1 b1 c1 a2 b2".split(),
|
||||
"l1 l2 l3 l4 l5".split(),
|
||||
[["a1"], ["b1", "c1", "a2"], ["b2"]],
|
||||
["a1", "c1", "a2", "b2"], # ocr token (missing b1 token at sentence start)
|
||||
[[], ["a1", "c1", "a2"], ["b2"]], [[], ["l1", "l3", "l4"], ["l5"]]
|
||||
[[], ["a1", "c1", "a2"], ["b2"]],
|
||||
[[], ["l1", "l3", "l4"], ["l5"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
|
||||
"a1 b1 c1 a2 b2".split(),
|
||||
"l1 l2 l3 l4 l5".split(),
|
||||
[["a1", "b1", "c1"], ["a2", "b2"]],
|
||||
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
|
||||
[["a1"], [ "b1", "b2"]], [["l1"], ["l2", "l5"]]
|
||||
[["a1"], ["b1", "b2"]],
|
||||
[["l1"], ["l2", "l5"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
|
||||
"a1 b1 c1 a2 b2".split(),
|
||||
"l1 l2 l3 l4 l5".split(),
|
||||
[["a1", "b1", "c1"], ["a2", "b2"]],
|
||||
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
|
||||
[["a1"], [ "b1", "b2"]], [["l1"], ["l2", "l5"]]
|
||||
[["a1"], ["b1", "b2"]],
|
||||
[["l1"], ["l2", "l5"]],
|
||||
),
|
||||
])
|
||||
def test_propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels):
|
||||
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
],
|
||||
)
|
||||
def test_propagate_labels_sentences(
|
||||
clean_tokens,
|
||||
clean_labels,
|
||||
clean_sentences,
|
||||
ocr_tokens,
|
||||
desired_sentences,
|
||||
desired_labels,
|
||||
):
|
||||
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
|
||||
assert len(ocr_text_sentences) == len(clean_sentences)
|
||||
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
|
||||
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
|
||||
assert len(ocr_sentences_flatten) == len(
|
||||
ocr_tokens
|
||||
) # ensure aligned ocr tokens == ocr tokens
|
||||
if desired_sentences != ocr_text_sentences:
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"
|
||||
)
|
||||
)
|
||||
if desired_labels != ocr_labels_sentences:
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens," +
|
||||
"mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clean_tokens, clean_labels, clean_sentences, ocr_tokens,"
|
||||
+ "mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels",
|
||||
[
|
||||
(
|
||||
"a b c d".split(), "l1 l2 l3 l4".split(),
|
||||
"a b c d".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a", "b"], ["c", "d"]],
|
||||
["a", "b"], # Sentence is empty
|
||||
[[0], [1], [], []],
|
||||
[[0], [1]],
|
||||
[["a", "b"], []],
|
||||
[["l1", "l2"], []]
|
||||
[["l1", "l2"], []],
|
||||
),
|
||||
(
|
||||
"a b c d".split(), "l1 l2 l3 l4".split(),
|
||||
[["a", "b",], ["c", "d"]],
|
||||
"a b c d".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[
|
||||
[
|
||||
"a",
|
||||
"b",
|
||||
],
|
||||
["c", "d"],
|
||||
],
|
||||
["a", "b", "d"], # Missing sentence start
|
||||
[[0], [1], [], [2]],
|
||||
[[0], [1], [3]],
|
||||
[["a", "b"], ["d"]],
|
||||
[["l1", "l2"], ["l4"]]
|
||||
[["l1", "l2"], ["l4"]],
|
||||
),
|
||||
(
|
||||
"a b c d".split(), "l1 l2 l3 l4".split(),
|
||||
[["a", "b",], ["c", "d"]],
|
||||
"a b c d".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[
|
||||
[
|
||||
"a",
|
||||
"b",
|
||||
],
|
||||
["c", "d"],
|
||||
],
|
||||
["a", "c", "d"], # Missing sentence end
|
||||
[[0], [], [1], [2]],
|
||||
[[0], [2], [3]],
|
||||
[["a"], ["c", "d"]],
|
||||
[["l1"], ["l3", "l4"]]
|
||||
[["l1"], ["l3", "l4"]],
|
||||
),
|
||||
])
|
||||
def test_propagate_labels_sentences_text_alignment_corner_cases(clean_tokens, clean_labels, clean_sentences, ocr_tokens,
|
||||
mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels):
|
||||
],
|
||||
)
|
||||
def test_propagate_labels_sentences_text_alignment_corner_cases(
|
||||
clean_tokens,
|
||||
clean_labels,
|
||||
clean_sentences,
|
||||
ocr_tokens,
|
||||
mock_gt_to_ocr_mapping,
|
||||
mock_ocr_to_gt_mapping,
|
||||
desired_sentences,
|
||||
desired_labels,
|
||||
):
|
||||
with patch("genalog.text.alignment.parse_alignment") as mock_alignment:
|
||||
mock_alignment.return_value = (mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping)
|
||||
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
(
|
||||
ocr_text_sentences,
|
||||
ocr_labels_sentences,
|
||||
) = conll_format.propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
|
||||
assert len(ocr_text_sentences) == len(clean_sentences)
|
||||
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
|
||||
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
|
||||
assert len(ocr_sentences_flatten) == len(
|
||||
ocr_tokens
|
||||
) # ensure aligned ocr tokens == ocr tokens
|
||||
if desired_sentences != ocr_text_sentences:
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"
|
||||
)
|
||||
)
|
||||
if desired_labels != ocr_labels_sentences:
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("s, desired_output", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s, desired_output",
|
||||
[
|
||||
("", []),
|
||||
("\n\n", []),
|
||||
("a1\tb1\na2\tb2", [["a1", "a2"]]),
|
||||
("a1\tb1\n\na2\tb2", [["a1"], ["a2"]]),
|
||||
("\n\n\na1\tb1\n\na2\tb2\n\n\n", [["a1"], ["a2"]]),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_get_sentences_from_iob_format(s, desired_output):
|
||||
output = conll_format.get_sentences_from_iob_format(s.splitlines(True))
|
||||
assert desired_output == output
|
|
@ -1,20 +1,27 @@
|
|||
from genalog.text.lcs import LCS
|
||||
|
||||
import pytest
|
||||
|
||||
@pytest.fixture(params=[
|
||||
from genalog.text.lcs import LCS
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("", ""), # empty
|
||||
("abcde", "ace"), # naive case
|
||||
])
|
||||
]
|
||||
)
|
||||
def lcs(request):
|
||||
str1, str2 = request.param
|
||||
return LCS(str1, str2)
|
||||
|
||||
|
||||
def test_lcs_init(lcs):
|
||||
assert lcs._lcs_len is not None
|
||||
assert lcs._lcs is not None
|
||||
|
||||
@pytest.mark.parametrize("str1, str2, expected_len, expected_lcs", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"str1, str2, expected_len, expected_lcs",
|
||||
[
|
||||
("", "", 0, ""), # empty
|
||||
("abc", "abc", 3, "abc"),
|
||||
("abcde", "ace", 3, "ace"), # naive case
|
||||
|
@ -22,16 +29,26 @@ def test_lcs_init(lcs):
|
|||
("abc", "cba", 1, "c"), # multiple cases
|
||||
("abcdgh", "aedfhr", 3, "adh"),
|
||||
("abc.!\t\nd", "dxab", 2, "ab"), # with punctuations
|
||||
("New York @", "New @ York", len("New York"), "New York"), # with space-separated, tokens
|
||||
(
|
||||
"New York @",
|
||||
"New @ York",
|
||||
len("New York"),
|
||||
"New York",
|
||||
), # with space-separated, tokens
|
||||
("Is A Big City", "A Big City Is", len("A Big City"), "A Big City"),
|
||||
("Is A Big City", "City Big Is A", len(" Big "), " Big "), # reversed order
|
||||
# mixed order with similar tokens
|
||||
("Is A Big City IS", "IS Big A City Is", len("I Big City I"), "I Big City I"),
|
||||
# casing
|
||||
("Is A Big City IS a", "IS a Big City Is A", len("I Big City I "), "I Big City I "),
|
||||
])
|
||||
(
|
||||
"Is A Big City IS a",
|
||||
"IS a Big City Is A",
|
||||
len("I Big City I "),
|
||||
"I Big City I ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_lcs_e2e(str1, str2, expected_len, expected_lcs):
|
||||
lcs = LCS(str1, str2)
|
||||
assert expected_lcs == lcs.get_str()
|
||||
assert expected_len == lcs.get_len()
|
||||
|
|
@ -1,171 +1,295 @@
|
|||
import pytest
|
||||
|
||||
from genalog.text import ner_label
|
||||
from genalog.text import alignment
|
||||
from tests.cases.label_propagation import LABEL_PROPAGATION_REGRESSION_TEST_CASES
|
||||
|
||||
import pytest
|
||||
import string
|
||||
|
||||
@pytest.mark.parametrize("label, desired_output", [
|
||||
@pytest.mark.parametrize(
|
||||
"label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("B-org", True), (" B-org ", True), #whitespae tolerant
|
||||
("B-org", True),
|
||||
(" B-org ", True), # whitespae tolerant
|
||||
("\tB-ORG\n", True),
|
||||
# Negative Cases
|
||||
("I-ORG", False), ("O", False), ("other-B-label", False),
|
||||
])
|
||||
("I-ORG", False),
|
||||
("O", False),
|
||||
("other-B-label", False),
|
||||
],
|
||||
)
|
||||
def test__is_begin_label(label, desired_output):
|
||||
output = ner_label._is_begin_label(label)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("label, desired_output", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("I-ORG", True), (" \t I-ORG ", True),
|
||||
("I-ORG", True),
|
||||
(" \t I-ORG ", True),
|
||||
# Negative Cases
|
||||
("O", False), ("B-LOC", False),("B-ORG", False),
|
||||
])
|
||||
("O", False),
|
||||
("B-LOC", False),
|
||||
("B-ORG", False),
|
||||
],
|
||||
)
|
||||
def test__is_inside_label(label, desired_output):
|
||||
output = ner_label._is_inside_label(label)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("label, desired_output", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("I-ORG", True), ("B-ORG", True),
|
||||
("I-ORG", True),
|
||||
("B-ORG", True),
|
||||
# Negative Cases
|
||||
("O", False)
|
||||
])
|
||||
("O", False),
|
||||
],
|
||||
)
|
||||
def test__is_multi_token_label(label, desired_output):
|
||||
output = ner_label._is_multi_token_label(label)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("label, desired_output", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("I-Place", "B-Place"), (" \t I-place ", "B-place"),
|
||||
("I-Place", "B-Place"),
|
||||
(" \t I-place ", "B-place"),
|
||||
# Negative Cases
|
||||
("O", "O"), ("B-LOC", "B-LOC"), (" B-ORG ", " B-ORG ")
|
||||
])
|
||||
("O", "O"),
|
||||
("B-LOC", "B-LOC"),
|
||||
(" B-ORG ", " B-ORG "),
|
||||
],
|
||||
)
|
||||
def test__convert_to_begin_label(label, desired_output):
|
||||
output = ner_label._convert_to_begin_label(label)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("label, desired_output", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("B-LOC", "I-LOC"),
|
||||
(" B-ORG ", "I-ORG"),
|
||||
# Negative Cases
|
||||
("", ""), ("O", "O"), ("I-Place", "I-Place"),
|
||||
(" \t I-place ", " \t I-place ")
|
||||
])
|
||||
("", ""),
|
||||
("O", "O"),
|
||||
("I-Place", "I-Place"),
|
||||
(" \t I-place ", " \t I-place "),
|
||||
],
|
||||
)
|
||||
def test__convert_to_inside_label(label, desired_output):
|
||||
output = ner_label._convert_to_inside_label(label)
|
||||
assert output == desired_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("begin_label, inside_label, desired_output", [
|
||||
@pytest.mark.parametrize(
|
||||
"begin_label, inside_label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("", "I-LOC", True),
|
||||
("B-LOC", "I-ORG", True),
|
||||
("", "I-ORG", True),
|
||||
# Negative Cases
|
||||
("", "", False), ("O", "O", False), ("", "", False),
|
||||
("", "", False),
|
||||
("O", "O", False),
|
||||
("", "", False),
|
||||
("B-LOC", "O", False),
|
||||
("B-LOC", "B-ORG", False),
|
||||
("B-LOC", "I-LOC", False),
|
||||
(" B-ORG ", "I-ORG", False),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test__is_missing_begin_label(begin_label, inside_label, desired_output):
|
||||
output = ner_label._is_missing_begin_label(begin_label, inside_label)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_input_char_set", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_tokens, ocr_tokens, desired_input_char_set",
|
||||
[
|
||||
(["a", "b"], ["c", "d"], set("abcd")),
|
||||
(["New", "York"], ["is", "big"], set("NewYorkisbig")),
|
||||
(["word1", "word2"], ["word1", "word2"], set("word12")),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test__find_gap_char_candidates(gt_tokens, ocr_tokens, desired_input_char_set):
|
||||
gap_char_candidates, input_char_set = ner_label._find_gap_char_candidates(gt_tokens, ocr_tokens)
|
||||
gap_char_candidates, input_char_set = ner_label._find_gap_char_candidates(
|
||||
gt_tokens, ocr_tokens
|
||||
)
|
||||
assert input_char_set == desired_input_char_set
|
||||
assert ner_label.GAP_CHAR_SET.difference(input_char_set) == gap_char_candidates
|
||||
|
||||
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, raised_exception",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_labels, gt_tokens, ocr_tokens, raised_exception",
|
||||
[
|
||||
(["o"], ["New York"], ["NewYork"], ValueError), # non-atomic gt_token
|
||||
(["o"], ["NewYork"], ["New York"], ValueError), # non-atomic ocr_token
|
||||
(["o"], [" @ New"], ["@ @"], ValueError), # non-atomic tokens with GAP_CHAR
|
||||
(["o", "o"], ["New"], ["New"], ValueError), # num gt_labels != num gt_tokens
|
||||
(["o"], ["@"], ["New"], ner_label.GapCharError), # invalid token with gap char only (gt_token)
|
||||
(["o"], ["New"], ["@"], ner_label.GapCharError), # invalid token with gap char only (ocr_token)
|
||||
(["o", "o"], ["New", "@"], ["New", "@"], ner_label.GapCharError), # invalid token (both)
|
||||
(["o"], [" \n\t@@"], ["New"], ner_label.GapCharError), # invalid token with gap char and space chars (gt_token)
|
||||
(["o"], ["New"], [" \n\t@"], ner_label.GapCharError), # invalid token with gap char and space chars (ocr_token)
|
||||
(
|
||||
["o"],
|
||||
["@"],
|
||||
["New"],
|
||||
ner_label.GapCharError,
|
||||
), # invalid token with gap char only (gt_token)
|
||||
(
|
||||
["o"],
|
||||
["New"],
|
||||
["@"],
|
||||
ner_label.GapCharError,
|
||||
), # invalid token with gap char only (ocr_token)
|
||||
(
|
||||
["o", "o"],
|
||||
["New", "@"],
|
||||
["New", "@"],
|
||||
ner_label.GapCharError,
|
||||
), # invalid token (both)
|
||||
(
|
||||
["o"],
|
||||
[" \n\t@@"],
|
||||
["New"],
|
||||
ner_label.GapCharError,
|
||||
), # invalid token with gap char and space chars (gt_token)
|
||||
(
|
||||
["o"],
|
||||
["New"],
|
||||
[" \n\t@"],
|
||||
ner_label.GapCharError,
|
||||
), # invalid token with gap char and space chars (ocr_token)
|
||||
(["o"], [""], ["New"], ValueError), # invalid token: empty string (gt_token)
|
||||
(["o"], ["New"], [""], ValueError), # invalid token: empty string (ocr_token)
|
||||
(["o"], [" \n\t"], ["New"], ValueError), # invalid token: space characters only (gt_token)
|
||||
(["o"], ["New"], [" \n\t"], ValueError), # invalid token: space characters only (ocr_token)
|
||||
(
|
||||
["o"],
|
||||
[" \n\t"],
|
||||
["New"],
|
||||
ValueError,
|
||||
), # invalid token: space characters only (gt_token)
|
||||
(
|
||||
["o"],
|
||||
["New"],
|
||||
[" \n\t"],
|
||||
ValueError,
|
||||
), # invalid token: space characters only (ocr_token)
|
||||
(["o"], ["New"], ["New"], None), # positive case
|
||||
(["o"], ["New@"], ["New"], None), # positive case with gap char
|
||||
(["o"], ["New"], ["@@New"], None), # positive case with gap char
|
||||
])
|
||||
def test__propagate_label_to_ocr_error(gt_labels, gt_tokens, ocr_tokens, raised_exception):
|
||||
],
|
||||
)
|
||||
def test__propagate_label_to_ocr_error(
|
||||
gt_labels, gt_tokens, ocr_tokens, raised_exception
|
||||
):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char='@')
|
||||
ner_label._propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens, gap_char="@"
|
||||
)
|
||||
else:
|
||||
ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char='@')
|
||||
ner_label._propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens, gap_char="@"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test__propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
|
||||
gap_char_candidates, _ = ner_label._find_gap_char_candidates(gt_tokens, ocr_tokens)
|
||||
# run regression test for each GAP_CHAR candidate to make sure
|
||||
# label propagate is function correctly
|
||||
for gap_char in gap_char_candidates:
|
||||
ocr_labels, _, _, _ = ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char)
|
||||
ocr_labels, _, _, _ = ner_label._propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char
|
||||
)
|
||||
assert ocr_labels == desired_ocr_labels
|
||||
|
||||
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, raised_exception", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_labels, gt_tokens, ocr_tokens, raised_exception",
|
||||
[
|
||||
(["o"], ["New"], ["New"], None), # positive case
|
||||
(["o"], ["New@"], ["New"], None), # positive case with gap char
|
||||
(["o"], ["New"], ["@@New"], None), # positive case with gap char
|
||||
(["o"], list(ner_label.GAP_CHAR_SET), [""], ner_label.GapCharError), # input char set == GAP_CHAR_SET
|
||||
(["o"], [""], list(ner_label.GAP_CHAR_SET), ner_label.GapCharError), # input char set == GAP_CHAR_SET
|
||||
(
|
||||
["o"],
|
||||
list(ner_label.GAP_CHAR_SET),
|
||||
[""],
|
||||
ner_label.GapCharError,
|
||||
), # input char set == GAP_CHAR_SET
|
||||
(
|
||||
["o"],
|
||||
[""],
|
||||
list(ner_label.GAP_CHAR_SET),
|
||||
ner_label.GapCharError,
|
||||
), # input char set == GAP_CHAR_SET
|
||||
# all possible gap chars set split between ocr and gt tokens
|
||||
(["o"], list(ner_label.GAP_CHAR_SET)[:10], list(ner_label.GAP_CHAR_SET)[10:], ner_label.GapCharError),
|
||||
])
|
||||
def test_propagate_label_to_ocr_error(gt_labels, gt_tokens, ocr_tokens, raised_exception):
|
||||
(
|
||||
["o"],
|
||||
list(ner_label.GAP_CHAR_SET)[:10],
|
||||
list(ner_label.GAP_CHAR_SET)[10:],
|
||||
ner_label.GapCharError,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_propagate_label_to_ocr_error(
|
||||
gt_labels, gt_tokens, ocr_tokens, raised_exception
|
||||
):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
|
||||
else:
|
||||
ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
|
||||
|
||||
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
|
||||
ocr_labels, _, _, _ = ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
|
||||
ocr_labels, _, _, _ = ner_label.propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens
|
||||
)
|
||||
assert ocr_labels == desired_ocr_labels
|
||||
|
||||
@pytest.mark.parametrize("tokens, labels, label_top, desired_output",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, labels, label_top, desired_output",
|
||||
[
|
||||
(
|
||||
["New", "York", "is", "big"],
|
||||
["B-place", "I-place", "o", "o"],
|
||||
True,
|
||||
"B-place I-place o o \n" +
|
||||
"New York is big \n"
|
||||
"B-place I-place o o \n" + "New York is big \n",
|
||||
),
|
||||
(
|
||||
["New", "York", "is", "big"],
|
||||
["B-place", "I-place", "o", "o"],
|
||||
False,
|
||||
"New York is big \n" +
|
||||
"B-place I-place o o \n"
|
||||
"New York is big \n" + "B-place I-place o o \n",
|
||||
),
|
||||
],
|
||||
)
|
||||
])
|
||||
def test_format_label(tokens, labels, label_top, desired_output):
|
||||
output = ner_label.format_labels(tokens, labels, label_top=label_top)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_format_gt_ocr_w_labels(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
|
||||
ocr_labels, aligned_gt, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
|
||||
ner_label.format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, aligned_gt, aligned_ocr)
|
||||
ocr_labels, aligned_gt, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens
|
||||
)
|
||||
ner_label.format_label_propagation(
|
||||
gt_tokens, gt_labels, ocr_tokens, ocr_labels, aligned_gt, aligned_ocr
|
||||
)
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from genalog.text import preprocess
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize("token, replacement, desired_output",
|
||||
from genalog.text import preprocess
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"token, replacement, desired_output",
|
||||
[
|
||||
("", "_", ""), # Do nothing to empty string
|
||||
(" ", "_", " "), # Do nothing to whitespaces
|
||||
|
@ -11,49 +14,37 @@ import pytest
|
|||
("a s\nc\tii", "_", "a s\nc\tii"),
|
||||
("ascii·", "_", "ascii"), # Tokens with non-ASCII values
|
||||
("·", "_", "_"), # Tokens with non-ASCII values
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_remove_non_ascii(token, replacement, desired_output):
|
||||
for code in range(128, 1000): # non-ASCII values
|
||||
token.replace("·", chr(code))
|
||||
output = preprocess.remove_non_ascii(token, replacement)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("s, desired_output",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s, desired_output",
|
||||
[
|
||||
(
|
||||
" New \t \n",
|
||||
["New"]
|
||||
),
|
||||
(" New \t \n", ["New"]),
|
||||
# Mixed in gap char "@"
|
||||
(
|
||||
" @ @",
|
||||
["@", "@"]
|
||||
),
|
||||
(
|
||||
"New York is big",
|
||||
["New", "York", "is", "big"]
|
||||
),
|
||||
(" @ @", ["@", "@"]),
|
||||
("New York is big", ["New", "York", "is", "big"]),
|
||||
# Mixed multiple spaces and tabs
|
||||
(
|
||||
" New York \t is \t big",
|
||||
["New", "York", "is", "big"]
|
||||
),
|
||||
(" New York \t is \t big", ["New", "York", "is", "big"]),
|
||||
# Mixed in punctuation
|
||||
(
|
||||
"New .York is, big !",
|
||||
["New", ".York", "is,", "big", "!"]
|
||||
),
|
||||
("New .York is, big !", ["New", ".York", "is,", "big", "!"]),
|
||||
# Mixed in gap char "@"
|
||||
(
|
||||
"@N@ew York@@@is,\t big@@@@@",
|
||||
["@N@ew", "York@@@is,", "big@@@@@"]
|
||||
("@N@ew York@@@is,\t big@@@@@", ["@N@ew", "York@@@is,", "big@@@@@"]),
|
||||
],
|
||||
)
|
||||
])
|
||||
def test_tokenize(s, desired_output):
|
||||
output = preprocess.tokenize(s)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("tokens, desired_output",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, desired_output",
|
||||
[
|
||||
(
|
||||
["New", "York", "is", "big"],
|
||||
|
@ -68,34 +59,54 @@ def test_tokenize(s, desired_output):
|
|||
(
|
||||
["@N@ew", "York@@@is,", "big@@@@@"],
|
||||
"@N@ew York@@@is, big@@@@@",
|
||||
),
|
||||
],
|
||||
)
|
||||
])
|
||||
def test_join_tokens(tokens, desired_output):
|
||||
output = preprocess.join_tokens(tokens)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("c, desired_output",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"c, desired_output",
|
||||
[
|
||||
# Gap char
|
||||
(GAP_CHAR, False),
|
||||
# Alphabet char
|
||||
('a', False), ('A', False),
|
||||
("a", False),
|
||||
("A", False),
|
||||
# Punctuation
|
||||
('.', False), ('!', False), (',', False), ('-', False),
|
||||
(".", False),
|
||||
("!", False),
|
||||
(",", False),
|
||||
("-", False),
|
||||
# Token separators
|
||||
(' ', True), ('\n', True), ('\t', True)
|
||||
])
|
||||
(" ", True),
|
||||
("\n", True),
|
||||
("\t", True),
|
||||
],
|
||||
)
|
||||
def test__is_spacing(c, desired_output):
|
||||
assert desired_output == preprocess._is_spacing(c)
|
||||
|
||||
@pytest.mark.parametrize("text, desired_output", [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, desired_output",
|
||||
[
|
||||
("", ""),
|
||||
("w .", "w ."), ("w !", "w !"), ("w ?", "w ?"),
|
||||
("w /.", "w /."), ("w /!", "w /!"), ("w /?", "w /?"),
|
||||
("w .", "w ."),
|
||||
("w !", "w !"),
|
||||
("w ?", "w ?"),
|
||||
("w /.", "w /."),
|
||||
("w /!", "w /!"),
|
||||
("w /?", "w /?"),
|
||||
("w1 , w2 .", "w1 , w2 ."),
|
||||
("w1 . w2 .", "w1 . \nw2 ."), ("w1 /. w2 /.", "w1 /. \nw2 /."),
|
||||
("w1 ! w2 .", "w1 ! \nw2 ."), ("w1 /! w2 /.", "w1 /! \nw2 /."),
|
||||
("w1 ? w2 .", "w1 ? \nw2 ."), ("w1 /? w2 /.", "w1 /? \nw2 /."),
|
||||
("w1 . w2 .", "w1 . \nw2 ."),
|
||||
("w1 /. w2 /.", "w1 /. \nw2 /."),
|
||||
("w1 ! w2 .", "w1 ! \nw2 ."),
|
||||
("w1 /! w2 /.", "w1 /! \nw2 /."),
|
||||
("w1 ? w2 .", "w1 ? \nw2 ."),
|
||||
("w1 /? w2 /.", "w1 /? \nw2 /."),
|
||||
("U.S. . w2 .", "U.S. . \nw2 ."),
|
||||
("w1 ??? w2 .", "w1 ??? w2 ."), # not splitting
|
||||
("w1 !!! w2 .", "w1 !!! w2 ."),
|
||||
|
@ -108,17 +119,30 @@ def test__is_spacing(c, desired_output):
|
|||
("w1 /? /? /? /? w2 /.", "w1 /? /? /? /? \nw2 /."),
|
||||
("w1 ! ! ! ! w2 .", "w1 ! ! ! ! \nw2 ."),
|
||||
("w1 /! /! /! /! w2 /.", "w1 /! /! /! /! \nw2 /."),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_split_sentences(text, desired_output):
|
||||
assert desired_output == preprocess.split_sentences(text)
|
||||
|
||||
@pytest.mark.parametrize("token, desired_output", [
|
||||
("", False), (" ", False), ("\n", False), ("\t", False),
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"token, desired_output",
|
||||
[
|
||||
("", False),
|
||||
(" ", False),
|
||||
("\n", False),
|
||||
("\t", False),
|
||||
(" \n \t", False),
|
||||
("...", False),
|
||||
("???", False), ("!!!", False),
|
||||
(".", True), ("!", True), ("?", True),
|
||||
("/.", True), ("/!", True), ("/?", True),
|
||||
])
|
||||
("???", False),
|
||||
("!!!", False),
|
||||
(".", True),
|
||||
("!", True),
|
||||
("?", True),
|
||||
("/.", True),
|
||||
("/!", True),
|
||||
("/?", True),
|
||||
],
|
||||
)
|
||||
def test_is_sentence_separator(token, desired_output):
|
||||
assert desired_output == preprocess.is_sentence_separator(token)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import random
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from genalog.text import alignment
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
|
||||
|
||||
|
||||
def random_utf8_char(byte_len=1):
|
||||
if byte_len == 1:
|
||||
return chr(random.randint(0, 0x007F))
|
||||
|
@ -16,25 +18,48 @@ def random_utf8_char(byte_len=1):
|
|||
elif byte_len == 4:
|
||||
return chr(random.randint(0xFFFF, 0x10FFFF))
|
||||
else:
|
||||
raise ValueError(f"Invalid byte length: {byte_len}." +
|
||||
"utf-8 does not encode characters with more than 4 bytes in length")
|
||||
raise ValueError(
|
||||
f"Invalid byte length: {byte_len}."
|
||||
+ "utf-8 does not encode characters with more than 4 bytes in length"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("num_utf_char_to_test", [100]) # Number of char per byte length
|
||||
@pytest.mark.parametrize("byte_len", [1,2,3,4]) # UTF does not encode with more than 4 bytes
|
||||
@pytest.mark.parametrize("gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", ALIGNMENT_REGRESSION_TEST_CASES)
|
||||
def test_align(num_utf_char_to_test, byte_len, gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
|
||||
|
||||
invalid_char = set(gt_txt).union(set(GAP_CHAR)) # character to replace to cannot be in this set
|
||||
@pytest.mark.parametrize(
|
||||
"num_utf_char_to_test", [100]
|
||||
) # Number of char per byte length
|
||||
@pytest.mark.parametrize(
|
||||
"byte_len", [1, 2, 3, 4]
|
||||
) # UTF does not encode with more than 4 bytes
|
||||
@pytest.mark.parametrize(
|
||||
"gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise",
|
||||
ALIGNMENT_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_align(
|
||||
num_utf_char_to_test,
|
||||
byte_len,
|
||||
gt_txt,
|
||||
noisy_txt,
|
||||
expected_aligned_gt,
|
||||
expected_aligned_noise,
|
||||
):
|
||||
|
||||
invalid_char = set(gt_txt).union(
|
||||
set(GAP_CHAR)
|
||||
) # character to replace to cannot be in this set
|
||||
for _ in range(num_utf_char_to_test):
|
||||
utf_char = random_utf8_char(byte_len)
|
||||
while utf_char in invalid_char: # find a utf char not in the input string and not GAP_CHAR
|
||||
while (
|
||||
utf_char in invalid_char
|
||||
): # find a utf char not in the input string and not GAP_CHAR
|
||||
utf_char = random_utf8_char(byte_len)
|
||||
char_to_replace = random.choice(list(invalid_char)) if gt_txt else ""
|
||||
|
||||
gt_txt_sub = gt_txt.replace(char_to_replace, utf_char)
|
||||
noisy_txt_sub = noisy_txt.replace(char_to_replace, utf_char)
|
||||
gt_txt.replace(char_to_replace, utf_char)
|
||||
noisy_txt.replace(char_to_replace, utf_char)
|
||||
expected_aligned_gt_sub = expected_aligned_gt.replace(char_to_replace, utf_char)
|
||||
expected_aligned_noise_sub = expected_aligned_noise.replace(char_to_replace, utf_char)
|
||||
expected_aligned_noise_sub = expected_aligned_noise.replace(
|
||||
char_to_replace, utf_char
|
||||
)
|
||||
|
||||
# Run alignment
|
||||
aligned_gt, aligned_noise = alignment.align(gt_txt, noisy_txt)
|
||||
|
@ -42,7 +67,12 @@ def test_align(num_utf_char_to_test, byte_len, gt_txt, noisy_txt, expected_align
|
|||
aligned_gt = aligned_gt.replace(char_to_replace, utf_char)
|
||||
aligned_noise = aligned_noise.replace(char_to_replace, utf_char)
|
||||
if aligned_gt != expected_aligned_gt_sub:
|
||||
expected_alignment = alignment._format_alignment(expected_aligned_gt_sub, expected_aligned_noise_sub)
|
||||
expected_alignment = alignment._format_alignment(
|
||||
expected_aligned_gt_sub, expected_aligned_noise_sub
|
||||
)
|
||||
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
|
||||
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
|
||||
)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
; [tox]
|
||||
; envlist = flake8, py36 # add other python versions if necessary
|
||||
|
||||
; [testenv]
|
||||
; # Reading additional dependencies to run the test
|
||||
; # https://tox.readthedocs.io/en/latest/example/basic.html#depending-on-requirements-txt-or-defining-constraints
|
||||
; deps = -rdev-requirements.txt
|
||||
; commands =
|
||||
; pytest
|
||||
|
||||
; [testenv:flake8]
|
||||
; deps = flake8
|
||||
; skip_install = True
|
||||
; commands = flake8 .
|
||||
|
||||
; # Configurations for running pytest
|
||||
[pytest]
|
||||
junit_family=xunit2
|
||||
testpaths =
|
||||
tests
|
||||
addopts =
|
||||
-rsx --cov=genalog --cov-report=html --cov-report=term-missing --cov-report=xml --junitxml=junit/test-results.xml
|
||||
|
||||
[flake8]
|
||||
# Configs for flake8-import-order, see https://pypi.org/project/flake8-import-order/ for more info.
|
||||
import-order-style=edited
|
||||
application-import-names=genalog, tests
|
||||
# Native flake8 configs
|
||||
max-line-length = 140
|
||||
exclude =
|
||||
build, dist
|
||||
.env*,.venv* # local virtual environments
|
||||
.tox
|
||||
|
||||
; [mypy]
|
||||
; ignore_missing_imports = True
|
Загрузка…
Ссылка в новой задаче