Merge pull request #10 from microsoft/laserprec/use_linter

Use flake8 as linter and fix code format issues.
This commit is contained in:
Jianjie Liu 2021-01-25 17:04:13 -05:00 коммит произвёл GitHub
Родитель ad504b4d7b 4c8445b874
Коммит cda2c1e77d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
46 изменённых файлов: 3815 добавлений и 2116 удалений

Просмотреть файл

@ -35,31 +35,20 @@ steps:
displayName: 'Use Python $(python.version)'
- bash: |
python -m venv .venv
displayName: 'Create virtual environment'
- bash: |
if [[ '$(Agent.OS)' == Windows* ]]
then
source .venv/Scripts/activate
else
source .venv/bin/activate
fi
pip install --upgrade pip
pip install setuptools wheel
pip install -r requirements.txt
pip install pytest==5.3.5 pytest-cov==2.8.1
python -m pip install --upgrade pip
python -m pip install setuptools wheel
python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt
workingDirectory: $(Build.SourcesDirectory)
displayName: 'Install dependencies'
- bash: |
if [[ '$(Agent.OS)' == Windows* ]]
then
source .venv/Scripts/activate
else
source .venv/bin/activate
fi
python -m pytest tests --cov=genalog --doctest-modules --junitxml=junit/test-results.xml --cov-report=xml --cov-report=html
python -m flake8
workingDirectory: $(Build.SourcesDirectory)
displayName: 'Run Linter (flake8)'
- bash: |
python -m pytest tests
env:
BLOB_KEY : $(BLOB_KEY)
SEARCH_SERVICE_KEY: $(SEARCH_SERVICE_KEY)
@ -86,12 +75,6 @@ steps:
displayName: 'Publish test coverage'
- bash: |
if [[ '$(Agent.OS)' == Windows* ]]
then
source .venv/Scripts/activate
else
source .venv/bin/activate
fi
python setup.py bdist_wheel --build-number $(Build.BuildNumber) --dist-dir dist
workingDirectory: $(Build.SourcesDirectory)
displayName: 'Building wheel package'

Просмотреть файл

@ -1,16 +1,20 @@
from genalog.degradation import effect
from enum import Enum
import copy
import inspect
from enum import Enum
from genalog.degradation import effect
DEFAULT_METHOD_PARAM_TO_INCLUDE = "src"
class ImageState(Enum):
ORIGINAL_STATE = "ORIGINAL_STATE"
CURRENT_STATE = "CURRENT_STATE"
class Degrader():
class Degrader:
""" An object for applying multiple degradation effects onto an image"""
def __init__(self, effects):
"""Initialize a Degrader object
@ -73,14 +77,21 @@ class Degrader():
# Try to find corresponding degradation method in the module
method = getattr(effect, method_name)
except AttributeError:
raise ValueError(f"Method '{method_name}' is not defined in 'genalog.degradation.effect'")
raise ValueError(
f"Method '{method_name}' is not defined in 'genalog.degradation.effect'"
)
# Get the method signatures
method_sign = inspect.signature(method)
# Check if method parameters are valid
for param_name in method_kwargs.keys(): # i.e. ["operation", "kernel_shape", ...]
if not param_name in method_sign.parameters:
for (
param_name
) in method_kwargs.keys(): # i.e. ["operation", "kernel_shape", ...]
if param_name not in method_sign.parameters:
method_args = [param for param in method_sign.parameters]
raise ValueError(f"Invalid parameter name '{param_name}' for method 'genalog.degradation.effect.{method_name}()'. Method parameter names are: {method_args}")
raise ValueError(
f"Invalid parameter name '{param_name}' for method 'genalog.degradation.effect.{method_name}()'. " +
f"Method parameter names are: {method_args}"
)
def _add_default_method_param(self):
"""All methods in "genalog.degradation.effect" module have a required
@ -90,7 +101,9 @@ class Degrader():
for effect_tuple in self.effects_to_apply:
method_name, method_kwargs = effect_tuple
if DEFAULT_METHOD_PARAM_TO_INCLUDE not in method_kwargs:
method_kwargs[DEFAULT_METHOD_PARAM_TO_INCLUDE] = ImageState.CURRENT_STATE
method_kwargs[
DEFAULT_METHOD_PARAM_TO_INCLUDE
] = ImageState.CURRENT_STATE
def apply_effects(self, src):
"""Apply degradation effects in sequence

Просмотреть файл

@ -1,6 +1,8 @@
from math import floor
import cv2
import numpy as np
from math import floor
def blur(src, radius=5):
"""Wrapper function for cv2.GaussianBlur
@ -17,6 +19,7 @@ def blur(src, radius=5):
"""
return cv2.GaussianBlur(src, (radius, radius), cv2.BORDER_DEFAULT)
def overlay_weighted(src, background, alpha, beta, gamma=0):
"""overlay two images together, pixels from each image is weighted as follow
@ -36,6 +39,7 @@ def overlay_weighted(src, background, alpha, beta, gamma=0):
"""
return cv2.addWeighted(src, alpha, background, beta, gamma).astype(np.uint8)
def overlay(src, background):
"""Overlay two images together via bitwise-and:
@ -50,6 +54,7 @@ def overlay(src, background):
"""
return cv2.bitwise_and(src, background).astype(np.uint8)
def translation(src, offset_x, offset_y):
"""Shift the image in x, y direction
@ -69,6 +74,7 @@ def translation(src, offset_x, offset_y):
dst = cv2.warpAffine(src, trans_matrix, (cols, rows), borderValue=255)
return dst.astype(np.uint8)
def bleed_through(src, background=None, alpha=0.8, gamma=0, offset_x=0, offset_y=5):
"""Apply bleed through effect, background is flipped horizontally.
@ -96,6 +102,7 @@ def bleed_through(src, background=None, alpha=0.8, gamma=0, offset_x=0, offset_y
beta = 1 - alpha
return overlay_weighted(src, background, alpha, beta, gamma)
def pepper(src, amount=0.05):
"""Randomly sprinkle dark pixels on src image.
Wrapper function for skimage.util.noise.random_noise().
@ -118,6 +125,7 @@ def pepper(src, amount=0.05):
dst[noise < amount] = 0
return dst.astype(np.uint8)
def salt(src, amount=0.3):
"""Randomly sprinkle white pixels on src image.
Wrapper function for skimage.util.noise.random_noise().
@ -140,6 +148,7 @@ def salt(src, amount=0.3):
dst[noise < amount] = 255
return dst.astype(np.uint8)
def salt_then_pepper(src, salt_amount=0.1, pepper_amount=0.05):
"""Randomly add salt then add pepper onto the image.
@ -159,6 +168,7 @@ def salt_then_pepper(src, salt_amount=0.1, pepper_amount=0.05):
salted = salt(src, amount=salt_amount)
return pepper(salted, amount=pepper_amount)
def pepper_then_salt(src, pepper_amount=0.05, salt_amount=0.1):
"""Randomly add pepper then salt onto the image.
@ -178,6 +188,7 @@ def pepper_then_salt(src, pepper_amount=0.05, salt_amount=0.1):
peppered = pepper(src, amount=pepper_amount)
return salt(peppered, amount=salt_amount)
def create_2D_kernel(kernel_shape, kernel_type="ones"):
"""Create 2D kernel for morphological operations.
@ -243,11 +254,21 @@ def create_2D_kernel(kernel_shape, kernel_type="ones"):
elif kernel_type == "ellipse":
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_shape)
else:
valid_kernel_types = {"ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"}
raise ValueError(f"Invalid kernel_type: {kernel_type}. Valid types are {valid_kernel_types}")
valid_kernel_types = {
"ones",
"upper_triangle",
"lower_triangle",
"x",
"plus",
"ellipse",
}
raise ValueError(
f"Invalid kernel_type: {kernel_type}. Valid types are {valid_kernel_types}"
)
return kernel.astype(np.uint8)
def morphology(src, operation="open", kernel_shape=(3, 3), kernel_type="ones"):
"""Dynamic calls different morphological operations
("open", "close", "dilate" and "erode") with the given parameters
@ -280,7 +301,10 @@ def morphology(src, operation="open", kernel_shape=(3,3), kernel_type="ones"):
return erode(src, kernel)
else:
valid_operations = ["open", "close", "dilate", "erode"]
raise ValueError(f"Invalid morphology operation '{operation}'. Valid morphological operations are {valid_operations}")
raise ValueError(
f"Invalid morphology operation '{operation}'. Valid morphological operations are {valid_operations}"
)
def open(src, kernel):
""" "open" morphological operation. Like morphological "erosion", it removes
@ -299,6 +323,7 @@ def open(src, kernel):
"""
return cv2.morphologyEx(src, cv2.MORPH_OPEN, kernel)
def close(src, kernel):
""" "close" morphological operation. Like morphological "dilation", it grows the
boundary of the foreground (white pixels), however, it is less destructive than
@ -317,6 +342,7 @@ def close(src, kernel):
"""
return cv2.morphologyEx(src, cv2.MORPH_CLOSE, kernel)
def erode(src, kernel):
""" "erode" morphological operation. Erodes foreground pixels (white pixels).
For more information see:
@ -332,6 +358,7 @@ def erode(src, kernel):
"""
return cv2.erode(src, kernel)
def dilate(src, kernel):
""" "dilate" morphological operation. Grows foreground pixels (white pixels).
For more information see:

Просмотреть файл

@ -1,4 +1,5 @@
from enum import Enum, auto
from enum import auto, Enum
class ContentType(Enum):
PARAGRAPH = auto()
@ -6,14 +7,17 @@ class ContentType(Enum):
IMAGE = auto()
COMPOSITE = auto()
class Content():
class Content:
def __init__(self):
self.iterable = True
self._content = None
def set_content_type(self, content_type):
if type(content_type) != ContentType:
raise TypeError(f"Invalid content type: {content_type}, valid types are {list(ContentType)}")
raise TypeError(
f"Invalid content type: {content_type}, valid types are {list(ContentType)}"
)
self.content_type = content_type
def validate_content(self):
@ -28,6 +32,7 @@ class Content():
def __getitem__(self, key):
return self._content.__getitem__(key)
class Paragraph(Content):
def __init__(self, content):
self.set_content_type(ContentType.PARAGRAPH)
@ -38,6 +43,7 @@ class Paragraph(Content):
if not isinstance(content, str):
raise TypeError(f"Expect a str, but got {type(content)}")
class Title(Content):
def __init__(self, content):
self.set_content_type(ContentType.TITLE)
@ -48,6 +54,7 @@ class Title(Content):
if not isinstance(content, str):
raise TypeError(f"Expect a str, but got {type(content)}")
class CompositeContent(Content):
def __init__(self, content_list, content_type_list):
self.set_content_type(ContentType.COMPOSITE)
@ -82,5 +89,5 @@ class CompositeContent(Content):
"""get a string transparent of the nested object types"""
transparent_str = "["
for content in self._content:
transparent_str += '"' + content.__str__() + "\", "
transparent_str += '"' + content.__str__() + '", '
return transparent_str + "]"

Просмотреть файл

@ -1,14 +1,12 @@
from jinja2 import PackageLoader, FileSystemLoader
from jinja2 import Environment, select_autoescape
from weasyprint import HTML
from cairocffi import FORMAT_ARGB32
import numpy as np
import itertools
import os
import cv2
import cv2
import numpy as np
from cairocffi import FORMAT_ARGB32
from jinja2 import Environment, select_autoescape
from jinja2 import FileSystemLoader, PackageLoader
from weasyprint import HTML
DEFAULT_DOCUMENT_STYLE = {
"language": "en_US",
@ -27,8 +25,10 @@ DEFAULT_STYLE_COMBINATION = {
"hyphenate": [False],
}
class Document(object):
""" A composite object that represents a document """
def __init__(self, content, template, **styles):
"""Initialize a Document object with source template and content
@ -97,21 +97,26 @@ class Document(object):
Arguments:
target -- a filename, file-like object, or None
split_pages {bool} -- true if save each document page as a separate file.
resolution {int} -- the output resolution in PNG pixels per CSS inch. At 300 dpi (the default), PNG pixels match the CSS px unit.
resolution {int} -- the output resolution in PNG pixels per CSS inch. At 300 dpi (the default),
PNG pixels match the CSS px unit.
Returns:
The image as bytes if target is not provided or None, otherwise None (the PDF is written to target)
"""
if target != None and split_pages:
if target is not None and split_pages:
# get destination filename and extension
filename, ext = os.path.splitext(target)
for page_num, page in enumerate(self._document.pages):
page_name = filename + f"_pg_{page_num}" + ext
self._document.copy([page]).write_png(target=page_name, resolution=resolution)
self._document.copy([page]).write_png(
target=page_name, resolution=resolution
)
return None
elif target == None:
elif target is None:
# return image bytes string if no target is specified
png_bytes, png_width, png_height = self._document.write_png(target=target, resolution=resolution)
png_bytes, png_width, png_height = self._document.write_png(
target=target, resolution=resolution
)
return png_bytes
else:
return self._document.write_png(target=target, resolution=resolution)
@ -131,18 +136,23 @@ class Document(object):
"""
# Method below returns a cairocffi.ImageSurface object
# https://cairocffi.readthedocs.io/en/latest/api.html#cairocffi.ImageSurface
surface, width, height = self._document.write_image_surface(resolution=resolution)
surface, width, height = self._document.write_image_surface(
resolution=resolution
)
img_format = surface.get_format()
# This is BGRA channel in little endian (reverse)
if img_format != FORMAT_ARGB32:
raise RuntimeError(f"Expect surface format to be 'cairocffi.FORMAT_ARGB32', but got {img_format}. Please check the underlining implementation of 'weasyprint.document.Document.write_image_surface()'")
raise RuntimeError(
f"Expect surface format to be 'cairocffi.FORMAT_ARGB32', but got {img_format}." +
"Please check the underlining implementation of 'weasyprint.document.Document.write_image_surface()'"
)
img_buffer = surface.get_data()
# Returns image array in "BGRA" channel
img_array = np.ndarray(shape=(height, width, 4),
dtype=np.uint8,
buffer=img_buffer)
img_array = np.ndarray(
shape=(height, width, 4), dtype=np.uint8, buffer=img_buffer
)
if channel == "GRAYSCALE":
return cv2.cvtColor(img_array, cv2.COLOR_BGRA2GRAY)
elif channel == "RGBA":
@ -155,7 +165,9 @@ class Document(object):
return cv2.cvtColor(img_array, cv2.COLOR_BGRA2BGR)
else:
valid_channels = ["GRAYSCALE", "RGB", "RGBA", "BGR", "BGRA"]
raise ValueError(f"Invalid channel code {channel}. Valid values are: {valid_channels}.")
raise ValueError(
f"Invalid channel code {channel}. Valid values are: {valid_channels}."
)
def update_style(self, **style):
"""Update template variables that controls the document style and re-compile the document to reflect the style change.
@ -175,11 +187,14 @@ class Document(object):
self.styles.update(style)
# Recompile the html template and the document obj
self.compiled_html = self.render_html()
self._document = HTML(string=self.compiled_html).render() # weasyprinter.document.Document object
self._document = HTML(
string=self.compiled_html
).render() # weasyprinter.document.Document object
class DocumentGenerator():
class DocumentGenerator:
""" Document generator class """
def __init__(self, template_path=None):
"""Initialize a DocumentGenerator class
@ -191,17 +206,19 @@ class DocumentGenerator():
if template_path:
self.template_env = Environment(
loader=FileSystemLoader(template_path),
autoescape=select_autoescape(['html', 'xml'])
autoescape=select_autoescape(["html", "xml"]),
)
self.template_list = self.template_env.list_templates()
else:
# Loading built-in templates from the genalog package
self.template_env = Environment(
loader=PackageLoader("genalog.generation", "templates"),
autoescape=select_autoescape(['html', 'xml'])
autoescape=select_autoescape(["html", "xml"]),
)
# Remove macros and css templates from rendering
self.template_list = self.template_env.list_templates(filter_func=DocumentGenerator._keep_template)
self.template_list = self.template_env.list_templates(
filter_func=DocumentGenerator._keep_template
)
self.set_styles_to_generate(DEFAULT_STYLE_COMBINATION)
@ -251,7 +268,9 @@ class DocumentGenerator():
If this parameter is not provided, generator will use default document
styles: DEFAULT_STYLE_COMBINATION
"""
self.styles_to_generate = DocumentGenerator.expand_style_combinations(style_combinations)
self.styles_to_generate = DocumentGenerator.expand_style_combinations(
style_combinations
)
def create_generator(self, content, templates_to_render):
"""Create a Document generator
@ -266,7 +285,9 @@ class DocumentGenerator():
"""
for template_name in templates_to_render:
if template_name not in self.template_list:
raise FileNotFoundError(f"File '{template_name}' not found. Available templates are {self.template_list}")
raise FileNotFoundError(
f"File '{template_name}' not found. Available templates are {self.template_list}"
)
template = self.template_env.get_template(template_name)
for style in self.styles_to_generate:
yield Document(content, template, **style)
@ -308,8 +329,12 @@ class DocumentGenerator():
if not styles:
return []
# Python 2.x+ guarantees that the order in keys() and values() is preserved
style_properties = styles.keys() # ex) ["font_family", "font_size", "hyphenate"]
property_values = styles.values() # ex) [["Calibri", "Times"], ["10px", "12px"], [True]]
style_properties = (
styles.keys()
) # ex) ["font_family", "font_size", "hyphenate"]
property_values = (
styles.values()
) # ex) [["Calibri", "Times"], ["10px", "12px"], [True]]
# Generate all possible combinations:
# [("Calibri", "10px", True), ("Calibri", "12px", True), ...]

Просмотреть файл

@ -1,20 +1,20 @@
"""Uses the python sdk to make operation on Azure Blob storage.
see: https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python
"""
import os
import time
import asyncio
import base64
import hashlib
import json
import asyncio
import os
import random
from multiprocessing import Pool
from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient
from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from tqdm import tqdm
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
import base64
import aiofiles
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.storage.blob import BlobServiceClient
from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient
from tqdm import tqdm
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
# maximum number of simultaneous requests
REQUEST_SEMAPHORE = asyncio.Semaphore(50)
@ -24,11 +24,19 @@ FILE_SEMAPHORE = asyncio.Semaphore(500)
MAX_RETRIES = 5
class GrokBlobClient:
"""This class is a client that is used to upload and delete files from Azure Blob storage
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python
"""
def __init__(self, datasource_container_name, blob_account_name, blob_key, projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME):
def __init__(
self,
datasource_container_name,
blob_account_name,
blob_key,
projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME,
):
"""Creates the blob storage client given the key and storage account name
Args:
@ -42,8 +50,10 @@ class GrokBlobClient:
self.PROJECTIONS_CONTAINER_NAME = projections_container_name
self.BLOB_NAME = blob_account_name
self.BLOB_KEY = blob_key
self.BLOB_CONNECTION_STRING = f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" \
self.BLOB_CONNECTION_STRING = (
f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};"
f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net"
)
@staticmethod
def create_from_env_var():
@ -55,11 +65,24 @@ class GrokBlobClient:
DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"]
BLOB_NAME = os.environ["BLOB_NAME"]
BLOB_KEY = os.environ["BLOB_KEY"]
PROJECTIONS_CONTAINER_NAME = os.environ.get("PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME)
client = GrokBlobClient(DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY, projections_container_name=PROJECTIONS_CONTAINER_NAME)
PROJECTIONS_CONTAINER_NAME = os.environ.get(
"PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME
)
client = GrokBlobClient(
DATASOURCE_CONTAINER_NAME,
BLOB_NAME,
BLOB_KEY,
projections_container_name=PROJECTIONS_CONTAINER_NAME,
)
return client
def upload_images_to_blob(self, src_folder_path, dest_folder_name=None,check_existing_cache=True, use_async=True):
def upload_images_to_blob(
self,
src_folder_path,
dest_folder_name=None,
check_existing_cache=True,
use_async=True,
):
"""Uploads images from the src_folder_path to blob storage at the destination folder.
The destination folder is created if it doesn't exist. If a destination folder is not
given a folder is created named by the md5 hash of the files.
@ -73,7 +96,8 @@ class GrokBlobClient:
"""
self._create_container()
blob_service_client = BlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
self.BLOB_CONNECTION_STRING
)
if dest_folder_name is None:
dest_folder_name = self.get_folder_hash(src_folder_path)
@ -96,28 +120,48 @@ class GrokBlobClient:
if check_existing_cache:
existing_blobs, _ = self.list_blobs(dest_folder_name or "")
existing_blobs = list(map(lambda blob: blob["name"], existing_blobs))
file_blob_names = filter(lambda file_blob_names: not file_blob_names[1] in existing_blobs, zip(files_to_upload, blob_names))
job_args = [get_job_args(file_path, blob_name) for file_path, blob_name in file_blob_names ]
file_blob_names = filter(
lambda file_blob_names: not file_blob_names[1] in existing_blobs,
zip(files_to_upload, blob_names),
)
job_args = [
get_job_args(file_path, blob_name)
for file_path, blob_name in file_blob_names
]
else:
job_args = [get_job_args(file_path, blob_name) for file_path, blob_name in zip(files_to_upload, blob_names)]
job_args = [
get_job_args(file_path, blob_name)
for file_path, blob_name in zip(files_to_upload, blob_names)
]
print("uploading ", len(job_args), "files")
if not use_async:
blob_service_client = BlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
blob_container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
self.BLOB_CONNECTION_STRING
)
blob_container_client = blob_service_client.get_container_client(
self.DATASOURCE_CONTAINER_NAME
)
jobs = [(blob_container_client,) + x for x in job_args]
for _ in tqdm(map(_upload_worker_sync, jobs), total=len(jobs)):
pass
else:
async_blob_service_client = asyncBlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
self.BLOB_CONNECTION_STRING
)
async def async_upload():
async with async_blob_service_client:
async_blob_container_client = async_blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
async_blob_container_client = (
async_blob_service_client.get_container_client(
self.DATASOURCE_CONTAINER_NAME
)
)
jobs = [(async_blob_container_client,) + x for x in job_args]
for f in tqdm(asyncio.as_completed(map(_upload_worker_async,jobs)), total=len(jobs)):
for f in tqdm(
asyncio.as_completed(map(_upload_worker_async, jobs)),
total=len(jobs),
):
await f
loop = asyncio.get_event_loop()
@ -158,53 +202,83 @@ class GrokBlobClient:
blobs_list, blob_service_client = self.list_blobs(folder_name)
for blob in blobs_list:
blob_client = blob_service_client.get_blob_client(
container=self.DATASOURCE_CONTAINER_NAME, blob=blob)
container=self.DATASOURCE_CONTAINER_NAME, blob=blob
)
blob_client.delete_blob()
def list_blobs(self, folder_name):
blob_service_client = BlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
return container_client.list_blobs(name_starts_with=folder_name), blob_service_client
self.BLOB_CONNECTION_STRING
)
container_client = blob_service_client.get_container_client(
self.DATASOURCE_CONTAINER_NAME
)
return (
container_client.list_blobs(name_starts_with=folder_name),
blob_service_client,
)
def _create_container(self):
"""Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist.
"""
"""Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist."""
# Create the BlobServiceClient object which will be used to create a container client
blob_service_client = BlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
self.BLOB_CONNECTION_STRING
)
try:
blob_service_client.create_container(
self.DATASOURCE_CONTAINER_NAME)
blob_service_client.create_container(self.DATASOURCE_CONTAINER_NAME)
except ResourceExistsError:
print("container already exists:", self.DATASOURCE_CONTAINER_NAME)
# create the container for storing ocr projections
try:
print("creating projections storage container:", self.PROJECTIONS_CONTAINER_NAME)
blob_service_client.create_container(
self.PROJECTIONS_CONTAINER_NAME)
print(
"creating projections storage container:",
self.PROJECTIONS_CONTAINER_NAME,
)
blob_service_client.create_container(self.PROJECTIONS_CONTAINER_NAME)
except ResourceExistsError:
print("container already exists:", self.PROJECTIONS_CONTAINER_NAME)
def get_ocr_json(self, remote_path, output_folder, use_async=True):
blob_service_client = BlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
self.BLOB_CONNECTION_STRING
)
container_client = blob_service_client.get_container_client(
self.DATASOURCE_CONTAINER_NAME
)
blobs_list = list(container_client.list_blobs(name_starts_with=remote_path))
container_uri = f"https://{self.BLOB_NAME}.blob.core.windows.net/{self.DATASOURCE_CONTAINER_NAME}"
if use_async:
async_blob_service_client = asyncBlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
self.BLOB_CONNECTION_STRING
)
async def async_download():
async with async_blob_service_client:
async_projection_container_client = async_blob_service_client.get_container_client(self.PROJECTIONS_CONTAINER_NAME)
jobs = list(map(lambda blob :(blob, async_projection_container_client, container_uri, output_folder), blobs_list ))
for f in tqdm(asyncio.as_completed(map(_download_worker_async,jobs)), total=len(jobs)):
async_projection_container_client = (
async_blob_service_client.get_container_client(
self.PROJECTIONS_CONTAINER_NAME
)
)
jobs = list(
map(
lambda blob: (
blob,
async_projection_container_client,
container_uri,
output_folder,
),
blobs_list,
)
)
for f in tqdm(
asyncio.as_completed(map(_download_worker_async, jobs)),
total=len(jobs),
):
await f
loop = asyncio.get_event_loop()
if loop.is_running():
result = loop.create_task(async_download())
@ -212,12 +286,25 @@ class GrokBlobClient:
result = loop.run_until_complete(async_download())
return result
else:
projection_container_client = blob_service_client.get_container_client(self.PROJECTIONS_CONTAINER_NAME)
jobs = list(map(lambda blob : (blob, projection_container_client, container_uri, output_folder), blobs_list))
projection_container_client = blob_service_client.get_container_client(
self.PROJECTIONS_CONTAINER_NAME
)
jobs = list(
map(
lambda blob: (
blob,
projection_container_client,
container_uri,
output_folder,
),
blobs_list,
)
)
print("downloading", len(jobs), "files")
for _ in tqdm(map(_download_worker_sync, jobs), total=len(jobs)):
pass
def _get_projection_path(container_uri, blob):
blob_uri = f"{container_uri}/{blob.name}"
@ -229,10 +316,13 @@ def _get_projection_path(container_uri, blob):
projection_path = projection_path.replace("=", "") + str(projection_path.count("="))
return projection_path
def _download_worker_sync(args):
blob, projection_container_client, container_uri, output_folder = args
projection_path = _get_projection_path(container_uri, blob)
blob_client = projection_container_client.get_blob_client(blob=f"{projection_path}/document.json")
blob_client = projection_container_client.get_blob_client(
blob=f"{projection_path}/document.json"
)
doc = json.loads(blob_client.download_blob().readall())
file_name = os.path.basename(blob.name)
base_name, ext = os.path.splitext(file_name)
@ -242,10 +332,13 @@ def _download_worker_sync(args):
json.dump(text, open(output_file, "w", encoding="utf-8"), ensure_ascii=False)
return output_file
async def _download_worker_async(args):
blob, async_projection_container_client, container_uri, output_folder = args
projection_path = _get_projection_path(container_uri, blob)
async_blob_client = async_projection_container_client.get_blob_client( blob=f"{projection_path}/document.json")
async_blob_client = async_projection_container_client.get_blob_client(
blob=f"{projection_path}/document.json"
)
file_name = os.path.basename(blob.name)
base_name, ext = os.path.splitext(file_name)
for retry in range(MAX_RETRIES):
@ -260,7 +353,7 @@ async def _download_worker_async(args):
json.dump(text, open(output_file, "w"))
return output_file
except ResourceNotFoundError:
print(f"blob doesn't exist in OCR projection. try rerunning OCR", blob.name)
print(f"Blob '{blob.name}'' doesn't exist in OCR projection. try rerunning OCR")
return
except Exception as e:
print("error getting blob OCR projection", blob.name, e)
@ -268,6 +361,7 @@ async def _download_worker_async(args):
# sleep for a bit then retry
asyncio.sleep(2 * random.random())
async def _upload_worker_async(args):
async_blob_container_client, upload_file_path, blob_name = args
async with FILE_SEMAPHORE:
@ -276,22 +370,31 @@ async def _upload_worker_async(args):
for retry in range(MAX_RETRIES):
async with REQUEST_SEMAPHORE:
try:
await async_blob_container_client.upload_blob(name=blob_name, max_concurrency=8, data=data)
await async_blob_container_client.upload_blob(
name=blob_name, max_concurrency=8, data=data
)
return blob_name
except ResourceExistsError:
print("blob already exists:", blob_name)
return
except Exception as e:
print(f"blob upload error. retry count: {retry}/{MAX_RETRIES} :", blob_name, e)
print(
f"blob upload error. retry count: {retry}/{MAX_RETRIES} :",
blob_name,
e,
)
# sleep for a bit then retry
asyncio.sleep(2 * random.random())
return blob_name
def _upload_worker_sync(args):
blob_container_client, upload_file_path, blob_name = args
with open(upload_file_path, "rb") as data:
try:
blob_container_client.upload_blob(name=blob_name, max_concurrency=8, data=data)
blob_container_client.upload_blob(
name=blob_name, max_concurrency=8, data=data
)
except ResourceExistsError:
print("blob already exists:", blob_name)
except Exception as e:

Просмотреть файл

@ -1,10 +1,10 @@
from .rest_client import GrokRestClient
from .blob_client import GrokBlobClient
import time
from .blob_client import GrokBlobClient
from .rest_client import GrokRestClient
class Grok:
@staticmethod
def create_from_env_var():
"""Initializes Grok based on keys in the environment variables.
@ -16,11 +16,20 @@ class Grok:
grok_blob_client = GrokBlobClient.create_from_env_var()
return Grok(grok_rest_client, grok_blob_client)
def __init__(self, grok_rest_client: GrokRestClient, grok_blob_client: GrokBlobClient):
def __init__(
self, grok_rest_client: GrokRestClient, grok_blob_client: GrokBlobClient
):
self.grok_rest_client = grok_rest_client
self.grok_blob_client = grok_blob_client
def run_grok(self, src_folder_path, dest_folder_path, blob_dest_folder=None,cleanup=False,use_async=True):
def run_grok(
self,
src_folder_path,
dest_folder_path,
blob_dest_folder=None,
cleanup=False,
use_async=True,
):
"""Uploads images in the source folder to blob, sets up an indexing pipeline to run
GROK OCR on this blob storage as a source, then dowloads the OCR output json to the destination
folder. There resulting json files are of the same name as the original images except prefixed
@ -40,7 +49,8 @@ class Grok:
"""
print("uploading images to blob")
blob_folder_name, _ = self.grok_blob_client.upload_images_to_blob(
src_folder_path, dest_folder_name=blob_dest_folder, use_async=use_async)
src_folder_path, dest_folder_name=blob_dest_folder, use_async=use_async
)
print(f"images upload under folder {blob_folder_name}")
try:
print("creating and running indexer")
@ -53,7 +63,10 @@ class Grok:
# if not already running start the indexer
print("indexer_status", indexer_status)
if indexer_status["lastResult"] == None or indexer_status["lastResult"]["status"] != "inProgress":
if (
indexer_status["lastResult"] is None
or indexer_status["lastResult"]["status"] != "inProgress"
):
self.grok_rest_client.run_indexer()
time.sleep(1)
@ -62,9 +75,13 @@ class Grok:
if indexer_status["lastResult"]["status"] == "success":
time.sleep(30)
print("fetching ocr json results.")
self.grok_blob_client.get_ocr_json(blob_folder_name, dest_folder_path, use_async=use_async)
self.grok_blob_client.get_ocr_json(
blob_folder_name, dest_folder_path, use_async=use_async
)
print(f"indexer status {indexer_status}")
print(f"finished running indexer. json files saved to {dest_folder_path}")
print(
f"finished running indexer. json files saved to {dest_folder_path}"
)
else:
print("GROK failed", indexer_status["status"])
raise RuntimeError("GROK failed", indexer_status["status"])

Просмотреть файл

@ -6,7 +6,8 @@ OCR Metrics
Accuracy = Correct Words/Total Words (in target strings)
2. Count of edit distance ops:
insert, delete, substitutions; like in the paper "Deep Statistical Analysis of OCR Errors for Effective Post-OCR Processing". This is based on Levenshtein edit distance.
insert, delete, substitutions; like in the paper "Deep Statistical Analysis of OCR Errors for Effective Post-OCR Processing".
This is based on Levenshtein edit distance.
3. By looking at the gaps in alignment we also generate substitution dicts:
e.g: if we have text "a worn coat" and ocr is "a wom coat" , "rn" -> "m" will be captured as a substitution
@ -14,29 +15,33 @@ since the rest of the segments align.The assumption here is that we do not expec
hence collecting and counting these substitutions will be managable.
"""
import string
import argparse
import json
import multiprocessing
import os
import re
import json
import argparse
import multiprocessing
from multiprocessing import Pool
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool
from genalog.text.alignment import GAP_CHAR
from genalog.text.ner_label import _find_gap_char_candidates
from genalog.text.anchor import align_w_anchor
from genalog.text.ner_label import _find_gap_char_candidates
LOG_LEVEL = 0
WORKERS_PER_CPU = 2
def _log(*args, **kwargs):
if LOG_LEVEL:
print(args)
def _trim_whitespace(src_string):
return re.sub(r"\s+", " ", src_string.strip())
def _update_align_stats(src, target, align_stats, substitution_dict, gap_char):
"""Given two string that differ and have no alignment at all,
update the alignment dict and fill in substitution if replacements are found.
@ -70,13 +75,23 @@ def _update_align_stats(src, target, align_stats, substitution_dict, gap_char):
else:
align_stats["replace"] += 1
_log("replacing", source_substr, target_substr)
substitution_dict[source_substr, target_substr] = substitution_dict.get(
(source_substr, target_substr), 0) + 1
substitution_dict[source_substr, target_substr] = (
substitution_dict.get((source_substr, target_substr), 0) + 1
)
_log("spacing count", spacing_count)
align_stats["spacing"] += spacing_count
def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matching_chars_count, \
matching_words_count, matching_alnum_words_count):
def _update_word_stats(
aligned_src,
aligned_target,
gap_char,
start,
end,
matching_chars_count,
matching_words_count,
matching_alnum_words_count,
):
"""Given two string segments that align. update the counts of matching words and characters
Args:
@ -93,7 +108,7 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
tuple(int,int,int): the updated matching_chars_count, matching_words_count, matching_alnum_words_count
"""
aligned_part = aligned_src[start:end]
matching_chars_count += (end-start)
matching_chars_count += end - start
# aligned_part = seq.strip()
_log("aligned", aligned_part, start, end)
if len(aligned_src) != len(aligned_target):
@ -112,10 +127,18 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
# to be compared with the full string to see if they have space before or after
if i == 0:
if start != 0 and (aligned_target[start] != " " or aligned_src[start] != " " ):
if start != 0 and (
aligned_target[start] != " " or aligned_src[start] != " "
):
# if this was the start of the string in the target or source
if not(aligned_src[:start].replace(gap_char,"").replace(" ","") == "" and aligned_target[start-1] == " ") and \
not(aligned_target[:start].replace(gap_char,"").replace(" ","") == "" and aligned_src[start-1] == " "):
if not (
aligned_src[:start].replace(gap_char, "").replace(" ", "") == ""
and aligned_target[start - 1] == " "
) and not (
aligned_target[:start].replace(gap_char, "").replace(" ", "")
== ""
and aligned_src[start - 1] == " "
):
# beginning word not matching completely
_log("removing first match word from count", word, aligned_part)
matching_words_count -= 1
@ -124,10 +147,18 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
continue
if i == len(words) - 1:
if end != len(aligned_target) and (aligned_target[end] != " " or aligned_src[end] != " " ):
if end != len(aligned_target) and (
aligned_target[end] != " " or aligned_src[end] != " "
):
# this was not the end of the string in the src and not end of string in target
if not(aligned_src[end:].replace(gap_char,"").replace(" ","") == "" and aligned_target[end] == " ") and \
not(aligned_target[end:].replace(gap_char,"").replace(" ","") == ""and aligned_src[end] == " "):
if not (
aligned_src[end:].replace(gap_char, "").replace(" ", "") == ""
and aligned_target[end] == " "
) and not (
aligned_target[end:].replace(gap_char, "").replace(" ", "")
== ""
and aligned_src[end] == " "
):
# last word not matching completely
_log("removing last match word from count", word, aligned_part)
matching_words_count -= 1
@ -138,6 +169,7 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
_log("matched alnum count", matching_alnum_words_count)
return matching_chars_count, matching_words_count, matching_alnum_words_count
def _get_align_stats(alignment, src_string, target, gap_char):
"""Given an alignment, this function get the align stats and substitution mapping to
transform the source string to the target string
@ -174,8 +206,15 @@ def _get_align_stats(alignment, src_string, target, gap_char):
matching_words_count = 0
matching_alnum_words_count = 0
align_stats = {"insert": 0, "delete": 0, "replace": 0, "spacing": 0,
"total_chars": char_count, "total_words": word_count, "total_alnum_words": alnum_words_count}
align_stats = {
"insert": 0,
"delete": 0,
"replace": 0,
"spacing": 0,
"total_chars": char_count,
"total_words": word_count,
"total_alnum_words": alnum_words_count,
}
start = 0
_log("######### Alignment ############")
@ -190,35 +229,83 @@ def _get_align_stats(alignment, src_string, target, gap_char):
# since this substring aligns, simple count the number of matching words and chars in and update
# the word stats
end = i
_log("sequences", aligned_src[start:end], aligned_target[start:end], start, end)
_log(
"sequences",
aligned_src[start:end],
aligned_target[start:end],
start,
end,
)
assert aligned_src[start:end] == aligned_target[start:end]
matching_chars_count, matching_words_count, matching_alnum_words_count = _update_word_stats(aligned_src,
aligned_target, gap_char, start, end, matching_chars_count,matching_words_count, matching_alnum_words_count)
(
matching_chars_count,
matching_words_count,
matching_alnum_words_count,
) = _update_word_stats(
aligned_src,
aligned_target,
gap_char,
start,
end,
matching_chars_count,
matching_words_count,
matching_alnum_words_count,
)
start = end + 1
if gap_start is None:
gap_start = end
else:
gap_end = i
if not gap_start is None:
if gap_start is not None:
# since characters now match gap_start:i contains a substring of the characters that didnt align before
# handle this gap alignment by calling _update_align_stats
_log("gap", aligned_src[gap_start:gap_end], aligned_target[gap_start:gap_end], gap_start, gap_end)
_update_align_stats(aligned_src[gap_start:gap_end], aligned_target[gap_start:gap_end], align_stats, substitution_dict, gap_char)
_log(
"gap",
aligned_src[gap_start:gap_end],
aligned_target[gap_start:gap_end],
gap_start,
gap_end,
)
_update_align_stats(
aligned_src[gap_start:gap_end],
aligned_target[gap_start:gap_end],
align_stats,
substitution_dict,
gap_char,
)
gap_start = None
# Now compare any left overs string segments from the for loop
if gap_start is not None:
# handle last alignment gap
_log("last gap", aligned_src[gap_start:], aligned_target[gap_start:])
_update_align_stats(aligned_src[gap_start:], aligned_target[gap_start:], align_stats, substitution_dict, gap_char)
_update_align_stats(
aligned_src[gap_start:],
aligned_target[gap_start:],
align_stats,
substitution_dict,
gap_char,
)
else:
# handle last aligned substring
seq = aligned_src[start:]
aligned_part = seq.strip()
end = len(aligned_src)
_log("last aligned", aligned_part)
matching_chars_count, matching_words_count, matching_alnum_words_count = _update_word_stats(aligned_src,
aligned_target, gap_char, start, end, matching_chars_count,matching_words_count, matching_alnum_words_count)
(
matching_chars_count,
matching_words_count,
matching_alnum_words_count,
) = _update_word_stats(
aligned_src,
aligned_target,
gap_char,
start,
end,
matching_chars_count,
matching_words_count,
matching_alnum_words_count,
)
align_stats["matching_chars"] = matching_chars_count
align_stats["matching_alnum_words"] = matching_alnum_words_count
@ -248,11 +335,17 @@ def get_editops_stats(alignment, gap_char):
aligned_src, aligned_target = alignment
if aligned_src == "" or aligned_target == "":
raise ValueError("one of the input strings is empty")
stats = {"edit_insert": 0, "edit_delete": 0, "edit_replace": 0,
"edit_insert_spacing": 0, "edit_delete_spacing": 0}
stats = {
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 0,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
}
actions = {}
for i, (char_1, char_2) in enumerate(zip(aligned_src, aligned_target)):
if LOG_LEVEL > 1: _log(char_1, char_2)
if LOG_LEVEL > 1:
_log(char_1, char_2)
if char_1 == gap_char:
# insert
if char_2 == " ":
@ -266,12 +359,13 @@ def get_editops_stats(alignment, gap_char):
stats["edit_delete_spacing"] += 1
else:
stats["edit_delete"] += 1
actions[i] = ("D")
actions[i] = "D"
elif char_2 != char_1:
stats["edit_replace"] += 1
actions[i] = ("R", char_2)
return stats, actions
def get_align_stats(alignment, src_string, target, gap_char):
"""Get alignment stats
@ -293,9 +387,11 @@ def get_align_stats(alignment, src_string, target, gap_char):
_log("alignment results")
_log(alignment)
align_stats, substitution_dict = _get_align_stats(
alignment, src_string, target, gap_char)
alignment, src_string, target, gap_char
)
return align_stats, substitution_dict
def get_stats(target, src_string):
"""Get align stats, edit stats, and substitution mappings for transforming the
source string to the target string. Edit stats refers to character level edit operation
@ -311,16 +407,24 @@ def get_stats(target, src_string):
Returns:
tuple(str, str): One dict containing the edit and align stats, another dict containing the substitutions
"""
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
gap_char_candidates, input_char_set = _find_gap_char_candidates(
[src_string], [target]
)
gap_char = (
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
)
alignment = align_w_anchor(src_string, target, gap_char=gap_char)
align_stats, substitution_dict = get_align_stats(alignment,src_string, target, gap_char)
align_stats, substitution_dict = get_align_stats(
alignment, src_string, target, gap_char
)
edit_stats, actions = get_editops_stats(alignment, gap_char)
_log("alignment", align_stats)
return {**edit_stats, **align_stats}, substitution_dict, actions
def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocessing=True):
def get_metrics(
src_text_path, ocr_json_path, folder_hash=None, use_multiprocessing=True
):
"""Given a path to the folder containing the source text and a folder containing
the output OCR json, this generates the metrics for all files in the source folder.
This assumes that the files json folder are of the same name the text files except they
@ -338,7 +442,6 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess
filename and the values are dicts of the substition mappings for that file.
"""
rows = []
substitutions = {}
actions_map = {}
@ -347,16 +450,25 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess
cpu_count = multiprocessing.cpu_count()
n_workers = WORKERS_PER_CPU * cpu_count
job_args = list(map(lambda f: (f, src_text_path, ocr_json_path, folder_hash) , os.listdir(src_text_path)))
job_args = list(
map(
lambda f: (f, src_text_path, ocr_json_path, folder_hash),
os.listdir(src_text_path),
)
)
if use_multiprocessing:
with Pool(n_workers) as pool:
for f, stats, actions, subs in tqdm(pool.imap_unordered(_worker, job_args), total=len(job_args)):
for f, stats, actions, subs in tqdm(
pool.imap_unordered(_worker, job_args), total=len(job_args)
):
substitutions[f] = subs
actions_map[f] = actions
rows.append(stats)
else:
for f, stats, actions, subs in tqdm(map(_worker, job_args), total=len(job_args)):
for f, stats, actions, subs in tqdm(
map(_worker, job_args), total=len(job_args)
):
substitutions[f] = subs
actions_map[f] = actions
rows.append(stats)
@ -364,16 +476,17 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess
df = pd.DataFrame(rows)
return df, substitutions, actions_map
def get_file_metrics(f, src_text_path, ocr_json_path, folder_hash):
src_filename = os.path.join(src_text_path, f)
if folder_hash:
ocr_filename = os.path.join(
ocr_json_path, f"{folder_hash}_{f.split('txt')[0] + 'json'}")
ocr_json_path, f"{folder_hash}_{f.split('txt')[0] + 'json'}"
)
else:
ocr_filename = os.path.join(
ocr_json_path, f"{f.split('txt')[0] + 'json'}")
ocr_filename = os.path.join(ocr_json_path, f"{f.split('txt')[0] + 'json'}")
try:
src_string = open(src_filename, "r", errors='ignore', encoding="utf8").read()
src_string = open(src_filename, "r", errors="ignore", encoding="utf8").read()
except FileNotFoundError:
print(f"File not found: {src_filename}, skipping this file.")
return f, {}, {}, {}
@ -395,10 +508,12 @@ def get_file_metrics(f, src_text_path, ocr_json_path, folder_hash):
stats["filename"] = f
return f, stats, actions, subs
def _worker(args):
(f, src_text_path, ocr_json_path, folder_hash) = args
return get_file_metrics(f, src_text_path, ocr_json_path, folder_hash)
def _get_sorted_text(ocr_json):
if "lines" in ocr_json[0]:
lines = ocr_json[0]["lines"]
@ -407,22 +522,27 @@ def _get_sorted_text(ocr_json):
else:
return ocr_json[0]["text"]
def substitution_dict_to_json(substitution_dict):
"""Converts substitution dict to list of tuples of (source_substring, target_substring, count)
Args:
substitution_dict ([type]): [description]
"""
to_tuple = lambda x: [(k +(x[k],)) for k in x]
to_tuple = lambda x: [(k + (x[k],)) for k in x] # noqa: E731
out = {}
for filename in substitution_dict:
out[filename] = to_tuple(substitution_dict[filename])
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("src", help="path to folder with text files.")
parser.add_argument("ocr", help="folder with ocr json. the filename must match the text filename prefixed by ocr_prefix.")
parser.add_argument(
"ocr",
help="folder with ocr json. the filename must match the text filename prefixed by ocr_prefix.",
)
parser.add_argument("--ocr_prefix", help="the prefix of the ocr files")
parser.add_argument("--output", help="output names of metrics files")

Просмотреть файл

@ -1,17 +1,18 @@
"""Uses the REST api to perform operations on the search service.
see: https://docs.microsoft.com/en-us/rest/api/searchservice/
"""
import requests
import json
import os
import pkgutil
import json
from dotenv import load_dotenv
import time
import sys
import time
from itertools import cycle
import requests
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
API_VERSION = '?api-version=2019-05-06-Preview'
API_VERSION = "?api-version=2019-05-06-Preview"
# 15 min schedule
SCHEDULE_INTERVAL = "PT15M"
@ -25,9 +26,20 @@ class GrokRestClient:
ongoing indexers. The indexing pipeline can allow you to run batch OCR enrichment of documents.
"""
def __init__(self, cognitive_service_key, search_service_key, search_service_name, skillset_name,
index_name, indexer_name, datasource_name, datasource_container_name, blob_account_name, blob_key,
projections_container_name = DEFAULT_PROJECTIONS_CONTAINER_NAME):
def __init__(
self,
cognitive_service_key,
search_service_key,
search_service_name,
skillset_name,
index_name,
indexer_name,
datasource_name,
datasource_container_name,
blob_account_name,
blob_key,
projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME,
):
"""Creates the REST client
Args:
@ -70,8 +82,10 @@ class GrokRestClient:
self.API_VERSION = API_VERSION
self.BLOB_CONNECTION_STRING = f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" \
self.BLOB_CONNECTION_STRING = (
f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};"
f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net"
)
@staticmethod
def create_from_env_var():
@ -85,51 +99,68 @@ class GrokRestClient:
DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"]
BLOB_NAME = os.environ["BLOB_NAME"]
BLOB_KEY = os.environ["BLOB_KEY"]
PROJECTIONS_CONTAINER_NAME = os.environ.get("PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME)
PROJECTIONS_CONTAINER_NAME = os.environ.get(
"PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME
)
client = GrokRestClient(COGNITIVE_SERVICE_KEY, SEARCH_SERVICE_KEY, SEARCH_SERVICE_NAME, SKILLSET_NAME, INDEX_NAME,
INDEXER_NAME, DATASOURCE_NAME, DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY,projections_container_name=PROJECTIONS_CONTAINER_NAME)
client = GrokRestClient(
COGNITIVE_SERVICE_KEY,
SEARCH_SERVICE_KEY,
SEARCH_SERVICE_NAME,
SKILLSET_NAME,
INDEX_NAME,
INDEXER_NAME,
DATASOURCE_NAME,
DATASOURCE_CONTAINER_NAME,
BLOB_NAME,
BLOB_KEY,
projections_container_name=PROJECTIONS_CONTAINER_NAME,
)
return client
def create_skillset(self):
"""Adds a skillset that performs OCR on images
"""
"""Adds a skillset that performs OCR on images"""
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
skillset_json = json.loads(pkgutil.get_data(
__name__, "templates/skillset.json"))
skillset_json = json.loads(
pkgutil.get_data(__name__, "templates/skillset.json")
)
skillset_json["name"] = self.SKILLSET_NAME
skillset_json["cognitiveServices"]["key"] = self.COGNITIVE_SERVICE_KEY
knowledge_store_json = json.loads(pkgutil.get_data(
__name__, "templates/knowledge_store.json"))
knowledge_store_json = json.loads(
pkgutil.get_data(__name__, "templates/knowledge_store.json")
)
knowledge_store_json["storageConnectionString"] = self.BLOB_CONNECTION_STRING
knowledge_store_json["projections"][0]["objects"][0]["storageContainer"] = self.PROJECTIONS_CONTAINER_NAME
knowledge_store_json["projections"][0]["objects"][0][
"storageContainer"
] = self.PROJECTIONS_CONTAINER_NAME
skillset_json["knowledgeStore"] = knowledge_store_json
print(skillset_json)
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}"
r = requests.put(endpoint + self.API_VERSION,
json.dumps(skillset_json), headers=headers)
r = requests.put(
endpoint + self.API_VERSION, json.dumps(skillset_json), headers=headers
)
print("skillset response", r.text)
r.raise_for_status()
print("added skillset", self.SKILLSET_NAME, r)
def create_datasource(self):
"""Attaches the blob data store to the search service as a source for image documents
"""
"""Attaches the blob data store to the search service as a source for image documents"""
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
datasource_json = json.loads(pkgutil.get_data(
__name__, "templates/datasource.json"))
datasource_json = json.loads(
pkgutil.get_data(__name__, "templates/datasource.json")
)
datasource_json["name"] = self.DATASOURCE_NAME
datasource_json["credentials"]["connectionString"] = self.BLOB_CONNECTION_STRING
datasource_json["type"] = "azureblob"
@ -137,27 +168,27 @@ class GrokRestClient:
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/datasources/{self.DATASOURCE_NAME}"
r = requests.put(endpoint + self.API_VERSION,
json.dumps(datasource_json), headers=headers)
r = requests.put(
endpoint + self.API_VERSION, json.dumps(datasource_json), headers=headers
)
print("datasource response", r.text)
r.raise_for_status()
print("added datasource", self.DATASOURCE_NAME, r)
def create_index(self):
"""Create an index with the layoutText column to store OCR output from the enrichment
"""
"""Create an index with the layoutText column to store OCR output from the enrichment"""
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
index_json = json.loads(pkgutil.get_data(
__name__, "templates/index.json"))
index_json = json.loads(pkgutil.get_data(__name__, "templates/index.json"))
index_json["name"] = self.INDEX_NAME
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexes/{self.INDEX_NAME}"
r = requests.put(endpoint + self.API_VERSION,
json.dumps(index_json), headers=headers)
r = requests.put(
endpoint + self.API_VERSION, json.dumps(index_json), headers=headers
)
print("index response", r.text)
r.raise_for_status()
print("created index", self.INDEX_NAME, r)
@ -167,24 +198,26 @@ class GrokRestClient:
The enriched results are pushed to the index.
"""
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
indexer_json = json.loads(pkgutil.get_data(
__name__, "templates/indexer.json"))
indexer_json = json.loads(pkgutil.get_data(__name__, "templates/indexer.json"))
indexer_json["name"] = self.INDEXER_NAME
indexer_json["skillsetName"] = self.SKILLSET_NAME
indexer_json["targetIndexName"] = self.INDEX_NAME
indexer_json["dataSourceName"] = self.DATASOURCE_NAME
indexer_json["schedule"] = {"interval": SCHEDULE_INTERVAL}
indexer_json["parameters"]["configuration"]["excludedFileNameExtensions"] = extension_to_exclude
indexer_json["parameters"]["configuration"][
"excludedFileNameExtensions"
] = extension_to_exclude
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}"
r = requests.put(endpoint + self.API_VERSION,
json.dumps(indexer_json), headers=headers)
r = requests.put(
endpoint + self.API_VERSION, json.dumps(indexer_json), headers=headers
)
print("indexer response", r.text)
r.raise_for_status()
print("created indexer", self.INDEXER_NAME, r)
@ -203,14 +236,14 @@ class GrokRestClient:
created
"""
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
endpoints = [
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}",
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexes/{self.INDEX_NAME}",
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/datasources/{self.DATASOURCE_NAME}",
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}"
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}",
]
for endpoint in endpoints:
@ -220,8 +253,8 @@ class GrokRestClient:
def run_indexer(self):
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}/run"
@ -239,7 +272,10 @@ class GrokRestClient:
request_json = self.get_indexer_status()
if request_json["status"] == "error":
raise RuntimeError("Indexer failed")
if request_json["lastResult"] and not request_json["lastResult"]["status"] == "inProgress":
if (
request_json["lastResult"]
and not request_json["lastResult"]["status"] == "inProgress"
):
print(request_json["lastResult"]["status"], self.INDEXER_NAME)
return request_json
@ -250,8 +286,8 @@ class GrokRestClient:
def get_indexer_status(self):
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}/status"
response = requests.get(endpoint + self.API_VERSION, headers=headers)

Просмотреть файл

@ -1,14 +1,17 @@
from genalog.generation.document import DocumentGenerator
from genalog.generation.document import DEFAULT_STYLE_COMBINATION
from genalog.generation.content import CompositeContent, ContentType
from genalog.degradation.degrader import Degrader, ImageState
from json import JSONEncoder
from tqdm import tqdm
import concurrent.futures
import timeit
import cv2
import os
import timeit
from json import JSONEncoder
import cv2
from tqdm import tqdm
from genalog.degradation.degrader import Degrader, ImageState
from genalog.generation.content import CompositeContent, ContentType
from genalog.generation.document import DEFAULT_STYLE_COMBINATION
from genalog.generation.document import DocumentGenerator
class ImageStateEncoder(JSONEncoder):
def default(self, obj):
@ -16,8 +19,15 @@ class ImageStateEncoder(JSONEncoder):
return obj.value
return JSONEncoder.default(self, obj)
class AnalogDocumentGeneration(object):
def __init__(self, template_path=None, styles=DEFAULT_STYLE_COMBINATION, degradations=[], resolution=300):
def __init__(
self,
template_path=None,
styles=DEFAULT_STYLE_COMBINATION,
degradations=[],
resolution=300,
):
self.doc_generator = DocumentGenerator(template_path=template_path)
self.doc_generator.set_styles_to_generate(styles)
self.degrader = Degrader(degradations)
@ -67,45 +77,71 @@ class AnalogDocumentGeneration(object):
cv2.imwrite(img_dst_path, dst)
return
def _divide_batches(a, batch_size):
for i in range(0, len(a), batch_size):
yield a[i: i + batch_size]
def _setup_folder(output_folder):
os.makedirs(os.path.join(output_folder, "img"), exist_ok=True)
def batch_img_generate(args):
input_files, output_folder, styles, degradations, template, resolution = args
generator = AnalogDocumentGeneration(styles=styles, degradations=degradations, resolution=resolution)
generator = AnalogDocumentGeneration(
styles=styles, degradations=degradations, resolution=resolution
)
for file in input_files:
generator.generate_img(file, template, target_folder=output_folder)
def _set_batch_generate_args(file_batches, output_folder, styles, degradations, template, resolution):
return list(map(
lambda batch:
(batch, output_folder, styles, degradations, template, resolution),
file_batches
))
def _set_batch_generate_args(
file_batches, output_folder, styles, degradations, template, resolution
):
return list(
map(
lambda batch: (
batch,
output_folder,
styles,
degradations,
template,
resolution,
),
file_batches,
)
)
def generate_dataset_multiprocess(
input_text_files, output_folder, styles, degradations, template,
resolution=300, batch_size=25
input_text_files,
output_folder,
styles,
degradations,
template,
resolution=300,
batch_size=25,
):
_setup_folder(output_folder)
print(f"Storing generated images in {output_folder}")
batches = list(_divide_batches(input_text_files, batch_size))
print(f"Splitting {len(input_text_files)} documents into {len(batches)} batches with size {batch_size}")
print(
f"Splitting {len(input_text_files)} documents into {len(batches)} batches with size {batch_size}"
)
batch_img_generate_args = _set_batch_generate_args(batches, output_folder, styles, degradations, template, resolution)
batch_img_generate_args = _set_batch_generate_args(
batches, output_folder, styles, degradations, template, resolution
)
# Default to the number of processors on the machine
start_time = timeit.default_timer()
with concurrent.futures.ProcessPoolExecutor() as executor:
batch_iterator = executor.map(batch_img_generate, batch_img_generate_args)
for _ in tqdm(batch_iterator, total=len(batch_img_generate_args)): # wrapping tqdm for progress report
for _ in tqdm(
batch_iterator, total=len(batch_img_generate_args)
): # wrapping tqdm for progress report
pass
elapsed = timeit.default_timer() - start_time
print(f"Time to generate {len(input_text_files)} documents: {elapsed:.3f} sec")

Просмотреть файл

@ -1 +0,0 @@

Просмотреть файл

@ -1,25 +1,36 @@
from genalog.text.preprocess import _is_spacing, tokenize
from Bio import pairwise2
import re
from Bio import pairwise2
from genalog.text.preprocess import _is_spacing, tokenize
# Configuration params for global sequence alignment algorithm (Needleman-Wunsch)
MATCH_REWARD = 1
GAP_PENALTY = -0.5
GAP_EXT_PENALTY = -0.5
MISMATCH_PENALTY = -0.5
GAP_CHAR = '@'
GAP_CHAR = "@"
ONE_ALIGNMENT_ONLY = False
SPACE_MISMATCH_PENALTY = .1
SPACE_MISMATCH_PENALTY = 0.1
def _join_char_list(alignment_tuple):
""" Post-process alignment results for unicode support """
gt_char_list, noise_char_list, score, start, end = alignment_tuple
return "".join(gt_char_list), "".join(noise_char_list), score, start, end
def _align_seg(gt, noise,
match_reward=MATCH_REWARD, mismatch_pen=MISMATCH_PENALTY,
gap_pen=GAP_PENALTY, gap_ext_pen=GAP_EXT_PENALTY, space_mismatch_penalty=SPACE_MISMATCH_PENALTY,
gap_char=GAP_CHAR, one_alignment_only=ONE_ALIGNMENT_ONLY):
def _align_seg(
gt,
noise,
match_reward=MATCH_REWARD,
mismatch_pen=MISMATCH_PENALTY,
gap_pen=GAP_PENALTY,
gap_ext_pen=GAP_EXT_PENALTY,
space_mismatch_penalty=SPACE_MISMATCH_PENALTY,
gap_char=GAP_CHAR,
one_alignment_only=ONE_ALIGNMENT_ONLY,
):
"""Wrapper function for Bio.pairwise2.align.globalms(), which
calls the sequence alignment algorithm (Needleman-Wunsch)
@ -47,6 +58,7 @@ def _align_seg(gt, noise,
...
]
"""
def match_reward_fn(x, y):
if x == y:
return match_reward
@ -55,12 +67,21 @@ def _align_seg(gt, noise,
return mismatch_pen - space_mismatch_penalty
else:
return mismatch_pen
# NOTE: Work-around to enable support full Unicode character set - passing string as a list of characters
alignments = pairwise2.align.globalcs(list(gt), list(noise), match_reward_fn,
gap_pen, gap_ext_pen, gap_char=[gap_char], one_alignment_only=ONE_ALIGNMENT_ONLY)
alignments = pairwise2.align.globalcs(
list(gt),
list(noise),
match_reward_fn,
gap_pen,
gap_ext_pen,
gap_char=[gap_char],
one_alignment_only=ONE_ALIGNMENT_ONLY,
)
# Alignment result is a list of char instead of string because of the work-around
return list(map(_join_char_list, alignments))
def _select_alignment_candidates(alignments, target_num_gt_tokens):
"""Return an alignment that contains the desired number
of ground truth tokens from a list of possible alignments
@ -111,11 +132,16 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens):
if num_aligned_gt_tokens == target_num_gt_tokens:
# Invariant 1
if len(aligned_gt) != len(aligned_noise):
raise ValueError(f"Aligned strings are not equal in length: \naligned_gt: '{aligned_gt}'\naligned_noise '{aligned_noise}'\n")
raise ValueError(
f"Aligned strings are not equal in length: \naligned_gt: '{aligned_gt}'\naligned_noise '{aligned_noise}'\n"
)
# Returns the FIRST candidate that satisfies the invariant
return alignment
raise ValueError(f"No alignment candidates with {target_num_gt_tokens} tokens. Total candidates: {len(alignments)}")
raise ValueError(
f"No alignment candidates with {target_num_gt_tokens} tokens. Total candidates: {len(alignments)}"
)
def align(gt, noise, gap_char=GAP_CHAR):
"""Align two text segments via sequence alignment algorithm
@ -144,7 +170,7 @@ def align(gt, noise, gap_char=GAP_CHAR):
"""
if not gt and not noise: # Both inputs are empty string
return '', ''
return "", ""
elif not gt: # Either is empty
return gap_char * len(noise), noise
elif not noise:
@ -153,11 +179,16 @@ def align(gt, noise, gap_char=GAP_CHAR):
num_gt_tokens = len(tokenize(gt))
alignments = _align_seg(gt, noise, gap_char=gap_char)
try:
aligned_gt, aligned_noise, _, _, _ = _select_alignment_candidates(alignments, num_gt_tokens)
aligned_gt, aligned_noise, _, _, _ = _select_alignment_candidates(
alignments, num_gt_tokens
)
except ValueError as e:
raise ValueError(f"Error with input strings '{gt}' and '{noise}': \n{str(e)}")
raise ValueError(
f"Error with input strings '{gt}' and '{noise}': \n{str(e)}"
)
return aligned_gt, aligned_noise
def _format_alignment(align1, align2):
"""Wrapper function for Bio.pairwise2.format_alignment()
@ -178,11 +209,14 @@ def _format_alignment(align1, align2):
New Yerk@is big.
"
"""
formatted_str = pairwise2.format_alignment(align1, align2, 0, 0, len(align1), full_sequences=True)
formatted_str = pairwise2.format_alignment(
align1, align2, 0, 0, len(align1), full_sequences=True
)
# Remove the "Score=0" from the str
formatted_str_no_score = formatted_str.replace("\n Score=0", "")
return formatted_str_no_score
def _find_token_start(s, index):
"""Find the position of the start of token
@ -198,13 +232,16 @@ def _find_token_start(s, index):
IndexError: if is out-of-bound index
"""
max_index = len(s) - 1
if len(s) == 0: raise ValueError("Cannot search in an empty string")
if index > max_index: raise IndexError(f"Out-of-bound index: {index} in string: {s}")
if len(s) == 0:
raise ValueError("Cannot search in an empty string")
if index > max_index:
raise IndexError(f"Out-of-bound index: {index} in string: {s}")
while index < max_index and _is_spacing(s[index]):
index += 1
return index
def _find_token_end(s, index):
"""Find the position of the end of a token
@ -224,13 +261,16 @@ def _find_token_end(s, index):
IndexError: if is out-of-bound index
"""
max_index = len(s) - 1
if len(s) == 0: raise ValueError("Cannot search in an empty string")
if index > max_index: raise IndexError(f"Out-of-bound index: {index} in string: {s}")
if len(s) == 0:
raise ValueError("Cannot search in an empty string")
if index > max_index:
raise IndexError(f"Out-of-bound index: {index} in string: {s}")
while index < max_index and not _is_spacing(s[index]):
index += 1
return index
def _find_next_token(s, start):
"""Return the start and end index of a token in a string
@ -250,6 +290,7 @@ def _find_next_token(s, start):
token_end = _find_token_end(s, token_start)
return token_start, token_end
def _is_valid_token(token, gap_char=GAP_CHAR):
"""Returns true if token is valid (i.e. compose of non-gap characters)
Invalid tokens are
@ -269,9 +310,12 @@ def _is_valid_token(token, gap_char=GAP_CHAR):
bool-- True if is a valid token, false otherwise
"""
# Matches multiples of 'gap_char' that are padded with whitespace characters on either end
INVALID_TOKEN_REGEX = rf'^\s*{re.escape(gap_char)}*\s*$' # Escape special regex chars
INVALID_TOKEN_REGEX = (
rf"^\s*{re.escape(gap_char)}*\s*$" # Escape special regex chars
)
return not re.match(INVALID_TOKEN_REGEX, token)
def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
"""Parse alignment to pair ground truth tokens with noise tokens
@ -352,8 +396,8 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
total_noise_tokens = len(tokenize(aligned_noise))
# Initialization
aligned_gt += ' ' # add whitespace padding to prevent ptr overflow
aligned_noise += ' ' # add whitespace padding to prevent ptr overflow
aligned_gt += " " # add whitespace padding to prevent ptr overflow
aligned_noise += " " # add whitespace padding to prevent ptr overflow
tk_index_gt = tk_index_noise = 0
tk_start_gt, tk_end_gt = _find_next_token(aligned_gt, 0)
tk_start_noise, tk_end_noise = _find_next_token(aligned_noise, 0)
@ -364,8 +408,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
# If both tokens are aligned (one-to-one case)
if tk_end_gt == tk_end_noise:
# if both gt_token and noise_token are valid (missing token case)
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
if _is_valid_token(
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
) and _is_valid_token(
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
):
# register the index of these tokens in the gt_to_noise_mapping
index_row = gt_to_noise_mapping[tk_index_gt]
index_row.append(tk_index_noise)
@ -381,8 +428,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
elif tk_end_gt < tk_end_noise:
while tk_end_gt < tk_end_noise:
# if both gt_token and noise_token are valid (missing token case)
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
if _is_valid_token(
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
) and _is_valid_token(
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
):
# register the index of these tokens in the gt_to_noise_mapping
index_row = gt_to_noise_mapping[tk_index_gt]
index_row.append(tk_index_noise)
@ -397,8 +447,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
else:
while tk_end_gt > tk_end_noise:
# if both gt_token and noise_token are valid (missing token case)
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
if _is_valid_token(
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
) and _is_valid_token(
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
):
# register the index of these token in the gt_to_noise mapping
index_row = gt_to_noise_mapping[tk_index_gt]
index_row.append(tk_index_noise)
@ -406,7 +459,9 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
index_row = noise_to_gt_mapping[tk_index_noise]
index_row.append(tk_index_gt)
# Find the next gt_token
tk_start_noise, tk_end_noise = _find_next_token(aligned_noise, tk_end_noise)
tk_start_noise, tk_end_noise = _find_next_token(
aligned_noise, tk_end_noise
)
# Increment index
tk_index_noise += 1

Просмотреть файл

@ -14,15 +14,17 @@
"""
import itertools
from collections import Counter
from genalog.text import preprocess, alignment
from genalog.text.lcs import LCS
from genalog.text import alignment, preprocess
from genalog.text.alignment import GAP_CHAR
from genalog.text.lcs import LCS
# The recursively portion of the algorithm will run on
# segments longer than this value to find anchor points in
# the longer segment (to break it up further).
MAX_ALIGN_SEGMENT_LENGTH = 100 # in characters length
def get_unique_words(tokens, case_sensitive=False):
"""Get a set of unique words from a Counter dictionary of word occurrences
@ -44,6 +46,7 @@ def get_unique_words(tokens, case_sensitive=False):
word_count = Counter(tokens_lowercase)
return {tk for tk in tokens if word_count[tk.lower()] < 2}
def segment_len(tokens):
"""Get length of the segment
@ -54,6 +57,7 @@ def segment_len(tokens):
"""
return sum(map(len, tokens))
def get_word_map(unique_words, src_tokens):
"""Arrange the set of unique words by the order they original appear in the text
@ -73,6 +77,7 @@ def get_word_map(unique_words, src_tokens):
word_map.sort(key=lambda x: x[1]) # Re-arrange order by the index
return word_map
def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
"""Find the location of anchor words in both the gt and ocr text.
Anchor words are location where we can split both the source gt
@ -129,18 +134,29 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
anchor_words = lcs_words.intersection(unique_words_common)
# 6. Filter the unique words to keep the anchor words ONLY
anchor_map_gt = list(filter(
anchor_map_gt = list(
filter(
# This is a list of (unique_word, unique_word_index)
lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_gt
))
anchor_map_ocr = list(filter(
lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_ocr
))
lambda word_coordinate: word_coordinate[0] in anchor_words,
unique_word_map_gt,
)
)
anchor_map_ocr = list(
filter(
lambda word_coordinate: word_coordinate[0] in anchor_words,
unique_word_map_ocr,
)
)
return anchor_map_gt, anchor_map_ocr
def find_anchor_recur(gt_tokens, ocr_tokens,
start_pos_gt=0, start_pos_ocr=0,
max_seg_length=MAX_ALIGN_SEGMENT_LENGTH):
def find_anchor_recur(
gt_tokens,
ocr_tokens,
start_pos_gt=0,
start_pos_ocr=0,
max_seg_length=MAX_ALIGN_SEGMENT_LENGTH,
):
"""Recursively find anchor positions in the gt and ocr text
Arguments:
@ -186,13 +202,22 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt]
ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr]
# 4. Loop through each segment
for gt_seg, ocr_seg, gt_start, ocr_start in zip(gt_segments, ocr_segments, seg_start_gt, seg_start_ocr):
if segment_len(gt_seg) > max_seg_length or segment_len(ocr_seg) > max_seg_length:
for gt_seg, ocr_seg, gt_start, ocr_start in zip(
gt_segments, ocr_segments, seg_start_gt, seg_start_ocr
):
if (
segment_len(gt_seg) > max_seg_length
or segment_len(ocr_seg) > max_seg_length
):
# recur on the segment in between the two anchors.
# We assume the first token in the segment is an anchor word
gt_anchors, ocr_anchors = find_anchor_recur(gt_seg[1:], ocr_seg[1:],
start_pos_gt=gt_start + 1, start_pos_ocr=ocr_start + 1,
max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = find_anchor_recur(
gt_seg[1:],
ocr_seg[1:],
start_pos_gt=gt_start + 1,
start_pos_ocr=ocr_start + 1,
max_seg_length=max_seg_length,
)
# shift the token indices
# (these are indices of a subsequence and does not reflect true position in the source sequence)
gt_anchors = set(map(lambda x: x + start_pos_gt, gt_anchors))
@ -203,6 +228,7 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
return sorted(output_gt_anchors), sorted(output_ocr_anchors)
def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_LENGTH):
"""A faster alignment scheme of two text segments. This method first
breaks the strings into smaller segments with anchor words.
@ -248,11 +274,17 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
ocr_tokens = preprocess.tokenize(ocr)
# 1. Find anchor positions
gt_anchors, ocr_anchors = find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = find_anchor_recur(
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
)
# 2. Split into segments
start_n_end_gt = zip(itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None]))
start_n_end_ocr = zip(itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None]))
start_n_end_gt = zip(
itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None])
)
start_n_end_ocr = zip(
itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None])
)
gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt]
ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr]
@ -263,13 +295,15 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
gt_segment = preprocess.join_tokens(gt_segment)
noisy_segment = preprocess.join_tokens(noisy_segment)
# Run alignment algorithm
aligned_seg_gt, aligned_seg_ocr = alignment.align(gt_segment, noisy_segment, gap_char=gap_char)
aligned_seg_gt, aligned_seg_ocr = alignment.align(
gt_segment, noisy_segment, gap_char=gap_char
)
if aligned_seg_gt and aligned_seg_ocr: # if not empty string ""
aligned_segments_gt.append(aligned_seg_gt)
aligned_segments_ocr.append(aligned_seg_ocr)
# Stitch all segments together
aligned_gt = ' '.join(aligned_segments_gt)
aligned_noise = ' '.join(aligned_segments_ocr)
aligned_gt = " ".join(aligned_segments_gt)
aligned_noise = " ".join(aligned_segments_ocr)
return aligned_gt, aligned_noise

Просмотреть файл

@ -34,21 +34,22 @@ example usage
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
--train_subset
"""
import itertools
import difflib
import argparse
import concurrent.futures
import difflib
import itertools
import json
import os
import sys
import timeit
import concurrent.futures
from tqdm import tqdm
from genalog.text import ner_label, ner_label, alignment
from genalog.text import alignment, ner_label
EMPTY_SENTENCE_SENTINEL = "<<<<EMPTY_OCR_SENTENCE>>>>"
EMPTY_SENTENCE_SENTINEL_NER_LABEL = "O"
def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens):
"""
propagate_labels_sentences propagates clean labels for clean tokens to ocr tokens and splits ocr tokens into sentences
@ -73,9 +74,18 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
# Ensure equal number of tokens in both clean_tokens and clean_sentences
merged_sentences = list(itertools.chain(*clean_sentences))
if merged_sentences != clean_tokens:
delta = "\n".join(difflib.unified_diff(merged_sentences, clean_tokens, fromfile='merged_clean_sentences', tofile="clean_tokens"))
raise ValueError(f"Inconsistent tokens. " +
f"Delta between clean_text and clean_labels:\n{delta}")
delta = "\n".join(
difflib.unified_diff(
merged_sentences,
clean_tokens,
fromfile="merged_clean_sentences",
tofile="clean_tokens",
)
)
raise ValueError(
"Inconsistent tokens. "
+ f"Delta between clean_text and clean_labels:\n{delta}"
)
# Ensure that there's OCR result
if len(ocr_tokens) == 0:
raise ValueError("Empty OCR tokens.")
@ -83,14 +93,18 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
raise ValueError("Empty clean tokens.")
# 1. Propagate labels + alig
ocr_labels, aligned_clean, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(clean_labels, clean_tokens, ocr_tokens)
ocr_labels, aligned_clean, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(
clean_labels, clean_tokens, ocr_tokens
)
# 2. Parse alignment to get mapping
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(aligned_clean, aligned_ocr, gap_char=gap_char)
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(
aligned_clean, aligned_ocr, gap_char=gap_char
)
# 3. Find sentence breaks in clean text sentences
gt_to_ocr_mapping_is_empty = [len(mapping) == 0 for mapping in gt_to_ocr_mapping]
gt_to_ocr_mapping_is_empty_reverse = gt_to_ocr_mapping_is_empty[::-1]
sentence_index = []
sentence_token_counts = 0
for sentence in clean_sentences:
@ -135,11 +149,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
ocr_labels_sentences.append(ocr_sentence_labels)
return ocr_text_sentences, ocr_labels_sentences
def get_sentences_from_iob_format(iob_format_str):
sentences = []
sentence = []
for line in iob_format_str:
if line.strip() == '': # if line is empty (sentence separator)
if line.strip() == "": # if line is empty (sentence separator)
sentences.append(sentence)
sentence = []
else:
@ -149,44 +164,78 @@ def get_sentences_from_iob_format(iob_format_str):
# filter any empty sentences
return list(filter(lambda sentence: len(sentence) > 0, sentences))
def propagate_labels_sentence_single_file(arg):
clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext, input_filename = arg
clean_labels_file = os.path.join(clean_labels_dir, input_filename).replace(clean_label_ext, ".txt")
def propagate_labels_sentence_single_file(arg):
(
clean_labels_dir,
output_text_dir,
output_labels_dir,
clean_label_ext,
input_filename,
) = arg
clean_labels_file = os.path.join(clean_labels_dir, input_filename).replace(
clean_label_ext, ".txt"
)
ocr_text_file = os.path.join(output_text_dir, input_filename)
ocr_labels_file = os.path.join(output_labels_dir, input_filename)
if not os.path.exists(clean_labels_file):
print(f"Warning: missing clean label file '{clean_labels_file}'. Please check file corruption. Skipping this file index...")
print(
f"Warning: missing clean label file '{clean_labels_file}'. Please check file corruption. Skipping this file index..."
)
return
elif not os.path.exists(ocr_text_file):
print(f"Warning: missing ocr text file '{ocr_text_file}'. Please check file corruption. Skipping this file index...")
print(
f"Warning: missing ocr text file '{ocr_text_file}'. Please check file corruption. Skipping this file index..."
)
return
else:
with open(clean_labels_file, 'r', encoding='utf-8') as clf:
with open(clean_labels_file, "r", encoding="utf-8") as clf:
tokens_labels_str = clf.readlines()
clean_tokens = [line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2]
clean_labels = [line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2]
clean_tokens = [
line.split()[0].strip()
for line in tokens_labels_str
if len(line.split()) == 2
]
clean_labels = [
line.split()[1].strip()
for line in tokens_labels_str
if len(line.split()) == 2
]
clean_sentences = get_sentences_from_iob_format(tokens_labels_str)
# read ocr tokens
with open(ocr_text_file, 'r', encoding='utf-8') as otf:
ocr_text_str = ' '.join(otf.readlines())
ocr_tokens = [token.strip() for token in ocr_text_str.split()] # already tokenized in data
with open(ocr_text_file, "r", encoding="utf-8") as otf:
ocr_text_str = " ".join(otf.readlines())
ocr_tokens = [
token.strip() for token in ocr_text_str.split()
] # already tokenized in data
try:
ocr_tokens_sentences, ocr_labels_sentences = propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
ocr_tokens_sentences, ocr_labels_sentences = propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
except Exception as e:
print(f"\nWarning: error processing '{input_filename}': {str(e)}.\nSkipping this file...")
print(
f"\nWarning: error processing '{input_filename}': {str(e)}.\nSkipping this file..."
)
return
# Write result to file
with open(ocr_labels_file, 'w', encoding="utf-8") as olf:
for ocr_tokens, ocr_labels in zip(ocr_tokens_sentences, ocr_labels_sentences):
with open(ocr_labels_file, "w", encoding="utf-8") as olf:
for ocr_tokens, ocr_labels in zip(
ocr_tokens_sentences, ocr_labels_sentences
):
if len(ocr_tokens) == 0: # if empty OCR sentences
olf.write(f'{EMPTY_SENTENCE_SENTINEL}\t{EMPTY_SENTENCE_SENTINEL_NER_LABEL}\n')
olf.write(
f"{EMPTY_SENTENCE_SENTINEL}\t{EMPTY_SENTENCE_SENTINEL_NER_LABEL}\n"
)
else:
for token, label in zip(ocr_tokens, ocr_labels):
olf.write(f"{token}\t{label}\n")
olf.write('\n')
olf.write("\n")
def propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext):
def propagate_labels_sentences_multiprocess(
clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext
):
"""
propagate_labels_sentences_all_files propagates labels and sentences for all files in dataset
@ -204,16 +253,24 @@ def propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, o
file extension of the clean_labels
"""
clean_label_files = os.listdir(clean_labels_dir)
args = list(map(
lambda clean_label_filename:
(clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext, clean_label_filename),
clean_label_files
))
args = list(
map(
lambda clean_label_filename: (
clean_labels_dir,
output_text_dir,
output_labels_dir,
clean_label_ext,
clean_label_filename,
),
clean_label_files,
)
)
with concurrent.futures.ProcessPoolExecutor() as executor:
iterator = executor.map(propagate_labels_sentence_single_file, args)
for _ in tqdm(iterator, total=len(args)): # wrapping tqdm for progress report
pass
def extract_ocr_text(input_file, output_file):
"""
extract_ocr_text from GROK json
@ -227,16 +284,17 @@ def extract_ocr_text(input_file, output_file):
"""
out_dir = os.path.dirname(output_file)
in_file_name = os.path.basename(input_file)
file_pre = in_file_name.split('_')[-1].split('.')[0]
output_file_name = '{}.txt'.format(file_pre)
file_pre = in_file_name.split("_")[-1].split(".")[0]
output_file_name = "{}.txt".format(file_pre)
output_file = os.path.join(out_dir, output_file_name)
with open(input_file, 'r', encoding='utf-8') as fin:
with open(input_file, "r", encoding="utf-8") as fin:
json_data = json.load(fin)
json_dict = json_data[0]
text = json_dict['text']
with open(output_file, 'wb') as fout:
text = json_dict["text"]
with open(output_file, "wb") as fout:
fout.write(text.encode("utf-8"))
def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
"""
check_n_sentences prints file name if number of sentences is different in clean and OCR files
@ -253,21 +311,23 @@ def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
text_files = os.listdir(output_labels_dir)
skip_files = []
for text_filename in tqdm(text_files):
clean_labels_file = os.path.join(clean_labels_dir, text_filename).replace(".txt", clean_label_ext)
clean_labels_file = os.path.join(clean_labels_dir, text_filename).replace(
".txt", clean_label_ext
)
ocr_labels_file = os.path.join(output_labels_dir, text_filename)
remove_first_line(clean_labels_file, clean_labels_file)
remove_first_line(ocr_labels_file, ocr_labels_file)
remove_last_line(clean_labels_file, clean_labels_file)
remove_last_line(ocr_labels_file, ocr_labels_file)
with open(clean_labels_file, 'r', encoding='utf-8') as lf:
with open(clean_labels_file, "r", encoding="utf-8") as lf:
clean_tokens_labels = lf.readlines()
with open(ocr_labels_file, 'r', encoding='utf-8') as of:
with open(ocr_labels_file, "r", encoding="utf-8") as of:
ocr_tokens_labels = of.readlines()
error = False
n_clean_sentences = 0
nl = False
for line in clean_tokens_labels:
if line == '\n':
if line == "\n":
if nl is True:
error = True
else:
@ -278,7 +338,7 @@ def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
n_ocr_sentences = 0
nl = False
for line in ocr_tokens_labels:
if line == '\n':
if line == "\n":
if nl is True:
error = True
else:
@ -287,11 +347,14 @@ def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
else:
nl = False
if error or n_ocr_sentences != n_clean_sentences:
print(f"Warning: Inconsistent numbers of sentences in '{text_filename}''." +
f"clean_sentences to ocr_sentences: {n_clean_sentences}:{n_ocr_sentences}")
print(
f"Warning: Inconsistent numbers of sentences in '{text_filename}''."
+ f"clean_sentences to ocr_sentences: {n_clean_sentences}:{n_ocr_sentences}"
)
skip_files.append(text_filename)
return skip_files
def remove_first_line(input_file, output_file):
"""
remove_first_line from files (some clean CoNLL files have an empty first line)
@ -303,13 +366,14 @@ def remove_first_line(input_file, output_file):
output_file : str
output file path
"""
with open(input_file, 'r', encoding='utf-8') as in_f:
with open(input_file, "r", encoding="utf-8") as in_f:
lines = in_f.readlines()
if len(lines) > 1 and lines[0].strip() == '':
if len(lines) > 1 and lines[0].strip() == "":
# the clean CoNLL formatted files had a newline as the first line
with open(output_file, 'w', encoding='utf-8') as out_f:
with open(output_file, "w", encoding="utf-8") as out_f:
out_f.writelines(lines[1:])
def remove_last_line(input_file, output_file):
"""
remove_last_line from files (some clean CoNLL files have an empty last line)
@ -322,12 +386,13 @@ def remove_last_line(input_file, output_file):
output file path
"""
with open(input_file, 'r', encoding='utf-8') as in_f:
with open(input_file, "r", encoding="utf-8") as in_f:
lines = in_f.readlines()
if len(lines) > 1 and lines[-1].strip() == '':
with open(output_file, 'w', encoding='utf-8') as out_f:
if len(lines) > 1 and lines[-1].strip() == "":
with open(output_file, "w", encoding="utf-8") as out_f:
out_f.writelines(lines[:-1])
def for_all_files(input_dir, output_dir, func):
"""
for_all_files will apply function to every file in a director
@ -347,25 +412,34 @@ def for_all_files(input_dir, output_dir, func):
output_file = os.path.join(output_dir, text_filename)
func(input_file, output_file)
def main(args):
if not args.train_subset and not args.test_subset:
subsets = ['train', 'test']
subsets = ["train", "test"]
else:
subsets = []
if args.train_subset:
subsets.append('train')
subsets.append("train")
if args.test_subset:
subsets.append('test')
subsets.append("test")
for subset in subsets:
print("Processing {} subset...".format(subset))
clean_labels_dir = os.path.join(args.base_folder, args.gt_folder, subset,'clean_labels')
ocr_json_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr')
clean_labels_dir = os.path.join(
args.base_folder, args.gt_folder, subset, "clean_labels"
)
ocr_json_dir = os.path.join(
args.base_folder, args.degraded_folder, subset, "ocr"
)
output_text_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr_text')
output_labels_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr_labels')
output_text_dir = os.path.join(
args.base_folder, args.degraded_folder, subset, "ocr_text"
)
output_labels_dir = os.path.join(
args.base_folder, args.degraded_folder, subset, "ocr_labels"
)
# remove first empty line of labels file, if exists
for_all_files(clean_labels_dir, clean_labels_dir, remove_first_line)
@ -380,21 +454,50 @@ def main(args):
os.mkdir(output_labels_dir)
# make ocr labels files by propagating clean labels to ocr_text and creating files in ocr_labels
propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, output_labels_dir, args.clean_label_ext)
propagate_labels_sentences_multiprocess(
clean_labels_dir, output_text_dir, output_labels_dir, args.clean_label_ext
)
print("Validating number of sentences in gt and ocr labels")
check_n_sentences(clean_labels_dir, output_labels_dir, args.clean_label_ext) # check number of sentences and make sure same; print anomaly files
check_n_sentences(
clean_labels_dir, output_labels_dir, args.clean_label_ext
) # check number of sentences and make sure same; print anomaly files
def create_parser():
parser = argparse.ArgumentParser()
parser.add_argument("base_folder", help="base directory containing the collection of dataset")
parser.add_argument("degraded_folder", help="directory containing train and test subset for degradation")
parser.add_argument("--gt_folder", type=str, default="shared", help="directory containing the ground truth")
parser.add_argument("--clean_label_ext", type=str, default=".txt", help="file extension of the clean_labels files")
parser.add_argument('--train_subset', help="include if only train folder should be processed", action='store_true')
parser.add_argument('--test_subset', help="include if only test folder should be processed", action='store_true')
parser.add_argument(
"base_folder", help="base directory containing the collection of dataset"
)
parser.add_argument(
"degraded_folder",
help="directory containing train and test subset for degradation",
)
parser.add_argument(
"--gt_folder",
type=str,
default="shared",
help="directory containing the ground truth",
)
parser.add_argument(
"--clean_label_ext",
type=str,
default=".txt",
help="file extension of the clean_labels files",
)
parser.add_argument(
"--train_subset",
help="include if only train folder should be processed",
action="store_true",
)
parser.add_argument(
"--test_subset",
help="include if only test folder should be processed",
action="store_true",
)
return parser
if __name__ == '__main__':
if __name__ == "__main__":
start = timeit.default_timer()
parser = create_parser()
args = parser.parse_args()

Просмотреть файл

@ -1,7 +1,6 @@
from genalog.text import preprocess
class LCS():
class LCS:
""" Compute the Longest Common Subsequence (LCS) of two given string."""
def __init__(self, str_m, str_n):
self.str_m_len = len(str_m)
self.str_n_len = len(str_n)

Просмотреть файл

@ -1,8 +1,9 @@
from genalog.text import alignment, anchor
from genalog.text import preprocess
import itertools
import re
import string
import itertools
from genalog.text import alignment, anchor
from genalog.text import preprocess
# Both regex below has the following behavior:
# 1. whitespace-tolerant at both ends of the string
@ -10,34 +11,42 @@ import itertools
# For example, given a label 'B-PLACE'
# Group 1 (denoted by \1): Label Indicator (B-)
# Group 2 (denoted by \2): Label Name (PLACE)
MULTI_TOKEN_BEGIN_LABEL_REGEX = r'^\s*(B-)([a-z|A-Z]+)\s*$'
MULTI_TOKEN_INSIDE_LABEL_REGEX = r'^\s*(I-)([a-z|A-Z]+)\s*$'
MULTI_TOKEN_LABEL_REGEX = r'^\s*([B|I]-)([a-z|A-Z]+)\s*'
MULTI_TOKEN_BEGIN_LABEL_REGEX = r"^\s*(B-)([a-z|A-Z]+)\s*$"
MULTI_TOKEN_INSIDE_LABEL_REGEX = r"^\s*(I-)([a-z|A-Z]+)\s*$"
MULTI_TOKEN_LABEL_REGEX = r"^\s*([B|I]-)([a-z|A-Z]+)\s*"
# To avoid confusion in the Python interpreter,
# gap char should not be any of the following special characters
SPECIAL_CHAR = set(" \t\n'\x0b''\x0c''\r'") # Notice space characters (' ', '\t', '\n') are in this set.
SPECIAL_CHAR = set(
" \t\n'\x0b''\x0c''\r'"
) # Notice space characters (' ', '\t', '\n') are in this set.
GAP_CHAR_SET = set(string.printable).difference(SPECIAL_CHAR)
# GAP_CHAR_SET = '!"#$%&()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~'
class GapCharError(Exception):
pass
def _is_begin_label(label):
""" Return true if the NER label is a begin label (eg. B-PLACE) """
return re.match(MULTI_TOKEN_BEGIN_LABEL_REGEX, label) != None
return re.match(MULTI_TOKEN_BEGIN_LABEL_REGEX, label) is not None
def _is_inside_label(label):
""" Return true if the NER label is an inside label (eg. I-PLACE) """
return re.match(MULTI_TOKEN_INSIDE_LABEL_REGEX, label) != None
return re.match(MULTI_TOKEN_INSIDE_LABEL_REGEX, label) is not None
def _is_multi_token_label(label):
""" Return true if the NER label is a multi token label (eg. B-PLACE, I-PLACE) """
return re.match(MULTI_TOKEN_LABEL_REGEX, label) != None
return re.match(MULTI_TOKEN_LABEL_REGEX, label) is not None
def _clean_multi_token_label(label):
""" Rid the multi-token-labels of whitespaces"""
return re.sub(MULTI_TOKEN_LABEL_REGEX, r'\1\2', label)
return re.sub(MULTI_TOKEN_LABEL_REGEX, r"\1\2", label)
def _convert_to_begin_label(label):
"""Convert an inside label, or I-label, (ex. I-PLACE) to a begin label, or B-Label, (ex. B-PLACE)
@ -50,9 +59,10 @@ def _convert_to_begin_label(label):
"""
if _is_inside_label(label):
# Replace the Label Indicator to 'B-'(\1) and keep the Label Name (\2)
return re.sub(MULTI_TOKEN_INSIDE_LABEL_REGEX, r'B-\2', label)
return re.sub(MULTI_TOKEN_INSIDE_LABEL_REGEX, r"B-\2", label)
return label
def _convert_to_inside_label(label):
"""Convert a begin label, or B-label, (ex. B-PLACE) to an inside label, or I-Label, (ex. B-PLACE)
@ -64,9 +74,10 @@ def _convert_to_inside_label(label):
"""
if _is_begin_label(label):
# Replace the Label Indicator to 'I-'(\1) and keep the Label Name (\2)
return re.sub(MULTI_TOKEN_BEGIN_LABEL_REGEX, r'I-\2', label)
return re.sub(MULTI_TOKEN_BEGIN_LABEL_REGEX, r"I-\2", label)
return label
def _is_missing_begin_label(begin_label, inside_label):
"""Validate a inside label given an begin label
@ -93,6 +104,7 @@ def _is_missing_begin_label(begin_label, inside_label):
else:
return True
def correct_ner_labels(labels):
"""Correct the given list of labels for the following case:
@ -119,6 +131,7 @@ def correct_ner_labels(labels):
cur_begin_label = ""
return labels
def _select_from_multiple_ner_labels(label_indices):
"""Private method to select a NER label from a list of candidate
@ -147,6 +160,7 @@ def _select_from_multiple_ner_labels(label_indices):
# TODO: may need a more sophisticated way to select from multiple NER labels
return label_indices[0]
def _find_gap_char_candidates(gt_tokens, ocr_tokens):
"""Find a set of suitable GAP_CHARs based not in the set of input characters
@ -159,12 +173,15 @@ def _find_gap_char_candidates(gt_tokens, ocr_tokens):
1. the set of suitable GAP_CHARs
2. the set of input characters
"""
input_char_set = set(''.join(itertools.chain(gt_tokens, ocr_tokens))) # The set of input characters
input_char_set = set(
"".join(itertools.chain(gt_tokens, ocr_tokens))
) # The set of input characters
gap_char_set = GAP_CHAR_SET # The set of possible GAP_CHARs
# Find a set of gap_char that is NOT in the set of input characters
gap_char_candidates = gap_char_set.difference(input_char_set)
return gap_char_candidates, input_char_set
def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
"""Propagate NER label for ground truth tokens to to ocr tokens.
@ -199,20 +216,29 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
`gap_char` is the char used to alignment for inserting gaps
"""
# Find a set of suitable GAP_CHAR based not in the set of input characters
gap_char_candidates, input_char_set = _find_gap_char_candidates(gt_tokens, ocr_tokens)
gap_char_candidates, input_char_set = _find_gap_char_candidates(
gt_tokens, ocr_tokens
)
if len(gap_char_candidates) == 0:
raise GapCharError("Exhausted all possible GAP_CHAR candidates for alignment." +
" Consider reducing cardinality of the input character set.\n" +
f"The set of possible GAP_CHAR candidates is: '{''.join(sorted(GAP_CHAR_SET))}'\n" +
f"The set of input character is: '{''.join(sorted(input_char_set))}'")
raise GapCharError(
"Exhausted all possible GAP_CHAR candidates for alignment."
+ " Consider reducing cardinality of the input character set.\n"
+ f"The set of possible GAP_CHAR candidates is: '{''.join(sorted(GAP_CHAR_SET))}'\n"
+ f"The set of input character is: '{''.join(sorted(input_char_set))}'"
)
else:
if alignment.GAP_CHAR in gap_char_candidates:
gap_char = alignment.GAP_CHAR # prefer to use default GAP_CHAR
else:
gap_char = gap_char_candidates.pop()
return _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char, use_anchor=use_anchor)
return _propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char, use_anchor=use_anchor
)
def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment.GAP_CHAR, use_anchor=True):
def _propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens, gap_char=alignment.GAP_CHAR, use_anchor=True
):
"""Propagate NER label for ground truth tokens to to ocr tokens. Low level implementation
NOTE: that `gt_tokens` and `ocr_tokens` MUST NOT contain invalid tokens.
@ -342,31 +368,43 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
# Sanity check:
if len(gt_tokens) != len(gt_labels):
raise ValueError(f"Unequal number of gt_tokens ({len(gt_tokens)})" +
f"to that of gt_labels ({len(gt_labels)})")
raise ValueError(
f"Unequal number of gt_tokens ({len(gt_tokens)})"
+ f"to that of gt_labels ({len(gt_labels)})"
)
for tk in (gt_tokens + ocr_tokens):
for tk in gt_tokens + ocr_tokens:
if len(preprocess.tokenize(tk)) > 1:
raise ValueError(f"Invalid token '{tk}'. Tokens must be atomic.")
if not alignment._is_valid_token(tk, gap_char=gap_char):
if re.search(rf'{re.escape(gap_char)}+', tk): # Escape special regex chars
raise GapCharError(f"Invalid token '{tk}'. Tokens cannot be a chain repetition of the GAP_CHAR '{gap_char}'")
if re.search(rf"{re.escape(gap_char)}+", tk): # Escape special regex chars
raise GapCharError(
f"Invalid token '{tk}'. Tokens cannot be a chain repetition of the GAP_CHAR '{gap_char}'"
)
else:
raise ValueError(f"Invalid token '{tk}'. Tokens cannot be an empty string or a mix of space characters (spaces, tabs, newlines)")
raise ValueError(
f"Invalid token '{tk}'. Tokens cannot be an empty string or a mix of space characters (spaces, tabs, newlines)"
)
# Stitch tokens together into one string for alignment
gt_txt = preprocess.join_tokens(gt_tokens)
ocr_txt = preprocess.join_tokens(ocr_tokens)
# Align the ground truth and ocr text first
if use_anchor:
aligned_gt, aligned_ocr = anchor.align_w_anchor(gt_txt, ocr_txt, gap_char=gap_char)
aligned_gt, aligned_ocr = anchor.align_w_anchor(
gt_txt, ocr_txt, gap_char=gap_char
)
else:
aligned_gt, aligned_ocr = alignment.align(gt_txt, ocr_txt, gap_char=gap_char)
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(aligned_gt, aligned_ocr, gap_char=gap_char)
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(
aligned_gt, aligned_ocr, gap_char=gap_char
)
# Check invariant
if len(gt_to_ocr_mapping) != len(gt_tokens):
raise ValueError(f"Alignment modified number of gt_tokens. aligned_gt_tokens to gt_tokens: " +
f"{len(gt_to_ocr_mapping)}:{len(gt_tokens)}. \nCheck alignment.parse_alignment().")
raise ValueError(
"Alignment modified number of gt_tokens. aligned_gt_tokens to gt_tokens: "
+ f"{len(gt_to_ocr_mapping)}:{len(gt_tokens)}. \nCheck alignment.parse_alignment()."
)
ocr_labels = []
# STEP 1: naively propagate NER label based on text-alignment
@ -374,7 +412,9 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
# if is not mapping to missing a token (Case 4)
if ocr_to_gt_token_relationship:
# Find the corresponding gt_token it is aligned to
ner_label_index = _select_from_multiple_ner_labels(ocr_to_gt_token_relationship)
ner_label_index = _select_from_multiple_ner_labels(
ocr_to_gt_token_relationship
)
# Get the NER label for that particular gt_token
ocr_labels.append(gt_labels[ner_label_index])
@ -429,18 +469,26 @@ def format_labels(tokens, labels, label_top=True):
len_diff = abs(len(label) - len(token))
# Add padding spaces for whichever is shorter
if len(label) > len(token):
formatted_labels += label + ' '
formatted_tokens += token + ' '*len_diff + ' '
formatted_labels += label + " "
formatted_tokens += token + " " * len_diff + " "
else:
formatted_labels += label + ' '*len_diff + ' '
formatted_tokens += token + ' '
formatted_labels += label + " " * len_diff + " "
formatted_tokens += token + " "
if label_top:
return formatted_labels + '\n' + formatted_tokens + '\n'
return formatted_labels + "\n" + formatted_tokens + "\n"
else:
return formatted_tokens + '\n' + formatted_labels + '\n'
return formatted_tokens + "\n" + formatted_labels + "\n"
def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \
aligned_gt, aligned_ocr, show_alignment=True):
def format_label_propagation(
gt_tokens,
gt_labels,
ocr_tokens,
ocr_labels,
aligned_gt,
aligned_ocr,
show_alignment=True,
):
"""Format label propagation for display
Arguments:
@ -483,4 +531,3 @@ def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \
return gt_label_str + alignment_str + label_str
else:
return gt_label_str + label_str

Просмотреть файл

@ -1,8 +1,9 @@
import re
END_OF_TOKEN = {' ', '\t', '\n'}
END_OF_TOKEN = {" ", "\t", "\n"}
NON_ASCII_REPLACEMENT = "_"
def remove_non_ascii(token, replacement=NON_ASCII_REPLACEMENT):
"""Remove non ascii characters in a token
@ -15,12 +16,13 @@ def remove_non_ascii(token, replacement=NON_ASCII_REPLACEMENT):
str -- a word token with non-ASCII characters removed
"""
# Remove non-ASCII characters in the token
ascii_token = str(token.encode('utf-8').decode('ascii', 'ignore'))
ascii_token = str(token.encode("utf-8").decode("ascii", "ignore"))
# If token becomes an empty string as a result
if len(ascii_token) == 0 and len(token) != 0:
ascii_token = replacement # replace with a default character
return ascii_token
def tokenize(s):
"""Tokenize string
@ -33,6 +35,7 @@ def tokenize(s):
# split alignment tokens by spaces, tabs and newline (and excluding them in the tokens)
return s.split()
def join_tokens(tokens):
"""Join a list of tokens into a string
@ -44,14 +47,17 @@ def join_tokens(tokens):
"""
return " ".join(tokens)
def _is_spacing(c):
""" Determine if the character is ignorable """
return True if c in END_OF_TOKEN else False
def split_sentences(text, delimiter="\n"):
""" Split a text into sentences with a delimiter"""
return re.sub(r'(( /?[.!?])+ )', rf'\1{delimiter}', text)
return re.sub(r"(( /?[.!?])+ )", rf"\1{delimiter}", text)
def is_sentence_separator(token):
""" Returns true if the token is a sentence splitter """
return re.match(r'^/?[.!?]$', token) != None
return re.match(r"^/?[.!?]$", token) is not None

Просмотреть файл

@ -1,5 +1,5 @@
"""This is a utility tool to split CoNLL formated files. It has the capability to pack sentences into generated
pages more tightly.
"""This is a utility tool to split CoNLL formated files.
It has the capability to pack sentences into generated pages more tightly.
usage: splitter.py [-h] [--doc_sep DOC_SEP] [--line_sep LINE_SEP]
[--force_doc_sep]
@ -22,16 +22,17 @@ example usage:
python -m genalog.text.splitter CoNLL-2012_train.txt conll2012_train
"""
import re
import os
import multiprocessing
import argparse
from tqdm import tqdm
from genalog.text import preprocess
from genalog.generation.document import DocumentGenerator
from genalog.generation.content import CompositeContent, ContentType
import multiprocessing
import os
from multiprocessing.pool import ThreadPool
from tqdm import tqdm
from genalog.generation.content import CompositeContent, ContentType
from genalog.generation.document import DocumentGenerator
from genalog.text import preprocess
# default buffer. Preferebly set this to something large
# It holds the lines read from the CoNLL file
BUFFER_SIZE = 50000
@ -48,6 +49,7 @@ WORKERS_PER_CPU = 2
default_generator = DocumentGenerator()
def unwrap(size, accumulator):
words = []
labels = []
@ -58,7 +60,10 @@ def unwrap(size, accumulator):
labels.append((word, tok))
return words, labels
def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='text_block.html.jinja'):
def find_split_position(
accumulator, start_pos, iters=SPLIT_ITERS, template_name="text_block.html.jinja"
):
"""Run a few iterations of binary search to find the best split point
from the start to pack in sentences into a page without overflow.
@ -75,7 +80,10 @@ def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='
best = None
count = 0
while start <= end:
if count==0 and (STARTING_SPLIT_GUESS+start_pos > start and STARTING_SPLIT_GUESS + start_pos < end):
if count == 0 and (
STARTING_SPLIT_GUESS + start_pos > start
and STARTING_SPLIT_GUESS + start_pos < end
):
split_point = STARTING_SPLIT_GUESS
else:
split_point = (start + end) // 2
@ -101,8 +109,15 @@ def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='
return best
def generate_splits(input_file, output_folder, sentence_seperator="", doc_seperator=None, pool=None,
force_doc_sep=False, ext="txt"):
def generate_splits(
input_file,
output_folder,
sentence_seperator="",
doc_seperator=None,
pool=None,
force_doc_sep=False,
ext="txt",
):
"""Processes the file line by line and add sentences to the buffer for processing.
Args:
@ -132,7 +147,9 @@ def generate_splits(input_file, output_folder, sentence_seperator="", doc_sepera
continue
start_pos = 0
while start_pos < len(accumulator):
start_pos = next_doc(accumulator,doc_id, start_pos, output_folder,pool)
start_pos = next_doc(
accumulator, doc_id, start_pos, output_folder, pool
)
doc_id += 1
progress_bar.update(1)
accumulator = []
@ -145,17 +162,20 @@ def generate_splits(input_file, output_folder, sentence_seperator="", doc_sepera
# process any left over lines
start_pos = 0
if len(sentence) > 0 : accumulator.append(sentence)
if len(sentence) > 0:
accumulator.append(sentence)
while start_pos < len(accumulator):
start_pos = next_doc(accumulator, doc_id, start_pos, output_folder, pool)
doc_id += 1
progress_bar.update(1)
def next_doc(accumulator, doc_id, start_pos, output_folder, pool, ext="txt"):
split_pos, doc, labels, text = find_split_position(accumulator, start_pos)
handle_doc(doc, labels, doc_id, text, output_folder, pool, ext)
return split_pos
def write_doc(doc, doc_id, labels, text, output_folder, ext="txt", write_png=False):
if write_png:
@ -165,45 +185,66 @@ def write_doc(doc, doc_id, labels, text, output_folder, ext="txt", write_png=Fal
text += " " # adding a space at EOF
text = preprocess.split_sentences(text)
with open(f"{output_folder}/clean_labels/{doc_id}.{ext}", "w") as l:
with open(f"{output_folder}/clean_labels/{doc_id}.{ext}", "w") as fp:
for idx, (token, label) in enumerate(labels):
l.write(token + "\t" + label)
fp.write(token + "\t" + label)
next_token, _ = labels[(idx + 1) % len(labels)]
if preprocess.is_sentence_separator(token) and not \
preprocess.is_sentence_separator(next_token):
l.write("\n")
if preprocess.is_sentence_separator(
token
) and not preprocess.is_sentence_separator(next_token):
fp.write("\n")
if idx == len(labels): # Reach the end of the document
l.write("\n")
fp.write("\n")
with open(f"{output_folder}/clean_text/{doc_id}.txt", "w") as text_file:
text_file.write(text)
return f"wrote: doc id: {doc_id}"
def _error_callback(err):
raise RuntimeError(err)
def handle_doc(doc, labels, doc_id, text, output_folder, pool, ext="txt"):
if pool:
pool.apply_async(write_doc, args=(doc, doc_id, labels, text, output_folder,ext), error_callback=_error_callback)
pool.apply_async(
write_doc,
args=(doc, doc_id, labels, text, output_folder, ext),
error_callback=_error_callback,
)
else:
write_doc(doc, doc_id, labels, text, output_folder)
def setup_folder(output_folder):
os.makedirs(os.path.join(output_folder, "clean_text"), exist_ok=True)
os.makedirs(os.path.join(output_folder, "clean_labels"), exist_ok=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_file", default="CoNLL-2012_train.txt", help="path to input CoNLL formated file.")
parser.add_argument("output_folder", default="conll2012_train", help="folder to write results to.")
parser.add_argument(
"input_file",
default="CoNLL-2012_train.txt",
help="path to input CoNLL formated file.",
)
parser.add_argument(
"output_folder", default="conll2012_train", help="folder to write results to."
)
parser.add_argument("--doc_sep", help="CoNLL doc seperator")
parser.add_argument("--ext", help="file extension", default="txt")
parser.add_argument("--line_sep", default=CONLL2012_DOC_SEPERATOR, help="CoNLL line seperator")
parser.add_argument("--force_doc_sep", default=False, action="store_true",
help="If set, documents are forced to be split by the doc seperator (recommended to turn this off)")
parser.add_argument(
"--line_sep", default=CONLL2012_DOC_SEPERATOR, help="CoNLL line seperator"
)
parser.add_argument(
"--force_doc_sep",
default=False,
action="store_true",
help="If set, documents are forced to be split by the doc seperator (recommended to turn this off)",
)
args = parser.parse_args()
unescape = lambda s: s.encode('utf-8').decode('unicode_escape') if s else None
unescape = lambda s: s.encode("utf-8").decode("unicode_escape") if s else None # noqa: E731
input_file = args.input_file
output_folder = args.output_folder
@ -215,6 +256,14 @@ if __name__ == "__main__":
n_workers = WORKERS_PER_CPU * multiprocessing.cpu_count()
with ThreadPool(processes=n_workers) as pool:
generate_splits(input_file, output_folder, line_sep, doc_seperator=doc_sep, pool=pool, force_doc_sep=False, ext=args.ext)
generate_splits(
input_file,
output_folder,
line_sep,
doc_seperator=doc_sep,
pool=pool,
force_doc_sep=False,
ext=args.ext,
)
pool.close()
pool.join()

Просмотреть файл

@ -1,2 +0,0 @@
[pytest]
junit_family=xunit1

4
requirements-dev.txt Normal file
Просмотреть файл

@ -0,0 +1,4 @@
pytest
pytest-cov
flake8
flake8-import-order

Просмотреть файл

@ -1,6 +1,7 @@
import setuptools
import os
import setuptools
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'VERSION.txt')) as version_file:
BUILD_VERSION = version_file.read().strip()

Просмотреть файл

@ -114,4 +114,6 @@ ns_tokens = [preprocess.tokenize(txt) for txt in ns_txt]
# test function expect params in tuple of
# (gt_label, gt_tokens, ocr_tokens, desired_ocr_labels)
LABEL_PROPAGATION_REGRESSION_TEST_CASES = list(zip(ner_labels, gt_tokens, ns_tokens, desired_ocr_labels))
LABEL_PROPAGATION_REGRESSION_TEST_CASES = list(
zip(ner_labels, gt_tokens, ns_tokens, desired_ocr_labels)
)

Просмотреть файл

@ -1,4 +1,3 @@
# Initializing test cases
# For extensibility, all parameters in a case are append to following arrays
gt_txt = []
@ -23,17 +22,21 @@ gt_to_noise_maps.append(
[0, 1],
[1],
[2],
[3]
])
[3],
]
)
noise_to_gt_maps.append([
noise_to_gt_maps.append(
[
[0],
# Similarly, the following shows that the second token in noise "ewYork" maps to the
# first ("New") and second ("York") token in gt
[0, 1],
[2],
[3]
])
[3],
]
)
##############################################################################################
@ -52,18 +55,9 @@ gt_txt.append("Boston is great")
aligned_ns.append("B oston@@@ grea t")
aligned_gt.append("B@oston is grea@t")
gt_to_noise_maps.append([
[0,1],
[1], # 'is' is also mapped to 'oston'
[2,3]
])
gt_to_noise_maps.append([[0, 1], [1], [2, 3]]) # 'is' is also mapped to 'oston'
noise_to_gt_maps.append([
[0],
[0,1], # 'oston' is to 'Boston' and 'is'
[2],
[2]
])
noise_to_gt_maps.append([[0], [0, 1], [2], [2]]) # 'oston' is to 'Boston' and 'is'
############################################################################################
# Empty Cases:
@ -95,18 +89,9 @@ ns_txt.append("B oston bi g")
aligned_gt.append("B@oston is bi@g")
aligned_ns.append("B oston@@@ bi g")
gt_to_noise_maps.append([
[0,1],
[1],
[2,3]
])
gt_to_noise_maps.append([[0, 1], [1], [2, 3]])
noise_to_gt_maps.append([
[0],
[0,1],
[2],
[2]
])
noise_to_gt_maps.append([[0], [0, 1], [2], [2]])
############################################################################################
gt_txt.append("New York is big.")
@ -115,17 +100,9 @@ ns_txt.append("NewYork big")
aligned_gt.append("New York is big.")
aligned_ns.append("New@York @@@big@")
gt_to_noise_maps.append([
[0],
[0],
[1],
[1]
])
gt_to_noise_maps.append([[0], [0], [1], [1]])
noise_to_gt_maps.append([
[0, 1],
[2, 3]
])
noise_to_gt_maps.append([[0, 1], [2, 3]])
#############################################################################################
gt_txt.append("politicians who lag superfluous on the")
@ -134,23 +111,9 @@ ns_txt.append("politicians who kg superfluous on the")
aligned_gt.append("politicians who lag superfluous on the")
aligned_ns.append("politicians who @kg superfluous on the")
gt_to_noise_maps.append([
[0],
[1],
[2],
[3],
[4],
[5]
])
gt_to_noise_maps.append([[0], [1], [2], [3], [4], [5]])
noise_to_gt_maps.append([
[0],
[1],
[2],
[3],
[4],
[5]
])
noise_to_gt_maps.append([[0], [1], [2], [3], [4], [5]])
############################################################################################
@ -160,20 +123,9 @@ ns_txt.append("faithei uifoimtdon the subject")
aligned_gt.append("farther @informed on the subject.")
aligned_ns.append("faithei ui@foimtd@on the subject@")
gt_to_noise_maps.append([
[0],
[1],
[1],
[2],
[3]
])
gt_to_noise_maps.append([[0], [1], [1], [2], [3]])
noise_to_gt_maps.append([
[0],
[1,2],
[3],
[4]
])
noise_to_gt_maps.append([[0], [1, 2], [3], [4]])
############################################################################################
@ -183,20 +135,9 @@ ns_txt.append("New Yorkis big .")
aligned_gt.append("New York is big .")
aligned_ns.append("New York@is big .")
gt_to_noise_maps.append([
[0],
[1],
[1],
[2],
[3]
])
gt_to_noise_maps.append([[0], [1], [1], [2], [3]])
noise_to_gt_maps.append([
[0],
[1,2],
[3],
[4]
])
noise_to_gt_maps.append([[0], [1, 2], [3], [4]])
############################################################################################
@ -206,23 +147,14 @@ ns_txt.append("New Yo rk is big.")
aligned_gt.append("New Yo@rk is big.")
aligned_ns.append("New Yo rk is big.")
gt_to_noise_maps.append([
[0],
[1,2],
[3],
[4]
])
gt_to_noise_maps.append([[0], [1, 2], [3], [4]])
noise_to_gt_maps.append([
[0],
[1],
[1],
[2],
[3]
])
noise_to_gt_maps.append([[0], [1], [1], [2], [3]])
# Format tests for pytest
# Each test expect in the following format
# (aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps)
PARSE_ALIGNMENT_REGRESSION_TEST_CASES = zip(aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps)
PARSE_ALIGNMENT_REGRESSION_TEST_CASES = zip(
aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps
)
ALIGNMENT_REGRESSION_TEST_CASES = list(zip(gt_txt, ns_txt, aligned_gt, aligned_ns))

Просмотреть файл

@ -1,20 +1,24 @@
from genalog.degradation.degrader import Degrader, ImageState
from genalog.degradation.degrader import DEFAULT_METHOD_PARAM_TO_INCLUDE
import copy
from unittest.mock import patch
import numpy as np
import pytest
import copy
from genalog.degradation.degrader import DEFAULT_METHOD_PARAM_TO_INCLUDE
from genalog.degradation.degrader import Degrader, ImageState
MOCK_IMAGE_SHAPE = (4, 3)
MOCK_IMAGE = np.arange(12, dtype=np.uint8).reshape(MOCK_IMAGE_SHAPE)
@pytest.fixture
def empty_degrader():
effects = []
return Degrader(effects)
@pytest.fixture(params=[
@pytest.fixture(
params=[
[("blur", {"radius": 5})],
[("blur", {"src": ImageState.ORIGINAL_STATE, "radius": 5})],
[("blur", {"src": ImageState.CURRENT_STATE, "radius": 5})],
@ -26,34 +30,45 @@ def empty_degrader():
],
[
("blur", {"radius": 5}),
("bleed_through", {
(
"bleed_through",
{
"src": ImageState.CURRENT_STATE,
"alpha": 0.7,
"background": ImageState.ORIGINAL_STATE,
}),
("morphology", {
"operation": "open",
"kernel_shape": (3,3),
"kernel_type": "ones"
}),
},
),
(
"morphology",
{"operation": "open", "kernel_shape": (3, 3), "kernel_type": "ones"},
),
],
]
])
)
def degrader(request):
effects = request.param
return Degrader(effects)
def test_degrader_init(empty_degrader):
def test_empty_degrader_init(empty_degrader):
assert empty_degrader.effects_to_apply == []
def test_degrader_init(degrader):
assert degrader.effects_to_apply is not []
for effect_tuple in degrader.effects_to_apply:
method_name, method_kwargs = effect_tuple
assert DEFAULT_METHOD_PARAM_TO_INCLUDE in method_kwargs
param_value = method_kwargs[DEFAULT_METHOD_PARAM_TO_INCLUDE]
assert param_value is ImageState.ORIGINAL_STATE or param_value is ImageState.CURRENT_STATE
assert (
param_value is ImageState.ORIGINAL_STATE
or param_value is ImageState.CURRENT_STATE
)
@pytest.mark.parametrize("effects, error_thrown", [
@pytest.mark.parametrize(
"effects, error_thrown",
[
([], None), # Empty effect
(None, TypeError),
([("blur", {"radius": 5})], None), # Validate input
@ -64,17 +79,20 @@ def test_degrader_init(degrader):
[
("blur", {"radius": 5}),
("bleed_through", {"alpha": "0.8"}),
("morphology", {"operation": "open"})
], None
("morphology", {"operation": "open"}),
],
None,
), # Multiple effects
(
[
("blur", {"radius": 5}),
("bleed_through", {"not_argument": "0.8"}),
("morphology", {"missing value"})
], ValueError
("morphology", {"missing value"}),
],
ValueError,
), # Multiple effects
])
],
)
def test_degrader_validate_effects(effects, error_thrown):
if error_thrown:
with pytest.raises(error_thrown):
@ -82,23 +100,26 @@ def test_degrader_validate_effects(effects, error_thrown):
else:
Degrader.validate_effects(effects)
def test_degrader_apply_effects(degrader):
method_names = [effect[0] for effect in degrader.effects_to_apply]
with patch("genalog.degradation.effect") as mock_effect:
degraded = degrader.apply_effects(MOCK_IMAGE)
degrader.apply_effects(MOCK_IMAGE)
for method in method_names:
assert mock_effect[method].is_called()
# assert degraded.shape == MOCK_IMAGE_SHAPE
def test_degrader_apply_effects_e2e(degrader):
degraded = degrader.apply_effects(MOCK_IMAGE)
assert degraded.shape == MOCK_IMAGE_SHAPE
assert degraded.dtype == np.uint8
def test_degrader_instructions(degrader):
original_instruction = copy.deepcopy(degrader.effects_to_apply)
degraded1 = degrader.apply_effects(MOCK_IMAGE)
degraded2 = degrader.apply_effects(MOCK_IMAGE)
degrader.apply_effects(MOCK_IMAGE)
degrader.apply_effects(MOCK_IMAGE)
# Make sure the degradation instructions are not altered
assert len(original_instruction) == len(degrader.effects_to_apply)
for i in range(len(original_instruction)):
@ -107,5 +128,5 @@ def test_degrader_instructions(degrader):
assert org_method_name == method_name
assert len(org_method_arg) == len(method_arg)
for key in org_method_arg.keys():
assert type(org_method_arg[key]) == type(method_arg[key])
assert isinstance(org_method_arg[key], type(method_arg[key]))
assert org_method_arg[key] == method_arg[key]

Просмотреть файл

@ -1,18 +1,21 @@
from genalog.degradation import effect
from unittest.mock import patch
import numpy as np
import pytest
from genalog.degradation import effect
NEW_IMG_SHAPE = (100, 100)
MOCK_IMG_SHAPE = (100, 120)
MOCK_IMG = np.ones(MOCK_IMG_SHAPE, dtype=np.uint8)
def test_blur():
dst = effect.blur(MOCK_IMG, radius=3)
assert dst.dtype == np.uint8 # preverse dtype
assert dst.shape == MOCK_IMG_SHAPE # preverse image size
def test_translation():
offset_x = offset_y = 1
# Test that border pixels are not white (<255)
@ -25,6 +28,7 @@ def test_translation():
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_overlay_weighted():
src = MOCK_IMG.copy()
src[0][0] = 10
@ -34,6 +38,7 @@ def test_overlay_weighted():
assert dst.shape == MOCK_IMG_SHAPE
assert dst[0][0] == src[0][0] * alpha + src[0][0] * beta
def test_overlay():
src1 = MOCK_IMG.copy()
src2 = MOCK_IMG.copy()
@ -45,6 +50,7 @@ def test_overlay():
assert dst[0][0] == 0
assert dst[0][1] == 1
@patch("genalog.degradation.effect.translation")
def test_bleed_through_default(mock_translation):
mock_translation.return_value = MOCK_IMG
@ -53,11 +59,15 @@ def test_bleed_through_default(mock_translation):
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
@pytest.mark.parametrize("foreground, background, error_thrown", [
@pytest.mark.parametrize(
"foreground, background, error_thrown",
[
(MOCK_IMG, MOCK_IMG, None),
# Test unmatched shape
(MOCK_IMG, MOCK_IMG[:, :-1], Exception),
])
],
)
def test_bleed_through_kwargs(foreground, background, error_thrown):
if error_thrown:
assert foreground.shape != background.shape
@ -68,36 +78,43 @@ def test_bleed_through_kwargs(foreground, background, error_thrown):
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_pepper():
dst = effect.pepper(MOCK_IMG, amount=0.1)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_salt():
dst = effect.salt(MOCK_IMG, amount=0.1)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_salt_then_pepper():
dst = effect.salt_then_pepper(MOCK_IMG, 0.5, 0.001)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_pepper_then_salt():
dst = effect.pepper_then_salt(MOCK_IMG, 0.001, 0.5)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
@pytest.mark.parametrize("kernel_shape, kernel_type", [
((3,3), "NOT_VALID_TYPE"),
(1, "ones"),
((1,2,3), "ones")
])
@pytest.mark.parametrize(
"kernel_shape, kernel_type",
[((3, 3), "NOT_VALID_TYPE"), (1, "ones"), ((1, 2, 3), "ones")],
)
def test_create_2D_kernel_error(kernel_shape, kernel_type):
with pytest.raises(Exception):
effect.create_2D_kernel(kernel_shape, kernel_type)
@pytest.mark.parametrize("kernel_shape, kernel_type, expected_kernel", [
@pytest.mark.parametrize(
"kernel_shape, kernel_type, expected_kernel",
[
((2, 2), "ones", np.array([[1, 1], [1, 1]])), # sq kernel
((1, 2), "ones", np.array([[1, 1]])), # horizontal
((2, 1), "ones", np.array([[1], [1]])), # vertical
@ -108,56 +125,77 @@ def test_create_2D_kernel_error(kernel_shape, kernel_type):
((2, 2), "plus", np.array([[0, 1], [1, 1]])),
((3, 3), "plus", np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])),
((3, 3), "ellipse", np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])),
((5,5), "ellipse",
np.array([
(
(5, 5),
"ellipse",
np.array(
[
[0, 0, 1, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 0, 1, 0, 0]
])),
])
[0, 0, 1, 0, 0],
]
),
),
],
)
def test_create_2D_kernel(kernel_shape, kernel_type, expected_kernel):
kernel = effect.create_2D_kernel(kernel_shape, kernel_type)
assert np.array_equal(kernel, expected_kernel)
def test_morphology_with_error():
INVALID_OPERATION = "NOT_A_OPERATION"
with pytest.raises(ValueError):
effect.morphology(MOCK_IMG, operation=INVALID_OPERATION)
@pytest.mark.parametrize("operation, kernel_shape, kernel_type", [
@pytest.mark.parametrize(
"operation, kernel_shape, kernel_type",
[
("open", (3, 3), "ones"),
("close", (3, 3), "ones"),
("dilate", (3, 3), "ones"),
("erode", (3, 3), "ones"),
])
],
)
def test_morphology(operation, kernel_shape, kernel_type):
dst = effect.morphology(MOCK_IMG,
operation=operation, kernel_shape=kernel_shape,
kernel_type=kernel_type)
dst = effect.morphology(
MOCK_IMG,
operation=operation,
kernel_shape=kernel_shape,
kernel_type=kernel_type,
)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
@pytest.fixture(params=["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"])
@pytest.fixture(
params=["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"]
)
def kernel(request):
return effect.create_2D_kernel((5, 5), request.param)
def test_open(kernel):
dst = effect.open(MOCK_IMG, kernel)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_close(kernel):
dst = effect.close(MOCK_IMG, kernel)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_erode(kernel):
dst = effect.erode(MOCK_IMG, kernel)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_dilate(kernel):
dst = effect.dilate(MOCK_IMG, kernel)
assert dst.dtype == np.uint8

Просмотреть файл

@ -1,15 +1,18 @@
from genalog.text import alignment, anchor, preprocess
import glob
import pytest
import difflib
import glob
import warnings
@pytest.mark.parametrize("gt_file, ocr_file",
import pytest
from genalog.text import alignment, anchor, preprocess
@pytest.mark.parametrize(
"gt_file, ocr_file",
zip(
sorted(glob.glob("tests/text/data/gt_*.txt")),
sorted(glob.glob("tests/text/data/ocr_*.txt"))
)
sorted(glob.glob("tests/text/data/ocr_*.txt")),
),
)
def test_align_w_anchor_and_align(gt_file, ocr_file):
gt_text = open(gt_file, "r").read()
@ -21,17 +24,22 @@ def test_align_w_anchor_and_align(gt_file, ocr_file):
aligned_anchor_gt = aligned_anchor_gt.split(".")
aligned_gt = aligned_gt.split(".")
str_diff = "\n".join(difflib.unified_diff(aligned_gt, aligned_anchor_gt))
warnings.warn(UserWarning(
"\n"+ f"{str_diff}" +
f"\n\n**** Inconsistent Alignment Results between align() and " +
f"align_w_anchor(). Ignore this if the delta is not significant. ****\n"))
warnings.warn(
UserWarning(
"\n"
+ f"{str_diff}"
+ "\n\n**** Inconsistent Alignment Results between align() and "
+ "align_w_anchor(). Ignore this if the delta is not significant. ****\n"
)
)
@pytest.mark.parametrize("gt_file, ocr_file",
@pytest.mark.parametrize(
"gt_file, ocr_file",
zip(
sorted(glob.glob("tests/text/data/gt_*.txt")),
sorted(glob.glob("tests/text/data/ocr_*.txt"))
)
sorted(glob.glob("tests/text/data/ocr_*.txt")),
),
)
@pytest.mark.parametrize("max_seg_length", [25, 50, 75, 100, 150])
def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
@ -39,7 +47,9 @@ def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
ocr_text = open(ocr_file, "r").read()
gt_tokens = preprocess.tokenize(gt_text)
ocr_tokens = preprocess.tokenize(ocr_text)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
)
for gt_anchor, ocr_anchor in zip(gt_anchors, ocr_anchors):
# Ensure that each anchor word is the same word in both text
assert gt_tokens[gt_anchor] == ocr_tokens[ocr_anchor]

Просмотреть файл

@ -1,47 +1,62 @@
from genalog.text import conll_format
from unittest import mock
import argparse
import pytest
import glob
import itertools
@pytest.mark.parametrize("required_args", [
(["tests/e2e/data/synthetic_dataset", "test_version"])
])
@pytest.mark.parametrize("optional_args", [
import pytest
from genalog.text import conll_format
@pytest.mark.parametrize(
"required_args", [(["tests/e2e/data/synthetic_dataset", "test_version"])]
)
@pytest.mark.parametrize(
"optional_args",
[
(["--train_subset"]),
(["--test_subset"]),
(["--gt_folder", "shared"]),
])
],
)
def test_conll_format(required_args, optional_args):
parser = conll_format.create_parser()
arg_list = required_args + optional_args
args = parser.parse_args(args=arg_list)
conll_format.main(args)
basepath = "tests/e2e/data/conll_formatter/"
@pytest.mark.parametrize("clean_label_filename, ocr_text_filename",
@pytest.mark.parametrize(
"clean_label_filename, ocr_text_filename",
zip(
sorted(glob.glob("tests/e2e/data/conll_formatter/clean_labels/*.txt")),
sorted(glob.glob("tests/e2e/data/conll_formatter/ocr_text/*.txt"))
)
sorted(glob.glob("tests/e2e/data/conll_formatter/ocr_text/*.txt")),
),
)
def test_propagate_labels_sentence_single_file(clean_label_filename, ocr_text_filename):
with open(clean_label_filename, 'r', encoding='utf-8') as clf:
with open(clean_label_filename, "r", encoding="utf-8") as clf:
tokens_labels_str = clf.readlines()
clean_tokens = [line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2]
clean_labels = [line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2]
clean_tokens = [
line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2
]
clean_labels = [
line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2
]
clean_sentences = conll_format.get_sentences_from_iob_format(tokens_labels_str)
# read ocr tokens
with open(ocr_text_filename, 'r', encoding='utf-8') as otf:
ocr_text_str = ' '.join(otf.readlines())
ocr_tokens = [token.strip() for token in ocr_text_str.split()] # already tokenized in data
with open(ocr_text_filename, "r", encoding="utf-8") as otf:
ocr_text_str = " ".join(otf.readlines())
ocr_tokens = [
token.strip() for token in ocr_text_str.split()
] # already tokenized in data
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
assert len(ocr_text_sentences) == len(clean_sentences)
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
assert len(ocr_sentences_flatten) == len(
ocr_tokens
) # ensure aligned ocr tokens == ocr tokens

Просмотреть файл

@ -1,10 +1,13 @@
from genalog.generation.document import DocumentGenerator
from genalog.generation.content import CompositeContent, ContentType
import pytest
import os
CONTENT = CompositeContent(["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH])
import pytest
from genalog.generation.content import CompositeContent, ContentType
from genalog.generation.document import DocumentGenerator
CONTENT = CompositeContent(
["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH]
)
UNSUPPORTED_CONTENT_FORMAT = ["foo bar"]
UNSUPPORTED_CONTENT_TYPE = CompositeContent(["foo"], [ContentType.TITLE])
@ -21,9 +24,10 @@ CUSTOM_STYLE = {
"font_family": ["Calibri", "Times"],
"font_size": ["10px"],
"text_align": ["right"],
"hyphenate": [True, False]
"hyphenate": [True, False],
}
def test_default_template_generation():
doc_gen = DocumentGenerator()
generator = doc_gen.create_generator(CONTENT, doc_gen.template_list)
@ -32,20 +36,27 @@ def test_default_template_generation():
assert "Unsupported Content Type:" not in html_str
assert "No content loaded" not in html_str
def test_default_template_generation_w_unsupported_content_format():
doc_gen = DocumentGenerator()
generator = doc_gen.create_generator(UNSUPPORTED_CONTENT_FORMAT, doc_gen.template_list)
generator = doc_gen.create_generator(
UNSUPPORTED_CONTENT_FORMAT, doc_gen.template_list
)
for doc in generator:
html_str = doc.render_html()
assert "No content loaded" in html_str
def test_default_template_generation_w_unsupported_content_type():
doc_gen = DocumentGenerator()
generator = doc_gen.create_generator(UNSUPPORTED_CONTENT_TYPE, ["text_block.html.jinja"])
generator = doc_gen.create_generator(
UNSUPPORTED_CONTENT_TYPE, ["text_block.html.jinja"]
)
for doc in generator:
html_str = doc.render_html()
assert "Unsupported Content Type: ContentType.TITLE" in html_str
def test_custom_template_generation():
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME])
@ -53,6 +64,7 @@ def test_custom_template_generation():
result = doc.render_html()
assert result == str(CONTENT)
def test_undefined_template_generation():
doc_gen = DocumentGenerator()
assert UNDEFINED_TEMPLATE_NAME not in doc_gen.template_list
@ -60,6 +72,7 @@ def test_undefined_template_generation():
with pytest.raises(FileNotFoundError):
next(generator)
def test_custom_style_template_generation():
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
assert len(doc_gen.styles_to_generate) == 1
@ -70,14 +83,16 @@ def test_custom_style_template_generation():
result = doc.render_html()
assert doc.styles["font_family"] == result
def test_render_pdf_and_png():
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME])
for doc in generator:
pdf_bytes = doc.render_pdf()
png_bytes = doc.render_png()
assert pdf_bytes != None
assert png_bytes != None
assert pdf_bytes is not None
assert png_bytes is not None
def test_save_document_as_png():
if not os.path.exists(TEST_OUTPUT_DIR):
@ -89,6 +104,7 @@ def test_save_document_as_png():
# Check if the document is saved in filepath
assert os.path.exists(FILE_DESTINATION)
def test_save_document_as_separate_png():
if not os.path.exists(TEST_OUTPUT_DIR):
os.mkdir(TEST_OUTPUT_DIR)
@ -102,6 +118,7 @@ def test_save_document_as_separate_png():
printed_doc_name = FILE_DESTINATION.replace(".png", f"_pg_{page_num}.png")
assert os.path.exists(printed_doc_name)
def test_overwriting_style():
new_font = "NewFontFamily"
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)

Просмотреть файл

@ -1,23 +1,25 @@
from genalog.generation.document import DocumentGenerator
from genalog.generation.content import CompositeContent, ContentType
from genalog.degradation.degrader import Degrader
from genalog.degradation import effect
from genalog.generation.content import CompositeContent, ContentType
from genalog.generation.document import DocumentGenerator
import numpy as np
import cv2
TEST_OUTPUT_DIR = "test_out/"
SAMPLE_TXT = "Everton 's Duncan Ferguson , who scored twice against Manchester United on Wednesday , was picked on Thursday for the Scottish squad after a 20-month exile ."
SAMPLE_TXT = """Everton 's Duncan Ferguson , who scored twice against Manchester United on Wednesday ,
was picked on Thursday for the Scottish squad after a 20-month exile ."""
DEFAULT_TEMPLATE = "text_block.html.jinja"
DEGRADATION_EFFECTS = [
("blur", {"radius": 5}),
("bleed_through", {"alpha": 0.8}),
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "plus"}),
(
"morphology",
{"operation": "open", "kernel_shape": (3, 3), "kernel_type": "plus"},
),
("morphology", {"operation": "close"}),
("morphology", {"operation": "dilate"}),
("morphology", {"operation": "erode"})
("morphology", {"operation": "erode"}),
]
def test_generation_and_degradation():
# Initiate content
content = CompositeContent([SAMPLE_TXT], [ContentType.PARAGRAPH])
@ -32,6 +34,4 @@ def test_generation_and_degradation():
# get the image in bytes in RGBA channels
src = doc.render_array(resolution=100, channel="GRAYSCALE")
# run each degradation effect
dst = degrader.apply_effects(src)
degrader.apply_effects(src)

Просмотреть файл

@ -1,19 +1,20 @@
from genalog.generation.document import DocumentGenerator
from genalog.generation.content import CompositeContent, ContentType
import pytest
import numpy as np
import cv2
import pytest
from genalog.generation.content import CompositeContent, ContentType
from genalog.generation.document import DocumentGenerator
TEMPLATE_PATH = "tests/e2e/templates"
TEST_OUT_FOLDER = "test_out/"
SAMPLE_TXT = "foo"
CONTENT = CompositeContent([SAMPLE_TXT], [ContentType.PARAGRAPH])
@pytest.fixture
def doc_generator():
return DocumentGenerator(template_path=TEMPLATE_PATH)
def test_red_channel(doc_generator):
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
for doc in generator:
@ -23,6 +24,7 @@ def test_red_channel(doc_generator):
assert tuple(img_array[0][0]) == (0, 0, 255, 255)
cv2.imwrite(TEST_OUT_FOLDER + "red.png", img_array)
def test_green_channel(doc_generator):
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
for doc in generator:
@ -32,6 +34,7 @@ def test_green_channel(doc_generator):
assert tuple(img_array[0][0]) == (0, 128, 0, 255)
cv2.imwrite(TEST_OUT_FOLDER + "green.png", img_array)
def test_blue_channel(doc_generator):
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
for doc in generator:

Просмотреть файл

@ -1,30 +1,36 @@
from genalog.ocr.rest_client import GrokRestClient
import json
import pytest
from dotenv import load_dotenv
from genalog.ocr.blob_client import GrokBlobClient
from genalog.ocr.grok import Grok
import requests
import pytest
import time
import json
import os
from dotenv import load_dotenv
load_dotenv("tests/ocr/.env")
class TestBlobClient:
@pytest.mark.parametrize("use_async", [True, False])
def test_upload_images(self, use_async):
blob_client = GrokBlobClient.create_from_env_var()
subfolder = "tests/ocr/data/img"
file_prefix = subfolder.replace("/", "_")
dst_folder, _ = blob_client.upload_images_to_blob(subfolder, use_async=use_async)
subfolder.replace("/", "_")
dst_folder, _ = blob_client.upload_images_to_blob(
subfolder, use_async=use_async
)
uploaded_items, _ = blob_client.list_blobs(dst_folder)
uploaded_items = sorted(list(uploaded_items), key=lambda x: x.name)
assert uploaded_items[0].name == f"{dst_folder}/0.png"
assert uploaded_items[1].name == f"{dst_folder}/1.png"
assert uploaded_items[2].name == f"{dst_folder}/11.png"
blob_client.delete_blobs_folder(dst_folder)
assert len(list(blob_client.list_blobs(dst_folder)[0])) == 0, f"folder {dst_folder} was not deleted"
assert (
len(list(blob_client.list_blobs(dst_folder)[0])) == 0
), f"folder {dst_folder} was not deleted"
dst_folder, _ = blob_client.upload_images_to_blob(subfolder, "test_images", use_async=use_async)
dst_folder, _ = blob_client.upload_images_to_blob(
subfolder, "test_images", use_async=use_async
)
assert dst_folder == "test_images"
uploaded_items, _ = blob_client.list_blobs(dst_folder)
uploaded_items = sorted(list(uploaded_items), key=lambda x: x.name)
@ -32,17 +38,23 @@ class TestBlobClient:
assert uploaded_items[1].name == f"{dst_folder}/1.png"
assert uploaded_items[2].name == f"{dst_folder}/11.png"
blob_client.delete_blobs_folder(dst_folder)
assert len(list(blob_client.list_blobs(dst_folder)[0])) == 0, f"folder {dst_folder} was not deleted"
assert (
len(list(blob_client.list_blobs(dst_folder)[0])) == 0
), f"folder {dst_folder} was not deleted"
class TestGROKe2e:
@pytest.mark.parametrize("use_async", [False, True])
def test_grok_e2e(self, tmpdir, use_async):
grok = Grok.create_from_env_var()
src_folder = "tests/ocr/data/img"
grok.run_grok(src_folder, tmpdir, blob_dest_folder="testimages", use_async=use_async, cleanup=True)
json_folder = "tests/ocr/data/json"
json_hash = "521c38122f783673598856cd81d91c21"
grok.run_grok(
src_folder,
tmpdir,
blob_dest_folder="testimages",
use_async=use_async,
cleanup=True,
)
assert json.load(open(f"{tmpdir}/0.json", "r"))[0]["text"]
assert json.load(open(f"{tmpdir}/1.json", "r"))[0]["text"]
assert json.load(open(f"{tmpdir}/11.json", "r"))[0]["text"]

Просмотреть файл

@ -1,32 +1,43 @@
from genalog import pipeline
import pytest
import glob
import pytest
from genalog import pipeline
EXAMPLE_TEXT_FILE = "tests/text/data/gt_1.txt"
@pytest.fixture
def default_analog_generator():
return pipeline.AnalogDocumentGeneration()
@pytest.fixture
def custom_analog_generator():
custom_styles = {"font_size": ["5px"]}
custom_degradation = [("blur", {"radius": 3})]
return pipeline.AnalogDocumentGeneration(
styles=custom_styles,
degradations=custom_degradation,
resolution=300)
styles=custom_styles, degradations=custom_degradation, resolution=300
)
def test_default_generate_img(default_analog_generator):
example_template = default_analog_generator.list_templates()[0]
img_array = default_analog_generator.generate_img(EXAMPLE_TEXT_FILE, example_template, target_folder=None)
default_analog_generator.generate_img(
EXAMPLE_TEXT_FILE, example_template, target_folder=None
)
def test_custom_generate_img(custom_analog_generator):
example_template = custom_analog_generator.list_templates()[0]
img_array = custom_analog_generator.generate_img(EXAMPLE_TEXT_FILE, example_template, target_folder=None)
custom_analog_generator.generate_img(
EXAMPLE_TEXT_FILE, example_template, target_folder=None
)
def test_generate_dataset_multiprocess():
INPUT_TEXT_FILENAMES = glob.glob("tests/text/data/gt_*.txt")
with pytest.deprecated_call():
pipeline.generate_dataset_multiprocess(INPUT_TEXT_FILENAMES, "test_out", {}, [], "text_block.html.jinja")
pipeline.generate_dataset_multiprocess(
INPUT_TEXT_FILENAMES, "test_out", {}, [], "text_block.html.jinja"
)

Просмотреть файл

@ -1,7 +1,8 @@
import os
import difflib
import os
from genalog.text.splitter import CONLL2003_DOC_SEPERATOR, generate_splits
from genalog.text.splitter import generate_splits, CONLL2003_DOC_SEPERATOR
def _compare_content(file1, file2):
txt1 = open(file1, "r").read()
@ -12,15 +13,32 @@ def _compare_content(file1, file2):
str_diff = "\n".join(difflib.unified_diff(sentences_txt1, sentences_txt2))
assert False, f"Delta between outputs: \n {str_diff}"
def test_splitter(tmpdir):
# tmpdir = "test_out"
os.makedirs(f"{tmpdir}/clean_labels")
os.makedirs(f"{tmpdir}/clean_text")
generate_splits("tests/e2e/data/splitter/example_conll2012.txt", tmpdir,
doc_seperator=CONLL2003_DOC_SEPERATOR, sentence_seperator="")
generate_splits(
"tests/e2e/data/splitter/example_conll2012.txt",
tmpdir,
doc_seperator=CONLL2003_DOC_SEPERATOR,
sentence_seperator="",
)
_compare_content("tests/e2e/data/splitter/example_splits/clean_text/0.txt", f"{tmpdir}/clean_text/0.txt")
_compare_content("tests/e2e/data/splitter/example_splits/clean_text/1.txt", f"{tmpdir}/clean_text/1.txt")
_compare_content("tests/e2e/data/splitter/example_splits/clean_labels/0.txt", f"{tmpdir}/clean_labels/0.txt")
_compare_content("tests/e2e/data/splitter/example_splits/clean_labels/1.txt", f"{tmpdir}/clean_labels/1.txt")
_compare_content(
"tests/e2e/data/splitter/example_splits/clean_text/0.txt",
f"{tmpdir}/clean_text/0.txt",
)
_compare_content(
"tests/e2e/data/splitter/example_splits/clean_text/1.txt",
f"{tmpdir}/clean_text/1.txt",
)
_compare_content(
"tests/e2e/data/splitter/example_splits/clean_labels/0.txt",
f"{tmpdir}/clean_labels/0.txt",
)
_compare_content(
"tests/e2e/data/splitter/example_splits/clean_labels/1.txt",
f"{tmpdir}/clean_labels/1.txt",
)

Просмотреть файл

@ -1,62 +1,76 @@
from genalog.generation.content import *
import pytest
from genalog.generation.content import CompositeContent, Content, ContentType
from genalog.generation.content import Paragraph, Title
CONTENT_LIST = ["foo", "bar"]
COMPOSITE_CONTENT_TYPE = [ContentType.TITLE, ContentType.PARAGRAPH]
TEXT = "foo bar"
@pytest.fixture
def content_base_class():
return Content()
@pytest.fixture
def paragraph():
return Paragraph(TEXT)
@pytest.fixture
def title():
return Title(TEXT)
@pytest.fixture
def section():
return CompositeContent(CONTENT_LIST, COMPOSITE_CONTENT_TYPE)
def test_content_set_content_type(content_base_class):
with pytest.raises(TypeError):
content_base_class.set_content_type("NOT VALID CONTENT TYPE")
content_base_class.set_content_type(ContentType.PARAGRAPH)
def test_paragraph_init(paragraph):
with pytest.raises(TypeError):
Paragraph([])
assert paragraph.content_type == ContentType.PARAGRAPH
def test_paragraph_print(paragraph):
assert paragraph.__str__()
def test_paragraph_iterable_indexable(paragraph):
for index, character in enumerate(paragraph):
assert character == paragraph[index]
def test_title_init(title):
with pytest.raises(TypeError):
Title([])
assert title.content_type == ContentType.TITLE
def test_title_iterable_indexable(title):
for index, character in enumerate(title):
assert character == title[index]
def test_composite_content_init(section):
with pytest.raises(TypeError):
CompositeContent((), [])
assert section.content_type == ContentType.COMPOSITE
def test_composite_content_iterable(section):
for index, content in enumerate(section):
assert content.content_type == COMPOSITE_CONTENT_TYPE[index]
def test_composite_content_print(section):
assert "foo" in section.__str__()
assert "bar" in section.__str__()

Просмотреть файл

@ -1,8 +1,10 @@
from genalog.generation.document import Document, DocumentGenerator
from genalog.generation.document import DEFAULT_DOCUMENT_STYLE
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import MagicMock, patch
from genalog.generation.document import DEFAULT_DOCUMENT_STYLE
from genalog.generation.document import Document, DocumentGenerator
FRENCH = "fr"
CONTENT = ["some text"]
@ -21,53 +23,72 @@ DEFAULT_TEMPLATE_NAME = "text_block.html.jinja"
DEFAULT_PACKAGE_NAME = "genalog.generation"
DEFAULT_TEMPLATE_FOLDER = "templates"
@pytest.fixture
def default_document():
mock_jinja_template = MagicMock()
mock_jinja_template.render.return_value = MOCK_COMPILED_DOCUMENT
return Document(CONTENT, mock_jinja_template)
@pytest.fixture
def french_document():
mock_jinja_template = MagicMock()
mock_jinja_template.render.return_value = MOCK_COMPILED_DOCUMENT
return Document(CONTENT, mock_jinja_template, language=FRENCH)
def test_document_init(default_document):
assert default_document.styles == DEFAULT_DOCUMENT_STYLE
assert default_document._document != None
assert default_document.compiled_html != None
assert default_document._document is not None
assert default_document.compiled_html is not None
def test_document_init_with_kwargs(french_document):
assert french_document.styles["language"] == FRENCH
assert french_document._document != None
assert french_document.compiled_html != None
assert french_document._document is not None
assert french_document.compiled_html is not None
def test_document_render_html(french_document):
compiled_document = french_document.render_html()
assert compiled_document == MOCK_COMPILED_DOCUMENT
french_document.template.render.assert_called_with(content=CONTENT, **french_document.styles)
french_document.template.render.assert_called_with(
content=CONTENT, **french_document.styles
)
def test_document_render_pdf(default_document):
default_document._document = MagicMock()
# run tested function
default_document.render_pdf(target=FILE_DESTINATION_PDF, zoom=2)
default_document._document.write_pdf.assert_called_with(target=FILE_DESTINATION_PDF, zoom=2)
default_document._document.write_pdf.assert_called_with(
target=FILE_DESTINATION_PDF, zoom=2
)
def test_document_render_png(default_document):
default_document._document = MagicMock()
# run tested function
default_document.render_png(target=FILE_DESTINATION_PNG, resolution=100)
default_document._document.write_png.assert_called_with(target=FILE_DESTINATION_PNG, resolution=100)
default_document._document.write_png.assert_called_with(
target=FILE_DESTINATION_PNG, resolution=100
)
def test_document_render_png_split_pages(default_document):
default_document._document.copy = MagicMock()
# run tested function
default_document.render_png(target=FILE_DESTINATION_PNG, split_pages=True, resolution=100)
default_document.render_png(
target=FILE_DESTINATION_PNG, split_pages=True, resolution=100
)
result_destination = FILE_DESTINATION_PNG.replace(".png", "_pg_0.png")
# assertion
document_copy = default_document._document.copy.return_value
document_copy.write_png.assert_called_with(target=result_destination, resolution=100)
document_copy.write_png.assert_called_with(
target=result_destination, resolution=100
)
def test_document_render_array_valid_args(default_document):
# setup mock
@ -84,11 +105,13 @@ def test_document_render_array_valid_args(default_document):
img_array = default_document.render_array(resolution=100, channel=channel_type)
assert img_array.shape == expected_img_shape
def test_document_render_array_invalid_args(default_document):
invalid_channel_types = "INVALID"
with pytest.raises(ValueError):
default_document.render_array(resolution=100, channel=invalid_channel_types)
def test_document_render_array_invalid_format(default_document):
# setup mock
mock_surface = MagicMock()
@ -99,6 +122,7 @@ def test_document_render_array_invalid_format(default_document):
with pytest.raises(RuntimeError):
default_document.render_array(resolution=100)
def test_document_update_style(default_document):
new_style = {"language": FRENCH, "new_property": "some value"}
# Ensure that a new property is not already defined
@ -111,11 +135,13 @@ def test_document_update_style(default_document):
# Ensure that a new property is added
assert default_document.styles["new_property"] == new_style["new_property"]
@patch("genalog.generation.document.Environment")
@patch("genalog.generation.document.PackageLoader")
@patch("genalog.generation.document.FileSystemLoader")
def test_document_generator_init_default_setting(mock_file_system_loader,
mock_package_loader, mock_environment):
def test_document_generator_init_default_setting(
mock_file_system_loader, mock_package_loader, mock_environment
):
# setup mock template environment
mock_environment_instance = mock_environment.return_value
mock_environment_instance.list_templates.return_value = [DEFAULT_TEMPLATE_NAME]
@ -123,15 +149,19 @@ def test_document_generator_init_default_setting(mock_file_system_loader,
document_generator = DocumentGenerator()
# Ensure the right loader is called
mock_file_system_loader.assert_not_called()
mock_package_loader.assert_called_with(DEFAULT_PACKAGE_NAME, DEFAULT_TEMPLATE_FOLDER)
mock_package_loader.assert_called_with(
DEFAULT_PACKAGE_NAME, DEFAULT_TEMPLATE_FOLDER
)
# Ensure that the default template in the package is loaded
assert DEFAULT_TEMPLATE_NAME in document_generator.template_list
@patch("genalog.generation.document.Environment")
@patch("genalog.generation.document.PackageLoader")
@patch("genalog.generation.document.FileSystemLoader")
def test_document_generator_init_custom_template(mock_file_system_loader,
mock_package_loader, mock_environment):
def test_document_generator_init_custom_template(
mock_file_system_loader, mock_package_loader, mock_environment
):
# setup mock template environment
mock_environment_instance = mock_environment.return_value
mock_environment_instance.list_templates.return_value = [CUSTOM_TEMPLATE_NAME]
@ -143,67 +173,79 @@ def test_document_generator_init_custom_template(mock_file_system_loader,
# Ensure that the expected template is registered
assert CUSTOM_TEMPLATE_NAME in document_generator.template_list
@pytest.fixture
def default_document_generator():
with patch("genalog.generation.document.Environment") as MockEnvironment:
template_environment_instance = MockEnvironment.return_value
template_environment_instance.list_templates.return_value = [DEFAULT_TEMPLATE_NAME]
template_environment_instance.list_templates.return_value = [
DEFAULT_TEMPLATE_NAME
]
template_environment_instance.get_template.return_value = MOCK_TEMPLATE
doc_gen = DocumentGenerator()
return doc_gen
def test_document_generator_create_generator(default_document_generator):
available_templates = default_document_generator.template_list
assert len(available_templates) < 2
generator = default_document_generator.create_generator(CONTENT, available_templates)
doc = next(generator)
generator = default_document_generator.create_generator(
CONTENT, available_templates
)
next(generator)
with pytest.raises(StopIteration):
next(generator)
def test_document_generator_create_generator_(default_document_generator):
# setup test case
available_templates = default_document_generator.template_list
undefined_template = "NOT A VALID TEMPLATE"
assert undefined_template not in available_templates
generator = default_document_generator.create_generator(CONTENT, [undefined_template])
generator = default_document_generator.create_generator(
CONTENT, [undefined_template]
)
with pytest.raises(FileNotFoundError):
doc = next(generator)
next(generator)
@pytest.mark.parametrize("template_name, expected_output", [
@pytest.mark.parametrize(
"template_name, expected_output",
[
("base.html.jinja", False),
("text_block.html.jinja", True),
("text_block.css.jinja", False),
("macro/dimension.css.jinja", False)
])
("macro/dimension.css.jinja", False),
],
)
def test__keep_templates(template_name, expected_output):
output = DocumentGenerator._keep_template(template_name)
assert output == expected_output
def test_set_styles_to_generate(default_document_generator):
assert len(default_document_generator.styles_to_generate) == 1
default_document_generator.set_styles_to_generate({"foo": ["bar", "bar"]})
assert len(default_document_generator.styles_to_generate) == 2
@pytest.mark.parametrize("styles, expected_output", [
@pytest.mark.parametrize(
"styles, expected_output",
[
({}, []), # empty case
({"size": ["10px"], "color": [] }, []), #empty value will result in null combinations
(
{"size": ["10px"], "color": ["red"] },
[{"size":"10px", "color":"red"}]
),
(
{"size": ["5px", "10px"]},
[{"size": "5px"}, {"size": "10px"}]
),
{"size": ["10px"], "color": []},
[],
), # empty value will result in null combinations
({"size": ["10px"], "color": ["red"]}, [{"size": "10px", "color": "red"}]),
({"size": ["5px", "10px"]}, [{"size": "5px"}, {"size": "10px"}]),
(
{"size": ["10px", "15px"], "color": ["blue"]},
[
{"size":"10px", "color": "blue"},
{"size":"15px", "color": "blue"}
]
[{"size": "10px", "color": "blue"}, {"size": "15px", "color": "blue"}],
),
],
)
])
def test_document_generator_expand_style_combinations(styles, expected_output):
output = DocumentGenerator.expand_style_combinations(styles)
assert output == expected_output

Просмотреть файл

@ -1,209 +1,608 @@
from genalog.ocr.metrics import get_align_stats, get_editops_stats, get_metrics, get_stats
from genalog.text.anchor import align_w_anchor
from genalog.text.alignment import GAP_CHAR, align
from genalog.text.ner_label import _find_gap_char_candidates
from pandas._testing import assert_frame_equal
import pytest
import genalog.ocr.metrics
import pandas as pd
import numpy as np
import json
import pickle
import os
from genalog.ocr.metrics import get_align_stats, get_editops_stats, get_stats
from genalog.text.alignment import align, GAP_CHAR
from genalog.text.ner_label import _find_gap_char_candidates
genalog.ocr.metrics.LOG_LEVEL = 0
@pytest.mark.parametrize("src_string, target, expected_stats",
@pytest.mark.parametrize(
"src_string, target, expected_stats",
[
("a worn coat", "a wom coat",
{'edit_insert': 1, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
(" ", "a",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("a", " ",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("a", "a",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 0, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("ab", "ac",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("ac", "ab",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("New York is big.", "N ewYork kis big.",
{'edit_insert': 0, 'edit_delete': 1, 'edit_replace': 0, 'edit_insert_spacing': 1, 'edit_delete_spacing': 1}),
("B oston grea t", "Boston is great",
{'edit_insert': 0, 'edit_delete': 2, 'edit_replace': 0, 'edit_insert_spacing': 2, 'edit_delete_spacing': 1}),
("New York is big.", "N ewyork kis big",
{'edit_insert': 1, 'edit_delete': 1, 'edit_replace': 1, 'edit_insert_spacing': 1, 'edit_delete_spacing': 1}),
("dog", "d@g", # Test against default gap_char "@"
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("some@one.com", "some@one.com",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 0, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0})
])
(
"a worn coat",
"a wom coat",
{
"edit_insert": 1,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
" ",
"a",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"a",
" ",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"a",
"a",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 0,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"ab",
"ac",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"ac",
"ab",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"New York is big.",
"N ewYork kis big.",
{
"edit_insert": 0,
"edit_delete": 1,
"edit_replace": 0,
"edit_insert_spacing": 1,
"edit_delete_spacing": 1,
},
),
(
"B oston grea t",
"Boston is great",
{
"edit_insert": 0,
"edit_delete": 2,
"edit_replace": 0,
"edit_insert_spacing": 2,
"edit_delete_spacing": 1,
},
),
(
"New York is big.",
"N ewyork kis big",
{
"edit_insert": 1,
"edit_delete": 1,
"edit_replace": 1,
"edit_insert_spacing": 1,
"edit_delete_spacing": 1,
},
),
(
"dog",
"d@g", # Test against default gap_char "@"
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"some@one.com",
"some@one.com",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 0,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
],
)
def test_editops_stats(src_string, target, expected_stats):
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
gap_char_candidates, input_char_set = _find_gap_char_candidates(
[src_string], [target]
)
gap_char = (
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
)
alignment = align(target, src_string)
stats, actions = get_editops_stats(alignment, gap_char)
for k in expected_stats:
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
@pytest.mark.parametrize("src_string, target, expected_stats, expected_substitutions",
@pytest.mark.parametrize(
"src_string, target, expected_stats, expected_substitutions",
[
(
"a worn coat", "a wom coat",
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 11, 'total_words': 3,
'matching_chars': 9, 'matching_words': 2,"matching_alnum_words" : 2, 'word_accuracy': 2/3, 'char_accuracy': 9/11},
{('rn', 'm'): 1}
"a worn coat",
"a wom coat",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 11,
"total_words": 3,
"matching_chars": 9,
"matching_words": 2,
"matching_alnum_words": 2,
"word_accuracy": 2 / 3,
"char_accuracy": 9 / 11,
},
{("rn", "m"): 1},
),
(
"a c", "def",
{'insert': 0, 'delete': 0, 'replace': 1 , 'spacing': 0, 'total_chars': 3, 'total_words': 1,
'matching_chars': 0, 'matching_words': 0,"matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 0},
{('a c', 'def'): 1}
"a c",
"def",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 3,
"total_words": 1,
"matching_chars": 0,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0,
"char_accuracy": 0,
},
{("a c", "def"): 1},
),
(
"a", "a b",
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 2,
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 0.5, 'char_accuracy': 1/3},
{}
"a",
"a b",
{
"insert": 1,
"delete": 0,
"replace": 0,
"spacing": 1,
"total_chars": 3,
"total_words": 2,
"matching_chars": 1,
"matching_words": 1,
"matching_alnum_words": 1,
"word_accuracy": 0.5,
"char_accuracy": 1 / 3,
},
{},
),
(
"a b", "b",
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 1,
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1, 'char_accuracy': 1/3},
{}
"a b",
"b",
{
"insert": 0,
"delete": 1,
"replace": 0,
"spacing": 1,
"total_chars": 3,
"total_words": 1,
"matching_chars": 1,
"matching_words": 1,
"matching_alnum_words": 1,
"word_accuracy": 1,
"char_accuracy": 1 / 3,
},
{},
),
(
"a b", "a",
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 1,
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1, 'char_accuracy': 1/3},
{}
"a b",
"a",
{
"insert": 0,
"delete": 1,
"replace": 0,
"spacing": 1,
"total_chars": 3,
"total_words": 1,
"matching_chars": 1,
"matching_words": 1,
"matching_alnum_words": 1,
"word_accuracy": 1,
"char_accuracy": 1 / 3,
},
{},
),
(
"b ..", "a b ..",
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 1, 'total_chars': 6, 'total_words': 3, 'total_alnum_words': 2,
'matching_chars': 4, 'matching_words': 2, "matching_alnum_words" : 1, 'word_accuracy': 2/3, 'char_accuracy': 4/6},
{}
"b ..",
"a b ..",
{
"insert": 1,
"delete": 0,
"replace": 0,
"spacing": 1,
"total_chars": 6,
"total_words": 3,
"total_alnum_words": 2,
"matching_chars": 4,
"matching_words": 2,
"matching_alnum_words": 1,
"word_accuracy": 2 / 3,
"char_accuracy": 4 / 6,
},
{},
),
(
"taxi cab", "taxl c b",
{'insert': 0, 'delete': 1, 'replace': 1, 'spacing': 1, 'total_chars': 9, 'total_words': 3,
'matching_chars': 6, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 6/9},
{('i','l'):1}
"taxi cab",
"taxl c b",
{
"insert": 0,
"delete": 1,
"replace": 1,
"spacing": 1,
"total_chars": 9,
"total_words": 3,
"matching_chars": 6,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0,
"char_accuracy": 6 / 9,
},
{("i", "l"): 1},
),
(
"taxl c b ri de", "taxi cab ride",
{'insert': 1, 'delete': 0, 'replace': 1, 'spacing': 6, 'total_chars': 18, 'total_words': 3,
'matching_chars': 11, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 11/18},
{('l','i'):1}
"taxl c b ri de",
"taxi cab ride",
{
"insert": 1,
"delete": 0,
"replace": 1,
"spacing": 6,
"total_chars": 18,
"total_words": 3,
"matching_chars": 11,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0,
"char_accuracy": 11 / 18,
},
{("l", "i"): 1},
),
(
"ab", "ac",
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 2, 'total_words': 1,
'matching_chars': 1, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0.0, 'char_accuracy': 0.5},
{}
"ab",
"ac",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 2,
"total_words": 1,
"matching_chars": 1,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0.0,
"char_accuracy": 0.5,
},
{},
),
(
"a", "a",
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 1, 'total_words': 1,
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1.0, 'char_accuracy': 1.0},
{}
"a",
"a",
{
"insert": 0,
"delete": 0,
"replace": 0,
"spacing": 0,
"total_chars": 1,
"total_words": 1,
"matching_chars": 1,
"matching_words": 1,
"matching_alnum_words": 1,
"word_accuracy": 1.0,
"char_accuracy": 1.0,
},
{},
),
(
"New York is big.", "N ewYork kis big.",
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 2, 'total_chars': 17, 'total_words': 4,
'matching_chars': 15, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1/4, 'char_accuracy': 15/17},
{}
"New York is big.",
"N ewYork kis big.",
{
"insert": 1,
"delete": 0,
"replace": 0,
"spacing": 2,
"total_chars": 17,
"total_words": 4,
"matching_chars": 15,
"matching_words": 1,
"matching_alnum_words": 1,
"word_accuracy": 1 / 4,
"char_accuracy": 15 / 17,
},
{},
),
(
"B oston grea t", "Boston is great",
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 3, 'total_chars': 15, 'total_words': 3,
'matching_chars': 12, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0.0, 'char_accuracy': 0.8},
{}
"B oston grea t",
"Boston is great",
{
"insert": 1,
"delete": 0,
"replace": 0,
"spacing": 3,
"total_chars": 15,
"total_words": 3,
"matching_chars": 12,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0.0,
"char_accuracy": 0.8,
},
{},
),
(
"New York is big.", "N ewyork kis big",
{'insert': 1, 'delete': 1, 'replace': 1, 'spacing': 2, 'total_chars': 16, 'total_words': 4,
'matching_chars': 13, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 13/16},
{('Y', 'y'): 1}
"New York is big.",
"N ewyork kis big",
{
"insert": 1,
"delete": 1,
"replace": 1,
"spacing": 2,
"total_chars": 16,
"total_words": 4,
"matching_chars": 13,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0,
"char_accuracy": 13 / 16,
},
{("Y", "y"): 1},
),
(
"dog", "d@g",
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
{('o', '@'): 1}
"dog",
"d@g",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 3,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 2,
"matching_alnum_words": 0,
"matching_words": 0,
"alnum_word_accuracy": 0.0,
"word_accuracy": 0.0,
"char_accuracy": 2 / 3,
},
{("o", "@"): 1},
),
(
"some@one.com", "some@one.com",
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 12, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 12, 'matching_alnum_words': 1, 'matching_words': 1, 'alnum_word_accuracy': 1.0, 'word_accuracy': 1.0, 'char_accuracy': 1.0}, {}
"some@one.com",
"some@one.com",
{
"insert": 0,
"delete": 0,
"replace": 0,
"spacing": 0,
"total_chars": 12,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 12,
"matching_alnum_words": 1,
"matching_words": 1,
"alnum_word_accuracy": 1.0,
"word_accuracy": 1.0,
"char_accuracy": 1.0,
},
{},
),
],
)
])
def test_align_stats(src_string, target, expected_stats, expected_substitutions):
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
gap_char_candidates, input_char_set = _find_gap_char_candidates(
[src_string], [target]
)
gap_char = (
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
)
alignment = align(src_string, target, gap_char=gap_char)
stats, substitution_dict = get_align_stats(alignment, src_string, target, gap_char)
for k in expected_stats:
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
for k in expected_substitutions:
assert substitution_dict[k] == expected_substitutions[k], (substitution_dict, expected_substitutions)
@pytest.mark.parametrize("src_string, target, expected_stats, expected_substitutions, expected_actions", [
(
"ab", "a",
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 0, 'total_chars': 2, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 1, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 1/2},
{},
{1: 'D'}
),
(
"ab", "abb",
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
{},
{2: ('I', 'b')}
),
(
"ab", "ac",
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 2, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 1, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 1/2},
{('b', 'c'): 1},
{1: ('R', 'c')}
),
(
"New York is big.", "N ewyork kis big",
{'insert': 1, 'delete': 1, 'replace': 1, 'spacing': 2, 'total_chars': 16, 'total_words': 4,
'matching_chars': 13, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 13/16},
{('Y', 'y'): 1},
{1: ('I', ' '), 4: 'D', 5: ('R', 'y'), 10: ('I', 'k'), 17: 'D'}
),
(
"dog", "d@g",
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
{('o', '@'): 1},
{1: ('R', '@')}
),
(
"some@one.com", "some@one.com",
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 12, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 12, 'matching_alnum_words': 1, 'matching_words': 1, 'alnum_word_accuracy': 1.0, 'word_accuracy': 1.0, 'char_accuracy': 1.0},
{},
{}
assert substitution_dict[k] == expected_substitutions[k], (
substitution_dict,
expected_substitutions,
)
])
def test_get_stats(src_string, target, expected_stats, expected_substitutions, expected_actions ):
@pytest.mark.parametrize(
"src_string, target, expected_stats, expected_substitutions, expected_actions",
[
(
"ab",
"a",
{
"insert": 0,
"delete": 1,
"replace": 0,
"spacing": 0,
"total_chars": 2,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 1,
"matching_alnum_words": 0,
"matching_words": 0,
"alnum_word_accuracy": 0.0,
"word_accuracy": 0.0,
"char_accuracy": 1 / 2,
},
{},
{1: "D"},
),
(
"ab",
"abb",
{
"insert": 1,
"delete": 0,
"replace": 0,
"spacing": 0,
"total_chars": 3,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 2,
"matching_alnum_words": 0,
"matching_words": 0,
"alnum_word_accuracy": 0.0,
"word_accuracy": 0.0,
"char_accuracy": 2 / 3,
},
{},
{2: ("I", "b")},
),
(
"ab",
"ac",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 2,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 1,
"matching_alnum_words": 0,
"matching_words": 0,
"alnum_word_accuracy": 0.0,
"word_accuracy": 0.0,
"char_accuracy": 1 / 2,
},
{("b", "c"): 1},
{1: ("R", "c")},
),
(
"New York is big.",
"N ewyork kis big",
{
"insert": 1,
"delete": 1,
"replace": 1,
"spacing": 2,
"total_chars": 16,
"total_words": 4,
"matching_chars": 13,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0,
"char_accuracy": 13 / 16,
},
{("Y", "y"): 1},
{1: ("I", " "), 4: "D", 5: ("R", "y"), 10: ("I", "k"), 17: "D"},
),
(
"dog",
"d@g",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 3,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 2,
"matching_alnum_words": 0,
"matching_words": 0,
"alnum_word_accuracy": 0.0,
"word_accuracy": 0.0,
"char_accuracy": 2 / 3,
},
{("o", "@"): 1},
{1: ("R", "@")},
),
(
"some@one.com",
"some@one.com",
{
"insert": 0,
"delete": 0,
"replace": 0,
"spacing": 0,
"total_chars": 12,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 12,
"matching_alnum_words": 1,
"matching_words": 1,
"alnum_word_accuracy": 1.0,
"word_accuracy": 1.0,
"char_accuracy": 1.0,
},
{},
{},
),
],
)
def test_get_stats(
src_string, target, expected_stats, expected_substitutions, expected_actions
):
stats, substitution_dict, actions = get_stats(target, src_string)
for k in expected_stats:
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
for k in expected_substitutions:
assert substitution_dict[k] == expected_substitutions[k], (substitution_dict, expected_substitutions)
assert substitution_dict[k] == expected_substitutions[k], (
substitution_dict,
expected_substitutions,
)
for k in expected_actions:
assert actions[k] == expected_actions[k], (k, actions[k], expected_actions[k])
@pytest.mark.parametrize("src_string, target, expected_actions",
@pytest.mark.parametrize(
"src_string, target, expected_actions",
[
("dog and cat", "g and at",
{0: ('I', 'd'), 1: ('I', 'o'), 8:('I', 'c')}),
])
("dog and cat", "g and at", {0: ("I", "d"), 1: ("I", "o"), 8: ("I", "c")}),
],
)
def test_actions_stats(src_string, target, expected_actions):
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
gap_char_candidates, input_char_set = _find_gap_char_candidates(
[src_string], [target]
)
gap_char = (
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
)
alignment = align(target, src_string, gap_char=gap_char)
_, actions = get_editops_stats(alignment, gap_char)
print(actions)

Просмотреть файл

@ -1,13 +1,15 @@
from genalog.ocr.rest_client import GrokRestClient
from genalog.ocr.blob_client import GrokBlobClient
from genalog.ocr.grok import Grok
import requests
import pytest
import time
import json
import pytest
import requests
from dotenv import load_dotenv
from genalog.ocr.rest_client import GrokRestClient
load_dotenv("tests/ocr/.env")
@pytest.fixture(autouse=True)
def setup_monkeypatch(monkeypatch):
def mock_http(*args, **kwargs):
@ -21,7 +23,6 @@ def setup_monkeypatch(monkeypatch):
class MockedResponse:
def __init__(self, args, kwargs):
self.url = args[0]
self.text = "response"
@ -34,10 +35,7 @@ class MockedResponse:
if "search.windows.net/indexers/" in self.url:
if "status" in self.url:
return {
"lastResult": {"status": "success"},
"status": "finished"
}
return {"lastResult": {"status": "success"}, "status": "finished"}
return {}
if "search.windows.net/indexes/" in self.url:
@ -46,15 +44,30 @@ class MockedResponse:
"value": [
{
"metadata_storage_name": "521c38122f783673598856cd81d91c21_0.png",
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_0.png.json", "r"))
"layoutText": json.load(
open(
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_0.png.json",
"r",
)
),
},
{
"metadata_storage_name": "521c38122f783673598856cd81d91c21_1.png",
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_1.png.json", "r"))
"layoutText": json.load(
open(
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_1.png.json",
"r",
)
),
},
{
"metadata_storage_name": "521c38122f783673598856cd81d91c21_11.png",
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_11.png.json", "r"))
"layoutText": json.load(
open(
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_11.png.json",
"r",
)
),
},
]
}
@ -69,7 +82,6 @@ class MockedResponse:
class TestGROK:
def test_creating_indexing_pipeline(self):
grok_rest_client = GrokRestClient.create_from_env_var()
grok_rest_client.create_indexing_pipeline()
@ -91,4 +103,3 @@ class TestGROK:
indexer_status = grok_rest_client.poll_indexer_till_complete()
assert indexer_status["lastResult"]["status"] == "success"
grok_rest_client.delete_indexer_pipeline()

Просмотреть файл

@ -1,43 +1,59 @@
from genalog.text import alignment
from genalog.text.alignment import MATCH_REWARD, MISMATCH_PENALTY, GAP_PENALTY, GAP_EXT_PENALTY
from tests.cases.text_alignment import PARSE_ALIGNMENT_REGRESSION_TEST_CASES, ALIGNMENT_REGRESSION_TEST_CASES
import warnings
from random import randint
from unittest.mock import MagicMock
import pytest
import warnings
from genalog.text import alignment
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
from tests.cases.text_alignment import PARSE_ALIGNMENT_REGRESSION_TEST_CASES
RANDOM_INT = randint(1, 100)
MOCK_ALIGNMENT_RESULT = [("X", "X", 0, 0, 1)]
# Settup mock for third party library call
@pytest.fixture
def mock_pairwise2_align(monkeypatch):
mock = MagicMock()
def mock_globalcs(*args, **kwargs):
mock.globalcs(*args, **kwargs)
return MOCK_ALIGNMENT_RESULT
# replace target method reference with the mock method
monkeypatch.setattr("Bio.pairwise2.align.globalcs", mock_globalcs)
return mock
def test__align_seg(mock_pairwise2_align):
# setup method input
required_arg = ("A", "B")
optional_arg = (alignment.MATCH_REWARD, alignment.MISMATCH_PENALTY, alignment.GAP_PENALTY, alignment.GAP_EXT_PENALTY)
optional_kwarg = {"gap_char": alignment.GAP_CHAR, "one_alignment_only": alignment.ONE_ALIGNMENT_ONLY}
optional_arg = (
alignment.MATCH_REWARD,
alignment.MISMATCH_PENALTY,
alignment.GAP_PENALTY,
alignment.GAP_EXT_PENALTY,
)
optional_kwarg = {
"gap_char": alignment.GAP_CHAR,
"one_alignment_only": alignment.ONE_ALIGNMENT_ONLY,
}
# test method
result = alignment._align_seg(*required_arg + optional_arg, **optional_kwarg)
# assertion
mock_pairwise2_align.globalcs.assert_called()
assert result == MOCK_ALIGNMENT_RESULT
@pytest.mark.parametrize("alignments, target_num_tokens, raised_exception",
@pytest.mark.parametrize(
"alignments, target_num_tokens, raised_exception",
[
(MOCK_ALIGNMENT_RESULT, 1, None),
(MOCK_ALIGNMENT_RESULT, 2, ValueError),
([("X", "XY", 0, 0, 1)], 1, ValueError)
])
([("X", "XY", 0, 0, 1)], 1, ValueError),
],
)
def test__select_alignment_candidates(alignments, target_num_tokens, raised_exception):
if raised_exception:
with pytest.raises(raised_exception):
@ -46,7 +62,9 @@ def test__select_alignment_candidates(alignments, target_num_tokens, raised_exce
result = alignment._select_alignment_candidates(alignments, target_num_tokens)
assert result == MOCK_ALIGNMENT_RESULT[0]
@pytest.mark.parametrize("s, index, desired_output, raised_exception",
@pytest.mark.parametrize(
"s, index, desired_output, raised_exception",
[
# Test exceptions
("s", 2, None, IndexError),
@ -63,7 +81,8 @@ def test__select_alignment_candidates(alignments, target_num_tokens, raised_exce
("t1 \t \n t2", 3, 7, None),
# Gap char
(" @", 0, 1, None),
])
],
)
def test__find_token_start(s, index, desired_output, raised_exception):
if raised_exception:
with pytest.raises(raised_exception):
@ -72,7 +91,9 @@ def test__find_token_start(s, index, desired_output, raised_exception):
output = alignment._find_token_start(s, index)
assert output == desired_output
@pytest.mark.parametrize("s, index, desired_output, raised_exception",
@pytest.mark.parametrize(
"s, index, desired_output, raised_exception",
[
# Test exceptions
("s", 2, None, IndexError),
@ -90,7 +111,8 @@ def test__find_token_start(s, index, desired_output, raised_exception):
(".", 0, 0, None),
# Gap char
("@@ @", 0, 2, None),
])
],
)
def test__find_token_end(s, index, desired_output, raised_exception):
if raised_exception:
with pytest.raises(raised_exception):
@ -99,7 +121,9 @@ def test__find_token_end(s, index, desired_output, raised_exception):
output = alignment._find_token_end(s, index)
assert output == desired_output
@pytest.mark.parametrize("s, start, desired_output",
@pytest.mark.parametrize(
"s, start, desired_output",
[
("token", 0, (0, 4)),
("token\t", 0, (0, 5)),
@ -111,13 +135,16 @@ def test__find_token_end(s, index, desired_output, raised_exception):
# single character string
("s", 0, (0, 0)),
# punctuation
(" !,.: ", 0, (2,6))
])
(" !,.: ", 0, (2, 6)),
],
)
def test__find_next_token(s, start, desired_output):
output = alignment._find_next_token(s, start)
assert output == desired_output
@pytest.mark.parametrize("token, desired_output",
@pytest.mark.parametrize(
"token, desired_output",
[
# Valid tokens
("\n\t token.!,:\n\t ", True),
@ -134,26 +161,41 @@ def test__find_next_token(s, start, desired_output):
("\t\n@", False),
(alignment.GAP_CHAR * 1, False),
(alignment.GAP_CHAR * RANDOM_INT, False),
(f"\n\t {alignment.GAP_CHAR*RANDOM_INT} \n\t", False)
])
(f"\n\t {alignment.GAP_CHAR*RANDOM_INT} \n\t", False),
],
)
def test__is_valid_token(token, desired_output):
result = alignment._is_valid_token(token)
assert result == desired_output
@pytest.mark.parametrize("aligned_gt, aligned_noise," +
"expected_gt_to_noise_map, expected_noise_to_gt_map",
PARSE_ALIGNMENT_REGRESSION_TEST_CASES)
def test_parse_alignment(aligned_gt, aligned_noise, expected_gt_to_noise_map, expected_noise_to_gt_map):
gt_to_noise_map, noise_to_gt_map = alignment.parse_alignment(aligned_gt, aligned_noise)
@pytest.mark.parametrize(
"aligned_gt, aligned_noise," + "expected_gt_to_noise_map, expected_noise_to_gt_map",
PARSE_ALIGNMENT_REGRESSION_TEST_CASES,
)
def test_parse_alignment(
aligned_gt, aligned_noise, expected_gt_to_noise_map, expected_noise_to_gt_map
):
gt_to_noise_map, noise_to_gt_map = alignment.parse_alignment(
aligned_gt, aligned_noise
)
assert gt_to_noise_map == expected_gt_to_noise_map
assert noise_to_gt_map == expected_noise_to_gt_map
@pytest.mark.parametrize("gt_txt, noisy_txt," +
"expected_aligned_gt, expected_aligned_noise",
ALIGNMENT_REGRESSION_TEST_CASES)
@pytest.mark.parametrize(
"gt_txt, noisy_txt," + "expected_aligned_gt, expected_aligned_noise",
ALIGNMENT_REGRESSION_TEST_CASES,
)
def test_align(gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
aligned_gt, aligned_noise = alignment.align(gt_txt, noisy_txt)
if aligned_gt != expected_aligned_gt:
expected_alignment = alignment._format_alignment(expected_aligned_gt, expected_aligned_noise)
expected_alignment = alignment._format_alignment(
expected_aligned_gt, expected_aligned_noise
)
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
)
)

Просмотреть файл

@ -1,33 +1,40 @@
import glob
import warnings
import pytest
from genalog.text import alignment, anchor, preprocess
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
import glob
import pytest
import warnings
@pytest.mark.parametrize("tokens, case_sensitive, desired_output", [
@pytest.mark.parametrize(
"tokens, case_sensitive, desired_output",
[
([], True, set()),
([], False, set()),
(["a", "A"], True, set(["a", "A"])),
(["a", "A"], False, set()),
(["An", "an", "ab"], True, set(["An", "an", "ab"])),
(["An", "an", "ab"], False, set(["ab"])),
])
],
)
def test_get_unique_words(tokens, case_sensitive, desired_output):
output = anchor.get_unique_words(tokens, case_sensitive=case_sensitive)
assert desired_output == output
@pytest.mark.parametrize("tokens, desired_output", [
([], 0),
([""], 0),
(["a", "b"], 2),
(["abc.", "def!"], 8)
])
@pytest.mark.parametrize(
"tokens, desired_output",
[([], 0), ([""], 0), (["a", "b"], 2), (["abc.", "def!"], 8)],
)
def test_segment_len(tokens, desired_output):
output = anchor.segment_len(tokens)
assert desired_output == output
@pytest.mark.parametrize("unique_words, src_tokens, desired_output, raised_exception", [
@pytest.mark.parametrize(
"unique_words, src_tokens, desired_output, raised_exception",
[
(set(), [], [], None),
(set(), ["a"], [], None),
(set("a"), [], [], ValueError), # unique word not in src_tokens
@ -36,8 +43,14 @@ def test_segment_len(tokens, desired_output):
(set("a"), ["an", "na", " a "], [], ValueError), # substring
(set("a"), ["a"], [("a", 0)], None), # valid input
(set("a"), ["c", "b", "a"], [("a", 2)], None), # multiple src_tokens
(set("ab"), ["c", "b", "a"], [("b", 1), ("a", 2)], None), # multiple matches ordered by index
])
(
set("ab"),
["c", "b", "a"],
[("b", 1), ("a", 2)],
None,
), # multiple matches ordered by index
],
)
def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception):
if raised_exception:
with pytest.raises(raised_exception):
@ -46,7 +59,10 @@ def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception
output = anchor.get_word_map(unique_words, src_tokens)
assert desired_output == output
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_output", [
@pytest.mark.parametrize(
"gt_tokens, ocr_tokens, desired_output",
[
([], [], ([], [])), # empty
([""], [""], ([], [])),
(["a"], ["b"], ([], [])), # no common unique words
@ -54,77 +70,139 @@ def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception
(["a"], ["a", "a"], ([], [])),
(["a"], ["a"], ([("a", 0)], [("a", 0)])), # common unique word exist
(["a"], ["b", "a"], ([("a", 0)], [("a", 1)])),
(["a", "b", "c"], ["a", "b", "c"], # common unique words
([("a", 0), ("b", 1), ("c", 2)], [("a", 0), ("b", 1), ("c", 2)])),
(["a", "b", "c"], ["c", "b", "a"], # common unique words but not in same order
([("b", 1)], [("b", 1)])),
(["b", "a", "c"], ["c", "b", "a"], # LCS has multiple results
([("b", 0), ("a", 1)], [("b", 1), ("a", 2)])),
(["c", "a", "b"], ["c", "b", "a"],
([("c", 0), ("b", 2)], [("c", 0), ("b", 1)])),
(["c", "a", "b"], ["a", "c", "b"], # LCS has multiple results
([("a", 1), ("b", 2)], [("a", 0), ("b", 2)])),
])
(
["a", "b", "c"],
["a", "b", "c"], # common unique words
([("a", 0), ("b", 1), ("c", 2)], [("a", 0), ("b", 1), ("c", 2)]),
),
(
["a", "b", "c"],
["c", "b", "a"], # common unique words but not in same order
([("b", 1)], [("b", 1)]),
),
(
["b", "a", "c"],
["c", "b", "a"], # LCS has multiple results
([("b", 0), ("a", 1)], [("b", 1), ("a", 2)]),
),
(
["c", "a", "b"],
["c", "b", "a"],
([("c", 0), ("b", 2)], [("c", 0), ("b", 1)]),
),
(
["c", "a", "b"],
["a", "c", "b"], # LCS has multiple results
([("a", 1), ("b", 2)], [("a", 0), ("b", 2)]),
),
],
)
def test_get_anchor_map(gt_tokens, ocr_tokens, desired_output):
desired_gt_map, desired_ocr_map = desired_output
gt_map, ocr_map = anchor.get_anchor_map(gt_tokens, ocr_tokens)
assert desired_gt_map == gt_map
assert desired_ocr_map == ocr_map
# max_seg_length does not change the following output
@pytest.mark.parametrize("max_seg_length", [0, 1, 2, 3, 5, 4, 6])
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_output", [
@pytest.mark.parametrize(
"gt_tokens, ocr_tokens, desired_output",
[
([], [], ([], [])), # empty
([""], [""], ([], [])),
(["a"], ["b"], ([], [])), # no anchors
(["a", "a"], ["a"], ([], [])),
(["a"], ["a", "a"], ([], [])),
(["a"], ["a"], ([0], [0])), # anchors exist
("a1 w w w".split(), "a1 w w w".split(), # no anchors in the subsequence [w w w]
([0], [0])),
("a1 w w w a2".split(), "a1 w w w a2".split(),
([0, 4], [0, 4])),
("a1 w w w2 a2".split(), "a1 w w w3 a2".split(),
([0, 4], [0, 4])),
("a1 a2 a3".split(), "a1 a2 a3".split(), # all words are anchors
([0, 1, 2], [0, 1, 2])),
("a1 a2 a3".split(), "A1 A2 A3".split(), # anchor words must be in the same casing
([], [])),
("a1 w w a2".split(), "a1 w W a2".split(), # unique words are case insensitive
([0, 3], [0, 3])),
("a1 w w a2".split(), "A1 w W A2".split(), # unique words are case insensitive, but anchor are case sensitive
([], [])),
])
def test_find_anchor_recur_various_seg_len(max_seg_length, gt_tokens, ocr_tokens, desired_output):
(
"a1 w w w".split(),
"a1 w w w".split(), # no anchors in the subsequence [w w w]
([0], [0]),
),
("a1 w w w a2".split(), "a1 w w w a2".split(), ([0, 4], [0, 4])),
("a1 w w w2 a2".split(), "a1 w w w3 a2".split(), ([0, 4], [0, 4])),
(
"a1 a2 a3".split(),
"a1 a2 a3".split(), # all words are anchors
([0, 1, 2], [0, 1, 2]),
),
(
"a1 a2 a3".split(),
"A1 A2 A3".split(), # anchor words must be in the same casing
([], []),
),
(
"a1 w w a2".split(),
"a1 w W a2".split(), # unique words are case insensitive
([0, 3], [0, 3]),
),
(
"a1 w w a2".split(),
"A1 w W A2".split(), # unique words are case insensitive, but anchor are case sensitive
([], []),
),
],
)
def test_find_anchor_recur_various_seg_len(
max_seg_length, gt_tokens, ocr_tokens, desired_output
):
desired_gt_anchors, desired_ocr_anchors = desired_output
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
)
assert desired_gt_anchors == gt_anchors
assert desired_ocr_anchors == ocr_anchors
# Test the recursion bahavior
@pytest.mark.parametrize("gt_tokens, ocr_tokens, max_seg_length, desired_output", [
("a1 w_ w_ a3".split(), "a1 w_ w_ a3".split(), 6,
([0, 3], [0, 3])),
("a1 w_ w_ a2 a3 a2".split(), "a1 w_ w_ a2 a3 a2".split(), 4, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
([0, 3, 4], [0, 3, 4])),
("a1 w_ w_ a2 a3 a2".split(), "a1 w_ w_ a2 a3 a2".split(), 2, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
([0, 2, 3, 4, 5], [0, 2, 3, 4, 5])),
("a1 w_ w_ a2 w_ w_ a3".split(), "a1 w_ a2 w_ a3".split(), 2, # missing ocr token
([0, 3, 6], [0, 2, 4])),
("a1 w_ w_ a2 w_ w_ a3".split(), "a1 w_ a2 W_ A3".split(), 2, # changing cases
([0, 3], [0, 2])),
])
def test_find_anchor_recur_fixed_seg_len(gt_tokens, ocr_tokens, max_seg_length, desired_output):
@pytest.mark.parametrize(
"gt_tokens, ocr_tokens, max_seg_length, desired_output",
[
("a1 w_ w_ a3".split(), "a1 w_ w_ a3".split(), 6, ([0, 3], [0, 3])),
(
"a1 w_ w_ a2 a3 a2".split(),
"a1 w_ w_ a2 a3 a2".split(),
4, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
([0, 3, 4], [0, 3, 4]),
),
(
"a1 w_ w_ a2 a3 a2".split(),
"a1 w_ w_ a2 a3 a2".split(),
2, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
([0, 2, 3, 4, 5], [0, 2, 3, 4, 5]),
),
(
"a1 w_ w_ a2 w_ w_ a3".split(),
"a1 w_ a2 w_ a3".split(),
2, # missing ocr token
([0, 3, 6], [0, 2, 4]),
),
(
"a1 w_ w_ a2 w_ w_ a3".split(),
"a1 w_ a2 W_ A3".split(),
2, # changing cases
([0, 3], [0, 2]),
),
],
)
def test_find_anchor_recur_fixed_seg_len(
gt_tokens, ocr_tokens, max_seg_length, desired_output
):
desired_gt_anchors, desired_ocr_anchors = desired_output
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
)
assert desired_gt_anchors == gt_anchors
assert desired_ocr_anchors == ocr_anchors
@pytest.mark.parametrize("gt_file, ocr_file",
@pytest.mark.parametrize(
"gt_file, ocr_file",
zip(
sorted(glob.glob("tests/text/data/gt_1.txt")),
sorted(glob.glob("tests/text/data/ocr_1.txt"))
)
sorted(glob.glob("tests/text/data/ocr_1.txt")),
),
)
@pytest.mark.parametrize("max_seg_length", [75])
def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
@ -132,15 +210,27 @@ def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
ocr_text = open(ocr_file, "r").read()
gt_tokens = preprocess.tokenize(gt_text)
ocr_tokens = preprocess.tokenize(ocr_text)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
)
for gt_anchor, ocr_anchor in zip(gt_anchors, ocr_anchors):
# Ensure that each anchor word is the same word in both text
assert gt_tokens[gt_anchor] == ocr_tokens[ocr_anchor]
@pytest.mark.parametrize("gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", ALIGNMENT_REGRESSION_TEST_CASES)
@pytest.mark.parametrize(
"gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise",
ALIGNMENT_REGRESSION_TEST_CASES,
)
def test_align_w_anchor(gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
aligned_gt, aligned_noise = anchor.align_w_anchor(gt_txt, noisy_txt)
if aligned_gt != expected_aligned_gt:
expected_alignment = alignment._format_alignment(expected_aligned_gt, expected_aligned_noise)
expected_alignment = alignment._format_alignment(
expected_aligned_gt, expected_aligned_noise
)
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
)
)

Просмотреть файл

@ -1,34 +1,75 @@
from genalog.text import conll_format
import itertools
import warnings
from unittest.mock import patch
import itertools
import pytest
import warnings
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception", [
from genalog.text import conll_format
@pytest.mark.parametrize(
"clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception",
[
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], None),
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], [], ValueError), # No alignment
(["w1", "w3"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], ValueError), # Unequal tokens
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w3"]], ["w", "w"], ValueError), # Unequal tokens
(["w1", "w3"], ["l1", "l2"], [["w1"]],["w", "w"], ValueError), # Unequal length
(["w1"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], ValueError), # Unequal length
])
def test_propagate_labels_sentences_error(clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception):
(
["w1", "w3"],
["l1", "l2"],
[["w1"], ["w2"]],
["w", "w"],
ValueError,
), # Unequal tokens
(
["w1", "w2"],
["l1", "l2"],
[["w1"], ["w3"]],
["w", "w"],
ValueError,
), # Unequal tokens
(
["w1", "w3"],
["l1", "l2"],
[["w1"]],
["w", "w"],
ValueError,
), # Unequal length
(
["w1"],
["l1", "l2"],
[["w1"], ["w2"]],
["w", "w"],
ValueError,
), # Unequal length
],
)
def test_propagate_labels_sentences_error(
clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception
):
if raised_exception:
with pytest.raises(raised_exception):
conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
conll_format.propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
else:
conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
conll_format.propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels", [
@pytest.mark.parametrize(
"clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels",
[
(
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
"a1 b1 a2 b2".split(),
"l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]], # clean sentences
["a1", "b1", "a2", "b2"], # ocr token
[["a1", "b1"], ["a2","b2"]], [["l1", "l2"], ["l3", "l4"]] # desired output
[["a1", "b1"], ["a2", "b2"]],
[["l1", "l2"], ["l3", "l4"]], # desired output
),
(
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
"a1 b1 a2 b2".split(),
"l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]], # clean sentences
["a1", "b1"], # Missing sentence 2
# Ideally we would expect [["a1", "b1"], []]
@ -39,119 +80,207 @@ def test_propagate_labels_sentences_error(clean_tokens, clean_labels, clean_sent
# when all tokens "b1" "a2" "b2" are aligned to "b1@@@@@@"
# NOTE: this is a improper behavior but the best
# solution to this corner case by preserving the number of OCR tokens.
[["a1"], ["b1"]], [["l1"], ["l2"]]
[["a1"], ["b1"]],
[["l1"], ["l2"]],
),
(
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
"a1 b1 a2 b2".split(),
"l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]],
["a", "a2", "b2"], # ocr token (missing b1 token at sentence boundary)
[["a"], ["a2", "b2"]], [["l1"], ["l3", "l4"]]
[["a"], ["a2", "b2"]],
[["l1"], ["l3", "l4"]],
),
(
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
"a1 b1 a2 b2".split(),
"l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]],
["a1", "b1", "a2"], # ocr token (missing b2 token at sentence boundary)
[["a1", "b1"], ["a2"]], [["l1", "l2"], ["l3"]]
[["a1", "b1"], ["a2"]],
[["l1", "l2"], ["l3"]],
),
(
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
"a1 b1 a2 b2".split(),
"l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]],
["b1", "a2", "b2"], # ocr token (missing a1 token at sentence start)
[["b1", "a2"], ["b2"]], [["l2", "l3"], ["l4"]]
[["b1", "a2"], ["b2"]],
[["l2", "l3"], ["l4"]],
),
(
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
"a1 b1 c1 a2 b2".split(),
"l1 l2 l3 l4 l5".split(),
[["a1"], ["b1", "c1", "a2"], ["b2"]],
["a1", "b1", "a2", "b2"], # ocr token (missing c1 token at middle of sentence)
[["a1"], ["b1", "a2"], ["b2"]], [["l1"], ["l2", "l4"], ["l5"]]
[
"a1",
"b1",
"a2",
"b2",
], # ocr token (missing c1 token at middle of sentence)
[["a1"], ["b1", "a2"], ["b2"]],
[["l1"], ["l2", "l4"], ["l5"]],
),
(
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
"a1 b1 c1 a2 b2".split(),
"l1 l2 l3 l4 l5".split(),
[["a1", "b1"], ["c1", "a2", "b2"]],
["a1", "b1", "b2"], # ocr token (missing c1 a2 tokens)
[["a1"], ["b1", "b2"]], [["l1"], ["l2", "l5"]]
[["a1"], ["b1", "b2"]],
[["l1"], ["l2", "l5"]],
),
(
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
"a1 b1 c1 a2 b2".split(),
"l1 l2 l3 l4 l5".split(),
[["a1"], ["b1", "c1", "a2"], ["b2"]],
["a1", "c1", "a2", "b2"], # ocr token (missing b1 token at sentence start)
[[], ["a1", "c1", "a2"], ["b2"]], [[], ["l1", "l3", "l4"], ["l5"]]
[[], ["a1", "c1", "a2"], ["b2"]],
[[], ["l1", "l3", "l4"], ["l5"]],
),
(
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
"a1 b1 c1 a2 b2".split(),
"l1 l2 l3 l4 l5".split(),
[["a1", "b1", "c1"], ["a2", "b2"]],
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
[["a1"], [ "b1", "b2"]], [["l1"], ["l2", "l5"]]
[["a1"], ["b1", "b2"]],
[["l1"], ["l2", "l5"]],
),
(
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
"a1 b1 c1 a2 b2".split(),
"l1 l2 l3 l4 l5".split(),
[["a1", "b1", "c1"], ["a2", "b2"]],
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
[["a1"], [ "b1", "b2"]], [["l1"], ["l2", "l5"]]
[["a1"], ["b1", "b2"]],
[["l1"], ["l2", "l5"]],
),
])
def test_propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels):
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
],
)
def test_propagate_labels_sentences(
clean_tokens,
clean_labels,
clean_sentences,
ocr_tokens,
desired_sentences,
desired_labels,
):
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
assert len(ocr_text_sentences) == len(clean_sentences)
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
assert len(ocr_sentences_flatten) == len(
ocr_tokens
) # ensure aligned ocr tokens == ocr tokens
if desired_sentences != ocr_text_sentences:
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"
)
)
if desired_labels != ocr_labels_sentences:
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"
)
)
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens," +
"mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels", [
@pytest.mark.parametrize(
"clean_tokens, clean_labels, clean_sentences, ocr_tokens,"
+ "mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels",
[
(
"a b c d".split(), "l1 l2 l3 l4".split(),
"a b c d".split(),
"l1 l2 l3 l4".split(),
[["a", "b"], ["c", "d"]],
["a", "b"], # Sentence is empty
[[0], [1], [], []],
[[0], [1]],
[["a", "b"], []],
[["l1", "l2"], []]
[["l1", "l2"], []],
),
(
"a b c d".split(), "l1 l2 l3 l4".split(),
[["a", "b",], ["c", "d"]],
"a b c d".split(),
"l1 l2 l3 l4".split(),
[
[
"a",
"b",
],
["c", "d"],
],
["a", "b", "d"], # Missing sentence start
[[0], [1], [], [2]],
[[0], [1], [3]],
[["a", "b"], ["d"]],
[["l1", "l2"], ["l4"]]
[["l1", "l2"], ["l4"]],
),
(
"a b c d".split(), "l1 l2 l3 l4".split(),
[["a", "b",], ["c", "d"]],
"a b c d".split(),
"l1 l2 l3 l4".split(),
[
[
"a",
"b",
],
["c", "d"],
],
["a", "c", "d"], # Missing sentence end
[[0], [], [1], [2]],
[[0], [2], [3]],
[["a"], ["c", "d"]],
[["l1"], ["l3", "l4"]]
[["l1"], ["l3", "l4"]],
),
])
def test_propagate_labels_sentences_text_alignment_corner_cases(clean_tokens, clean_labels, clean_sentences, ocr_tokens,
mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels):
],
)
def test_propagate_labels_sentences_text_alignment_corner_cases(
clean_tokens,
clean_labels,
clean_sentences,
ocr_tokens,
mock_gt_to_ocr_mapping,
mock_ocr_to_gt_mapping,
desired_sentences,
desired_labels,
):
with patch("genalog.text.alignment.parse_alignment") as mock_alignment:
mock_alignment.return_value = (mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping)
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
(
ocr_text_sentences,
ocr_labels_sentences,
) = conll_format.propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
assert len(ocr_text_sentences) == len(clean_sentences)
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
assert len(ocr_sentences_flatten) == len(
ocr_tokens
) # ensure aligned ocr tokens == ocr tokens
if desired_sentences != ocr_text_sentences:
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"
)
)
if desired_labels != ocr_labels_sentences:
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"
)
)
@pytest.mark.parametrize("s, desired_output", [
@pytest.mark.parametrize(
"s, desired_output",
[
("", []),
("\n\n", []),
("a1\tb1\na2\tb2", [["a1", "a2"]]),
("a1\tb1\n\na2\tb2", [["a1"], ["a2"]]),
("\n\n\na1\tb1\n\na2\tb2\n\n\n", [["a1"], ["a2"]]),
])
],
)
def test_get_sentences_from_iob_format(s, desired_output):
output = conll_format.get_sentences_from_iob_format(s.splitlines(True))
assert desired_output == output

Просмотреть файл

@ -1,20 +1,27 @@
from genalog.text.lcs import LCS
import pytest
@pytest.fixture(params=[
from genalog.text.lcs import LCS
@pytest.fixture(
params=[
("", ""), # empty
("abcde", "ace"), # naive case
])
]
)
def lcs(request):
str1, str2 = request.param
return LCS(str1, str2)
def test_lcs_init(lcs):
assert lcs._lcs_len is not None
assert lcs._lcs is not None
@pytest.mark.parametrize("str1, str2, expected_len, expected_lcs", [
@pytest.mark.parametrize(
"str1, str2, expected_len, expected_lcs",
[
("", "", 0, ""), # empty
("abc", "abc", 3, "abc"),
("abcde", "ace", 3, "ace"), # naive case
@ -22,16 +29,26 @@ def test_lcs_init(lcs):
("abc", "cba", 1, "c"), # multiple cases
("abcdgh", "aedfhr", 3, "adh"),
("abc.!\t\nd", "dxab", 2, "ab"), # with punctuations
("New York @", "New @ York", len("New York"), "New York"), # with space-separated, tokens
(
"New York @",
"New @ York",
len("New York"),
"New York",
), # with space-separated, tokens
("Is A Big City", "A Big City Is", len("A Big City"), "A Big City"),
("Is A Big City", "City Big Is A", len(" Big "), " Big "), # reversed order
# mixed order with similar tokens
("Is A Big City IS", "IS Big A City Is", len("I Big City I"), "I Big City I"),
# casing
("Is A Big City IS a", "IS a Big City Is A", len("I Big City I "), "I Big City I "),
])
(
"Is A Big City IS a",
"IS a Big City Is A",
len("I Big City I "),
"I Big City I ",
),
],
)
def test_lcs_e2e(str1, str2, expected_len, expected_lcs):
lcs = LCS(str1, str2)
assert expected_lcs == lcs.get_str()
assert expected_len == lcs.get_len()

Просмотреть файл

@ -1,171 +1,295 @@
import pytest
from genalog.text import ner_label
from genalog.text import alignment
from tests.cases.label_propagation import LABEL_PROPAGATION_REGRESSION_TEST_CASES
import pytest
import string
@pytest.mark.parametrize("label, desired_output", [
@pytest.mark.parametrize(
"label, desired_output",
[
# Positive Cases
("B-org", True), (" B-org ", True), #whitespae tolerant
("B-org", True),
(" B-org ", True), # whitespae tolerant
("\tB-ORG\n", True),
# Negative Cases
("I-ORG", False), ("O", False), ("other-B-label", False),
])
("I-ORG", False),
("O", False),
("other-B-label", False),
],
)
def test__is_begin_label(label, desired_output):
output = ner_label._is_begin_label(label)
assert output == desired_output
@pytest.mark.parametrize("label, desired_output", [
@pytest.mark.parametrize(
"label, desired_output",
[
# Positive Cases
("I-ORG", True), (" \t I-ORG ", True),
("I-ORG", True),
(" \t I-ORG ", True),
# Negative Cases
("O", False), ("B-LOC", False),("B-ORG", False),
])
("O", False),
("B-LOC", False),
("B-ORG", False),
],
)
def test__is_inside_label(label, desired_output):
output = ner_label._is_inside_label(label)
assert output == desired_output
@pytest.mark.parametrize("label, desired_output", [
@pytest.mark.parametrize(
"label, desired_output",
[
# Positive Cases
("I-ORG", True), ("B-ORG", True),
("I-ORG", True),
("B-ORG", True),
# Negative Cases
("O", False)
])
("O", False),
],
)
def test__is_multi_token_label(label, desired_output):
output = ner_label._is_multi_token_label(label)
assert output == desired_output
@pytest.mark.parametrize("label, desired_output", [
@pytest.mark.parametrize(
"label, desired_output",
[
# Positive Cases
("I-Place", "B-Place"), (" \t I-place ", "B-place"),
("I-Place", "B-Place"),
(" \t I-place ", "B-place"),
# Negative Cases
("O", "O"), ("B-LOC", "B-LOC"), (" B-ORG ", " B-ORG ")
])
("O", "O"),
("B-LOC", "B-LOC"),
(" B-ORG ", " B-ORG "),
],
)
def test__convert_to_begin_label(label, desired_output):
output = ner_label._convert_to_begin_label(label)
assert output == desired_output
@pytest.mark.parametrize("label, desired_output", [
@pytest.mark.parametrize(
"label, desired_output",
[
# Positive Cases
("B-LOC", "I-LOC"),
(" B-ORG ", "I-ORG"),
# Negative Cases
("", ""), ("O", "O"), ("I-Place", "I-Place"),
(" \t I-place ", " \t I-place ")
])
("", ""),
("O", "O"),
("I-Place", "I-Place"),
(" \t I-place ", " \t I-place "),
],
)
def test__convert_to_inside_label(label, desired_output):
output = ner_label._convert_to_inside_label(label)
assert output == desired_output
@pytest.mark.parametrize("begin_label, inside_label, desired_output", [
@pytest.mark.parametrize(
"begin_label, inside_label, desired_output",
[
# Positive Cases
("", "I-LOC", True),
("B-LOC", "I-ORG", True),
("", "I-ORG", True),
# Negative Cases
("", "", False), ("O", "O", False), ("", "", False),
("", "", False),
("O", "O", False),
("", "", False),
("B-LOC", "O", False),
("B-LOC", "B-ORG", False),
("B-LOC", "I-LOC", False),
(" B-ORG ", "I-ORG", False),
])
],
)
def test__is_missing_begin_label(begin_label, inside_label, desired_output):
output = ner_label._is_missing_begin_label(begin_label, inside_label)
assert output == desired_output
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_input_char_set", [
@pytest.mark.parametrize(
"gt_tokens, ocr_tokens, desired_input_char_set",
[
(["a", "b"], ["c", "d"], set("abcd")),
(["New", "York"], ["is", "big"], set("NewYorkisbig")),
(["word1", "word2"], ["word1", "word2"], set("word12")),
])
],
)
def test__find_gap_char_candidates(gt_tokens, ocr_tokens, desired_input_char_set):
gap_char_candidates, input_char_set = ner_label._find_gap_char_candidates(gt_tokens, ocr_tokens)
gap_char_candidates, input_char_set = ner_label._find_gap_char_candidates(
gt_tokens, ocr_tokens
)
assert input_char_set == desired_input_char_set
assert ner_label.GAP_CHAR_SET.difference(input_char_set) == gap_char_candidates
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, raised_exception",
@pytest.mark.parametrize(
"gt_labels, gt_tokens, ocr_tokens, raised_exception",
[
(["o"], ["New York"], ["NewYork"], ValueError), # non-atomic gt_token
(["o"], ["NewYork"], ["New York"], ValueError), # non-atomic ocr_token
(["o"], [" @ New"], ["@ @"], ValueError), # non-atomic tokens with GAP_CHAR
(["o", "o"], ["New"], ["New"], ValueError), # num gt_labels != num gt_tokens
(["o"], ["@"], ["New"], ner_label.GapCharError), # invalid token with gap char only (gt_token)
(["o"], ["New"], ["@"], ner_label.GapCharError), # invalid token with gap char only (ocr_token)
(["o", "o"], ["New", "@"], ["New", "@"], ner_label.GapCharError), # invalid token (both)
(["o"], [" \n\t@@"], ["New"], ner_label.GapCharError), # invalid token with gap char and space chars (gt_token)
(["o"], ["New"], [" \n\t@"], ner_label.GapCharError), # invalid token with gap char and space chars (ocr_token)
(
["o"],
["@"],
["New"],
ner_label.GapCharError,
), # invalid token with gap char only (gt_token)
(
["o"],
["New"],
["@"],
ner_label.GapCharError,
), # invalid token with gap char only (ocr_token)
(
["o", "o"],
["New", "@"],
["New", "@"],
ner_label.GapCharError,
), # invalid token (both)
(
["o"],
[" \n\t@@"],
["New"],
ner_label.GapCharError,
), # invalid token with gap char and space chars (gt_token)
(
["o"],
["New"],
[" \n\t@"],
ner_label.GapCharError,
), # invalid token with gap char and space chars (ocr_token)
(["o"], [""], ["New"], ValueError), # invalid token: empty string (gt_token)
(["o"], ["New"], [""], ValueError), # invalid token: empty string (ocr_token)
(["o"], [" \n\t"], ["New"], ValueError), # invalid token: space characters only (gt_token)
(["o"], ["New"], [" \n\t"], ValueError), # invalid token: space characters only (ocr_token)
(
["o"],
[" \n\t"],
["New"],
ValueError,
), # invalid token: space characters only (gt_token)
(
["o"],
["New"],
[" \n\t"],
ValueError,
), # invalid token: space characters only (ocr_token)
(["o"], ["New"], ["New"], None), # positive case
(["o"], ["New@"], ["New"], None), # positive case with gap char
(["o"], ["New"], ["@@New"], None), # positive case with gap char
])
def test__propagate_label_to_ocr_error(gt_labels, gt_tokens, ocr_tokens, raised_exception):
],
)
def test__propagate_label_to_ocr_error(
gt_labels, gt_tokens, ocr_tokens, raised_exception
):
if raised_exception:
with pytest.raises(raised_exception):
ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char='@')
ner_label._propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens, gap_char="@"
)
else:
ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char='@')
ner_label._propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens, gap_char="@"
)
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
@pytest.mark.parametrize(
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
)
def test__propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
gap_char_candidates, _ = ner_label._find_gap_char_candidates(gt_tokens, ocr_tokens)
# run regression test for each GAP_CHAR candidate to make sure
# label propagate is function correctly
for gap_char in gap_char_candidates:
ocr_labels, _, _, _ = ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char)
ocr_labels, _, _, _ = ner_label._propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char
)
assert ocr_labels == desired_ocr_labels
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, raised_exception", [
@pytest.mark.parametrize(
"gt_labels, gt_tokens, ocr_tokens, raised_exception",
[
(["o"], ["New"], ["New"], None), # positive case
(["o"], ["New@"], ["New"], None), # positive case with gap char
(["o"], ["New"], ["@@New"], None), # positive case with gap char
(["o"], list(ner_label.GAP_CHAR_SET), [""], ner_label.GapCharError), # input char set == GAP_CHAR_SET
(["o"], [""], list(ner_label.GAP_CHAR_SET), ner_label.GapCharError), # input char set == GAP_CHAR_SET
(
["o"],
list(ner_label.GAP_CHAR_SET),
[""],
ner_label.GapCharError,
), # input char set == GAP_CHAR_SET
(
["o"],
[""],
list(ner_label.GAP_CHAR_SET),
ner_label.GapCharError,
), # input char set == GAP_CHAR_SET
# all possible gap chars set split between ocr and gt tokens
(["o"], list(ner_label.GAP_CHAR_SET)[:10], list(ner_label.GAP_CHAR_SET)[10:], ner_label.GapCharError),
])
def test_propagate_label_to_ocr_error(gt_labels, gt_tokens, ocr_tokens, raised_exception):
(
["o"],
list(ner_label.GAP_CHAR_SET)[:10],
list(ner_label.GAP_CHAR_SET)[10:],
ner_label.GapCharError,
),
],
)
def test_propagate_label_to_ocr_error(
gt_labels, gt_tokens, ocr_tokens, raised_exception
):
if raised_exception:
with pytest.raises(raised_exception):
ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
else:
ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
@pytest.mark.parametrize(
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
)
def test_propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
ocr_labels, _, _, _ = ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
ocr_labels, _, _, _ = ner_label.propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens
)
assert ocr_labels == desired_ocr_labels
@pytest.mark.parametrize("tokens, labels, label_top, desired_output",
@pytest.mark.parametrize(
"tokens, labels, label_top, desired_output",
[
(
["New", "York", "is", "big"],
["B-place", "I-place", "o", "o"],
True,
"B-place I-place o o \n" +
"New York is big \n"
"B-place I-place o o \n" + "New York is big \n",
),
(
["New", "York", "is", "big"],
["B-place", "I-place", "o", "o"],
False,
"New York is big \n" +
"B-place I-place o o \n"
"New York is big \n" + "B-place I-place o o \n",
),
],
)
])
def test_format_label(tokens, labels, label_top, desired_output):
output = ner_label.format_labels(tokens, labels, label_top=label_top)
assert output == desired_output
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
@pytest.mark.parametrize(
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
)
def test_format_gt_ocr_w_labels(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
ocr_labels, aligned_gt, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
ner_label.format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, aligned_gt, aligned_ocr)
ocr_labels, aligned_gt, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens
)
ner_label.format_label_propagation(
gt_tokens, gt_labels, ocr_tokens, ocr_labels, aligned_gt, aligned_ocr
)

Просмотреть файл

@ -1,8 +1,11 @@
from genalog.text import preprocess
from genalog.text.alignment import GAP_CHAR
import pytest
@pytest.mark.parametrize("token, replacement, desired_output",
from genalog.text import preprocess
from genalog.text.alignment import GAP_CHAR
@pytest.mark.parametrize(
"token, replacement, desired_output",
[
("", "_", ""), # Do nothing to empty string
(" ", "_", " "), # Do nothing to whitespaces
@ -11,49 +14,37 @@ import pytest
("a s\nc\tii", "_", "a s\nc\tii"),
("ascii·", "_", "ascii"), # Tokens with non-ASCII values
("·", "_", "_"), # Tokens with non-ASCII values
])
],
)
def test_remove_non_ascii(token, replacement, desired_output):
for code in range(128, 1000): # non-ASCII values
token.replace("·", chr(code))
output = preprocess.remove_non_ascii(token, replacement)
assert output == desired_output
@pytest.mark.parametrize("s, desired_output",
@pytest.mark.parametrize(
"s, desired_output",
[
(
" New \t \n",
["New"]
),
(" New \t \n", ["New"]),
# Mixed in gap char "@"
(
" @ @",
["@", "@"]
),
(
"New York is big",
["New", "York", "is", "big"]
),
(" @ @", ["@", "@"]),
("New York is big", ["New", "York", "is", "big"]),
# Mixed multiple spaces and tabs
(
" New York \t is \t big",
["New", "York", "is", "big"]
),
(" New York \t is \t big", ["New", "York", "is", "big"]),
# Mixed in punctuation
(
"New .York is, big !",
["New", ".York", "is,", "big", "!"]
),
("New .York is, big !", ["New", ".York", "is,", "big", "!"]),
# Mixed in gap char "@"
(
"@N@ew York@@@is,\t big@@@@@",
["@N@ew", "York@@@is,", "big@@@@@"]
("@N@ew York@@@is,\t big@@@@@", ["@N@ew", "York@@@is,", "big@@@@@"]),
],
)
])
def test_tokenize(s, desired_output):
output = preprocess.tokenize(s)
assert output == desired_output
@pytest.mark.parametrize("tokens, desired_output",
@pytest.mark.parametrize(
"tokens, desired_output",
[
(
["New", "York", "is", "big"],
@ -68,34 +59,54 @@ def test_tokenize(s, desired_output):
(
["@N@ew", "York@@@is,", "big@@@@@"],
"@N@ew York@@@is, big@@@@@",
),
],
)
])
def test_join_tokens(tokens, desired_output):
output = preprocess.join_tokens(tokens)
assert output == desired_output
@pytest.mark.parametrize("c, desired_output",
@pytest.mark.parametrize(
"c, desired_output",
[
# Gap char
(GAP_CHAR, False),
# Alphabet char
('a', False), ('A', False),
("a", False),
("A", False),
# Punctuation
('.', False), ('!', False), (',', False), ('-', False),
(".", False),
("!", False),
(",", False),
("-", False),
# Token separators
(' ', True), ('\n', True), ('\t', True)
])
(" ", True),
("\n", True),
("\t", True),
],
)
def test__is_spacing(c, desired_output):
assert desired_output == preprocess._is_spacing(c)
@pytest.mark.parametrize("text, desired_output", [
@pytest.mark.parametrize(
"text, desired_output",
[
("", ""),
("w .", "w ."), ("w !", "w !"), ("w ?", "w ?"),
("w /.", "w /."), ("w /!", "w /!"), ("w /?", "w /?"),
("w .", "w ."),
("w !", "w !"),
("w ?", "w ?"),
("w /.", "w /."),
("w /!", "w /!"),
("w /?", "w /?"),
("w1 , w2 .", "w1 , w2 ."),
("w1 . w2 .", "w1 . \nw2 ."), ("w1 /. w2 /.", "w1 /. \nw2 /."),
("w1 ! w2 .", "w1 ! \nw2 ."), ("w1 /! w2 /.", "w1 /! \nw2 /."),
("w1 ? w2 .", "w1 ? \nw2 ."), ("w1 /? w2 /.", "w1 /? \nw2 /."),
("w1 . w2 .", "w1 . \nw2 ."),
("w1 /. w2 /.", "w1 /. \nw2 /."),
("w1 ! w2 .", "w1 ! \nw2 ."),
("w1 /! w2 /.", "w1 /! \nw2 /."),
("w1 ? w2 .", "w1 ? \nw2 ."),
("w1 /? w2 /.", "w1 /? \nw2 /."),
("U.S. . w2 .", "U.S. . \nw2 ."),
("w1 ??? w2 .", "w1 ??? w2 ."), # not splitting
("w1 !!! w2 .", "w1 !!! w2 ."),
@ -108,17 +119,30 @@ def test__is_spacing(c, desired_output):
("w1 /? /? /? /? w2 /.", "w1 /? /? /? /? \nw2 /."),
("w1 ! ! ! ! w2 .", "w1 ! ! ! ! \nw2 ."),
("w1 /! /! /! /! w2 /.", "w1 /! /! /! /! \nw2 /."),
])
],
)
def test_split_sentences(text, desired_output):
assert desired_output == preprocess.split_sentences(text)
@pytest.mark.parametrize("token, desired_output", [
("", False), (" ", False), ("\n", False), ("\t", False),
@pytest.mark.parametrize(
"token, desired_output",
[
("", False),
(" ", False),
("\n", False),
("\t", False),
(" \n \t", False),
("...", False),
("???", False), ("!!!", False),
(".", True), ("!", True), ("?", True),
("/.", True), ("/!", True), ("/?", True),
])
("???", False),
("!!!", False),
(".", True),
("!", True),
("?", True),
("/.", True),
("/!", True),
("/?", True),
],
)
def test_is_sentence_separator(token, desired_output):
assert desired_output == preprocess.is_sentence_separator(token)

Просмотреть файл

@ -1,11 +1,13 @@
import random
import pytest
import warnings
import pytest
from genalog.text import alignment
from genalog.text.alignment import GAP_CHAR
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
def random_utf8_char(byte_len=1):
if byte_len == 1:
return chr(random.randint(0, 0x007F))
@ -16,25 +18,48 @@ def random_utf8_char(byte_len=1):
elif byte_len == 4:
return chr(random.randint(0xFFFF, 0x10FFFF))
else:
raise ValueError(f"Invalid byte length: {byte_len}." +
"utf-8 does not encode characters with more than 4 bytes in length")
raise ValueError(
f"Invalid byte length: {byte_len}."
+ "utf-8 does not encode characters with more than 4 bytes in length"
)
@pytest.mark.parametrize("num_utf_char_to_test", [100]) # Number of char per byte length
@pytest.mark.parametrize("byte_len", [1,2,3,4]) # UTF does not encode with more than 4 bytes
@pytest.mark.parametrize("gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", ALIGNMENT_REGRESSION_TEST_CASES)
def test_align(num_utf_char_to_test, byte_len, gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
invalid_char = set(gt_txt).union(set(GAP_CHAR)) # character to replace to cannot be in this set
@pytest.mark.parametrize(
"num_utf_char_to_test", [100]
) # Number of char per byte length
@pytest.mark.parametrize(
"byte_len", [1, 2, 3, 4]
) # UTF does not encode with more than 4 bytes
@pytest.mark.parametrize(
"gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise",
ALIGNMENT_REGRESSION_TEST_CASES,
)
def test_align(
num_utf_char_to_test,
byte_len,
gt_txt,
noisy_txt,
expected_aligned_gt,
expected_aligned_noise,
):
invalid_char = set(gt_txt).union(
set(GAP_CHAR)
) # character to replace to cannot be in this set
for _ in range(num_utf_char_to_test):
utf_char = random_utf8_char(byte_len)
while utf_char in invalid_char: # find a utf char not in the input string and not GAP_CHAR
while (
utf_char in invalid_char
): # find a utf char not in the input string and not GAP_CHAR
utf_char = random_utf8_char(byte_len)
char_to_replace = random.choice(list(invalid_char)) if gt_txt else ""
gt_txt_sub = gt_txt.replace(char_to_replace, utf_char)
noisy_txt_sub = noisy_txt.replace(char_to_replace, utf_char)
gt_txt.replace(char_to_replace, utf_char)
noisy_txt.replace(char_to_replace, utf_char)
expected_aligned_gt_sub = expected_aligned_gt.replace(char_to_replace, utf_char)
expected_aligned_noise_sub = expected_aligned_noise.replace(char_to_replace, utf_char)
expected_aligned_noise_sub = expected_aligned_noise.replace(
char_to_replace, utf_char
)
# Run alignment
aligned_gt, aligned_noise = alignment.align(gt_txt, noisy_txt)
@ -42,7 +67,12 @@ def test_align(num_utf_char_to_test, byte_len, gt_txt, noisy_txt, expected_align
aligned_gt = aligned_gt.replace(char_to_replace, utf_char)
aligned_noise = aligned_noise.replace(char_to_replace, utf_char)
if aligned_gt != expected_aligned_gt_sub:
expected_alignment = alignment._format_alignment(expected_aligned_gt_sub, expected_aligned_noise_sub)
expected_alignment = alignment._format_alignment(
expected_aligned_gt_sub, expected_aligned_noise_sub
)
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
)
)

36
tox.ini Normal file
Просмотреть файл

@ -0,0 +1,36 @@
; [tox]
; envlist = flake8, py36 # add other python versions if necessary
; [testenv]
; # Reading additional dependencies to run the test
; # https://tox.readthedocs.io/en/latest/example/basic.html#depending-on-requirements-txt-or-defining-constraints
; deps = -rdev-requirements.txt
; commands =
; pytest
; [testenv:flake8]
; deps = flake8
; skip_install = True
; commands = flake8 .
; # Configurations for running pytest
[pytest]
junit_family=xunit2
testpaths =
tests
addopts =
-rsx --cov=genalog --cov-report=html --cov-report=term-missing --cov-report=xml --junitxml=junit/test-results.xml
[flake8]
# Configs for flake8-import-order, see https://pypi.org/project/flake8-import-order/ for more info.
import-order-style=edited
application-import-names=genalog, tests
# Native flake8 configs
max-line-length = 140
exclude =
build, dist
.env*,.venv* # local virtual environments
.tox
; [mypy]
; ignore_missing_imports = True