Merge pull request #10 from microsoft/laserprec/use_linter

Use flake8 as linter and fix code format issues.
This commit is contained in:
Jianjie Liu 2021-01-25 17:04:13 -05:00 коммит произвёл GitHub
Родитель ad504b4d7b 4c8445b874
Коммит cda2c1e77d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
46 изменённых файлов: 3815 добавлений и 2116 удалений

Просмотреть файл

@ -35,31 +35,20 @@ steps:
displayName: 'Use Python $(python.version)'
- bash: |
python -m venv .venv
displayName: 'Create virtual environment'
- bash: |
if [[ '$(Agent.OS)' == Windows* ]]
then
source .venv/Scripts/activate
else
source .venv/bin/activate
fi
pip install --upgrade pip
pip install setuptools wheel
pip install -r requirements.txt
pip install pytest==5.3.5 pytest-cov==2.8.1
python -m pip install --upgrade pip
python -m pip install setuptools wheel
python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt
workingDirectory: $(Build.SourcesDirectory)
displayName: 'Install dependencies'
- bash: |
if [[ '$(Agent.OS)' == Windows* ]]
then
source .venv/Scripts/activate
else
source .venv/bin/activate
fi
python -m pytest tests --cov=genalog --doctest-modules --junitxml=junit/test-results.xml --cov-report=xml --cov-report=html
python -m flake8
workingDirectory: $(Build.SourcesDirectory)
displayName: 'Run Linter (flake8)'
- bash: |
python -m pytest tests
env:
BLOB_KEY : $(BLOB_KEY)
SEARCH_SERVICE_KEY: $(SEARCH_SERVICE_KEY)
@ -86,12 +75,6 @@ steps:
displayName: 'Publish test coverage'
- bash: |
if [[ '$(Agent.OS)' == Windows* ]]
then
source .venv/Scripts/activate
else
source .venv/bin/activate
fi
python setup.py bdist_wheel --build-number $(Build.BuildNumber) --dist-dir dist
workingDirectory: $(Build.SourcesDirectory)
displayName: 'Building wheel package'

Просмотреть файл

@ -1,25 +1,29 @@
from genalog.degradation import effect
from enum import Enum
import copy
import inspect
from enum import Enum
from genalog.degradation import effect
DEFAULT_METHOD_PARAM_TO_INCLUDE = "src"
class ImageState(Enum):
ORIGINAL_STATE = "ORIGINAL_STATE"
CURRENT_STATE = "CURRENT_STATE"
class Degrader():
class Degrader:
""" An object for applying multiple degradation effects onto an image"""
def __init__(self, effects):
""" Initialize a Degrader object
"""Initialize a Degrader object
Arguments:
effects {list} -- a list of 2-element tuple that defines:
(method_name, method_kwargs)
1. method_name: the name of the degradation method
1. method_name: the name of the degradation method
(method must be defined in 'genalog.degradation.effect')
2. method_kwargs: the keyword arguments of the corresponding method
@ -28,10 +32,10 @@ class Degrader():
[
("blur", {"radius": 3}),
("bleed_through", {"alpha": 0.8),
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}),
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}),
]
The example above will apply degradation effects to the images
The example above will apply degradation effects to the images
in the following sequence:
blur -> bleed_through -> morphological operation (open)
@ -42,14 +46,14 @@ class Degrader():
@staticmethod
def validate_effects(effects):
""" Validate the effects list
"""Validate the effects list
Arguments:
effects {list} -- a list of 2-element tuple that defines:
(method_name, method_kwargs)
1. method_name: the name of the degradation method
1. method_name: the name of the degradation method
(method must be defined in 'genalog.degradation.effect')
2. method_kwargs: the keyword arguments of the corresponding method
@ -58,11 +62,11 @@ class Degrader():
[
("blur", {"radius": "3"}),
("bleed_through", {"alpha":"0.8"}),
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}),
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}),
]
Raises:
ValueError: raise this error when
ValueError: raise this error when
1. method_name not defined in "genalog.degradation.effect"
2. method_kwargs is not a valid keyword arguments in the
corresponding method
@ -73,37 +77,46 @@ class Degrader():
# Try to find corresponding degradation method in the module
method = getattr(effect, method_name)
except AttributeError:
raise ValueError(f"Method '{method_name}' is not defined in 'genalog.degradation.effect'")
raise ValueError(
f"Method '{method_name}' is not defined in 'genalog.degradation.effect'"
)
# Get the method signatures
method_sign = inspect.signature(method)
# Check if method parameters are valid
for param_name in method_kwargs.keys(): # i.e. ["operation", "kernel_shape", ...]
if not param_name in method_sign.parameters:
for (
param_name
) in method_kwargs.keys(): # i.e. ["operation", "kernel_shape", ...]
if param_name not in method_sign.parameters:
method_args = [param for param in method_sign.parameters]
raise ValueError(f"Invalid parameter name '{param_name}' for method 'genalog.degradation.effect.{method_name}()'. Method parameter names are: {method_args}")
raise ValueError(
f"Invalid parameter name '{param_name}' for method 'genalog.degradation.effect.{method_name}()'. " +
f"Method parameter names are: {method_args}"
)
def _add_default_method_param(self):
""" All methods in "genalog.degradation.effect" module have a required
"""All methods in "genalog.degradation.effect" module have a required
method parameter named "src". This parameter will be included if not provided
by the input keyword argument dictionary.
"""
for effect_tuple in self.effects_to_apply:
method_name, method_kwargs = effect_tuple
if DEFAULT_METHOD_PARAM_TO_INCLUDE not in method_kwargs:
method_kwargs[DEFAULT_METHOD_PARAM_TO_INCLUDE] = ImageState.CURRENT_STATE
method_kwargs[
DEFAULT_METHOD_PARAM_TO_INCLUDE
] = ImageState.CURRENT_STATE
def apply_effects(self, src):
""" Apply degradation effects in sequence
"""Apply degradation effects in sequence
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
Returns:
a copy of the source image {numpy.ndarray} after apply the effects
"""
self.original_state = src
self.current_state = src
# Preserve the original effect instructions
# Preserve the original effect instructions
effects_to_apply = copy.deepcopy(self.effects_to_apply)
for effect_tuple in effects_to_apply:
method_name, method_kwargs = effect_tuple
@ -115,7 +128,7 @@ class Degrader():
return self.current_state
def insert_image_state(self, kwargs):
""" Replace the enumeration (ImageState) with the actual image in
"""Replace the enumeration (ImageState) with the actual image in
the keyword argument dictionary
Arguments:
@ -124,12 +137,12 @@ class Degrader():
Ex: {"src": ImageState.ORIGINAL_STATE, "radius": 5}
Returns:
return keyword argument dictionary replaced with
reference to the image
return keyword argument dictionary replaced with
reference to the image
"""
for keyword, argument in kwargs.items():
if argument is ImageState.ORIGINAL_STATE:
kwargs[keyword] = self.original_state.copy()
if argument is ImageState.CURRENT_STATE:
kwargs[keyword] = self.current_state.copy()
return kwargs
return kwargs

Просмотреть файл

@ -1,31 +1,34 @@
import cv2
import numpy as np
from math import floor
import cv2
import numpy as np
def blur(src, radius=5):
""" Wrapper function for cv2.GaussianBlur
"""Wrapper function for cv2.GaussianBlur
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
Keyword Arguments:
radius {int} -- size of the square kernel,
radius {int} -- size of the square kernel,
MUST be an odd integer (default: {5})
Returns:
a copy of the source image {numpy.ndarray} after apply the effect
"""
return cv2.GaussianBlur(src, (radius, radius), cv2.BORDER_DEFAULT)
def overlay_weighted(src, background, alpha, beta, gamma=0):
""" overlay two images together, pixels from each image is weighted as follow
"""overlay two images together, pixels from each image is weighted as follow
dst[i] = alpha*src[i] + beta*background[i] + gamma
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
background {numpy.ndarray} -- background image. Must be in same shape are `src`
alpha {float} -- transparent factor for the foreground
alpha {float} -- transparent factor for the foreground
beta {float} -- transparent factor for the background
Keyword Arguments:
@ -36,8 +39,9 @@ def overlay_weighted(src, background, alpha, beta, gamma=0):
"""
return cv2.addWeighted(src, alpha, background, beta, gamma).astype(np.uint8)
def overlay(src, background):
""" Overlay two images together via bitwise-and:
"""Overlay two images together via bitwise-and:
dst[i] = src[i] & background[i]
@ -50,33 +54,35 @@ def overlay(src, background):
"""
return cv2.bitwise_and(src, background).astype(np.uint8)
def translation(src, offset_x, offset_y):
""" Shift the image in x, y direction
"""Shift the image in x, y direction
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
offset_x {int} -- pixels in the x direction.
offset_x {int} -- pixels in the x direction.
Positive value shifts right and negative shifts right.
offset_y {int} -- pixels in the y direction.
Positive value shifts down and negative shifts up.
Returns:
a copy of the source image {numpy.ndarray} after apply the effect
"""
rows, cols = src.shape
trans_matrix = np.float32([[1,0,offset_x], [0,1,offset_y]])
trans_matrix = np.float32([[1, 0, offset_x], [0, 1, offset_y]])
# size of the output image should be in the form of (width, height)
dst = cv2.warpAffine(src, trans_matrix, (cols, rows), borderValue=255)
return dst.astype(np.uint8)
def bleed_through(src, background=None, alpha=0.8, gamma=0, offset_x=0, offset_y=5):
""" Apply bleed through effect, background is flipped horizontally.
"""Apply bleed through effect, background is flipped horizontally.
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
Keyword Arguments:
background {numpy.ndarray} -- background image. Must be in same
background {numpy.ndarray} -- background image. Must be in same
shape as foreground (default: {None})
alpha {float} -- transparent factor for the foreground (default: {0.8})
gamma {int} -- luminance constant (default: {0})
@ -84,30 +90,31 @@ def bleed_through(src, background=None, alpha=0.8, gamma=0, offset_x=0, offset_y
Positive value shifts right and negative shifts right.
offset_y {int} -- background translation offset (default: {5})
Positive value shifts down and negative shifts up.
Returns:
a copy of the source image {numpy.ndarray} after apply the effect.
Pixel value ranges [0, 255]
"""
if background is None:
background = src.copy()
background = cv2.flip(background, 1) # flipped horizontally
background = cv2.flip(background, 1) # flipped horizontally
background = translation(background, offset_x, offset_y)
beta = 1 - alpha
return overlay_weighted(src, background, alpha, beta, gamma)
def pepper(src, amount=0.05):
""" Randomly sprinkle dark pixels on src image.
"""Randomly sprinkle dark pixels on src image.
Wrapper function for skimage.util.noise.random_noise().
See https://scikit-image.org/docs/stable/api/skimage.util.html#random-noise
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
Keyword Arguments:
amount {float} -- proportion of pixels in range [0, 1] to apply the effect
(default: {0.05})
Returns:
a copy of the source image {numpy.ndarray} after apply the effect.
Pixel value ranges [0, 255] as uint8.
@ -118,18 +125,19 @@ def pepper(src, amount=0.05):
dst[noise < amount] = 0
return dst.astype(np.uint8)
def salt(src, amount=0.3):
""" Randomly sprinkle white pixels on src image.
"""Randomly sprinkle white pixels on src image.
Wrapper function for skimage.util.noise.random_noise().
See https://scikit-image.org/docs/stable/api/skimage.util.html#random-noise
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
Keyword Arguments:
amount {float} -- proportion of pixels in range [0, 1] to apply the effect
(default: {0.05})
Returns:
a copy of the source image {numpy.ndarray} after apply the effect.
Pixel value ranges [0, 255]
@ -140,18 +148,19 @@ def salt(src, amount=0.3):
dst[noise < amount] = 255
return dst.astype(np.uint8)
def salt_then_pepper(src, salt_amount=0.1, pepper_amount=0.05):
""" Randomly add salt then add pepper onto the image.
"""Randomly add salt then add pepper onto the image.
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
salt_amount {float} -- proportion of pixels in range [0, 1] to
salt_amount {float} -- proportion of pixels in range [0, 1] to
apply the salt effect
(default: {0.1})
pepper_amount {float} -- proportion of pixels in range [0, 1] to
apply the pepper effect
pepper_amount {float} -- proportion of pixels in range [0, 1] to
apply the pepper effect
(default: {0.05})
Returns:
a copy of the source image {numpy.ndarray} after apply the effect.
Pixel value ranges [0, 255] as uint8.
@ -159,18 +168,19 @@ def salt_then_pepper(src, salt_amount=0.1, pepper_amount=0.05):
salted = salt(src, amount=salt_amount)
return pepper(salted, amount=pepper_amount)
def pepper_then_salt(src, pepper_amount=0.05, salt_amount=0.1):
""" Randomly add pepper then salt onto the image.
"""Randomly add pepper then salt onto the image.
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
pepper_amount {float} -- proportion of pixels in range [0, 1] to
pepper_amount {float} -- proportion of pixels in range [0, 1] to
apply the pepper effect.
(default: {0.05})
salt_amount {float} -- proportion of pixels in range [0, 1] to
(default: {0.05})
salt_amount {float} -- proportion of pixels in range [0, 1] to
apply the salt effect.
(default: {0.1})
Returns:
a copy of the source image {numpy.ndarray} after apply the effect.
Pixel value ranges [0, 255] as uint8.
@ -178,16 +188,17 @@ def pepper_then_salt(src, pepper_amount=0.05, salt_amount=0.1):
peppered = pepper(src, amount=pepper_amount)
return salt(peppered, amount=salt_amount)
def create_2D_kernel(kernel_shape, kernel_type="ones"):
""" Create 2D kernel for morphological operations.
"""Create 2D kernel for morphological operations.
Arguments:
kernel_shape {tuple} -- shape of the kernel (rows, cols)
Keyword Arguments:
kernel_type {str} -- type of kernel (default: {"ones"}).
kernel_type {str} -- type of kernel (default: {"ones"}).
All supported kernel types are below:
"ones": kernel is filled with all 1s in shape (rows, cols)
[[1,1,1],
[1,1,1],
@ -215,9 +226,9 @@ def create_2D_kernel(kernel_shape, kernel_type="ones"):
[1, 1, 1, 1, 1],
[0, 0, 1, 0, 0]]
Raises:
ValueError: if kernel is not a 2-element tuple or
ValueError: if kernel is not a 2-element tuple or
kernel_type is not one of the supported values
Returns:
a 2D array {numpy.ndarray} of shape `kernel_shape`.
"""
@ -233,37 +244,47 @@ def create_2D_kernel(kernel_shape, kernel_type="ones"):
elif kernel_type == "x":
diagonal = np.eye(kernel_rows, kernel_cols)
kernel = np.add(diagonal, np.fliplr(diagonal))
kernel[kernel>1] = 1
kernel[kernel > 1] = 1
elif kernel_type == "plus":
kernel = np.zeros(kernel_shape)
center_col = floor(kernel.shape[0]/2)
center_row = floor(kernel.shape[1]/2)
center_col = floor(kernel.shape[0] / 2)
center_row = floor(kernel.shape[1] / 2)
kernel[:, center_col] = 1
kernel[center_row, :] = 1
elif kernel_type == "ellipse":
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_shape)
else:
valid_kernel_types = {"ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"}
raise ValueError(f"Invalid kernel_type: {kernel_type}. Valid types are {valid_kernel_types}")
valid_kernel_types = {
"ones",
"upper_triangle",
"lower_triangle",
"x",
"plus",
"ellipse",
}
raise ValueError(
f"Invalid kernel_type: {kernel_type}. Valid types are {valid_kernel_types}"
)
return kernel.astype(np.uint8)
def morphology(src, operation="open", kernel_shape=(3,3), kernel_type="ones"):
""" Dynamic calls different morphological operations
def morphology(src, operation="open", kernel_shape=(3, 3), kernel_type="ones"):
"""Dynamic calls different morphological operations
("open", "close", "dilate" and "erode") with the given parameters
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
Keyword Arguments:
operation {str} -- name of a morphological operation:
("open", "close", "dilate", "erode")
(default: {"open"})
kernel_shape {tuple} -- shape of the kernel (rows, cols)
kernel_shape {tuple} -- shape of the kernel (rows, cols)
(default: {(3,3)})
kernel_type {str} -- type of kernel (default: {"ones"})
Supported kernel_types are:
["ones", "upper_triangle", "lower_triangle",
["ones", "upper_triangle", "lower_triangle",
"x", "plus", "ellipse"]
Returns:
@ -280,7 +301,10 @@ def morphology(src, operation="open", kernel_shape=(3,3), kernel_type="ones"):
return erode(src, kernel)
else:
valid_operations = ["open", "close", "dilate", "erode"]
raise ValueError(f"Invalid morphology operation '{operation}'. Valid morphological operations are {valid_operations}")
raise ValueError(
f"Invalid morphology operation '{operation}'. Valid morphological operations are {valid_operations}"
)
def open(src, kernel):
""" "open" morphological operation. Like morphological "erosion", it removes
@ -289,59 +313,62 @@ def open(src, kernel):
For more information see:
1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html
2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/open.htm
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect
Returns:
a copy of the source image {numpy.ndarray} after apply the effect
"""
return cv2.morphologyEx(src, cv2.MORPH_OPEN, kernel)
def close(src, kernel):
""" "close" morphological operation. Like morphological "dilation", it grows the
""" "close" morphological operation. Like morphological "dilation", it grows the
boundary of the foreground (white pixels), however, it is less destructive than
dilation of the original boundary shape.
For more information see:
1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html
2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/close.htm
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect
Returns:
a copy of the source image {numpy.ndarray} after apply the effect
"""
return cv2.morphologyEx(src, cv2.MORPH_CLOSE, kernel)
def erode(src, kernel):
""" "erode" morphological operation. Erodes foreground pixels (white pixels).
For more information see:
1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html
2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/erode.htm
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect
Returns:
a copy of the source image {numpy.ndarray} after apply the effect
"""
return cv2.erode(src, kernel)
def dilate(src, kernel):
""" "dilate" morphological operation. Grows foreground pixels (white pixels).
""" "dilate" morphological operation. Grows foreground pixels (white pixels).
For more information see:
1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html
2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/dilate.htm
Arguments:
src {numpy.ndarray} -- source image of shape (rows, cols)
kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect
Returns:
a copy of the source image {numpy.ndarray} after apply the effect
"""

Просмотреть файл

@ -1,4 +1,5 @@
from enum import Enum, auto
from enum import auto, Enum
class ContentType(Enum):
PARAGRAPH = auto()
@ -6,14 +7,17 @@ class ContentType(Enum):
IMAGE = auto()
COMPOSITE = auto()
class Content():
class Content:
def __init__(self):
self.iterable = True
self._content = None
def set_content_type(self, content_type):
if type(content_type) != ContentType:
raise TypeError(f"Invalid content type: {content_type}, valid types are {list(ContentType)}")
raise TypeError(
f"Invalid content type: {content_type}, valid types are {list(ContentType)}"
)
self.content_type = content_type
def validate_content(self):
@ -21,13 +25,14 @@ class Content():
def __str__(self):
return self._content.__str__()
def __iter__(self):
return self._content.__iter__()
def __getitem__(self, key):
return self._content.__getitem__(key)
class Paragraph(Content):
def __init__(self, content):
self.set_content_type(ContentType.PARAGRAPH)
@ -38,16 +43,18 @@ class Paragraph(Content):
if not isinstance(content, str):
raise TypeError(f"Expect a str, but got {type(content)}")
class Title(Content):
def __init__(self, content):
self.set_content_type(ContentType.TITLE)
self.validate_content(content)
self._content = content
def validate_content(self, content):
if not isinstance(content, str):
raise TypeError(f"Expect a str, but got {type(content)}")
class CompositeContent(Content):
def __init__(self, content_list, content_type_list):
self.set_content_type(ContentType.COMPOSITE)
@ -68,7 +75,7 @@ class CompositeContent(Content):
self._content.append(Paragraph(content))
else:
raise NotImplementedError(f"{content_type} is not currently supported")
def insert_content(self, new_content, index):
NotImplementedError
@ -82,5 +89,5 @@ class CompositeContent(Content):
"""get a string transparent of the nested object types"""
transparent_str = "["
for content in self._content:
transparent_str += '"' + content.__str__() + "\", "
transparent_str += '"' + content.__str__() + '", '
return transparent_str + "]"

Просмотреть файл

@ -1,14 +1,12 @@
from jinja2 import PackageLoader, FileSystemLoader
from jinja2 import Environment, select_autoescape
from weasyprint import HTML
from cairocffi import FORMAT_ARGB32
import numpy as np
import itertools
import os
import cv2
import cv2
import numpy as np
from cairocffi import FORMAT_ARGB32
from jinja2 import Environment, select_autoescape
from jinja2 import FileSystemLoader, PackageLoader
from weasyprint import HTML
DEFAULT_DOCUMENT_STYLE = {
"language": "en_US",
@ -27,19 +25,21 @@ DEFAULT_STYLE_COMBINATION = {
"hyphenate": [False],
}
class Document(object):
""" A composite object that represents a document """
def __init__(self, content, template, **styles):
"""Initialize a Document object with source template and content
Arguments:
content {CompositeContent} -- a iterable object whose elements
template {Template} -- a jinja2.Template object
Optional Argument:
styles [dict] -- a kwargs dictionary (context) whose keys and values are
styles [dict] -- a kwargs dictionary (context) whose keys and values are
the template variable and their respective values
Example:
{
"font_family": "Calibri",
@ -48,24 +48,24 @@ class Document(object):
}
Note that this assumes that "font_family", "font_size", "hyphenate" are valid
variables declared in the loaded template. There will be **NO SIDE-EFFECT**
variables declared in the loaded template. There will be **NO SIDE-EFFECT**
providing an variable undefined in the template.
You can also provide these key-value pairs via Python keyword arguments:
Document(content, template, font_family="Calibri, font_size="10px", hyphenate=True)
"""
self.content = content
self.template = template
self.styles = DEFAULT_DOCUMENT_STYLE.copy()
# This is a rendered document ready to be painted on a cairo surface
self._document = None # weasyprint.document.Document object
self._document = None # weasyprint.document.Document object
self.compiled_html = None
# Update the default styles and initialize self._document object
self.update_style(**styles)
def render_html(self):
""" Wrapper function for Jinjia2.Template.render(). Each template
"""Wrapper function for Jinjia2.Template.render(). Each template
declare its template variables. This method assigns each variable to
its respective value and compiles the template.
@ -77,8 +77,8 @@ class Document(object):
return self.template.render(content=self.content, **self.styles)
def render_pdf(self, target=None, zoom=1):
""" Wrapper function for WeasyPrint.Document.write_pdf
"""Wrapper function for WeasyPrint.Document.write_pdf
Arguments:
target -- a filename, file-like object, or None
split_pages {bool} -- true if saving each document page as a separate file.
@ -92,32 +92,37 @@ class Document(object):
return self._document.write_pdf(target=target, zoom=zoom)
def render_png(self, target=None, split_pages=False, resolution=300):
""" Wrapper function for WeasyPrint.Document.write_png
"""Wrapper function for WeasyPrint.Document.write_png
Arguments:
target -- a filename, file-like object, or None
split_pages {bool} -- true if save each document page as a separate file.
resolution {int} -- the output resolution in PNG pixels per CSS inch. At 300 dpi (the default), PNG pixels match the CSS px unit.
resolution {int} -- the output resolution in PNG pixels per CSS inch. At 300 dpi (the default),
PNG pixels match the CSS px unit.
Returns:
The image as bytes if target is not provided or None, otherwise None (the PDF is written to target)
"""
if target != None and split_pages:
if target is not None and split_pages:
# get destination filename and extension
filename, ext = os.path.splitext(target)
for page_num, page in enumerate(self._document.pages):
page_name = filename + f"_pg_{page_num}" + ext
self._document.copy([page]).write_png(target=page_name, resolution=resolution)
self._document.copy([page]).write_png(
target=page_name, resolution=resolution
)
return None
elif target == None:
elif target is None:
# return image bytes string if no target is specified
png_bytes, png_width, png_height = self._document.write_png(target=target, resolution=resolution)
png_bytes, png_width, png_height = self._document.write_png(
target=target, resolution=resolution
)
return png_bytes
else:
return self._document.write_png(target=target, resolution=resolution)
def render_array(self, resolution=300, channel="GRAYSCALE"):
""" Render document as a numpy.ndarray.
"""Render document as a numpy.ndarray.
Keyword Arguments:
resolution {int} -- in units dpi (default: {300})
@ -125,24 +130,29 @@ class Document(object):
available values are: "GRAYSCALE", "RGB", "RGBA", "BGRA", "BGR"
Note that "RGB" is 3-channel, "RGBA" is 4-channel and "GRAYSCALE" is single channel
Returns:
A numpy.ndarray representation of the document.
"""
# Method below returns a cairocffi.ImageSurface object
# https://cairocffi.readthedocs.io/en/latest/api.html#cairocffi.ImageSurface
surface, width, height = self._document.write_image_surface(resolution=resolution)
surface, width, height = self._document.write_image_surface(
resolution=resolution
)
img_format = surface.get_format()
# This is BGRA channel in little endian (reverse)
if img_format != FORMAT_ARGB32:
raise RuntimeError(f"Expect surface format to be 'cairocffi.FORMAT_ARGB32', but got {img_format}. Please check the underlining implementation of 'weasyprint.document.Document.write_image_surface()'")
raise RuntimeError(
f"Expect surface format to be 'cairocffi.FORMAT_ARGB32', but got {img_format}." +
"Please check the underlining implementation of 'weasyprint.document.Document.write_image_surface()'"
)
img_buffer = surface.get_data()
# Returns image array in "BGRA" channel
img_array = np.ndarray(shape=(height, width, 4),
dtype=np.uint8,
buffer=img_buffer)
img_array = np.ndarray(
shape=(height, width, 4), dtype=np.uint8, buffer=img_buffer
)
if channel == "GRAYSCALE":
return cv2.cvtColor(img_array, cv2.COLOR_BGRA2GRAY)
elif channel == "RGBA":
@ -155,13 +165,15 @@ class Document(object):
return cv2.cvtColor(img_array, cv2.COLOR_BGRA2BGR)
else:
valid_channels = ["GRAYSCALE", "RGB", "RGBA", "BGR", "BGRA"]
raise ValueError(f"Invalid channel code {channel}. Valid values are: {valid_channels}.")
raise ValueError(
f"Invalid channel code {channel}. Valid values are: {valid_channels}."
)
def update_style(self, **style):
""" Update template variables that controls the document style and re-compile the document to reflect the style change.
"""Update template variables that controls the document style and re-compile the document to reflect the style change.
Optional Arguments:
style {dict} -- a kwargs dictionary whose keys and values are
style {dict} -- a kwargs dictionary whose keys and values are
the template variable and their respective values
Example:
@ -175,44 +187,49 @@ class Document(object):
self.styles.update(style)
# Recompile the html template and the document obj
self.compiled_html = self.render_html()
self._document = HTML(string=self.compiled_html).render() # weasyprinter.document.Document object
self._document = HTML(
string=self.compiled_html
).render() # weasyprinter.document.Document object
class DocumentGenerator():
class DocumentGenerator:
""" Document generator class """
def __init__(self, template_path=None):
""" Initialize a DocumentGenerator class
"""Initialize a DocumentGenerator class
Keyword Arguments:
template_path {str} -- filepath of custom templates (default: {None})
*** Important *** if not set, will use the default templates from the
*** Important *** if not set, will use the default templates from the
package "genalog.generation.templates".
"""
if template_path:
self.template_env = Environment(
loader=FileSystemLoader(template_path),
autoescape=select_autoescape(['html', 'xml'])
)
loader=FileSystemLoader(template_path),
autoescape=select_autoescape(["html", "xml"]),
)
self.template_list = self.template_env.list_templates()
else:
# Loading built-in templates from the genalog package
self.template_env = Environment(
loader=PackageLoader("genalog.generation", "templates"),
autoescape=select_autoescape(['html', 'xml'])
)
loader=PackageLoader("genalog.generation", "templates"),
autoescape=select_autoescape(["html", "xml"]),
)
# Remove macros and css templates from rendering
self.template_list = self.template_env.list_templates(filter_func=DocumentGenerator._keep_template)
self.template_list = self.template_env.list_templates(
filter_func=DocumentGenerator._keep_template
)
self.set_styles_to_generate(DEFAULT_STYLE_COMBINATION)
@staticmethod
def _keep_template(template_name):
""" Auxiliary function for Jinja2.Environment.list_templates().
"""Auxiliary function for Jinja2.Environment.list_templates().
This function filters out non-html templates and base templates
Arguments:
template_name {str} -- target of the template
Returns:
[bool] -- True if keeping the template in the list. False otherwise.
"""
@ -220,16 +237,16 @@ class DocumentGenerator():
if any(name in template_name for name in TEMPLATES_TO_REMOVE):
return False
return True
def set_styles_to_generate(self, style_combinations):
"""
Set new styles to generate.
Arguments:
style_combination {dict} -- a dictionary {str: list} enlisting the combinations
of values to generate per style property
style_combination {dict} -- a dictionary {str: list} enlisting the combinations
of values to generate per style property
(default: {None})
Example:
{
"font_family": ["Calibri", "Times"],
@ -248,14 +265,16 @@ class DocumentGenerator():
variables declared in the loaded template. There will be NO side-effect providing
an variable UNDEFINED in the template.
If this parameter is not provided, generator will use default document
If this parameter is not provided, generator will use default document
styles: DEFAULT_STYLE_COMBINATION
"""
self.styles_to_generate = DocumentGenerator.expand_style_combinations(style_combinations)
self.styles_to_generate = DocumentGenerator.expand_style_combinations(
style_combinations
)
def create_generator(self, content, templates_to_render):
""" Create a Document generator
"""Create a Document generator
Arguments:
content {list} -- a list [str] of string to populate the template
templates_to_render {list} -- a list [str] or templates to render
@ -266,15 +285,17 @@ class DocumentGenerator():
"""
for template_name in templates_to_render:
if template_name not in self.template_list:
raise FileNotFoundError(f"File '{template_name}' not found. Available templates are {self.template_list}")
raise FileNotFoundError(
f"File '{template_name}' not found. Available templates are {self.template_list}"
)
template = self.template_env.get_template(template_name)
for style in self.styles_to_generate:
yield Document(content, template, **style)
@staticmethod
def expand_style_combinations(styles):
""" Expand the list of style values into all possible style combinations
"""Expand the list of style values into all possible style combinations
Example:
styles =
{
@ -291,12 +312,12 @@ class DocumentGenerator():
{"font_family": "Times", "font_size": "12px", "hyphenate":True }
]
The result dictionaries are intended to be used as a kwargs to initialize a
The result dictionaries are intended to be used as a kwargs to initialize a
Document Object:
Example:
Document(template, content, **{"font_family": "Calibri", "font_size": ...})
Arguments:
styles {dict} -- a dictionary {str: list} enlisting the combinations of values
to generate per style property
@ -308,10 +329,14 @@ class DocumentGenerator():
if not styles:
return []
# Python 2.x+ guarantees that the order in keys() and values() is preserved
style_properties = styles.keys() # ex) ["font_family", "font_size", "hyphenate"]
property_values = styles.values() # ex) [["Calibri", "Times"], ["10px", "12px"], [True]]
# Generate all possible combinations:
style_properties = (
styles.keys()
) # ex) ["font_family", "font_size", "hyphenate"]
property_values = (
styles.values()
) # ex) [["Calibri", "Times"], ["10px", "12px"], [True]]
# Generate all possible combinations:
# [("Calibri", "10px", True), ("Calibri", "12px", True), ...]
property_value_combinations = itertools.product(*property_values)
@ -328,4 +353,4 @@ class DocumentGenerator():
style_dict[style_property] = property_value
style_combinations.append(style_dict)
return style_combinations
return style_combinations

Просмотреть файл

@ -1,20 +1,20 @@
"""Uses the python sdk to make operation on Azure Blob storage.
see: https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python
"""
import os
import time
import asyncio
import base64
import hashlib
import json
import asyncio
import os
import random
from multiprocessing import Pool
from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient
from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from tqdm import tqdm
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
import base64
import aiofiles
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.storage.blob import BlobServiceClient
from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient
from tqdm import tqdm
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
# maximum number of simultaneous requests
REQUEST_SEMAPHORE = asyncio.Semaphore(50)
@ -24,13 +24,21 @@ FILE_SEMAPHORE = asyncio.Semaphore(500)
MAX_RETRIES = 5
class GrokBlobClient:
"""This class is a client that is used to upload and delete files from Azure Blob storage
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python
"""
def __init__(self, datasource_container_name, blob_account_name, blob_key, projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME):
def __init__(
self,
datasource_container_name,
blob_account_name,
blob_key,
projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME,
):
"""Creates the blob storage client given the key and storage account name
Args:
datasource_container_name (str): container name. This container does not need to be existing
projections_container_name (str): projections container to store ocr projections.
@ -42,38 +50,54 @@ class GrokBlobClient:
self.PROJECTIONS_CONTAINER_NAME = projections_container_name
self.BLOB_NAME = blob_account_name
self.BLOB_KEY = blob_key
self.BLOB_CONNECTION_STRING = f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" \
self.BLOB_CONNECTION_STRING = (
f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};"
f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net"
)
@staticmethod
def create_from_env_var():
"""Created the blob client using values in the environment variables
Returns:
GrokBlobClient: the new blob client
"""
DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"]
BLOB_NAME = os.environ["BLOB_NAME"]
BLOB_KEY = os.environ["BLOB_KEY"]
PROJECTIONS_CONTAINER_NAME = os.environ.get("PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME)
client = GrokBlobClient(DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY, projections_container_name=PROJECTIONS_CONTAINER_NAME)
PROJECTIONS_CONTAINER_NAME = os.environ.get(
"PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME
)
client = GrokBlobClient(
DATASOURCE_CONTAINER_NAME,
BLOB_NAME,
BLOB_KEY,
projections_container_name=PROJECTIONS_CONTAINER_NAME,
)
return client
def upload_images_to_blob(self, src_folder_path, dest_folder_name=None,check_existing_cache=True, use_async=True):
def upload_images_to_blob(
self,
src_folder_path,
dest_folder_name=None,
check_existing_cache=True,
use_async=True,
):
"""Uploads images from the src_folder_path to blob storage at the destination folder.
The destination folder is created if it doesn't exist. If a destination folder is not
given a folder is created named by the md5 hash of the files.
Args:
src_folder_path (src): path to local folder that has images
dest_folder_name (str, optional): destination folder name. Defaults to None.
Returns:
str: the destination folder name
"""
self._create_container()
blob_service_client = BlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
self.BLOB_CONNECTION_STRING
)
if dest_folder_name is None:
dest_folder_name = self.get_folder_hash(src_folder_path)
@ -94,30 +118,50 @@ class GrokBlobClient:
return (upload_file_path, blob_name)
if check_existing_cache:
existing_blobs,_ = self.list_blobs(dest_folder_name or "")
existing_blobs = list(map(lambda blob: blob["name"] , existing_blobs))
file_blob_names = filter(lambda file_blob_names: not file_blob_names[1] in existing_blobs, zip(files_to_upload, blob_names))
job_args = [get_job_args(file_path, blob_name) for file_path, blob_name in file_blob_names ]
existing_blobs, _ = self.list_blobs(dest_folder_name or "")
existing_blobs = list(map(lambda blob: blob["name"], existing_blobs))
file_blob_names = filter(
lambda file_blob_names: not file_blob_names[1] in existing_blobs,
zip(files_to_upload, blob_names),
)
job_args = [
get_job_args(file_path, blob_name)
for file_path, blob_name in file_blob_names
]
else:
job_args = [get_job_args(file_path, blob_name) for file_path, blob_name in zip(files_to_upload, blob_names)]
job_args = [
get_job_args(file_path, blob_name)
for file_path, blob_name in zip(files_to_upload, blob_names)
]
print("uploading ", len(job_args), "files")
if not use_async:
blob_service_client = BlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
blob_container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
self.BLOB_CONNECTION_STRING
)
blob_container_client = blob_service_client.get_container_client(
self.DATASOURCE_CONTAINER_NAME
)
jobs = [(blob_container_client,) + x for x in job_args]
for _ in tqdm(map(_upload_worker_sync, jobs), total=len(jobs) ):
for _ in tqdm(map(_upload_worker_sync, jobs), total=len(jobs)):
pass
else:
async_blob_service_client = asyncBlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
self.BLOB_CONNECTION_STRING
)
async def async_upload():
async with async_blob_service_client:
async_blob_container_client = async_blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
async_blob_container_client = (
async_blob_service_client.get_container_client(
self.DATASOURCE_CONTAINER_NAME
)
)
jobs = [(async_blob_container_client,) + x for x in job_args]
for f in tqdm(asyncio.as_completed(map(_upload_worker_async,jobs)), total=len(jobs)):
for f in tqdm(
asyncio.as_completed(map(_upload_worker_async, jobs)),
total=len(jobs),
):
await f
loop = asyncio.get_event_loop()
@ -150,7 +194,7 @@ class GrokBlobClient:
def delete_blobs_folder(self, folder_name):
"""Deletes all blobs in a folder
Args:
folder_name (str): folder to delete
"""
@ -158,53 +202,83 @@ class GrokBlobClient:
blobs_list, blob_service_client = self.list_blobs(folder_name)
for blob in blobs_list:
blob_client = blob_service_client.get_blob_client(
container=self.DATASOURCE_CONTAINER_NAME, blob=blob)
container=self.DATASOURCE_CONTAINER_NAME, blob=blob
)
blob_client.delete_blob()
def list_blobs(self, folder_name):
blob_service_client = BlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
return container_client.list_blobs(name_starts_with=folder_name), blob_service_client
self.BLOB_CONNECTION_STRING
)
container_client = blob_service_client.get_container_client(
self.DATASOURCE_CONTAINER_NAME
)
return (
container_client.list_blobs(name_starts_with=folder_name),
blob_service_client,
)
def _create_container(self):
"""Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist.
"""
"""Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist."""
# Create the BlobServiceClient object which will be used to create a container client
blob_service_client = BlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
self.BLOB_CONNECTION_STRING
)
try:
blob_service_client.create_container(
self.DATASOURCE_CONTAINER_NAME)
blob_service_client.create_container(self.DATASOURCE_CONTAINER_NAME)
except ResourceExistsError:
print("container already exists:", self.DATASOURCE_CONTAINER_NAME)
# create the container for storing ocr projections
try:
print("creating projections storage container:", self.PROJECTIONS_CONTAINER_NAME)
blob_service_client.create_container(
self.PROJECTIONS_CONTAINER_NAME)
print(
"creating projections storage container:",
self.PROJECTIONS_CONTAINER_NAME,
)
blob_service_client.create_container(self.PROJECTIONS_CONTAINER_NAME)
except ResourceExistsError:
print("container already exists:", self.PROJECTIONS_CONTAINER_NAME)
def get_ocr_json(self, remote_path, output_folder, use_async=True):
blob_service_client = BlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
self.BLOB_CONNECTION_STRING
)
container_client = blob_service_client.get_container_client(
self.DATASOURCE_CONTAINER_NAME
)
blobs_list = list(container_client.list_blobs(name_starts_with=remote_path))
container_uri = f"https://{self.BLOB_NAME}.blob.core.windows.net/{self.DATASOURCE_CONTAINER_NAME}"
if use_async:
async_blob_service_client = asyncBlobServiceClient.from_connection_string(
self.BLOB_CONNECTION_STRING)
self.BLOB_CONNECTION_STRING
)
async def async_download():
async with async_blob_service_client:
async_projection_container_client = async_blob_service_client.get_container_client(self.PROJECTIONS_CONTAINER_NAME)
jobs = list(map(lambda blob :(blob, async_projection_container_client, container_uri, output_folder), blobs_list ))
for f in tqdm(asyncio.as_completed(map(_download_worker_async,jobs)), total=len(jobs)):
async_projection_container_client = (
async_blob_service_client.get_container_client(
self.PROJECTIONS_CONTAINER_NAME
)
)
jobs = list(
map(
lambda blob: (
blob,
async_projection_container_client,
container_uri,
output_folder,
),
blobs_list,
)
)
for f in tqdm(
asyncio.as_completed(map(_download_worker_async, jobs)),
total=len(jobs),
):
await f
loop = asyncio.get_event_loop()
if loop.is_running():
result = loop.create_task(async_download())
@ -212,11 +286,24 @@ class GrokBlobClient:
result = loop.run_until_complete(async_download())
return result
else:
projection_container_client = blob_service_client.get_container_client(self.PROJECTIONS_CONTAINER_NAME)
jobs = list(map(lambda blob : (blob, projection_container_client, container_uri, output_folder), blobs_list))
projection_container_client = blob_service_client.get_container_client(
self.PROJECTIONS_CONTAINER_NAME
)
jobs = list(
map(
lambda blob: (
blob,
projection_container_client,
container_uri,
output_folder,
),
blobs_list,
)
)
print("downloading", len(jobs), "files")
for _ in tqdm(map( _download_worker_sync, jobs), total=len(jobs)):
pass
for _ in tqdm(map(_download_worker_sync, jobs), total=len(jobs)):
pass
def _get_projection_path(container_uri, blob):
blob_uri = f"{container_uri}/{blob.name}"
@ -226,13 +313,16 @@ def _get_projection_path(container_uri, blob):
# hopefully this doesn't change soon otherwise we will have to do linear search over all docs to find
# the projections we want
projection_path = base64.b64encode(blob_uri.encode()).decode()
projection_path = projection_path.replace("=","") + str(projection_path.count("="))
projection_path = projection_path.replace("=", "") + str(projection_path.count("="))
return projection_path
def _download_worker_sync(args):
blob, projection_container_client, container_uri, output_folder = args
projection_path = _get_projection_path(container_uri,blob)
blob_client = projection_container_client.get_blob_client(blob=f"{projection_path}/document.json")
projection_path = _get_projection_path(container_uri, blob)
blob_client = projection_container_client.get_blob_client(
blob=f"{projection_path}/document.json"
)
doc = json.loads(blob_client.download_blob().readall())
file_name = os.path.basename(blob.name)
base_name, ext = os.path.splitext(file_name)
@ -242,10 +332,13 @@ def _download_worker_sync(args):
json.dump(text, open(output_file, "w", encoding="utf-8"), ensure_ascii=False)
return output_file
async def _download_worker_async(args):
blob, async_projection_container_client, container_uri, output_folder = args
projection_path = _get_projection_path(container_uri,blob)
async_blob_client = async_projection_container_client.get_blob_client( blob=f"{projection_path}/document.json")
projection_path = _get_projection_path(container_uri, blob)
async_blob_client = async_projection_container_client.get_blob_client(
blob=f"{projection_path}/document.json"
)
file_name = os.path.basename(blob.name)
base_name, ext = os.path.splitext(file_name)
for retry in range(MAX_RETRIES):
@ -260,14 +353,15 @@ async def _download_worker_async(args):
json.dump(text, open(output_file, "w"))
return output_file
except ResourceNotFoundError:
print(f"blob doesn't exist in OCR projection. try rerunning OCR", blob.name)
print(f"Blob '{blob.name}'' doesn't exist in OCR projection. try rerunning OCR")
return
except Exception as e:
print("error getting blob OCR projection", blob.name, e)
# sleep for a bit then retry
asyncio.sleep(2 * random.random())
async def _upload_worker_async(args):
async_blob_container_client, upload_file_path, blob_name = args
async with FILE_SEMAPHORE:
@ -276,24 +370,33 @@ async def _upload_worker_async(args):
for retry in range(MAX_RETRIES):
async with REQUEST_SEMAPHORE:
try:
await async_blob_container_client.upload_blob(name=blob_name, max_concurrency=8, data=data)
await async_blob_container_client.upload_blob(
name=blob_name, max_concurrency=8, data=data
)
return blob_name
except ResourceExistsError:
print("blob already exists:", blob_name)
return
return
except Exception as e:
print(f"blob upload error. retry count: {retry}/{MAX_RETRIES} :", blob_name, e)
print(
f"blob upload error. retry count: {retry}/{MAX_RETRIES} :",
blob_name,
e,
)
# sleep for a bit then retry
asyncio.sleep(2 * random.random())
return blob_name
def _upload_worker_sync(args):
blob_container_client, upload_file_path, blob_name = args
with open(upload_file_path, "rb") as data:
try:
blob_container_client.upload_blob(name=blob_name, max_concurrency=8, data=data)
blob_container_client.upload_blob(
name=blob_name, max_concurrency=8, data=data
)
except ResourceExistsError:
print("blob already exists:", blob_name)
except Exception as e:
print("blob upload error:", blob_name, e)
return blob_name
return blob_name

Просмотреть файл

@ -1 +1 @@
DEFAULT_PROJECTIONS_CONTAINER_NAME = "ocrprojections"
DEFAULT_PROJECTIONS_CONTAINER_NAME = "ocrprojections"

Просмотреть файл

@ -1,14 +1,14 @@
from .rest_client import GrokRestClient
from .blob_client import GrokBlobClient
import time
from .blob_client import GrokBlobClient
from .rest_client import GrokRestClient
class Grok:
@staticmethod
def create_from_env_var():
"""Initializes Grok based on keys in the environment variables.
Returns:
Grok: the Grok client
"""
@ -16,31 +16,41 @@ class Grok:
grok_blob_client = GrokBlobClient.create_from_env_var()
return Grok(grok_rest_client, grok_blob_client)
def __init__(self, grok_rest_client: GrokRestClient, grok_blob_client: GrokBlobClient):
def __init__(
self, grok_rest_client: GrokRestClient, grok_blob_client: GrokBlobClient
):
self.grok_rest_client = grok_rest_client
self.grok_blob_client = grok_blob_client
def run_grok(self, src_folder_path, dest_folder_path, blob_dest_folder=None,cleanup=False,use_async=True):
"""Uploads images in the source folder to blob, sets up an indexing pipeline to run
def run_grok(
self,
src_folder_path,
dest_folder_path,
blob_dest_folder=None,
cleanup=False,
use_async=True,
):
"""Uploads images in the source folder to blob, sets up an indexing pipeline to run
GROK OCR on this blob storage as a source, then dowloads the OCR output json to the destination
folder. There resulting json files are of the same name as the original images except prefixed
folder. There resulting json files are of the same name as the original images except prefixed
with the name of their folder on the blob storages and suffixed with the .json extension.
Args:
src_folder_path (str): Path to folder holding the images. This folder must only contain png or jpg files
dest_folder_path (str): Path to folder where OCR json files will be placed
blob_dest_folder (str, optional): Folder tag to use on the blob storage. If set to None, a hash is generated
based on the names of files in the src folder. Defaults to None.
cleanup (bool, optional): If set to True, the indexing pipeline is deleted, and the files uploaded to the blob are
cleanup (bool, optional): If set to True, the indexing pipeline is deleted, and the files uploaded to the blob are
deleted from blob after running. Defaults to True.
use_multiprocessing (boo, optional): If set to True, this will use multiprocessing to increase blob transfers speed.
Returns:
indexer_status json, blob folder name
"""
print("uploading images to blob")
blob_folder_name,_ = self.grok_blob_client.upload_images_to_blob(
src_folder_path, dest_folder_name=blob_dest_folder, use_async=use_async)
blob_folder_name, _ = self.grok_blob_client.upload_images_to_blob(
src_folder_path, dest_folder_name=blob_dest_folder, use_async=use_async
)
print(f"images upload under folder {blob_folder_name}")
try:
print("creating and running indexer")
@ -50,10 +60,13 @@ class Grok:
indexer_status = self.grok_rest_client.get_indexer_status()
if indexer_status["status"] == "error":
raise RuntimeError(f"indexer error: {indexer_status}")
# if not already running start the indexer
print("indexer_status", indexer_status)
if indexer_status["lastResult"] == None or indexer_status["lastResult"]["status"] != "inProgress":
if (
indexer_status["lastResult"] is None
or indexer_status["lastResult"]["status"] != "inProgress"
):
self.grok_rest_client.run_indexer()
time.sleep(1)
@ -62,9 +75,13 @@ class Grok:
if indexer_status["lastResult"]["status"] == "success":
time.sleep(30)
print("fetching ocr json results.")
self.grok_blob_client.get_ocr_json(blob_folder_name, dest_folder_path, use_async=use_async)
self.grok_blob_client.get_ocr_json(
blob_folder_name, dest_folder_path, use_async=use_async
)
print(f"indexer status {indexer_status}")
print(f"finished running indexer. json files saved to {dest_folder_path}")
print(
f"finished running indexer. json files saved to {dest_folder_path}"
)
else:
print("GROK failed", indexer_status["status"])
raise RuntimeError("GROK failed", indexer_status["status"])
@ -77,7 +94,7 @@ class Grok:
def cleanup(self, folder_name):
"""Deletes the indexing pipeline (index, indexer, datasource, skillset) from the search service.
Deletes uploaded files from the blob
Args:
folder_name (str): blob folder name tag to remove
"""

Просмотреть файл

@ -2,54 +2,59 @@
Utility functions to support getting OCR metrics
OCR Metrics
1. word/character accuracy like in this paper https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6065412.
1. word/character accuracy like in this paper https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6065412.
Accuracy = Correct Words/Total Words (in target strings)
2. Count of edit distance ops:
insert, delete, substitutions; like in the paper "Deep Statistical Analysis of OCR Errors for Effective Post-OCR Processing". This is based on Levenshtein edit distance.
2. Count of edit distance ops:
insert, delete, substitutions; like in the paper "Deep Statistical Analysis of OCR Errors for Effective Post-OCR Processing".
This is based on Levenshtein edit distance.
3. By looking at the gaps in alignment we also generate substitution dicts:
e.g: if we have text "a worn coat" and ocr is "a wom coat" , "rn" -> "m" will be captured as a substitution
3. By looking at the gaps in alignment we also generate substitution dicts:
e.g: if we have text "a worn coat" and ocr is "a wom coat" , "rn" -> "m" will be captured as a substitution
since the rest of the segments align.The assumption here is that we do not expect to have very long gaps in alignment,
hence collecting and counting these substitutions will be managable.
"""
import string
import argparse
import json
import multiprocessing
import os
import re
import json
import argparse
import multiprocessing
from multiprocessing import Pool
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool
from genalog.text.alignment import GAP_CHAR
from genalog.text.ner_label import _find_gap_char_candidates
from genalog.text.anchor import align_w_anchor
from genalog.text.ner_label import _find_gap_char_candidates
LOG_LEVEL = 0
WORKERS_PER_CPU = 2
def _log(*args, **kwargs):
if LOG_LEVEL:
print(args)
def _trim_whitespace(src_string):
return re.sub(r"\s+", " ", src_string.strip())
def _update_align_stats(src, target, align_stats, substitution_dict, gap_char):
"""Given two string that differ and have no alignment at all,
update the alignment dict and fill in substitution if replacements are found.
update alignment stats with counts of the edit operation to transform the source
string to the targes
Args:
src (str): source string
target (str): target string at the
target (str): target string at the
align_stats (dict): key-value dictionary that stores the counts of inserts, deletes,
spacing and replacements
substitution_dict (dict): store the counts of mapping from one substring to another of
the replacement edit operation. e.g if 'rm' in source needs to map to 'm' in the target 2
the replacement edit operation. e.g if 'rm' in source needs to map to 'm' in the target 2
times this will be { ('rm','m'): 2}
gap_char (str): gap character used in alignment
"""
@ -70,30 +75,40 @@ def _update_align_stats(src, target, align_stats, substitution_dict, gap_char):
else:
align_stats["replace"] += 1
_log("replacing", source_substr, target_substr)
substitution_dict[source_substr, target_substr] = substitution_dict.get(
(source_substr, target_substr), 0) + 1
substitution_dict[source_substr, target_substr] = (
substitution_dict.get((source_substr, target_substr), 0) + 1
)
_log("spacing count", spacing_count)
align_stats["spacing"] += spacing_count
def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matching_chars_count, \
matching_words_count, matching_alnum_words_count):
def _update_word_stats(
aligned_src,
aligned_target,
gap_char,
start,
end,
matching_chars_count,
matching_words_count,
matching_alnum_words_count,
):
"""Given two string segments that align. update the counts of matching words and characters
Args:
aligned_src (str): full source string
aligned_target (str): full target string
gap_char (str): gap character used in alignment
gap_char (str): gap character used in alignment
start (int): start position of alignment
end (int): end position of alignment
matching_chars_count (int): current count of matching characters
matching_words_count (int): current count of matching words
matching_alnum_words_count (int): current count of alphanumeric matching words
Returns:
tuple(int,int,int): the updated matching_chars_count, matching_words_count, matching_alnum_words_count
"""
aligned_part = aligned_src[start:end]
matching_chars_count += (end-start)
matching_chars_count += end - start
# aligned_part = seq.strip()
_log("aligned", aligned_part, start, end)
if len(aligned_src) != len(aligned_target):
@ -112,10 +127,18 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
# to be compared with the full string to see if they have space before or after
if i == 0:
if start != 0 and (aligned_target[start] != " " or aligned_src[start] != " " ):
if start != 0 and (
aligned_target[start] != " " or aligned_src[start] != " "
):
# if this was the start of the string in the target or source
if not(aligned_src[:start].replace(gap_char,"").replace(" ","") == "" and aligned_target[start-1] == " ") and \
not(aligned_target[:start].replace(gap_char,"").replace(" ","") == "" and aligned_src[start-1] == " "):
if not (
aligned_src[:start].replace(gap_char, "").replace(" ", "") == ""
and aligned_target[start - 1] == " "
) and not (
aligned_target[:start].replace(gap_char, "").replace(" ", "")
== ""
and aligned_src[start - 1] == " "
):
# beginning word not matching completely
_log("removing first match word from count", word, aligned_part)
matching_words_count -= 1
@ -123,34 +146,43 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
matching_alnum_words_count -= 1
continue
if i == len(words)-1:
if end != len(aligned_target) and (aligned_target[end] != " " or aligned_src[end] != " " ):
if i == len(words) - 1:
if end != len(aligned_target) and (
aligned_target[end] != " " or aligned_src[end] != " "
):
# this was not the end of the string in the src and not end of string in target
if not(aligned_src[end:].replace(gap_char,"").replace(" ","") == "" and aligned_target[end] == " ") and \
not(aligned_target[end:].replace(gap_char,"").replace(" ","") == ""and aligned_src[end] == " "):
if not (
aligned_src[end:].replace(gap_char, "").replace(" ", "") == ""
and aligned_target[end] == " "
) and not (
aligned_target[end:].replace(gap_char, "").replace(" ", "")
== ""
and aligned_src[end] == " "
):
# last word not matching completely
_log("removing last match word from count",word, aligned_part)
_log("removing last match word from count", word, aligned_part)
matching_words_count -= 1
if re.search(r"\w", word):
matching_alnum_words_count -= 1
_log("matched count", matching_words_count)
_log("matched alnum count", matching_alnum_words_count)
return matching_chars_count, matching_words_count, matching_alnum_words_count
return matching_chars_count, matching_words_count, matching_alnum_words_count
def _get_align_stats(alignment, src_string, target, gap_char):
"""Given an alignment, this function get the align stats and substitution mapping to
"""Given an alignment, this function get the align stats and substitution mapping to
transform the source string to the target string
Args:
alignment (tuple(str, str)): the result of calling align on the two strings
src_source (str): the source string
target (str) : the target string
gap_char (str) : the gap character used in alignment
Raises:
ValueError: if any of the aligned string are empty
Returns:
tuple(dict, dict): align stats dict, substitution mappings dict
"""
@ -167,17 +199,24 @@ def _get_align_stats(alignment, src_string, target, gap_char):
word_count = len(words)
# alphanumeric words are defined here as words with at least one alphanumeric character
alnum_words_count = len(list(filter(lambda x: re.search(r"\w",x) , words )))
alnum_words_count = len(list(filter(lambda x: re.search(r"\w", x), words)))
char_count = max(len(target), len(src_string))
matching_chars_count = 0
matching_words_count = 0
matching_alnum_words_count = 0
align_stats = {"insert": 0, "delete": 0, "replace": 0, "spacing": 0,
"total_chars": char_count, "total_words": word_count, "total_alnum_words": alnum_words_count}
align_stats = {
"insert": 0,
"delete": 0,
"replace": 0,
"spacing": 0,
"total_chars": char_count,
"total_words": word_count,
"total_alnum_words": alnum_words_count,
}
start = 0
_log("######### Alignment ############")
_log(aligned_src)
_log(aligned_target)
@ -190,57 +229,105 @@ def _get_align_stats(alignment, src_string, target, gap_char):
# since this substring aligns, simple count the number of matching words and chars in and update
# the word stats
end = i
_log("sequences", aligned_src[start:end], aligned_target[start:end], start, end)
_log(
"sequences",
aligned_src[start:end],
aligned_target[start:end],
start,
end,
)
assert aligned_src[start:end] == aligned_target[start:end]
matching_chars_count, matching_words_count, matching_alnum_words_count = _update_word_stats(aligned_src,
aligned_target, gap_char, start, end, matching_chars_count,matching_words_count, matching_alnum_words_count)
(
matching_chars_count,
matching_words_count,
matching_alnum_words_count,
) = _update_word_stats(
aligned_src,
aligned_target,
gap_char,
start,
end,
matching_chars_count,
matching_words_count,
matching_alnum_words_count,
)
start = end + 1
if gap_start is None:
gap_start = end
else:
gap_end = i
if not gap_start is None:
if gap_start is not None:
# since characters now match gap_start:i contains a substring of the characters that didnt align before
# handle this gap alignment by calling _update_align_stats
_log("gap", aligned_src[gap_start:gap_end], aligned_target[gap_start:gap_end], gap_start, gap_end)
_update_align_stats(aligned_src[gap_start:gap_end], aligned_target[gap_start:gap_end], align_stats, substitution_dict, gap_char)
_log(
"gap",
aligned_src[gap_start:gap_end],
aligned_target[gap_start:gap_end],
gap_start,
gap_end,
)
_update_align_stats(
aligned_src[gap_start:gap_end],
aligned_target[gap_start:gap_end],
align_stats,
substitution_dict,
gap_char,
)
gap_start = None
# Now compare any left overs string segments from the for loop
if gap_start is not None:
# handle last alignment gap
_log("last gap", aligned_src[gap_start:], aligned_target[gap_start:])
_update_align_stats(aligned_src[gap_start:], aligned_target[gap_start:], align_stats, substitution_dict, gap_char)
_update_align_stats(
aligned_src[gap_start:],
aligned_target[gap_start:],
align_stats,
substitution_dict,
gap_char,
)
else:
# handle last aligned substring
seq = aligned_src[start:]
aligned_part = seq.strip()
end = len(aligned_src)
_log("last aligned", aligned_part)
matching_chars_count, matching_words_count, matching_alnum_words_count = _update_word_stats(aligned_src,
aligned_target, gap_char, start, end, matching_chars_count,matching_words_count, matching_alnum_words_count)
(
matching_chars_count,
matching_words_count,
matching_alnum_words_count,
) = _update_word_stats(
aligned_src,
aligned_target,
gap_char,
start,
end,
matching_chars_count,
matching_words_count,
matching_alnum_words_count,
)
align_stats["matching_chars"] = matching_chars_count
align_stats["matching_alnum_words"] = matching_alnum_words_count
align_stats["matching_words"] = matching_words_count
align_stats["alnum_word_accuracy"] = matching_alnum_words_count/alnum_words_count
align_stats["word_accuracy"] = matching_words_count/word_count
align_stats["char_accuracy"] = matching_chars_count/char_count
align_stats["alnum_word_accuracy"] = matching_alnum_words_count / alnum_words_count
align_stats["word_accuracy"] = matching_words_count / word_count
align_stats["char_accuracy"] = matching_chars_count / char_count
return align_stats, substitution_dict
def get_editops_stats(alignment, gap_char):
"""Get stats for character level edit operations that need to be done to
"""Get stats for character level edit operations that need to be done to
transform the source string to the target string. Inputs must not be empty
and must be the result of calling the runing the align function.
Args:
alignment (tuple(str, str)): the results from the string alignment biopy function
gap_char (str): gap character used in alignment
Raises:
ValueError: If any of the string in the alignment are empty
Returns:
[type]: [description]
"""
@ -248,42 +335,49 @@ def get_editops_stats(alignment, gap_char):
aligned_src, aligned_target = alignment
if aligned_src == "" or aligned_target == "":
raise ValueError("one of the input strings is empty")
stats = {"edit_insert": 0, "edit_delete": 0, "edit_replace": 0,
"edit_insert_spacing": 0, "edit_delete_spacing": 0}
stats = {
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 0,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
}
actions = {}
for i, (char_1, char_2) in enumerate(zip(aligned_src, aligned_target)):
if LOG_LEVEL > 1: _log(char_1, char_2)
if LOG_LEVEL > 1:
_log(char_1, char_2)
if char_1 == gap_char:
# insert
if char_2 == " ":
stats["edit_insert_spacing"] += 1
else:
stats["edit_insert"] += 1
actions[i] = ("I",char_2)
actions[i] = ("I", char_2)
elif char_2 == gap_char:
# delete
if char_1 == " ":
stats["edit_delete_spacing"] += 1
else:
stats["edit_delete"] += 1
actions[i] = ("D")
actions[i] = "D"
elif char_2 != char_1:
stats["edit_replace"] += 1
actions[i] = ("R",char_2)
actions[i] = ("R", char_2)
return stats, actions
def get_align_stats(alignment, src_string, target, gap_char):
"""Get alignment stats
"""Get alignment stats
Args:
alignment (tuple(str,str)): the result of calling the align function
src_string (str): the original source string
target (str): the original target string
gap_char (str): the gap character used in alignment
Raises:
ValueError: if any of the strings are empty
Returns:
tuple(dict, dict): dict of the align starts and dict of the substitution mappings
"""
@ -293,52 +387,61 @@ def get_align_stats(alignment, src_string, target, gap_char):
_log("alignment results")
_log(alignment)
align_stats, substitution_dict = _get_align_stats(
alignment, src_string, target, gap_char)
alignment, src_string, target, gap_char
)
return align_stats, substitution_dict
def get_stats(target, src_string):
"""Get align stats, edit stats, and substitution mappings for transforming the
"""Get align stats, edit stats, and substitution mappings for transforming the
source string to the target string. Edit stats refers to character level edit operation
required to transform the source to target. Align stats referers to substring level operation
required to transform the source to target. Align stats have keys insert,replace,delete and the special
key spacing which counts spacing differences between the two strings. Edit stats have the keys edit_insert,
edit_replace, edit_delete which count the character level edits.
Args:
src_string (str): the source string
target (str): the target string
Returns:
tuple(str, str): One dict containing the edit and align stats, another dict containing the substitutions
"""
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
gap_char_candidates, input_char_set = _find_gap_char_candidates(
[src_string], [target]
)
gap_char = (
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
)
alignment = align_w_anchor(src_string, target, gap_char=gap_char)
align_stats, substitution_dict = get_align_stats(alignment,src_string, target, gap_char)
align_stats, substitution_dict = get_align_stats(
alignment, src_string, target, gap_char
)
edit_stats, actions = get_editops_stats(alignment, gap_char)
_log("alignment", align_stats)
return {**edit_stats, **align_stats}, substitution_dict, actions
def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocessing=True):
"""Given a path to the folder containing the source text and a folder containing
the output OCR json, this generates the metrics for all files in the source folder.
def get_metrics(
src_text_path, ocr_json_path, folder_hash=None, use_multiprocessing=True
):
"""Given a path to the folder containing the source text and a folder containing
the output OCR json, this generates the metrics for all files in the source folder.
This assumes that the files json folder are of the same name the text files except they
are prefixed by the parameter folder_hash followed by underscore and suffixed by .png.json.
Args:
src_text_path (str): path to source txt files
ocr_json_path (str): path to OCR json files
folder_hash (str): prefix for OCR json files
use_multiprocessing (bool): use multiprocessing
Returns:
tuple(pandas.DataFrame, dict): A pandas dataframe of the metrics with each file in a row,
a dict containing the substitions mappings for each file. the key to the dict is the
filename and the values are dicts of the substition mappings for that file.
"""
rows = []
substitutions = {}
actions_map = {}
@ -347,16 +450,25 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess
cpu_count = multiprocessing.cpu_count()
n_workers = WORKERS_PER_CPU * cpu_count
job_args = list(map(lambda f: (f, src_text_path, ocr_json_path, folder_hash) , os.listdir(src_text_path)))
job_args = list(
map(
lambda f: (f, src_text_path, ocr_json_path, folder_hash),
os.listdir(src_text_path),
)
)
if use_multiprocessing:
with Pool(n_workers) as pool:
for f, stats, actions, subs in tqdm(pool.imap_unordered(_worker, job_args), total=len(job_args)):
for f, stats, actions, subs in tqdm(
pool.imap_unordered(_worker, job_args), total=len(job_args)
):
substitutions[f] = subs
actions_map[f] = actions
rows.append(stats)
rows.append(stats)
else:
for f, stats, actions, subs in tqdm(map(_worker, job_args), total=len(job_args)):
for f, stats, actions, subs in tqdm(
map(_worker, job_args), total=len(job_args)
):
substitutions[f] = subs
actions_map[f] = actions
rows.append(stats)
@ -364,16 +476,17 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess
df = pd.DataFrame(rows)
return df, substitutions, actions_map
def get_file_metrics(f, src_text_path, ocr_json_path, folder_hash):
src_filename = os.path.join(src_text_path, f)
if folder_hash:
ocr_filename = os.path.join(
ocr_json_path, f"{folder_hash}_{f.split('txt')[0] + 'json'}")
ocr_json_path, f"{folder_hash}_{f.split('txt')[0] + 'json'}"
)
else:
ocr_filename = os.path.join(
ocr_json_path, f"{f.split('txt')[0] + 'json'}")
ocr_filename = os.path.join(ocr_json_path, f"{f.split('txt')[0] + 'json'}")
try:
src_string = open(src_filename, "r", errors='ignore', encoding="utf8").read()
src_string = open(src_filename, "r", errors="ignore", encoding="utf8").read()
except FileNotFoundError:
print(f"File not found: {src_filename}, skipping this file.")
return f, {}, {}, {}
@ -395,34 +508,41 @@ def get_file_metrics(f, src_text_path, ocr_json_path, folder_hash):
stats["filename"] = f
return f, stats, actions, subs
def _worker(args):
(f, src_text_path, ocr_json_path, folder_hash) = args
return get_file_metrics(f, src_text_path, ocr_json_path, folder_hash)
def _get_sorted_text(ocr_json):
if "lines" in ocr_json[0]:
lines = ocr_json[0]["lines"]
sorted_lines = sorted(lines, key=lambda line : line["boundingBox"][0]["y"])
sorted_lines = sorted(lines, key=lambda line: line["boundingBox"][0]["y"])
return " ".join([line["text"] for line in sorted_lines])
else:
return ocr_json[0]["text"]
def substitution_dict_to_json(substitution_dict):
"""Converts substitution dict to list of tuples of (source_substring, target_substring, count)
Args:
substitution_dict ([type]): [description]
"""
to_tuple = lambda x: [(k +(x[k],)) for k in x]
to_tuple = lambda x: [(k + (x[k],)) for k in x] # noqa: E731
out = {}
for filename in substitution_dict:
out[filename] = to_tuple(substitution_dict[filename])
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("src", help="path to folder with text files.")
parser.add_argument("ocr", help="folder with ocr json. the filename must match the text filename prefixed by ocr_prefix.")
parser.add_argument(
"ocr",
help="folder with ocr json. the filename must match the text filename prefixed by ocr_prefix.",
)
parser.add_argument("--ocr_prefix", help="the prefix of the ocr files")
parser.add_argument("--output", help="output names of metrics files")

Просмотреть файл

@ -1,33 +1,45 @@
"""Uses the REST api to perform operations on the search service.
see: https://docs.microsoft.com/en-us/rest/api/searchservice/
"""
import requests
import json
import os
import pkgutil
import json
from dotenv import load_dotenv
import time
import sys
import time
from itertools import cycle
import requests
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
API_VERSION = '?api-version=2019-05-06-Preview'
API_VERSION = "?api-version=2019-05-06-Preview"
# 15 min schedule
SCHEDULE_INTERVAL= "PT15M"
SCHEDULE_INTERVAL = "PT15M"
class GrokRestClient:
"""This is a REST client. It is a wrapper around the REST api for the Azure Search Service
see: https://docs.microsoft.com/en-us/rest/api/searchservice/
This class can be used to create an indexing pipeline and can be used to run and monitor
This class can be used to create an indexing pipeline and can be used to run and monitor
ongoing indexers. The indexing pipeline can allow you to run batch OCR enrichment of documents.
"""
def __init__(self, cognitive_service_key, search_service_key, search_service_name, skillset_name,
index_name, indexer_name, datasource_name, datasource_container_name, blob_account_name, blob_key,
projections_container_name = DEFAULT_PROJECTIONS_CONTAINER_NAME):
def __init__(
self,
cognitive_service_key,
search_service_key,
search_service_name,
skillset_name,
index_name,
indexer_name,
datasource_name,
datasource_container_name,
blob_account_name,
blob_key,
projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME,
):
"""Creates the REST client
Args:
@ -36,7 +48,7 @@ class GrokRestClient:
search_service_name (str): name of the search service account
skillset_name (str): name of the skillset
index_name (str): name of the index
indexer_name (str): the name of indexer
indexer_name (str): the name of indexer
datasource_name (str): the name to give the the attached blob storage source
datasource_container_name (str): the container in the blob storage that host the files
blob_account_name (str): blob storage account name that will host the documents to push though the pipeline
@ -70,8 +82,10 @@ class GrokRestClient:
self.API_VERSION = API_VERSION
self.BLOB_CONNECTION_STRING = f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" \
self.BLOB_CONNECTION_STRING = (
f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};"
f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net"
)
@staticmethod
def create_from_env_var():
@ -85,51 +99,68 @@ class GrokRestClient:
DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"]
BLOB_NAME = os.environ["BLOB_NAME"]
BLOB_KEY = os.environ["BLOB_KEY"]
PROJECTIONS_CONTAINER_NAME = os.environ.get("PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME)
PROJECTIONS_CONTAINER_NAME = os.environ.get(
"PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME
)
client = GrokRestClient(COGNITIVE_SERVICE_KEY, SEARCH_SERVICE_KEY, SEARCH_SERVICE_NAME, SKILLSET_NAME, INDEX_NAME,
INDEXER_NAME, DATASOURCE_NAME, DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY,projections_container_name=PROJECTIONS_CONTAINER_NAME)
client = GrokRestClient(
COGNITIVE_SERVICE_KEY,
SEARCH_SERVICE_KEY,
SEARCH_SERVICE_NAME,
SKILLSET_NAME,
INDEX_NAME,
INDEXER_NAME,
DATASOURCE_NAME,
DATASOURCE_CONTAINER_NAME,
BLOB_NAME,
BLOB_KEY,
projections_container_name=PROJECTIONS_CONTAINER_NAME,
)
return client
def create_skillset(self):
"""Adds a skillset that performs OCR on images
"""
"""Adds a skillset that performs OCR on images"""
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
skillset_json = json.loads(pkgutil.get_data(
__name__, "templates/skillset.json"))
skillset_json = json.loads(
pkgutil.get_data(__name__, "templates/skillset.json")
)
skillset_json["name"] = self.SKILLSET_NAME
skillset_json["cognitiveServices"]["key"] = self.COGNITIVE_SERVICE_KEY
knowledge_store_json = json.loads(pkgutil.get_data(
__name__, "templates/knowledge_store.json"))
knowledge_store_json = json.loads(
pkgutil.get_data(__name__, "templates/knowledge_store.json")
)
knowledge_store_json["storageConnectionString"] = self.BLOB_CONNECTION_STRING
knowledge_store_json["projections"][0]["objects"][0]["storageContainer"] = self.PROJECTIONS_CONTAINER_NAME
knowledge_store_json["projections"][0]["objects"][0][
"storageContainer"
] = self.PROJECTIONS_CONTAINER_NAME
skillset_json["knowledgeStore"] = knowledge_store_json
print(skillset_json)
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}"
r = requests.put(endpoint + self.API_VERSION,
json.dumps(skillset_json), headers=headers)
r = requests.put(
endpoint + self.API_VERSION, json.dumps(skillset_json), headers=headers
)
print("skillset response", r.text)
r.raise_for_status()
print("added skillset", self.SKILLSET_NAME, r)
def create_datasource(self):
"""Attaches the blob data store to the search service as a source for image documents
"""
"""Attaches the blob data store to the search service as a source for image documents"""
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
datasource_json = json.loads(pkgutil.get_data(
__name__, "templates/datasource.json"))
datasource_json = json.loads(
pkgutil.get_data(__name__, "templates/datasource.json")
)
datasource_json["name"] = self.DATASOURCE_NAME
datasource_json["credentials"]["connectionString"] = self.BLOB_CONNECTION_STRING
datasource_json["type"] = "azureblob"
@ -137,27 +168,27 @@ class GrokRestClient:
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/datasources/{self.DATASOURCE_NAME}"
r = requests.put(endpoint + self.API_VERSION,
json.dumps(datasource_json), headers=headers)
r = requests.put(
endpoint + self.API_VERSION, json.dumps(datasource_json), headers=headers
)
print("datasource response", r.text)
r.raise_for_status()
print("added datasource", self.DATASOURCE_NAME, r)
def create_index(self):
"""Create an index with the layoutText column to store OCR output from the enrichment
"""
"""Create an index with the layoutText column to store OCR output from the enrichment"""
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
index_json = json.loads(pkgutil.get_data(
__name__, "templates/index.json"))
index_json = json.loads(pkgutil.get_data(__name__, "templates/index.json"))
index_json["name"] = self.INDEX_NAME
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexes/{self.INDEX_NAME}"
r = requests.put(endpoint + self.API_VERSION,
json.dumps(index_json), headers=headers)
r = requests.put(
endpoint + self.API_VERSION, json.dumps(index_json), headers=headers
)
print("index response", r.text)
r.raise_for_status()
print("created index", self.INDEX_NAME, r)
@ -167,24 +198,26 @@ class GrokRestClient:
The enriched results are pushed to the index.
"""
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
indexer_json = json.loads(pkgutil.get_data(
__name__, "templates/indexer.json"))
indexer_json = json.loads(pkgutil.get_data(__name__, "templates/indexer.json"))
indexer_json["name"] = self.INDEXER_NAME
indexer_json["skillsetName"] = self.SKILLSET_NAME
indexer_json["targetIndexName"] = self.INDEX_NAME
indexer_json["dataSourceName"] = self.DATASOURCE_NAME
indexer_json["schedule"] = {"interval": SCHEDULE_INTERVAL}
indexer_json["parameters"]["configuration"]["excludedFileNameExtensions"] = extension_to_exclude
indexer_json["parameters"]["configuration"][
"excludedFileNameExtensions"
] = extension_to_exclude
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}"
r = requests.put(endpoint + self.API_VERSION,
json.dumps(indexer_json), headers=headers)
r = requests.put(
endpoint + self.API_VERSION, json.dumps(indexer_json), headers=headers
)
print("indexer response", r.text)
r.raise_for_status()
print("created indexer", self.INDEXER_NAME, r)
@ -203,14 +236,14 @@ class GrokRestClient:
created
"""
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
endpoints = [
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}",
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexes/{self.INDEX_NAME}",
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/datasources/{self.DATASOURCE_NAME}",
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}"
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}",
]
for endpoint in endpoints:
@ -220,8 +253,8 @@ class GrokRestClient:
def run_indexer(self):
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}/run"
@ -235,23 +268,26 @@ class GrokRestClient:
i = 0
while True:
# attempt a call every 100 steps
if i % 100 == 0:
if i % 100 == 0:
request_json = self.get_indexer_status()
if request_json["status"] == "error":
raise RuntimeError("Indexer failed")
if request_json["lastResult"] and not request_json["lastResult"]["status"] == "inProgress":
if (
request_json["lastResult"]
and not request_json["lastResult"]["status"] == "inProgress"
):
print(request_json["lastResult"]["status"], self.INDEXER_NAME)
return request_json
sys.stdout.write(next(progress))
sys.stdout.flush()
time.sleep(0.05)
i = (1+i) % 1000 # to avoid overflow
i = (1 + i) % 1000 # to avoid overflow
def get_indexer_status(self):
headers = {
'Content-Type': 'application/json',
'api-key': self.SEARCH_SERVICE_KEY,
"Content-Type": "application/json",
"api-key": self.SEARCH_SERVICE_KEY,
}
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}/status"
response = requests.get(endpoint + self.API_VERSION, headers=headers)
@ -259,5 +295,5 @@ class GrokRestClient:
return response.json()
def _checkArg(self, name, value):
if not(value):
if not (value):
raise ValueError(f"argument {name} is not set")

Просмотреть файл

@ -1,14 +1,17 @@
from genalog.generation.document import DocumentGenerator
from genalog.generation.document import DEFAULT_STYLE_COMBINATION
from genalog.generation.content import CompositeContent, ContentType
from genalog.degradation.degrader import Degrader, ImageState
from json import JSONEncoder
from tqdm import tqdm
import concurrent.futures
import timeit
import cv2
import os
import timeit
from json import JSONEncoder
import cv2
from tqdm import tqdm
from genalog.degradation.degrader import Degrader, ImageState
from genalog.generation.content import CompositeContent, ContentType
from genalog.generation.document import DEFAULT_STYLE_COMBINATION
from genalog.generation.document import DocumentGenerator
class ImageStateEncoder(JSONEncoder):
def default(self, obj):
@ -16,8 +19,15 @@ class ImageStateEncoder(JSONEncoder):
return obj.value
return JSONEncoder.default(self, obj)
class AnalogDocumentGeneration(object):
def __init__(self, template_path=None, styles=DEFAULT_STYLE_COMBINATION, degradations=[], resolution=300):
def __init__(
self,
template_path=None,
styles=DEFAULT_STYLE_COMBINATION,
degradations=[],
resolution=300,
):
self.doc_generator = DocumentGenerator(template_path=template_path)
self.doc_generator.set_styles_to_generate(styles)
self.degrader = Degrader(degradations)
@ -25,7 +35,7 @@ class AnalogDocumentGeneration(object):
self.resolution = resolution
def list_templates(self):
""" List available templates to generate documents from
"""List available templates to generate documents from
Returns:
list -- a list of template names
@ -33,7 +43,7 @@ class AnalogDocumentGeneration(object):
return self.doc_generator.template_list
def generate_img(self, full_text_path, template, target_folder=None):
""" Generate synthetic images given the filepath of a text document
"""Generate synthetic images given the filepath of a text document
Arguments:
full_text_path {str} -- full filepath of a text document (i.e /dataset/doc.txt)
@ -44,18 +54,18 @@ class AnalogDocumentGeneration(object):
target_folder {str} -- folder path in which the generated images are stored
(default: {None})
resolution {int} -- resolution in dpi (default: {300})
"""
with open(full_text_path, "r",encoding="utf8") as f: # read file
"""
with open(full_text_path, "r", encoding="utf8") as f: # read file
text = f.read()
content = CompositeContent([text], [ContentType.PARAGRAPH])
generator = self.doc_generator.create_generator(content, [template])
# Generate the image
doc = next(generator)
src = doc.render_array(resolution=self.resolution, channel="GRAYSCALE")
# Degrade the image
dst = self.degrader.apply_effects(src)
if not target_folder:
# return the analog document as numpy.ndarray
return dst
@ -67,45 +77,71 @@ class AnalogDocumentGeneration(object):
cv2.imwrite(img_dst_path, dst)
return
def _divide_batches(a, batch_size):
for i in range(0, len(a), batch_size):
yield a[i:i + batch_size]
for i in range(0, len(a), batch_size):
yield a[i: i + batch_size]
def _setup_folder(output_folder):
os.makedirs(os.path.join(output_folder, "img"), exist_ok=True)
def batch_img_generate(args):
input_files, output_folder, styles, degradations, template, resolution = args
generator = AnalogDocumentGeneration(styles=styles, degradations=degradations, resolution=resolution)
generator = AnalogDocumentGeneration(
styles=styles, degradations=degradations, resolution=resolution
)
for file in input_files:
generator.generate_img(file, template, target_folder=output_folder)
def _set_batch_generate_args(file_batches, output_folder, styles, degradations, template, resolution):
return list(map(
lambda batch:
(batch, output_folder, styles, degradations, template, resolution),
file_batches
))
def _set_batch_generate_args(
file_batches, output_folder, styles, degradations, template, resolution
):
return list(
map(
lambda batch: (
batch,
output_folder,
styles,
degradations,
template,
resolution,
),
file_batches,
)
)
def generate_dataset_multiprocess(
input_text_files, output_folder, styles, degradations, template,
resolution=300, batch_size=25
):
input_text_files,
output_folder,
styles,
degradations,
template,
resolution=300,
batch_size=25,
):
_setup_folder(output_folder)
print(f"Storing generated images in {output_folder}")
print(f"Storing generated images in {output_folder}")
batches = list(_divide_batches(input_text_files, batch_size))
print(f"Splitting {len(input_text_files)} documents into {len(batches)} batches with size {batch_size}")
batches = list(_divide_batches(input_text_files, batch_size))
print(
f"Splitting {len(input_text_files)} documents into {len(batches)} batches with size {batch_size}"
)
batch_img_generate_args = _set_batch_generate_args(batches, output_folder, styles, degradations, template, resolution)
batch_img_generate_args = _set_batch_generate_args(
batches, output_folder, styles, degradations, template, resolution
)
# Default to the number of processors on the machine
start_time = timeit.default_timer()
with concurrent.futures.ProcessPoolExecutor() as executor:
batch_iterator = executor.map(batch_img_generate, batch_img_generate_args)
for _ in tqdm(batch_iterator, total=len(batch_img_generate_args)): # wrapping tqdm for progress report
for _ in tqdm(
batch_iterator, total=len(batch_img_generate_args)
): # wrapping tqdm for progress report
pass
elapsed = timeit.default_timer() - start_time
print(f"Time to generate {len(input_text_files)} documents: {elapsed:.3f} sec")

Просмотреть файл

@ -1 +0,0 @@

Просмотреть файл

@ -1,42 +1,53 @@
from genalog.text.preprocess import _is_spacing, tokenize
from Bio import pairwise2
import re
from Bio import pairwise2
from genalog.text.preprocess import _is_spacing, tokenize
# Configuration params for global sequence alignment algorithm (Needleman-Wunsch)
MATCH_REWARD = 1
GAP_PENALTY = -0.5
GAP_EXT_PENALTY = -0.5
MISMATCH_PENALTY = -0.5
GAP_CHAR = '@'
GAP_CHAR = "@"
ONE_ALIGNMENT_ONLY = False
SPACE_MISMATCH_PENALTY = .1
SPACE_MISMATCH_PENALTY = 0.1
def _join_char_list(alignment_tuple):
""" Post-process alignment results for unicode support """
gt_char_list, noise_char_list, score, start, end = alignment_tuple
return "".join(gt_char_list), "".join(noise_char_list), score, start, end
def _align_seg(gt, noise,
match_reward=MATCH_REWARD, mismatch_pen=MISMATCH_PENALTY,
gap_pen=GAP_PENALTY, gap_ext_pen=GAP_EXT_PENALTY, space_mismatch_penalty=SPACE_MISMATCH_PENALTY,
gap_char=GAP_CHAR, one_alignment_only=ONE_ALIGNMENT_ONLY):
""" Wrapper function for Bio.pairwise2.align.globalms(), which
def _align_seg(
gt,
noise,
match_reward=MATCH_REWARD,
mismatch_pen=MISMATCH_PENALTY,
gap_pen=GAP_PENALTY,
gap_ext_pen=GAP_EXT_PENALTY,
space_mismatch_penalty=SPACE_MISMATCH_PENALTY,
gap_char=GAP_CHAR,
one_alignment_only=ONE_ALIGNMENT_ONLY,
):
"""Wrapper function for Bio.pairwise2.align.globalms(), which
calls the sequence alignment algorithm (Needleman-Wunsch)
Arguments:
gt {str} -- a ground truth string
noise {str} -- a string with ocr noise
Keyword Arguments:
match_reward {int} -- reward for matching characters (default: {MATCH_REWARD})
mismatch_pen {int} -- penalty for mistmatching characters (default: {MISMATCH_PENALTY})
gap_pen {int} -- penalty for creating a gap (default: {GAP_PENALTY})
gap_ext_pen {int} -- penalty for extending a gap (default: {GAP_EXT_PENALTY})
Returns:
list -- a list of alignment tuples. Each alignment tuple
is one possible alignment candidate.
is one possible alignment candidate.
A tuple (str, str, int, int, int) contains the following information:
(aligned_gt, aligned_noise, alignment_score, alignment_start, alignment_end)
@ -47,22 +58,32 @@ def _align_seg(gt, noise,
...
]
"""
def match_reward_fn (x,y) :
def match_reward_fn(x, y):
if x == y:
return match_reward
elif x == " " or y == " ":
elif x == " " or y == " ":
# mismatch of a character with a space get a stronger penalty
return mismatch_pen - space_mismatch_penalty
return mismatch_pen - space_mismatch_penalty
else:
return mismatch_pen
# NOTE: Work-around to enable support full Unicode character set - passing string as a list of characters
alignments = pairwise2.align.globalcs(list(gt), list(noise), match_reward_fn,
gap_pen, gap_ext_pen, gap_char=[gap_char], one_alignment_only=ONE_ALIGNMENT_ONLY)
# NOTE: Work-around to enable support full Unicode character set - passing string as a list of characters
alignments = pairwise2.align.globalcs(
list(gt),
list(noise),
match_reward_fn,
gap_pen,
gap_ext_pen,
gap_char=[gap_char],
one_alignment_only=ONE_ALIGNMENT_ONLY,
)
# Alignment result is a list of char instead of string because of the work-around
return list(map(_join_char_list, alignments))
def _select_alignment_candidates(alignments, target_num_gt_tokens):
""" Return an alignment that contains the desired number
"""Return an alignment that contains the desired number
of ground truth tokens from a list of possible alignments
Case Analysis:
@ -70,20 +91,20 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens):
be guaranteed by the nature of text alignment.
Invariant 2: we should not expect alignment introducing
additional ground truth tokens.
However, in some cases, the alignment algorithm can
introduce a group of GAP_CHARs as a separate token at the
However, in some cases, the alignment algorithm can
introduce a group of GAP_CHARs as a separate token at the
end of string, especially if there are lingering whitespaces.
E.g:
gt: "Boston is big " (num_tokens = 3)
noise: "B oston bi g"
noise: "B oston bi g"
aligned_gt: "B@oston is big @" (num_tokens = 4)
aligned_noise: "B oston @@@bi@ g"
Remember, the example above is just one out of the many possible alignment
Remember, the example above is just one out of the many possible alignment
candidates, and we need to search for the one with the target number of gt_tokens
E.g:
gt: "Boston is big " (num_tokens = 3)
noise: "B oston bi g"
noise: "B oston bi g"
aligned_gt: "B@oston is bi@g " (num_tokens = 3)
aligned_noise: "B oston @@@bi g@"
@ -93,12 +114,12 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens):
alignments {list} -- a list of alignment tuples as follows:
[(str1, str2, alignment_score, alignment_start, alignment_end), (str1, str2, ...), ...]
target_num_gt_tokens {int} -- the number of token in the aligned ground truth string should have
Raises:
ValueError: raises this error if
ValueError: raises this error if
1. all the alignment candidates does NOT have the target number of tokens OR
2. the aligned strings (str1 and str2) in the selected candidate are NOT EQUAL in length
2. the aligned strings (str1 and str2) in the selected candidate are NOT EQUAL in length
Returns:
an alignment tuple (str, str, int, int, int) with following information:
(str1, str2, alignment_score, alignment_start, alignment_end)
@ -111,29 +132,34 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens):
if num_aligned_gt_tokens == target_num_gt_tokens:
# Invariant 1
if len(aligned_gt) != len(aligned_noise):
raise ValueError(f"Aligned strings are not equal in length: \naligned_gt: '{aligned_gt}'\naligned_noise '{aligned_noise}'\n")
raise ValueError(
f"Aligned strings are not equal in length: \naligned_gt: '{aligned_gt}'\naligned_noise '{aligned_noise}'\n"
)
# Returns the FIRST candidate that satisfies the invariant
return alignment
raise ValueError(f"No alignment candidates with {target_num_gt_tokens} tokens. Total candidates: {len(alignments)}")
raise ValueError(
f"No alignment candidates with {target_num_gt_tokens} tokens. Total candidates: {len(alignments)}"
)
def align(gt, noise, gap_char=GAP_CHAR):
"""Align two text segments via sequence alignment algorithm
NOTE: this algorithm is O(N^2) and is NOT efficient for longer text.
NOTE: this algorithm is O(N^2) and is NOT efficient for longer text.
Please refer to `genalog.text.anchor` for faster alignment on longer strings.
Arguments:
gt {str} -- ground true text (should not contain GAP_CHAR)
noise {str} -- str with ocr noise (should not contain GAP_CHAR)
Keyword Arguments:
gap_char {char} -- gap char used in alignment algorithm (default: {GAP_CHAR})
Returns:
a tuple (str, str) of aligned ground truth and noise:
(aligned_gt, aligned_noise)
Invariants:
The returned aligned strings will satisfy the following invariants:
1. len(aligned_gt) == len(aligned_noise)
@ -143,34 +169,39 @@ def align(gt, noise, gap_char=GAP_CHAR):
aligned_gt: "N@ew @@York @is big@@" (num_tokens = 4)
"""
if not gt and not noise: # Both inputs are empty string
return '', ''
elif not gt: # Either is empty
return gap_char*len(noise), noise
if not gt and not noise: # Both inputs are empty string
return "", ""
elif not gt: # Either is empty
return gap_char * len(noise), noise
elif not noise:
return gt, gap_char*len(gt)
return gt, gap_char * len(gt)
else:
num_gt_tokens = len(tokenize(gt))
alignments = _align_seg(gt, noise, gap_char=gap_char)
try:
aligned_gt, aligned_noise, _, _, _ = _select_alignment_candidates(alignments, num_gt_tokens)
aligned_gt, aligned_noise, _, _, _ = _select_alignment_candidates(
alignments, num_gt_tokens
)
except ValueError as e:
raise ValueError(f"Error with input strings '{gt}' and '{noise}': \n{str(e)}")
raise ValueError(
f"Error with input strings '{gt}' and '{noise}': \n{str(e)}"
)
return aligned_gt, aligned_noise
def _format_alignment(align1, align2):
"""Wrapper function for Bio.pairwise2.format_alignment()
Arguments:
align1 {str} -- alignment str
align2 {str} -- second str for alignment
Returns:
a string with formatted alignment.
a string with formatted alignment.
'|' is for matching character
'.' is for substition
'-' indicates gap
For example:
"
New York is big.
@ -178,44 +209,50 @@ def _format_alignment(align1, align2):
New Yerk@is big.
"
"""
formatted_str = pairwise2.format_alignment(align1, align2, 0, 0, len(align1), full_sequences=True)
formatted_str = pairwise2.format_alignment(
align1, align2, 0, 0, len(align1), full_sequences=True
)
# Remove the "Score=0" from the str
formatted_str_no_score = formatted_str.replace("\n Score=0", "")
return formatted_str_no_score
def _find_token_start(s, index):
"""Find the position of the start of token
"""Find the position of the start of token
Arguments:
s {str} -- string to search in
index {int} -- index to begin search from
Returns:
- position {int} of the first non-whitespace character
Raises:
ValueError: if input s is an empty string
IndexError: if is out-of-bound index
"""
max_index = len(s) - 1
if len(s) == 0: raise ValueError("Cannot search in an empty string")
if index > max_index: raise IndexError(f"Out-of-bound index: {index} in string: {s}")
if len(s) == 0:
raise ValueError("Cannot search in an empty string")
if index > max_index:
raise IndexError(f"Out-of-bound index: {index} in string: {s}")
while index < max_index and _is_spacing(s[index]):
index += 1
return index
def _find_token_end(s, index):
"""Find the position of the end of a token
*** Important ***
This method ALWAYS return index within the bound of the string.
So, for single character string (eg. "c"), it will return 0.
So, for single character string (eg. "c"), it will return 0.
Arguments:
s {str} -- string to search in
index {int} -- index to begin search from
Returns:
- position {int} of the first non-whitespace character
@ -224,40 +261,44 @@ def _find_token_end(s, index):
IndexError: if is out-of-bound index
"""
max_index = len(s) - 1
if len(s) == 0: raise ValueError("Cannot search in an empty string")
if index > max_index: raise IndexError(f"Out-of-bound index: {index} in string: {s}")
if len(s) == 0:
raise ValueError("Cannot search in an empty string")
if index > max_index:
raise IndexError(f"Out-of-bound index: {index} in string: {s}")
while index < max_index and not _is_spacing(s[index]):
index += 1
return index
def _find_next_token(s, start):
""" Return the start and end index of a token in a string
"""Return the start and end index of a token in a string
*** Important ***
This method ALWAYS return indices within the bound of the string.
So, for single character string (eg. "c"), it will return (0,0)
Arguments:
s {str} -- the string to search token in
start {int} -- the starting index to start search in
start {int} -- the starting index to start search in
Returns:
a tuple of (int, int) responding to the start and end indices of
a token in the given s.
a token in the given s.
"""
token_start = _find_token_start(s, start)
token_end = _find_token_end(s, token_start)
return token_start, token_end
def _is_valid_token(token, gap_char=GAP_CHAR):
""" Returns true if token is valid (i.e. compose of non-gap characters)
Invalid tokens are
"""Returns true if token is valid (i.e. compose of non-gap characters)
Invalid tokens are
1. multiple occurrences of the GAP_CHAR (e.g. '@@@')
2. empty string ("")
3. string with spaces (" ")
**Important: this method expects one token and not multiple space-separated tokens
**Important: this method expects one token and not multiple space-separated tokens
Arguments:
token {str} -- input string token
@ -269,12 +310,15 @@ def _is_valid_token(token, gap_char=GAP_CHAR):
bool-- True if is a valid token, false otherwise
"""
# Matches multiples of 'gap_char' that are padded with whitespace characters on either end
INVALID_TOKEN_REGEX = rf'^\s*{re.escape(gap_char)}*\s*$' # Escape special regex chars
INVALID_TOKEN_REGEX = (
rf"^\s*{re.escape(gap_char)}*\s*$" # Escape special regex chars
)
return not re.match(INVALID_TOKEN_REGEX, token)
def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
"""Parse alignment to pair ground truth tokens with noise tokens
Case 1: Case 2: Case 3: Case 4: Case 5:
one-to-many many-to-one many-to-many missing tokens one-to-one
(Case 1&2 Composite)
@ -295,10 +339,10 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
Returns:
a tuple (list, list) of two 2D int arrays as follows:
(gt_to_noise_mapping, noise_to_gt_mapping)
where each array defines the mapping between aligned gt tokens
where each array defines the mapping between aligned gt tokens
to noise tokens and vice versa.
For example:
@ -316,44 +360,44 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
# tk_start_gt=12 tk_index_gt = 4 total_tokens = 4
# | tk_end_gt=15 tk_index_noise = 3 total_tokens = 3
# | |
# "New York is big " gt_token:big gt_to_noise_mapping: [[0][0][][2]]
# "New York is big " gt_token:big gt_to_noise_mapping: [[0][0][][2]]
# "New@york @@ big " noise_token:big noise_to_gt_mapping: [[0][][3]]
# | |
# | tk_end_noise=15 INVALID TOKENS: @*
# tk_start_noise=12
# 1. Initialization:
#1. IMPORTANT: add whitespace padding (' ') to both end of aligned_gt and aligned_noise to avoid overflow
#2. find the first gt_token and the first noise_token
#3. tk_index_gt = tk_index_noise = 0
# 1. IMPORTANT: add whitespace padding (' ') to both end of aligned_gt and aligned_noise to avoid overflow
# 2. find the first gt_token and the first noise_token
# 3. tk_index_gt = tk_index_noise = 0
# 2. While tk_index_gt < total_tk_gt and tk_index_noise < total_tk_noise:
#1. if tk_end_gt == tk_end_noise (1-1 case)
#1. check if the two tokens are valid
#1. if so, register tokens in mapping
#2. find next gt_token token and next noise_token
#3. tk_index_gt ++, tk_index_noise ++
#3. if tk_end_gt < tk_end_noise (many-1 case)
#1. while tk_end_gt < tk_end_noise
#1. check if gt_token and noise_token are BOTH valid
#1. if so register tokens in mapping
#2. find next gt_token
#3. tk_index_gt ++
#4. if tk_end_gt > tk_end_noise (1-many case)
#1. while tk_end_gt > tk_end_noise
#1. check if gt_token and noise_token are BOTH valid
#1. if so register tokens in mapping
#2. find next noise token
#3. tk_index_noise ++
# 1. if tk_end_gt == tk_end_noise (1-1 case)
# 1. check if the two tokens are valid
# 1. if so, register tokens in mapping
# 2. find next gt_token token and next noise_token
# 3. tk_index_gt ++, tk_index_noise ++
# 3. if tk_end_gt < tk_end_noise (many-1 case)
# 1. while tk_end_gt < tk_end_noise
# 1. check if gt_token and noise_token are BOTH valid
# 1. if so register tokens in mapping
# 2. find next gt_token
# 3. tk_index_gt ++
# 4. if tk_end_gt > tk_end_noise (1-many case)
# 1. while tk_end_gt > tk_end_noise
# 1. check if gt_token and noise_token are BOTH valid
# 1. if so register tokens in mapping
# 2. find next noise token
# 3. tk_index_noise ++
# sanity check
if len(aligned_gt) != len(aligned_noise):
raise ValueError("Aligned strings are not equal in length")
raise ValueError("Aligned strings are not equal in length")
total_gt_tokens = len(tokenize(aligned_gt))
total_noise_tokens = len(tokenize(aligned_noise))
# Initialization
aligned_gt += ' ' # add whitespace padding to prevent ptr overflow
aligned_noise += ' ' # add whitespace padding to prevent ptr overflow
aligned_gt += " " # add whitespace padding to prevent ptr overflow
aligned_noise += " " # add whitespace padding to prevent ptr overflow
tk_index_gt = tk_index_noise = 0
tk_start_gt, tk_end_gt = _find_next_token(aligned_gt, 0)
tk_start_noise, tk_end_noise = _find_next_token(aligned_noise, 0)
@ -364,8 +408,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
# If both tokens are aligned (one-to-one case)
if tk_end_gt == tk_end_noise:
# if both gt_token and noise_token are valid (missing token case)
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
if _is_valid_token(
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
) and _is_valid_token(
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
):
# register the index of these tokens in the gt_to_noise_mapping
index_row = gt_to_noise_mapping[tk_index_gt]
index_row.append(tk_index_noise)
@ -381,8 +428,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
elif tk_end_gt < tk_end_noise:
while tk_end_gt < tk_end_noise:
# if both gt_token and noise_token are valid (missing token case)
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
if _is_valid_token(
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
) and _is_valid_token(
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
):
# register the index of these tokens in the gt_to_noise_mapping
index_row = gt_to_noise_mapping[tk_index_gt]
index_row.append(tk_index_noise)
@ -397,8 +447,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
else:
while tk_end_gt > tk_end_noise:
# if both gt_token and noise_token are valid (missing token case)
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
if _is_valid_token(
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
) and _is_valid_token(
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
):
# register the index of these token in the gt_to_noise mapping
index_row = gt_to_noise_mapping[tk_index_gt]
index_row.append(tk_index_noise)
@ -406,8 +459,10 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
index_row = noise_to_gt_mapping[tk_index_noise]
index_row.append(tk_index_gt)
# Find the next gt_token
tk_start_noise, tk_end_noise = _find_next_token(aligned_noise, tk_end_noise)
tk_start_noise, tk_end_noise = _find_next_token(
aligned_noise, tk_end_noise
)
# Increment index
tk_index_noise += 1
return gt_to_noise_mapping, noise_to_gt_mapping
return gt_to_noise_mapping, noise_to_gt_mapping

Просмотреть файл

@ -1,36 +1,38 @@
"""
Baseline alignment algorithm is slow on long documents.
The idea is to break down the longer text into smaller fragments
for quicker alignment on individual pieces. We refer "anchor words"
Baseline alignment algorithm is slow on long documents.
The idea is to break down the longer text into smaller fragments
for quicker alignment on individual pieces. We refer "anchor words"
as these points of breakage.
The bulk of this algorithm is to identify these "anchor words".
This is an re-implementation of the algorithm in this paper
This is an re-implementation of the algorithm in this paper
"A Fast Alignment Scheme for Automatic OCR Evaluation of Books"
(https://ieeexplore.ieee.org/document/6065412)
We rely on `genalog.text.alignment` to align the subsequences.
"""
import itertools
from collections import Counter
from genalog.text import preprocess, alignment
from genalog.text.lcs import LCS
from genalog.text.alignment import GAP_CHAR
# The recursively portion of the algorithm will run on
# segments longer than this value to find anchor points in
from genalog.text import alignment, preprocess
from genalog.text.alignment import GAP_CHAR
from genalog.text.lcs import LCS
# The recursively portion of the algorithm will run on
# segments longer than this value to find anchor points in
# the longer segment (to break it up further).
MAX_ALIGN_SEGMENT_LENGTH = 100 # in characters length
MAX_ALIGN_SEGMENT_LENGTH = 100 # in characters length
def get_unique_words(tokens, case_sensitive=False):
""" Get a set of unique words from a Counter dictionary of word occurrences
"""Get a set of unique words from a Counter dictionary of word occurrences
Arguments:
d {dict} -- a Counter dictionary of word occurrences
Keyword Arguments:
case_sensitive {bool} -- whether unique words are case sensitive
case_sensitive {bool} -- whether unique words are case sensitive
(default: {False})
Returns:
@ -38,14 +40,15 @@ def get_unique_words(tokens, case_sensitive=False):
"""
if case_sensitive:
word_count = Counter(tokens)
return {word for word, count in word_count.items() if count < 2 }
return {word for word, count in word_count.items() if count < 2}
else:
tokens_lowercase = [tk.lower() for tk in tokens]
word_count = Counter(tokens_lowercase)
return {tk for tk in tokens if word_count[tk.lower()] < 2 }
return {tk for tk in tokens if word_count[tk.lower()] < 2}
def segment_len(tokens):
""" Get length of the segment
"""Get length of the segment
Arguments:
segment {list} -- a list of tokens
@ -54,8 +57,9 @@ def segment_len(tokens):
"""
return sum(map(len, tokens))
def get_word_map(unique_words, src_tokens):
""" Arrange the set of unique words by the order they original appear in the text
"""Arrange the set of unique words by the order they original appear in the text
Arguments:
unique_words {set} -- a set of unique words
@ -70,18 +74,19 @@ def get_word_map(unique_words, src_tokens):
# Find the indices of the unique words in the source text
unique_word_indices = map(src_tokens.index, unique_words)
word_map = list(zip(unique_words, unique_word_indices))
word_map.sort(key = lambda x: x[1]) # Re-arrange order by the index
word_map.sort(key=lambda x: x[1]) # Re-arrange order by the index
return word_map
def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
""" Find the location of anchor words in both the gt and ocr text.
Anchor words are location where we can split both the source gt
"""Find the location of anchor words in both the gt and ocr text.
Anchor words are location where we can split both the source gt
and ocr text into smaller text fragment for faster alignment.
Arguments:
gt_tokens {list} -- a list of ground truth tokens
ocr_tokens {list} -- a list of tokens from OCR'ed document
Keyword Arguments:
min_anchor_len {int} -- minimum len of the anchor word
(default: {2})
@ -91,9 +96,9 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
(anchor_map_gt, anchor_map_ocr)
1. `anchor_map_gt` is a `word_map` that locates all the anchor words in the gt tokens
2. `anchor_map_gt` is a `word_map` that locates all the anchor words in the ocr tokens
For example:
Input:
Input:
gt_tokens: ["b", "a", "c"]
ocr_tokens: ["c", "b", "a"]
Ourput:
@ -113,15 +118,15 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
unique_word_map_ocr = get_word_map(unique_words_common, ocr_tokens)
# Unzip to get the ordered unique_words
ordered_unique_words_gt, _ = zip(*unique_word_map_gt)
ordered_unique_words_ocr, _ = zip(*unique_word_map_ocr)
ordered_unique_words_ocr, _ = zip(*unique_word_map_ocr)
# Join words into a space-separated string for finding LCS
unique_words_gt_str = preprocess.join_tokens(ordered_unique_words_gt)
unique_words_gt_str = preprocess.join_tokens(ordered_unique_words_gt)
unique_words_ocr_str = preprocess.join_tokens(ordered_unique_words_ocr)
# 3. Find the LCS between the two ordered list of unique words
lcs = LCS(unique_words_gt_str, unique_words_ocr_str)
lcs_str = lcs.get_str()
# 4. Break up the LCS string into tokens
lcs_words = set(preprocess.tokenize(lcs_str))
@ -129,19 +134,30 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
anchor_words = lcs_words.intersection(unique_words_common)
# 6. Filter the unique words to keep the anchor words ONLY
anchor_map_gt = list(filter(
# This is a list of (unique_word, unique_word_index)
lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_gt
))
anchor_map_ocr = list(filter(
lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_ocr
))
anchor_map_gt = list(
filter(
# This is a list of (unique_word, unique_word_index)
lambda word_coordinate: word_coordinate[0] in anchor_words,
unique_word_map_gt,
)
)
anchor_map_ocr = list(
filter(
lambda word_coordinate: word_coordinate[0] in anchor_words,
unique_word_map_ocr,
)
)
return anchor_map_gt, anchor_map_ocr
def find_anchor_recur(gt_tokens, ocr_tokens,
start_pos_gt=0, start_pos_ocr=0,
max_seg_length=MAX_ALIGN_SEGMENT_LENGTH):
""" Recursively find anchor positions in the gt and ocr text
def find_anchor_recur(
gt_tokens,
ocr_tokens,
start_pos_gt=0,
start_pos_ocr=0,
max_seg_length=MAX_ALIGN_SEGMENT_LENGTH,
):
"""Recursively find anchor positions in the gt and ocr text
Arguments:
gt_tokens {list} -- a list of ground truth tokens
@ -150,12 +166,12 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
Keyword Arguments:
start_pos {int} -- a constant to add to all the resulting indices
(default: {0})
max_seg_length {int} -- trigger recursion if any text segment is larger than this
max_seg_length {int} -- trigger recursion if any text segment is larger than this
(default: {MAX_ALIGN_SEGMENT_LENGTH})
Raises:
ValueError: when there different number of anchor points in gt and ocr.
Returns:
tuple -- two lists of token indices:
(output_gt_anchors, output_ocr_anchors)
@ -165,7 +181,7 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
# 1. Try to find anchor words
anchor_word_map_gt, anchor_word_map_ocr = get_anchor_map(gt_tokens, ocr_tokens)
# 2. Check invariant
# 2. Check invariant
if len(anchor_word_map_gt) != len(anchor_word_map_ocr):
raise ValueError("Unequal number of anchor points across gt and ocr string")
# Return empty if no anchor word found
@ -182,17 +198,26 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
seg_start_gt = list(itertools.chain([0], anchor_indices_gt))
seg_start_ocr = list(itertools.chain([0], anchor_indices_ocr))
start_n_end_gt = zip(seg_start_gt, itertools.chain(anchor_indices_gt, [None]))
start_n_end_ocr = zip(seg_start_ocr, itertools.chain(anchor_indices_ocr, [None]))
start_n_end_ocr = zip(seg_start_ocr, itertools.chain(anchor_indices_ocr, [None]))
gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt]
ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr]
# 4. Loop through each segment
for gt_seg, ocr_seg, gt_start, ocr_start in zip(gt_segments, ocr_segments, seg_start_gt, seg_start_ocr):
if segment_len(gt_seg) > max_seg_length or segment_len(ocr_seg) > max_seg_length:
# 4. Loop through each segment
for gt_seg, ocr_seg, gt_start, ocr_start in zip(
gt_segments, ocr_segments, seg_start_gt, seg_start_ocr
):
if (
segment_len(gt_seg) > max_seg_length
or segment_len(ocr_seg) > max_seg_length
):
# recur on the segment in between the two anchors.
# We assume the first token in the segment is an anchor word
gt_anchors, ocr_anchors = find_anchor_recur(gt_seg[1:], ocr_seg[1:],
start_pos_gt=gt_start + 1, start_pos_ocr=ocr_start + 1,
max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = find_anchor_recur(
gt_seg[1:],
ocr_seg[1:],
start_pos_gt=gt_start + 1,
start_pos_ocr=ocr_start + 1,
max_seg_length=max_seg_length,
)
# shift the token indices
# (these are indices of a subsequence and does not reflect true position in the source sequence)
gt_anchors = set(map(lambda x: x + start_pos_gt, gt_anchors))
@ -200,12 +225,13 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
# merge recursion results
output_gt_anchors = output_gt_anchors.union(gt_anchors)
output_ocr_anchors = output_ocr_anchors.union(ocr_anchors)
return sorted(output_gt_anchors), sorted(output_ocr_anchors)
def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_LENGTH):
"""A faster alignment scheme of two text segments. This method first
breaks the strings into smaller segments with anchor words.
breaks the strings into smaller segments with anchor words.
Then these smaller segments are aligned.
NOTE: this function shares the same contract as `genalog.text.alignment.align()`
@ -222,7 +248,7 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
"The planet Mar, "
"The plamet Maris, "
"I scarcely need "
"I scacely neee "
@ -230,7 +256,7 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
"remind te reader,"
And run sequence alignment on each pair.
Arguments:
gt {str} -- ground truth text
noise {str} -- text with ocr noise
@ -248,11 +274,17 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
ocr_tokens = preprocess.tokenize(ocr)
# 1. Find anchor positions
gt_anchors, ocr_anchors = find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = find_anchor_recur(
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
)
# 2. Split into segments
start_n_end_gt = zip(itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None]))
start_n_end_ocr = zip(itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None]))
start_n_end_gt = zip(
itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None])
)
start_n_end_ocr = zip(
itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None])
)
gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt]
ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr]
@ -263,13 +295,15 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
gt_segment = preprocess.join_tokens(gt_segment)
noisy_segment = preprocess.join_tokens(noisy_segment)
# Run alignment algorithm
aligned_seg_gt, aligned_seg_ocr = alignment.align(gt_segment, noisy_segment, gap_char=gap_char)
if aligned_seg_gt and aligned_seg_ocr: # if not empty string ""
aligned_seg_gt, aligned_seg_ocr = alignment.align(
gt_segment, noisy_segment, gap_char=gap_char
)
if aligned_seg_gt and aligned_seg_ocr: # if not empty string ""
aligned_segments_gt.append(aligned_seg_gt)
aligned_segments_ocr.append(aligned_seg_ocr)
# Stitch all segments together
aligned_gt = ' '.join(aligned_segments_gt)
aligned_noise = ' '.join(aligned_segments_ocr)
return aligned_gt, aligned_noise
# Stitch all segments together
aligned_gt = " ".join(aligned_segments_gt)
aligned_noise = " ".join(aligned_segments_ocr)
return aligned_gt, aligned_noise

Просмотреть файл

@ -4,7 +4,7 @@
usage: conll_format.py [-h] [--train_subset] [--test_subset]
[--gt_folder GT_FOLDER]
base_folder degraded_folder
base_folder degraded_folder
positional argument:
base_folder base directory containing the collection of dataset
@ -18,41 +18,42 @@ optional arguments:
optional arguments:
-h, --help show this help message and exit
example usage
example usage
(to run for specified degradation of the dataset on both train and test)
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
(to run for specified degradation of the dataset and ground truth)
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
--gt_folder='shared'
(to run for specified degradation of the dataset on only test subset)
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
--test_subset
(to run for specified degradation of the dataset on only train subset)
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
--train_subset
"""
import itertools
import difflib
import argparse
import concurrent.futures
import difflib
import itertools
import json
import os
import sys
import timeit
import concurrent.futures
from tqdm import tqdm
from genalog.text import ner_label, ner_label, alignment
from genalog.text import alignment, ner_label
EMPTY_SENTENCE_SENTINEL = "<<<<EMPTY_OCR_SENTENCE>>>>"
EMPTY_SENTENCE_SENTINEL_NER_LABEL = "O"
def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens):
"""
propagate_labels_sentences propagates clean labels for clean tokens to ocr tokens and splits ocr tokens into sentences
Parameters
----------
clean_tokens : list
@ -63,7 +64,7 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
list of sentences (each sentence is a list of tokens)
ocr_tokens : list
list of tokens in ocr text
Returns
-------
list, list
@ -73,9 +74,18 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
# Ensure equal number of tokens in both clean_tokens and clean_sentences
merged_sentences = list(itertools.chain(*clean_sentences))
if merged_sentences != clean_tokens:
delta = "\n".join(difflib.unified_diff(merged_sentences, clean_tokens, fromfile='merged_clean_sentences', tofile="clean_tokens"))
raise ValueError(f"Inconsistent tokens. " +
f"Delta between clean_text and clean_labels:\n{delta}")
delta = "\n".join(
difflib.unified_diff(
merged_sentences,
clean_tokens,
fromfile="merged_clean_sentences",
tofile="clean_tokens",
)
)
raise ValueError(
"Inconsistent tokens. "
+ f"Delta between clean_text and clean_labels:\n{delta}"
)
# Ensure that there's OCR result
if len(ocr_tokens) == 0:
raise ValueError("Empty OCR tokens.")
@ -83,14 +93,18 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
raise ValueError("Empty clean tokens.")
# 1. Propagate labels + alig
ocr_labels, aligned_clean, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(clean_labels, clean_tokens, ocr_tokens)
ocr_labels, aligned_clean, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(
clean_labels, clean_tokens, ocr_tokens
)
# 2. Parse alignment to get mapping
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(aligned_clean, aligned_ocr, gap_char=gap_char)
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(
aligned_clean, aligned_ocr, gap_char=gap_char
)
# 3. Find sentence breaks in clean text sentences
gt_to_ocr_mapping_is_empty = [len(mapping) == 0 for mapping in gt_to_ocr_mapping]
gt_to_ocr_mapping_is_empty_reverse = gt_to_ocr_mapping_is_empty[::-1]
sentence_index = []
sentence_token_counts = 0
for sentence in clean_sentences:
@ -108,12 +122,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
ocr_start = 0
# if gt token at sentence break is not mapped to any ocr token
elif len(gt_to_ocr_mapping[gt_start]) < 1:
try: # finding next gt token that is mapped to an ocr token
try: # finding next gt token that is mapped to an ocr token
new_gt_start = gt_to_ocr_mapping_is_empty.index(False, gt_start)
ocr_start = gt_to_ocr_mapping[new_gt_start][0]
# If no valid token mapping in the remaining gt tokens
except ValueError:
ocr_start = len(ocr_tokens) # use the last ocr token
ocr_start = len(ocr_tokens) # use the last ocr token
else:
ocr_start = gt_to_ocr_mapping[gt_start][0]
@ -121,12 +135,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
if gt_end >= len(gt_to_ocr_mapping):
ocr_end = len(ocr_tokens)
elif len(gt_to_ocr_mapping[gt_end]) < 1:
try: # finding next gt token that is mapped to an ocr token
try: # finding next gt token that is mapped to an ocr token
new_gt_end = gt_to_ocr_mapping_is_empty.index(False, gt_end)
ocr_end = gt_to_ocr_mapping[new_gt_end][0]
# If no valid token mapping in the remaining gt tokens
except ValueError:
ocr_end = len(ocr_tokens) # use the last ocr token
ocr_end = len(ocr_tokens) # use the last ocr token
else:
ocr_end = gt_to_ocr_mapping[gt_end][0]
ocr_sentence = ocr_tokens[ocr_start:ocr_end]
@ -135,11 +149,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
ocr_labels_sentences.append(ocr_sentence_labels)
return ocr_text_sentences, ocr_labels_sentences
def get_sentences_from_iob_format(iob_format_str):
sentences = []
sentence = []
for line in iob_format_str:
if line.strip() == '': # if line is empty (sentence separator)
if line.strip() == "": # if line is empty (sentence separator)
sentences.append(sentence)
sentence = []
else:
@ -149,47 +164,81 @@ def get_sentences_from_iob_format(iob_format_str):
# filter any empty sentences
return list(filter(lambda sentence: len(sentence) > 0, sentences))
def propagate_labels_sentence_single_file(arg):
clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext, input_filename = arg
clean_labels_file = os.path.join(clean_labels_dir, input_filename).replace(clean_label_ext, ".txt")
def propagate_labels_sentence_single_file(arg):
(
clean_labels_dir,
output_text_dir,
output_labels_dir,
clean_label_ext,
input_filename,
) = arg
clean_labels_file = os.path.join(clean_labels_dir, input_filename).replace(
clean_label_ext, ".txt"
)
ocr_text_file = os.path.join(output_text_dir, input_filename)
ocr_labels_file = os.path.join(output_labels_dir, input_filename)
if not os.path.exists(clean_labels_file):
print(f"Warning: missing clean label file '{clean_labels_file}'. Please check file corruption. Skipping this file index...")
print(
f"Warning: missing clean label file '{clean_labels_file}'. Please check file corruption. Skipping this file index..."
)
return
elif not os.path.exists(ocr_text_file):
print(f"Warning: missing ocr text file '{ocr_text_file}'. Please check file corruption. Skipping this file index...")
print(
f"Warning: missing ocr text file '{ocr_text_file}'. Please check file corruption. Skipping this file index..."
)
return
else:
with open(clean_labels_file, 'r', encoding='utf-8') as clf:
with open(clean_labels_file, "r", encoding="utf-8") as clf:
tokens_labels_str = clf.readlines()
clean_tokens = [line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2]
clean_labels = [line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2]
clean_tokens = [
line.split()[0].strip()
for line in tokens_labels_str
if len(line.split()) == 2
]
clean_labels = [
line.split()[1].strip()
for line in tokens_labels_str
if len(line.split()) == 2
]
clean_sentences = get_sentences_from_iob_format(tokens_labels_str)
# read ocr tokens
with open(ocr_text_file, 'r', encoding='utf-8') as otf:
ocr_text_str = ' '.join(otf.readlines())
ocr_tokens = [token.strip() for token in ocr_text_str.split()] # already tokenized in data
with open(ocr_text_file, "r", encoding="utf-8") as otf:
ocr_text_str = " ".join(otf.readlines())
ocr_tokens = [
token.strip() for token in ocr_text_str.split()
] # already tokenized in data
try:
ocr_tokens_sentences, ocr_labels_sentences = propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
ocr_tokens_sentences, ocr_labels_sentences = propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
except Exception as e:
print(f"\nWarning: error processing '{input_filename}': {str(e)}.\nSkipping this file...")
print(
f"\nWarning: error processing '{input_filename}': {str(e)}.\nSkipping this file..."
)
return
# Write result to file
with open(ocr_labels_file, 'w', encoding="utf-8") as olf:
for ocr_tokens, ocr_labels in zip(ocr_tokens_sentences, ocr_labels_sentences):
if len(ocr_tokens) == 0: # if empty OCR sentences
olf.write(f'{EMPTY_SENTENCE_SENTINEL}\t{EMPTY_SENTENCE_SENTINEL_NER_LABEL}\n')
else:
for token, label in zip(ocr_tokens, ocr_labels):
olf.write(f"{token}\t{label}\n")
olf.write('\n')
# Write result to file
with open(ocr_labels_file, "w", encoding="utf-8") as olf:
for ocr_tokens, ocr_labels in zip(
ocr_tokens_sentences, ocr_labels_sentences
):
if len(ocr_tokens) == 0: # if empty OCR sentences
olf.write(
f"{EMPTY_SENTENCE_SENTINEL}\t{EMPTY_SENTENCE_SENTINEL_NER_LABEL}\n"
)
else:
for token, label in zip(ocr_tokens, ocr_labels):
olf.write(f"{token}\t{label}\n")
olf.write("\n")
def propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext):
def propagate_labels_sentences_multiprocess(
clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext
):
"""
propagate_labels_sentences_all_files propagates labels and sentences for all files in dataset
Parameters
----------
clean_labels_dir : str
@ -204,20 +253,28 @@ def propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, o
file extension of the clean_labels
"""
clean_label_files = os.listdir(clean_labels_dir)
args = list(map(
lambda clean_label_filename:
(clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext, clean_label_filename),
clean_label_files
))
args = list(
map(
lambda clean_label_filename: (
clean_labels_dir,
output_text_dir,
output_labels_dir,
clean_label_ext,
clean_label_filename,
),
clean_label_files,
)
)
with concurrent.futures.ProcessPoolExecutor() as executor:
iterator = executor.map(propagate_labels_sentence_single_file, args)
for _ in tqdm(iterator, total=len(args)): # wrapping tqdm for progress report
for _ in tqdm(iterator, total=len(args)): # wrapping tqdm for progress report
pass
def extract_ocr_text(input_file, output_file):
"""
extract_ocr_text from GROK json
Parameters
----------
input_file : str
@ -227,16 +284,17 @@ def extract_ocr_text(input_file, output_file):
"""
out_dir = os.path.dirname(output_file)
in_file_name = os.path.basename(input_file)
file_pre = in_file_name.split('_')[-1].split('.')[0]
output_file_name = '{}.txt'.format(file_pre)
file_pre = in_file_name.split("_")[-1].split(".")[0]
output_file_name = "{}.txt".format(file_pre)
output_file = os.path.join(out_dir, output_file_name)
with open(input_file, 'r', encoding='utf-8') as fin:
with open(input_file, "r", encoding="utf-8") as fin:
json_data = json.load(fin)
json_dict = json_data[0]
text = json_dict['text']
with open(output_file, 'wb') as fout:
text = json_dict["text"]
with open(output_file, "wb") as fout:
fout.write(text.encode("utf-8"))
def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
"""
check_n_sentences prints file name if number of sentences is different in clean and OCR files
@ -253,49 +311,54 @@ def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
text_files = os.listdir(output_labels_dir)
skip_files = []
for text_filename in tqdm(text_files):
clean_labels_file = os.path.join(clean_labels_dir, text_filename).replace(".txt", clean_label_ext)
clean_labels_file = os.path.join(clean_labels_dir, text_filename).replace(
".txt", clean_label_ext
)
ocr_labels_file = os.path.join(output_labels_dir, text_filename)
remove_first_line(clean_labels_file, clean_labels_file)
remove_first_line(ocr_labels_file, ocr_labels_file)
remove_last_line(clean_labels_file, clean_labels_file)
remove_last_line(ocr_labels_file, ocr_labels_file)
with open(clean_labels_file, 'r', encoding='utf-8') as lf:
with open(clean_labels_file, "r", encoding="utf-8") as lf:
clean_tokens_labels = lf.readlines()
with open(ocr_labels_file, 'r', encoding='utf-8') as of:
with open(ocr_labels_file, "r", encoding="utf-8") as of:
ocr_tokens_labels = of.readlines()
error = False
n_clean_sentences = 0
nl = False
for line in clean_tokens_labels:
if line == '\n':
if line == "\n":
if nl is True:
error = True
else:
nl=True
nl = True
n_clean_sentences += 1
else:
nl = False
n_ocr_sentences = 0
nl = False
for line in ocr_tokens_labels:
if line == '\n':
if line == "\n":
if nl is True:
error = True
else:
nl=True
nl = True
n_ocr_sentences += 1
else:
nl = False
if error or n_ocr_sentences != n_clean_sentences:
print(f"Warning: Inconsistent numbers of sentences in '{text_filename}''." +
f"clean_sentences to ocr_sentences: {n_clean_sentences}:{n_ocr_sentences}")
print(
f"Warning: Inconsistent numbers of sentences in '{text_filename}''."
+ f"clean_sentences to ocr_sentences: {n_clean_sentences}:{n_ocr_sentences}"
)
skip_files.append(text_filename)
return skip_files
def remove_first_line(input_file, output_file):
"""
remove_first_line from files (some clean CoNLL files have an empty first line)
Parameters
----------
input_file : str
@ -303,17 +366,18 @@ def remove_first_line(input_file, output_file):
output_file : str
output file path
"""
with open(input_file, 'r', encoding='utf-8') as in_f:
with open(input_file, "r", encoding="utf-8") as in_f:
lines = in_f.readlines()
if len(lines) > 1 and lines[0].strip() == '':
if len(lines) > 1 and lines[0].strip() == "":
# the clean CoNLL formatted files had a newline as the first line
with open(output_file, 'w', encoding='utf-8') as out_f:
with open(output_file, "w", encoding="utf-8") as out_f:
out_f.writelines(lines[1:])
def remove_last_line(input_file, output_file):
"""
remove_last_line from files (some clean CoNLL files have an empty last line)
Parameters
----------
input_file : str
@ -322,16 +386,17 @@ def remove_last_line(input_file, output_file):
output file path
"""
with open(input_file, 'r', encoding='utf-8') as in_f:
with open(input_file, "r", encoding="utf-8") as in_f:
lines = in_f.readlines()
if len(lines) > 1 and lines[-1].strip() == '':
with open(output_file, 'w', encoding='utf-8') as out_f:
if len(lines) > 1 and lines[-1].strip() == "":
with open(output_file, "w", encoding="utf-8") as out_f:
out_f.writelines(lines[:-1])
def for_all_files(input_dir, output_dir, func):
"""
for_all_files will apply function to every file in a director
Parameters
----------
input_dir : str
@ -347,25 +412,34 @@ def for_all_files(input_dir, output_dir, func):
output_file = os.path.join(output_dir, text_filename)
func(input_file, output_file)
def main(args):
if not args.train_subset and not args.test_subset:
subsets = ['train', 'test']
subsets = ["train", "test"]
else:
subsets = []
if args.train_subset:
subsets.append('train')
subsets.append("train")
if args.test_subset:
subsets.append('test')
subsets.append("test")
for subset in subsets:
print("Processing {} subset...".format(subset))
clean_labels_dir = os.path.join(args.base_folder, args.gt_folder, subset,'clean_labels')
ocr_json_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr')
clean_labels_dir = os.path.join(
args.base_folder, args.gt_folder, subset, "clean_labels"
)
ocr_json_dir = os.path.join(
args.base_folder, args.degraded_folder, subset, "ocr"
)
output_text_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr_text')
output_labels_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr_labels')
output_text_dir = os.path.join(
args.base_folder, args.degraded_folder, subset, "ocr_text"
)
output_labels_dir = os.path.join(
args.base_folder, args.degraded_folder, subset, "ocr_labels"
)
# remove first empty line of labels file, if exists
for_all_files(clean_labels_dir, clean_labels_dir, remove_first_line)
@ -380,21 +454,50 @@ def main(args):
os.mkdir(output_labels_dir)
# make ocr labels files by propagating clean labels to ocr_text and creating files in ocr_labels
propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, output_labels_dir, args.clean_label_ext)
propagate_labels_sentences_multiprocess(
clean_labels_dir, output_text_dir, output_labels_dir, args.clean_label_ext
)
print("Validating number of sentences in gt and ocr labels")
check_n_sentences(clean_labels_dir, output_labels_dir, args.clean_label_ext) # check number of sentences and make sure same; print anomaly files
check_n_sentences(
clean_labels_dir, output_labels_dir, args.clean_label_ext
) # check number of sentences and make sure same; print anomaly files
def create_parser():
parser = argparse.ArgumentParser()
parser.add_argument("base_folder", help="base directory containing the collection of dataset")
parser.add_argument("degraded_folder", help="directory containing train and test subset for degradation")
parser.add_argument("--gt_folder", type=str, default="shared", help="directory containing the ground truth")
parser.add_argument("--clean_label_ext", type=str, default=".txt", help="file extension of the clean_labels files")
parser.add_argument('--train_subset', help="include if only train folder should be processed", action='store_true')
parser.add_argument('--test_subset', help="include if only test folder should be processed", action='store_true')
parser.add_argument(
"base_folder", help="base directory containing the collection of dataset"
)
parser.add_argument(
"degraded_folder",
help="directory containing train and test subset for degradation",
)
parser.add_argument(
"--gt_folder",
type=str,
default="shared",
help="directory containing the ground truth",
)
parser.add_argument(
"--clean_label_ext",
type=str,
default=".txt",
help="file extension of the clean_labels files",
)
parser.add_argument(
"--train_subset",
help="include if only train folder should be processed",
action="store_true",
)
parser.add_argument(
"--test_subset",
help="include if only test folder should be processed",
action="store_true",
)
return parser
if __name__ == '__main__':
if __name__ == "__main__":
start = timeit.default_timer()
parser = create_parser()
args = parser.parse_args()

Просмотреть файл

@ -1,14 +1,13 @@
from genalog.text import preprocess
class LCS():
class LCS:
""" Compute the Longest Common Subsequence (LCS) of two given string."""
def __init__(self, str_m, str_n):
self.str_m_len = len(str_m)
self.str_n_len = len(str_n)
dp_table = self._construct_dp_table(str_m, str_n)
self._lcs_len = dp_table[self.str_m_len][self.str_n_len]
self._lcs = self._find_lcs_str(str_m, str_n, dp_table)
def _construct_dp_table(self, str_m, str_n):
m = self.str_m_len
n = self.str_n_len
@ -16,14 +15,14 @@ class LCS():
# Initialize DP table
dp = [[0 for j in range(n + 1)] for i in range(m + 1)]
for i in range(1, m+1):
for j in range(1, n+1):
for i in range(1, m + 1):
for j in range(1, n + 1):
# Case 1: if char1 == char2
if str_m[i-1] == str_n[j-1]:
dp[i][j] = 1 + dp[i-1][j-1]
if str_m[i - 1] == str_n[j - 1]:
dp[i][j] = 1 + dp[i - 1][j - 1]
# Case 2: take the max of the values in the top and left cell
else:
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
return dp
def _find_lcs_str(self, str_m, str_n, dp_table):
@ -32,13 +31,13 @@ class LCS():
lcs = ""
while m > 0 and n > 0:
# same char
if str_m[m-1] == str_n[n-1]:
if str_m[m - 1] == str_n[n - 1]:
# prepend the character
lcs = str_m[m - 1] + lcs
lcs = str_m[m - 1] + lcs
m -= 1
n -= 1
# top cell > left cell
elif dp_table[m-1][n] > dp_table[m][n-1]:
elif dp_table[m - 1][n] > dp_table[m][n - 1]:
m -= 1
else:
n -= 1

Просмотреть файл

@ -1,8 +1,9 @@
from genalog.text import alignment, anchor
from genalog.text import preprocess
import itertools
import re
import string
import itertools
from genalog.text import alignment, anchor
from genalog.text import preprocess
# Both regex below has the following behavior:
# 1. whitespace-tolerant at both ends of the string
@ -10,71 +11,81 @@ import itertools
# For example, given a label 'B-PLACE'
# Group 1 (denoted by \1): Label Indicator (B-)
# Group 2 (denoted by \2): Label Name (PLACE)
MULTI_TOKEN_BEGIN_LABEL_REGEX = r'^\s*(B-)([a-z|A-Z]+)\s*$'
MULTI_TOKEN_INSIDE_LABEL_REGEX = r'^\s*(I-)([a-z|A-Z]+)\s*$'
MULTI_TOKEN_LABEL_REGEX = r'^\s*([B|I]-)([a-z|A-Z]+)\s*'
MULTI_TOKEN_BEGIN_LABEL_REGEX = r"^\s*(B-)([a-z|A-Z]+)\s*$"
MULTI_TOKEN_INSIDE_LABEL_REGEX = r"^\s*(I-)([a-z|A-Z]+)\s*$"
MULTI_TOKEN_LABEL_REGEX = r"^\s*([B|I]-)([a-z|A-Z]+)\s*"
# To avoid confusion in the Python interpreter,
# gap char should not be any of the following special characters
SPECIAL_CHAR = set(" \t\n'\x0b''\x0c''\r'") # Notice space characters (' ', '\t', '\n') are in this set.
# gap char should not be any of the following special characters
SPECIAL_CHAR = set(
" \t\n'\x0b''\x0c''\r'"
) # Notice space characters (' ', '\t', '\n') are in this set.
GAP_CHAR_SET = set(string.printable).difference(SPECIAL_CHAR)
# GAP_CHAR_SET = '!"#$%&()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~'
class GapCharError(Exception):
pass
def _is_begin_label(label):
""" Return true if the NER label is a begin label (eg. B-PLACE) """
return re.match(MULTI_TOKEN_BEGIN_LABEL_REGEX, label) != None
return re.match(MULTI_TOKEN_BEGIN_LABEL_REGEX, label) is not None
def _is_inside_label(label):
""" Return true if the NER label is an inside label (eg. I-PLACE) """
return re.match(MULTI_TOKEN_INSIDE_LABEL_REGEX, label) != None
return re.match(MULTI_TOKEN_INSIDE_LABEL_REGEX, label) is not None
def _is_multi_token_label(label):
""" Return true if the NER label is a multi token label (eg. B-PLACE, I-PLACE) """
return re.match(MULTI_TOKEN_LABEL_REGEX, label) != None
return re.match(MULTI_TOKEN_LABEL_REGEX, label) is not None
def _clean_multi_token_label(label):
""" Rid the multi-token-labels of whitespaces"""
return re.sub(MULTI_TOKEN_LABEL_REGEX, r'\1\2', label)
return re.sub(MULTI_TOKEN_LABEL_REGEX, r"\1\2", label)
def _convert_to_begin_label(label):
""" Convert an inside label, or I-label, (ex. I-PLACE) to a begin label, or B-Label, (ex. B-PLACE)
"""Convert an inside label, or I-label, (ex. I-PLACE) to a begin label, or B-Label, (ex. B-PLACE)
Arguments:
label {str} -- an NER label
Returns:
an NER label. This method DOES NOT alter the label unless it is an inside label
"""
if _is_inside_label(label):
# Replace the Label Indicator to 'B-'(\1) and keep the Label Name (\2)
return re.sub(MULTI_TOKEN_INSIDE_LABEL_REGEX, r'B-\2', label)
return re.sub(MULTI_TOKEN_INSIDE_LABEL_REGEX, r"B-\2", label)
return label
def _convert_to_inside_label(label):
""" Convert a begin label, or B-label, (ex. B-PLACE) to an inside label, or I-Label, (ex. B-PLACE)
"""Convert a begin label, or B-label, (ex. B-PLACE) to an inside label, or I-Label, (ex. B-PLACE)
Arguments:
label {str} -- an NER label
Returns:
an NER label. This method DOES NOT alter the label unless it is a begin label
"""
if _is_begin_label(label):
# Replace the Label Indicator to 'I-'(\1) and keep the Label Name (\2)
return re.sub(MULTI_TOKEN_BEGIN_LABEL_REGEX, r'I-\2', label)
return re.sub(MULTI_TOKEN_BEGIN_LABEL_REGEX, r"I-\2", label)
return label
def _is_missing_begin_label(begin_label, inside_label):
""" Validate a inside label given an begin label
"""Validate a inside label given an begin label
Arguments:
begin_label {str} -- a begin NER label used to
begin_label {str} -- a begin NER label used to
check if the given label is part of a multi-token label
inside_label {str} -- an inside label to check for its validity
Returns:
True if the inside label paired with the begin_label. False otherwise.
Also False if input is not an inside label
@ -87,20 +98,21 @@ def _is_missing_begin_label(begin_label, inside_label):
inside_label = _clean_multi_token_label(inside_label)
begin_label = _clean_multi_token_label(begin_label)
# convert inside label to a begin label for string comparison
# True if the two labels have different names
# True if the two labels have different names
# (e.g. B-LOC followed by I-ORG, and I-ORG is missing a begin label)
return _convert_to_begin_label(inside_label) != begin_label
else:
return True
def correct_ner_labels(labels):
""" Correct the given list of labels for the following case:
"""Correct the given list of labels for the following case:
1. Missing B-Label (i.e. I-PLACE I-PLACE -> B-PLACE I-PLACE)
Arguments:
labels {list} -- list of NER labels
Returns:
a list of NER labels
"""
@ -109,7 +121,7 @@ def correct_ner_labels(labels):
if _is_multi_token_label(label):
if _is_begin_label(label):
cur_begin_label = label
# else is an inside label, so we check if it's missing a begin label
# else is an inside label, so we check if it's missing a begin label
else:
if _is_missing_begin_label(cur_begin_label, label):
labels[i] = _convert_to_begin_label(label)
@ -118,10 +130,11 @@ def correct_ner_labels(labels):
else:
cur_begin_label = ""
return labels
def _select_from_multiple_ner_labels(label_indices):
""" Private method to select a NER label from a list of candidate
"""Private method to select a NER label from a list of candidate
Note: this method is used to tackle the issue when multiple gt tokens
are aligned to ONE ocr_token
@ -129,7 +142,7 @@ def _select_from_multiple_ner_labels(label_indices):
gt_labels: B-p I-p O O
| | | |
gt: New York is big
gt: New York is big
| \\ / |
ocr: New Yorkis big
| | |
@ -140,15 +153,16 @@ def _select_from_multiple_ner_labels(label_indices):
Arguments:
label_indices {list} -- a list of token indices
Returns:
a specific index
"""
# TODO: may need a more sophisticated way to select from multiple NER labels
return label_indices[0]
def _find_gap_char_candidates(gt_tokens, ocr_tokens):
""" Find a set of suitable GAP_CHARs based not in the set of input characters
"""Find a set of suitable GAP_CHARs based not in the set of input characters
Arguments:
gt_tokens {list} -- a list of tokens
@ -159,15 +173,18 @@ def _find_gap_char_candidates(gt_tokens, ocr_tokens):
1. the set of suitable GAP_CHARs
2. the set of input characters
"""
input_char_set = set(''.join(itertools.chain(gt_tokens, ocr_tokens))) # The set of input characters
input_char_set = set(
"".join(itertools.chain(gt_tokens, ocr_tokens))
) # The set of input characters
gap_char_set = GAP_CHAR_SET # The set of possible GAP_CHARs
# Find a set of gap_char that is NOT in the set of input characters
gap_char_candidates = gap_char_set.difference(input_char_set)
return gap_char_candidates, input_char_set
def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
""" Propagate NER label for ground truth tokens to to ocr tokens.
"""Propagate NER label for ground truth tokens to to ocr tokens.
NOTE that `gt_tokens` and `ocr_tokens` MUST NOT contain invalid tokens.
Invalid tokens are:
1. non-atomic tokens, or space-separated string ("New York")
@ -175,7 +192,7 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
4. string with spaces (" ")
Arguments:
gt_labels {list} -- a list of NER label for ground truth token
gt_labels {list} -- a list of NER label for ground truth token
gt_tokens {list} -- a list of ground truth string tokens
ocr_tokens {list} -- a list of OCR'ed text tokens
@ -185,8 +202,8 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
(default: {True})
Raises:
GapCharError:
when the set of input character is EQUAL
GapCharError:
when the set of input character is EQUAL
to set of all possible gap characters (GAP_CHAR_SET)
Returns:
@ -199,21 +216,30 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
`gap_char` is the char used to alignment for inserting gaps
"""
# Find a set of suitable GAP_CHAR based not in the set of input characters
gap_char_candidates, input_char_set = _find_gap_char_candidates(gt_tokens, ocr_tokens)
gap_char_candidates, input_char_set = _find_gap_char_candidates(
gt_tokens, ocr_tokens
)
if len(gap_char_candidates) == 0:
raise GapCharError("Exhausted all possible GAP_CHAR candidates for alignment." +
" Consider reducing cardinality of the input character set.\n" +
f"The set of possible GAP_CHAR candidates is: '{''.join(sorted(GAP_CHAR_SET))}'\n" +
f"The set of input character is: '{''.join(sorted(input_char_set))}'")
raise GapCharError(
"Exhausted all possible GAP_CHAR candidates for alignment."
+ " Consider reducing cardinality of the input character set.\n"
+ f"The set of possible GAP_CHAR candidates is: '{''.join(sorted(GAP_CHAR_SET))}'\n"
+ f"The set of input character is: '{''.join(sorted(input_char_set))}'"
)
else:
if alignment.GAP_CHAR in gap_char_candidates:
gap_char = alignment.GAP_CHAR # prefer to use default GAP_CHAR
gap_char = alignment.GAP_CHAR # prefer to use default GAP_CHAR
else:
gap_char = gap_char_candidates.pop()
return _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char, use_anchor=use_anchor)
return _propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char, use_anchor=use_anchor
)
def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment.GAP_CHAR, use_anchor=True):
""" Propagate NER label for ground truth tokens to to ocr tokens. Low level implementation
def _propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens, gap_char=alignment.GAP_CHAR, use_anchor=True
):
"""Propagate NER label for ground truth tokens to to ocr tokens. Low level implementation
NOTE: that `gt_tokens` and `ocr_tokens` MUST NOT contain invalid tokens.
Invalid tokens are:
@ -221,10 +247,10 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
2. multiple occurrences of the GAP_CHAR ('@@@')
3. empty string ("")
4. string with spaces (" ")
Case Analysis:
******************************** MULTI-TOKEN-LABELS ********************************
Case 1: Case 2: Case 3: Case 4: Case 5:
one-to-many many-to-one many-to-many missing tokens missing tokens
(Case 1&2 comb) (I-label) (B-label)
@ -233,24 +259,24 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
gt_token New York New York New York New York New York City
/ \\ / \\ \\/ /\\ / | | |
ocr_token N ew Yo rk NewYork N ew@York New York City
| | | | | | | | | |
| | | | | | | | | |
ocr label B-p I-p I-p I-p B-p B-p I-p B-p B-p I-p
******************************** SINGLE-TOKEN-LABELS ********************************
Case 1: Case 2: Case 3: Case 4:
one-to-many many-to-one many-to-many missing tokens
(Case 1&2 comb)
gt label O V O O V W O O
| | | | | | | |
gt_token something is big this is huge is big
/ \\ \\ \\/ /\\ /\\/ |
ocr_token so me thing isbig th isi shuge is
| | | | | | | |
ocr label o o o V O O V O
Case 1: Case 2: Case 3: Case 4:
one-to-many many-to-one many-to-many missing tokens
(Case 1&2 comb)
gt label O V O O V W O O
| | | | | | | |
gt_token something is big this is huge is big
/ \\ \\ \\/ /\\ /\\/ |
ocr_token so me thing isbig th isi shuge is
| | | | | | | |
ocr label o o o V O O V O
Arguments:
gt_labels {list} -- a list of NER label for ground truth token
gt_labels {list} -- a list of NER label for ground truth token
gt_tokens {list} -- a list of ground truth string tokens
ocr_tokens {list} -- a list of OCR'ed text tokens
@ -259,13 +285,13 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
use_anchor {bool} -- use faster alignment method with anchors if set to True
(default: {True})
Raises:
ValueError: when
ValueError: when
1. there is unequal number of gt_tokens and gt_labels
2. there is a non-atomic token in gt_tokens or ocr_tokens
3. there is an empty string in gt_tokens or ocr_tokens
4. there is a token full of space characters only in gt_tokens or ocr_tokens
5. gt_to_ocr_mapping has more tokens than gt_tokens
GapCharError: when
GapCharError: when
1. there is a token consisted of GAP_CHAR only
@ -278,24 +304,24 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
`aligned_ocr` is the ocr text aligned with ground true
`gap_char` is the char used to alignment for inserting gaps
For example,
For example,
given input:
gt_labels: ["B-place", "I-place", "o", "o"]
gt_tokens: ["New", "York", "is", "big"]
ocr_tokens: ["N", "ewYork", "big"]
output:
(
["B-place", "I-place", "o"],
"N@ew York is big",
"N ew@York@@@ big"
"N ew@York@@@ big"
)
"""
# Pseudo-algorithm:
# ocr_to_gt_mapping = [
# gt_labels: B-P I-P I-P O O B-P I-P [1, 2], ('YorkCity' maps to 'York' and 'City')
# gt_labels: B-P I-P I-P O O B-P I-P [1, 2], ('YorkCity' maps to 'York' and 'City')
# | | | | | | | [3], ('i' maps to 'is')
# gt_txt: "New York City is in New York" [3, 4], ('sin' maps to 'is' and 'in')
# \/ /\ | /\ [5], ('N' maps to 'New')
@ -312,13 +338,13 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
#
# gt_to_ocr_mapping = [
# gt_labels: B-P I-P I-P O O B-P I-P [], ('New' does not map to any ocr token)
# gt_labels: B-P I-P I-P O O B-P I-P [], ('New' does not map to any ocr token)
# | | | | | | | [0], ('York' maps to 'YorkCity')
# gt_txt: "New York City is in New York" [0], ('City' maps to 'YorkCity')
# \/ /\ | /\ [1, 2], ('is' maps to 'i' and 'sin')
# ocr_txt: "YorkCity i sin N ew" [2], ('in' maps to 'sin)
# | | | | | [3,4], ('New' maps to 'N' and 'ew')
# I-P O O B-P B-P [] ('York' does not map to any ocr token)
# | | | | | [3,4], ('New' maps to 'N' and 'ew')
# I-P O O B-P B-P [] ('York' does not map to any ocr token)
# ]
# STEP 2, clean up corner cases from multi-token-labels
@ -332,41 +358,53 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
# YorkCity
# We can address MULTI-TOKEN-LABELS Case 1 with following pseudo-algorithm:
# 1. For each gt_token in gt_to_ocr_mapping:
# 1. If the gt_token is mapped to 2 or more ocr_tokens AND the gt_token has a B-label
# 1. For all the ocr_tokens this gt_token mapped to
# 1. Keep the B-label for the 1st ocr_token
# 2. For the rest of the ocr_token, convert the B-label to an I-label
# 1. For each gt_token in gt_to_ocr_mapping:
# 1. If the gt_token is mapped to 2 or more ocr_tokens AND the gt_token has a B-label
# 1. For all the ocr_tokens this gt_token mapped to
# 1. Keep the B-label for the 1st ocr_token
# 2. For the rest of the ocr_token, convert the B-label to an I-label
# We can address the MULTI-TOKEN-LABELS Case 5 with the '_correct_ner_labels()' method
# Sanity check:
if len(gt_tokens) != len(gt_labels):
raise ValueError(f"Unequal number of gt_tokens ({len(gt_tokens)})" +
f"to that of gt_labels ({len(gt_labels)})")
for tk in (gt_tokens + ocr_tokens):
raise ValueError(
f"Unequal number of gt_tokens ({len(gt_tokens)})"
+ f"to that of gt_labels ({len(gt_labels)})"
)
for tk in gt_tokens + ocr_tokens:
if len(preprocess.tokenize(tk)) > 1:
raise ValueError(f"Invalid token '{tk}'. Tokens must be atomic.")
if not alignment._is_valid_token(tk, gap_char=gap_char):
if re.search(rf'{re.escape(gap_char)}+', tk): # Escape special regex chars
raise GapCharError(f"Invalid token '{tk}'. Tokens cannot be a chain repetition of the GAP_CHAR '{gap_char}'")
if re.search(rf"{re.escape(gap_char)}+", tk): # Escape special regex chars
raise GapCharError(
f"Invalid token '{tk}'. Tokens cannot be a chain repetition of the GAP_CHAR '{gap_char}'"
)
else:
raise ValueError(f"Invalid token '{tk}'. Tokens cannot be an empty string or a mix of space characters (spaces, tabs, newlines)")
raise ValueError(
f"Invalid token '{tk}'. Tokens cannot be an empty string or a mix of space characters (spaces, tabs, newlines)"
)
# Stitch tokens together into one string for alignment
gt_txt = preprocess.join_tokens(gt_tokens)
ocr_txt = preprocess.join_tokens(ocr_tokens)
# Align the ground truth and ocr text first
if use_anchor:
aligned_gt, aligned_ocr = anchor.align_w_anchor(gt_txt, ocr_txt, gap_char=gap_char)
aligned_gt, aligned_ocr = anchor.align_w_anchor(
gt_txt, ocr_txt, gap_char=gap_char
)
else:
aligned_gt, aligned_ocr = alignment.align(gt_txt, ocr_txt, gap_char=gap_char)
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(aligned_gt, aligned_ocr, gap_char=gap_char)
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(
aligned_gt, aligned_ocr, gap_char=gap_char
)
# Check invariant
if len(gt_to_ocr_mapping) != len(gt_tokens):
raise ValueError(f"Alignment modified number of gt_tokens. aligned_gt_tokens to gt_tokens: " +
f"{len(gt_to_ocr_mapping)}:{len(gt_tokens)}. \nCheck alignment.parse_alignment().")
raise ValueError(
"Alignment modified number of gt_tokens. aligned_gt_tokens to gt_tokens: "
+ f"{len(gt_to_ocr_mapping)}:{len(gt_tokens)}. \nCheck alignment.parse_alignment()."
)
ocr_labels = []
# STEP 1: naively propagate NER label based on text-alignment
@ -374,7 +412,9 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
# if is not mapping to missing a token (Case 4)
if ocr_to_gt_token_relationship:
# Find the corresponding gt_token it is aligned to
ner_label_index = _select_from_multiple_ner_labels(ocr_to_gt_token_relationship)
ner_label_index = _select_from_multiple_ner_labels(
ocr_to_gt_token_relationship
)
# Get the NER label for that particular gt_token
ocr_labels.append(gt_labels[ner_label_index])
@ -397,7 +437,7 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
def format_labels(tokens, labels, label_top=True):
"""Format tokens and their NER label for display
Arguments:
tokens {list} -- a list of word tokens
labels {list} -- a list of NER labels
@ -405,7 +445,7 @@ def format_labels(tokens, labels, label_top=True):
Keyword Arguments:
label_top {bool} -- True if label is place on top of the token
(default: {True})
Returns:
a str with NER label align to the token it is labeling
@ -429,20 +469,28 @@ def format_labels(tokens, labels, label_top=True):
len_diff = abs(len(label) - len(token))
# Add padding spaces for whichever is shorter
if len(label) > len(token):
formatted_labels += label + ' '
formatted_tokens += token + ' '*len_diff + ' '
formatted_labels += label + " "
formatted_tokens += token + " " * len_diff + " "
else:
formatted_labels += label + ' '*len_diff + ' '
formatted_tokens += token + ' '
formatted_labels += label + " " * len_diff + " "
formatted_tokens += token + " "
if label_top:
return formatted_labels + '\n' + formatted_tokens + '\n'
return formatted_labels + "\n" + formatted_tokens + "\n"
else:
return formatted_tokens + '\n' + formatted_labels + '\n'
return formatted_tokens + "\n" + formatted_labels + "\n"
def format_label_propagation(
gt_tokens,
gt_labels,
ocr_tokens,
ocr_labels,
aligned_gt,
aligned_ocr,
show_alignment=True,
):
"""Format label propagation for display
def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \
aligned_gt, aligned_ocr, show_alignment=True):
""" Format label propagation for display
Arguments:
gt_tokens {list} -- list of ground truth tokens
gt_labels {list} -- list of NER labels for ground truth tokens
@ -450,15 +498,15 @@ def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \
ocr_labels {list} -- list of NER labels for the OCR'ed tokens
aligned_gt {str} -- ground truth string aligned with the OCR'ed text
aligned_ocr {str} -- OCR'ed text aligned with ground truth
Keyword Arguments:
show_alignment {bool} -- if true, show alignment result (default: {True})
Returns:
a string formatted for display as follows:
if show_alignment=TRUE
"
"
B-PLACE I-PLACE V O [gt_labels]
New York is big [gt_txt]
New York is big [aligned_gt]
@ -468,19 +516,18 @@ def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \
B-PLACE V O [ocr_labels]
"
else
"
"
B-PLACE I-PLACE V O [gt_labels]
New York is big [gt_txt]
New is big [ocr_txt]
B-PLACE V O [ocr_labels]
"
"""
gt_label_str = format_labels(gt_tokens, gt_labels)
label_str = format_labels(ocr_tokens, ocr_labels, label_top=False)
label_str = format_labels(ocr_tokens, ocr_labels, label_top=False)
if show_alignment:
alignment_str = alignment._format_alignment(aligned_gt, aligned_ocr)
return gt_label_str + alignment_str + label_str
else:
return gt_label_str + label_str

Просмотреть файл

@ -1,10 +1,11 @@
import re
import re
END_OF_TOKEN = {' ', '\t', '\n'}
END_OF_TOKEN = {" ", "\t", "\n"}
NON_ASCII_REPLACEMENT = "_"
def remove_non_ascii(token, replacement=NON_ASCII_REPLACEMENT):
""" Remove non ascii characters in a token
"""Remove non ascii characters in a token
Arguments:
token {str} -- a word token
@ -15,27 +16,29 @@ def remove_non_ascii(token, replacement=NON_ASCII_REPLACEMENT):
str -- a word token with non-ASCII characters removed
"""
# Remove non-ASCII characters in the token
ascii_token = str(token.encode('utf-8').decode('ascii', 'ignore'))
# If token becomes an empty string as a result
ascii_token = str(token.encode("utf-8").decode("ascii", "ignore"))
# If token becomes an empty string as a result
if len(ascii_token) == 0 and len(token) != 0:
ascii_token = replacement # replace with a default character
ascii_token = replacement # replace with a default character
return ascii_token
def tokenize(s):
""" Tokenize string
"""Tokenize string
Arguments:
s {str} -- aligned string
Returns:
a list of tokens
"""
# split alignment tokens by spaces, tabs and newline (and excluding them in the tokens)
return s.split()
def join_tokens(tokens):
""" Join a list of tokens into a string
"""Join a list of tokens into a string
Arguments:
tokens {list} -- a list of tokens
@ -44,14 +47,17 @@ def join_tokens(tokens):
"""
return " ".join(tokens)
def _is_spacing(c):
""" Determine if the character is ignorable """
return True if c in END_OF_TOKEN else False
def split_sentences(text, delimiter="\n"):
""" Split a text into sentences with a delimiter"""
return re.sub(r'(( /?[.!?])+ )', rf'\1{delimiter}', text)
return re.sub(r"(( /?[.!?])+ )", rf"\1{delimiter}", text)
def is_sentence_separator(token):
""" Returns true if the token is a sentence splitter """
return re.match(r'^/?[.!?]$', token) != None
return re.match(r"^/?[.!?]$", token) is not None

Просмотреть файл

@ -1,5 +1,5 @@
"""This is a utility tool to split CoNLL formated files. It has the capability to pack sentences into generated
pages more tightly.
"""This is a utility tool to split CoNLL formated files.
It has the capability to pack sentences into generated pages more tightly.
usage: splitter.py [-h] [--doc_sep DOC_SEP] [--line_sep LINE_SEP]
[--force_doc_sep]
@ -22,16 +22,17 @@ example usage:
python -m genalog.text.splitter CoNLL-2012_train.txt conll2012_train
"""
import re
import os
import multiprocessing
import argparse
from tqdm import tqdm
from genalog.text import preprocess
from genalog.generation.document import DocumentGenerator
from genalog.generation.content import CompositeContent, ContentType
import multiprocessing
import os
from multiprocessing.pool import ThreadPool
from tqdm import tqdm
from genalog.generation.content import CompositeContent, ContentType
from genalog.generation.document import DocumentGenerator
from genalog.text import preprocess
# default buffer. Preferebly set this to something large
# It holds the lines read from the CoNLL file
BUFFER_SIZE = 50000
@ -40,14 +41,15 @@ CONLL2012_DOC_SEPERATOR = ""
CONLL2003_DOC_SEPERATOR = "-DOCSTART-"
SEPERATOR = ""
STARTING_SPLIT_GUESS = 100 # starting estimate of point where to split text
MAX_SIZE = 100 # max number of sentences to pack on a doc page
STARTING_SPLIT_GUESS = 100 # starting estimate of point where to split text
MAX_SIZE = 100 # max number of sentences to pack on a doc page
SPLIT_ITERS = 2 # number of iterations to run to find a good split
SPLIT_ITERS = 2 # number of iterations to run to find a good split
WORKERS_PER_CPU = 2
default_generator = DocumentGenerator()
def unwrap(size, accumulator):
words = []
labels = []
@ -55,11 +57,14 @@ def unwrap(size, accumulator):
sentence = accumulator[i]
for word, tok in sentence:
words.append(word)
labels.append((word,tok))
labels.append((word, tok))
return words, labels
def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='text_block.html.jinja'):
"""Run a few iterations of binary search to find the best split point
def find_split_position(
accumulator, start_pos, iters=SPLIT_ITERS, template_name="text_block.html.jinja"
):
"""Run a few iterations of binary search to find the best split point
from the start to pack in sentences into a page without overflow.
Args:
@ -71,38 +76,48 @@ def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='
"""
global STARTING_SPLIT_GUESS
# use binary search to find page split point
start, end = start_pos, min(len(accumulator), MAX_SIZE+start_pos)
start, end = start_pos, min(len(accumulator), MAX_SIZE + start_pos)
best = None
count = 0
while start <= end:
if count==0 and (STARTING_SPLIT_GUESS+start_pos > start and STARTING_SPLIT_GUESS + start_pos < end):
if count == 0 and (
STARTING_SPLIT_GUESS + start_pos > start
and STARTING_SPLIT_GUESS + start_pos < end
):
split_point = STARTING_SPLIT_GUESS
else:
split_point = (start + end)//2
doc_buf = (start_pos,split_point)
split_point = (start + end) // 2
doc_buf = (start_pos, split_point)
content_words, labels = unwrap(doc_buf, accumulator)
content_types = [ContentType.PARAGRAPH]
text = " ".join(content_words)
content = CompositeContent([text], content_types)
doc_gen = default_generator.create_generator(content, [template_name])
doc_gen = default_generator.create_generator(content, [template_name])
doc = next(doc_gen)
if len(doc._document.pages) > 1:
end = split_point-1
end = split_point - 1
else:
start = split_point+1
best = split_point, doc, labels,text
start = split_point + 1
best = split_point, doc, labels, text
if count >= iters:
break
count += 1
STARTING_SPLIT_GUESS = split_point
return best
def generate_splits(input_file, output_folder, sentence_seperator="", doc_seperator=None, pool=None,
force_doc_sep=False, ext="txt"):
def generate_splits(
input_file,
output_folder,
sentence_seperator="",
doc_seperator=None,
pool=None,
force_doc_sep=False,
ext="txt",
):
"""Processes the file line by line and add sentences to the buffer for processing.
Args:
@ -120,19 +135,21 @@ def generate_splits(input_file, output_folder, sentence_seperator="", doc_sepera
with open(input_file) as f:
for line in f:
if line.strip() == sentence_seperator or line.strip() == doc_seperator:
if len(sentence) > 0:
if len(sentence) > 0:
accumulator.append(sentence)
sentence = []
if line.strip() == doc_seperator and force_doc_sep:
# progress to processing buffer immediately if force_doc_sep
pass
elif len(accumulator) < BUFFER_SIZE:
elif len(accumulator) < BUFFER_SIZE:
continue
start_pos = 0
while start_pos < len(accumulator):
start_pos = next_doc(accumulator,doc_id, start_pos, output_folder,pool)
start_pos = next_doc(
accumulator, doc_id, start_pos, output_folder, pool
)
doc_id += 1
progress_bar.update(1)
accumulator = []
@ -141,69 +158,93 @@ def generate_splits(input_file, output_folder, sentence_seperator="", doc_sepera
word, tok = line.split("\t")
if word.strip() == "":
continue
sentence.append((word,tok))
sentence.append((word, tok))
# process any left over lines
start_pos = 0
if len(sentence) > 0 : accumulator.append(sentence)
if len(sentence) > 0:
accumulator.append(sentence)
while start_pos < len(accumulator):
start_pos = next_doc(accumulator,doc_id, start_pos, output_folder,pool)
start_pos = next_doc(accumulator, doc_id, start_pos, output_folder, pool)
doc_id += 1
progress_bar.update(1)
def next_doc(accumulator,doc_id, start_pos, output_folder,pool,ext="txt"):
split_pos,doc,labels,text = find_split_position(accumulator, start_pos)
def next_doc(accumulator, doc_id, start_pos, output_folder, pool, ext="txt"):
split_pos, doc, labels, text = find_split_position(accumulator, start_pos)
handle_doc(doc, labels, doc_id, text, output_folder, pool, ext)
return split_pos
def write_doc(doc, doc_id, labels, text, output_folder, ext="txt", write_png=False):
if write_png:
f = f"{output_folder}/img/img_{doc_id}.png"
doc.render_png(target=f)
text += " " # adding a space at EOF
text += " " # adding a space at EOF
text = preprocess.split_sentences(text)
with open(f"{output_folder}/clean_labels/{doc_id}.{ext}", "w") as l:
with open(f"{output_folder}/clean_labels/{doc_id}.{ext}", "w") as fp:
for idx, (token, label) in enumerate(labels):
l.write(token + "\t" + label)
fp.write(token + "\t" + label)
next_token, _ = labels[(idx + 1) % len(labels)]
if preprocess.is_sentence_separator(token) and not \
preprocess.is_sentence_separator(next_token):
l.write("\n")
if idx == len(labels): # Reach the end of the document
l.write("\n")
if preprocess.is_sentence_separator(
token
) and not preprocess.is_sentence_separator(next_token):
fp.write("\n")
if idx == len(labels): # Reach the end of the document
fp.write("\n")
with open(f"{output_folder}/clean_text/{doc_id}.txt", "w") as text_file:
text_file.write(text)
return f"wrote: doc id: {doc_id}"
def _error_callback(err):
raise RuntimeError(err)
def handle_doc(doc, labels, doc_id, text, output_folder, pool, ext="txt"):
if pool:
pool.apply_async(write_doc, args=(doc, doc_id, labels, text, output_folder,ext), error_callback=_error_callback)
pool.apply_async(
write_doc,
args=(doc, doc_id, labels, text, output_folder, ext),
error_callback=_error_callback,
)
else:
write_doc(doc, doc_id, labels, text, output_folder)
def setup_folder(output_folder):
os.makedirs(os.path.join(output_folder,"clean_text"), exist_ok=True)
os.makedirs(os.path.join(output_folder,"clean_labels"), exist_ok=True)
os.makedirs(os.path.join(output_folder, "clean_text"), exist_ok=True)
os.makedirs(os.path.join(output_folder, "clean_labels"), exist_ok=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_file", default="CoNLL-2012_train.txt", help="path to input CoNLL formated file.")
parser.add_argument("output_folder", default="conll2012_train", help="folder to write results to.")
parser.add_argument(
"input_file",
default="CoNLL-2012_train.txt",
help="path to input CoNLL formated file.",
)
parser.add_argument(
"output_folder", default="conll2012_train", help="folder to write results to."
)
parser.add_argument("--doc_sep", help="CoNLL doc seperator")
parser.add_argument("--ext", help="file extension", default="txt")
parser.add_argument("--line_sep", default=CONLL2012_DOC_SEPERATOR, help="CoNLL line seperator")
parser.add_argument("--force_doc_sep", default=False, action="store_true",
help="If set, documents are forced to be split by the doc seperator (recommended to turn this off)")
parser.add_argument(
"--line_sep", default=CONLL2012_DOC_SEPERATOR, help="CoNLL line seperator"
)
parser.add_argument(
"--force_doc_sep",
default=False,
action="store_true",
help="If set, documents are forced to be split by the doc seperator (recommended to turn this off)",
)
args = parser.parse_args()
unescape = lambda s: s.encode('utf-8').decode('unicode_escape') if s else None
unescape = lambda s: s.encode("utf-8").decode("unicode_escape") if s else None # noqa: E731
input_file = args.input_file
output_folder = args.output_folder
@ -211,10 +252,18 @@ if __name__ == "__main__":
# allow special characters in seperators
line_sep = unescape(args.line_sep) or ""
doc_sep = unescape(args.doc_sep)
doc_sep = unescape(args.doc_sep)
n_workers = WORKERS_PER_CPU * multiprocessing.cpu_count()
with ThreadPool(processes=n_workers) as pool:
generate_splits(input_file, output_folder, line_sep, doc_seperator=doc_sep, pool=pool, force_doc_sep=False, ext=args.ext)
generate_splits(
input_file,
output_folder,
line_sep,
doc_seperator=doc_sep,
pool=pool,
force_doc_sep=False,
ext=args.ext,
)
pool.close()
pool.join()
pool.join()

Просмотреть файл

@ -1,2 +0,0 @@
[pytest]
junit_family=xunit1

4
requirements-dev.txt Normal file
Просмотреть файл

@ -0,0 +1,4 @@
pytest
pytest-cov
flake8
flake8-import-order

Просмотреть файл

@ -1,6 +1,7 @@
import setuptools
import os
import setuptools
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'VERSION.txt')) as version_file:
BUILD_VERSION = version_file.read().strip()
@ -13,7 +14,7 @@ with open("README.md", "r", encoding="utf8") as fh:
setuptools.setup(
name="genalog",
install_requires=requirements,
install_requires=requirements,
version=BUILD_VERSION,
author="Team Enki",
author_email="ta_nerds@microsoft.com",

Просмотреть файл

@ -1,10 +1,10 @@
# Test cases for genalog.text.ner_label.propagate_label_to_ocr() method.
# For READABILITY purpose, ground truth and noisy text are presented as
# For READABILITY purpose, ground truth and noisy text are presented as
# a whole string, not in their tokenized format.
# Notice the `propagate_label_to_ocr()` method has the contract of
# (list, list, list) -> (list, list, list)
# consuming both ground truth text and noisy text as lists of tokens.
# (list, list, list) -> (list, list, list)
# consuming both ground truth text and noisy text as lists of tokens.
# We will use `genalog.text.preprocess.tokenize()` to tokenize these strings
from genalog.text import preprocess
@ -106,12 +106,14 @@ desired_ocr_labels.append(["O", "B-FRUIT", "I-FRUIT", "O", "O"])
ner_labels.append(["O", "O", "ENTERTAINMENT", "O"])
gt_txt.append("@ new TV !")
ns_txt.append("@ n ow T\\/ |")
desired_ocr_labels.append(["O", "O", "O", "ENTERTAINMENT" ,"O"])
desired_ocr_labels.append(["O", "O", "O", "ENTERTAINMENT", "O"])
# Tokenize ground truth and noisy text strings
# Tokenize ground truth and noisy text strings
gt_tokens = [preprocess.tokenize(txt) for txt in gt_txt]
ns_tokens = [preprocess.tokenize(txt) for txt in ns_txt]
# test function expect params in tuple of
# test function expect params in tuple of
# (gt_label, gt_tokens, ocr_tokens, desired_ocr_labels)
LABEL_PROPAGATION_REGRESSION_TEST_CASES = list(zip(ner_labels, gt_tokens, ns_tokens, desired_ocr_labels))
LABEL_PROPAGATION_REGRESSION_TEST_CASES = list(
zip(ner_labels, gt_tokens, ns_tokens, desired_ocr_labels)
)

Просмотреть файл

@ -1,4 +1,3 @@
# Initializing test cases
# For extensibility, all parameters in a case are append to following arrays
gt_txt = []
@ -17,31 +16,35 @@ ns_txt.append("N ewYork kis big.")
aligned_gt.append("N@ew York @is big.")
aligned_ns.append("N ew@York kis big.")
gt_to_noise_maps.append(
[
# This shows that the first token in gt "New" maps to the
# first ("N") and second ("ewYork") token in the noise
[0,1],
[1],
[2],
[3]
])
[
# This shows that the first token in gt "New" maps to the
# first ("N") and second ("ewYork") token in the noise
[0, 1],
[1],
[2],
[3],
]
)
noise_to_gt_maps.append([
[0],
# Similarly, the following shows that the second token in noise "ewYork" maps to the
# first ("New") and second ("York") token in gt
[0,1],
[2],
[3]
])
noise_to_gt_maps.append(
[
[0],
# Similarly, the following shows that the second token in noise "ewYork" maps to the
# first ("New") and second ("York") token in gt
[0, 1],
[2],
[3],
]
)
##############################################################################################
# SPECIAL CASE: noisy text does not contain sufficient whitespaces to account
# SPECIAL CASE: noisy text does not contain sufficient whitespaces to account
# for missing tokens
# Notice there's only 1 whitespace b/w 'oston' and 'grea'
# The ideal situation is that there are 2 whitespaces. Ex:
# ("B oston grea t")
# ("B oston grea t")
ns_txt.append("B oston grea t")
gt_txt.append("Boston is great")
@ -52,18 +55,9 @@ gt_txt.append("Boston is great")
aligned_ns.append("B oston@@@ grea t")
aligned_gt.append("B@oston is grea@t")
gt_to_noise_maps.append([
[0,1],
[1], # 'is' is also mapped to 'oston'
[2,3]
])
gt_to_noise_maps.append([[0, 1], [1], [2, 3]]) # 'is' is also mapped to 'oston'
noise_to_gt_maps.append([
[0],
[0,1], # 'oston' is to 'Boston' and 'is'
[2],
[2]
])
noise_to_gt_maps.append([[0], [0, 1], [2], [2]]) # 'oston' is to 'Boston' and 'is'
############################################################################################
# Empty Cases:
@ -95,18 +89,9 @@ ns_txt.append("B oston bi g")
aligned_gt.append("B@oston is bi@g")
aligned_ns.append("B oston@@@ bi g")
gt_to_noise_maps.append([
[0,1],
[1],
[2,3]
])
gt_to_noise_maps.append([[0, 1], [1], [2, 3]])
noise_to_gt_maps.append([
[0],
[0,1],
[2],
[2]
])
noise_to_gt_maps.append([[0], [0, 1], [2], [2]])
############################################################################################
gt_txt.append("New York is big.")
@ -115,17 +100,9 @@ ns_txt.append("NewYork big")
aligned_gt.append("New York is big.")
aligned_ns.append("New@York @@@big@")
gt_to_noise_maps.append([
[0],
[0],
[1],
[1]
])
gt_to_noise_maps.append([[0], [0], [1], [1]])
noise_to_gt_maps.append([
[0, 1],
[2, 3]
])
noise_to_gt_maps.append([[0, 1], [2, 3]])
#############################################################################################
gt_txt.append("politicians who lag superfluous on the")
@ -134,23 +111,9 @@ ns_txt.append("politicians who kg superfluous on the")
aligned_gt.append("politicians who lag superfluous on the")
aligned_ns.append("politicians who @kg superfluous on the")
gt_to_noise_maps.append([
[0],
[1],
[2],
[3],
[4],
[5]
])
gt_to_noise_maps.append([[0], [1], [2], [3], [4], [5]])
noise_to_gt_maps.append([
[0],
[1],
[2],
[3],
[4],
[5]
])
noise_to_gt_maps.append([[0], [1], [2], [3], [4], [5]])
############################################################################################
@ -160,20 +123,9 @@ ns_txt.append("faithei uifoimtdon the subject")
aligned_gt.append("farther @informed on the subject.")
aligned_ns.append("faithei ui@foimtd@on the subject@")
gt_to_noise_maps.append([
[0],
[1],
[1],
[2],
[3]
])
gt_to_noise_maps.append([[0], [1], [1], [2], [3]])
noise_to_gt_maps.append([
[0],
[1,2],
[3],
[4]
])
noise_to_gt_maps.append([[0], [1, 2], [3], [4]])
############################################################################################
@ -183,20 +135,9 @@ ns_txt.append("New Yorkis big .")
aligned_gt.append("New York is big .")
aligned_ns.append("New York@is big .")
gt_to_noise_maps.append([
[0],
[1],
[1],
[2],
[3]
])
gt_to_noise_maps.append([[0], [1], [1], [2], [3]])
noise_to_gt_maps.append([
[0],
[1,2],
[3],
[4]
])
noise_to_gt_maps.append([[0], [1, 2], [3], [4]])
############################################################################################
@ -206,23 +147,14 @@ ns_txt.append("New Yo rk is big.")
aligned_gt.append("New Yo@rk is big.")
aligned_ns.append("New Yo rk is big.")
gt_to_noise_maps.append([
[0],
[1,2],
[3],
[4]
])
gt_to_noise_maps.append([[0], [1, 2], [3], [4]])
noise_to_gt_maps.append([
[0],
[1],
[1],
[2],
[3]
])
noise_to_gt_maps.append([[0], [1], [1], [2], [3]])
# Format tests for pytest
# Each test expect in the following format
# (aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps)
PARSE_ALIGNMENT_REGRESSION_TEST_CASES = zip(aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps)
ALIGNMENT_REGRESSION_TEST_CASES = list(zip(gt_txt, ns_txt, aligned_gt, aligned_ns))
PARSE_ALIGNMENT_REGRESSION_TEST_CASES = zip(
aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps
)
ALIGNMENT_REGRESSION_TEST_CASES = list(zip(gt_txt, ns_txt, aligned_gt, aligned_ns))

Просмотреть файл

@ -1,80 +1,98 @@
from genalog.degradation.degrader import Degrader, ImageState
from genalog.degradation.degrader import DEFAULT_METHOD_PARAM_TO_INCLUDE
import copy
from unittest.mock import patch
import numpy as np
import pytest
import copy
MOCK_IMAGE_SHAPE = (4,3)
from genalog.degradation.degrader import DEFAULT_METHOD_PARAM_TO_INCLUDE
from genalog.degradation.degrader import Degrader, ImageState
MOCK_IMAGE_SHAPE = (4, 3)
MOCK_IMAGE = np.arange(12, dtype=np.uint8).reshape(MOCK_IMAGE_SHAPE)
@pytest.fixture
def empty_degrader():
effects = []
return Degrader(effects)
@pytest.fixture(params=[
[("blur", {"radius": 5})],
[("blur", {"src": ImageState.ORIGINAL_STATE, "radius": 5})],
[("blur", {"src": ImageState.CURRENT_STATE, "radius": 5})],
[
("morphology", {"src": ImageState.ORIGINAL_STATE,"operation": "open"}),
("morphology", {"operation": "close"}),
("morphology", {"src": ImageState.ORIGINAL_STATE,"operation": "dilate"}),
("morphology", {"operation": "erode"}),
],
[
("blur", {"radius": 5}),
("bleed_through", {
"src": ImageState.CURRENT_STATE,
"alpha": 0.7,
"background": ImageState.ORIGINAL_STATE,
}),
("morphology", {
"operation": "open",
"kernel_shape": (3,3),
"kernel_type": "ones"
}),
@pytest.fixture(
params=[
[("blur", {"radius": 5})],
[("blur", {"src": ImageState.ORIGINAL_STATE, "radius": 5})],
[("blur", {"src": ImageState.CURRENT_STATE, "radius": 5})],
[
("morphology", {"src": ImageState.ORIGINAL_STATE, "operation": "open"}),
("morphology", {"operation": "close"}),
("morphology", {"src": ImageState.ORIGINAL_STATE, "operation": "dilate"}),
("morphology", {"operation": "erode"}),
],
[
("blur", {"radius": 5}),
(
"bleed_through",
{
"src": ImageState.CURRENT_STATE,
"alpha": 0.7,
"background": ImageState.ORIGINAL_STATE,
},
),
(
"morphology",
{"operation": "open", "kernel_shape": (3, 3), "kernel_type": "ones"},
),
],
]
])
)
def degrader(request):
effects = request.param
return Degrader(effects)
def test_degrader_init(empty_degrader):
def test_empty_degrader_init(empty_degrader):
assert empty_degrader.effects_to_apply == []
def test_degrader_init(degrader):
assert degrader.effects_to_apply is not []
for effect_tuple in degrader.effects_to_apply:
method_name, method_kwargs = effect_tuple
assert DEFAULT_METHOD_PARAM_TO_INCLUDE in method_kwargs
param_value = method_kwargs[DEFAULT_METHOD_PARAM_TO_INCLUDE]
assert param_value is ImageState.ORIGINAL_STATE or param_value is ImageState.CURRENT_STATE
assert (
param_value is ImageState.ORIGINAL_STATE
or param_value is ImageState.CURRENT_STATE
)
@pytest.mark.parametrize("effects, error_thrown", [
([], None), #Empty effect
(None, TypeError),
([("blur", {"radius": 5})], None), # Validate input
([("not_a_func", {"radius": 5})], ValueError), # Invalid method name
([("blur", {"not_a_argument": 5})], ValueError), # Invalid kwargs
([("blur")], ValueError), # Missing kwargs
(
[
("blur", {"radius": 5}),
("bleed_through", {"alpha":"0.8"}),
("morphology", {"operation": "open"})
], None
), # Multiple effects
(
[
("blur", {"radius": 5}),
("bleed_through", {"not_argument":"0.8"}),
("morphology", {"missing value"})
], ValueError
), # Multiple effects
])
@pytest.mark.parametrize(
"effects, error_thrown",
[
([], None), # Empty effect
(None, TypeError),
([("blur", {"radius": 5})], None), # Validate input
([("not_a_func", {"radius": 5})], ValueError), # Invalid method name
([("blur", {"not_a_argument": 5})], ValueError), # Invalid kwargs
([("blur")], ValueError), # Missing kwargs
(
[
("blur", {"radius": 5}),
("bleed_through", {"alpha": "0.8"}),
("morphology", {"operation": "open"}),
],
None,
), # Multiple effects
(
[
("blur", {"radius": 5}),
("bleed_through", {"not_argument": "0.8"}),
("morphology", {"missing value"}),
],
ValueError,
), # Multiple effects
],
)
def test_degrader_validate_effects(effects, error_thrown):
if error_thrown:
with pytest.raises(error_thrown):
@ -82,23 +100,26 @@ def test_degrader_validate_effects(effects, error_thrown):
else:
Degrader.validate_effects(effects)
def test_degrader_apply_effects(degrader):
method_names = [effect[0] for effect in degrader.effects_to_apply]
with patch("genalog.degradation.effect") as mock_effect:
degraded = degrader.apply_effects(MOCK_IMAGE)
degrader.apply_effects(MOCK_IMAGE)
for method in method_names:
assert mock_effect[method].is_called()
# assert degraded.shape == MOCK_IMAGE_SHAPE
def test_degrader_apply_effects_e2e(degrader):
degraded = degrader.apply_effects(MOCK_IMAGE)
assert degraded.shape == MOCK_IMAGE_SHAPE
assert degraded.dtype == np.uint8
def test_degrader_instructions(degrader):
original_instruction = copy.deepcopy(degrader.effects_to_apply)
degraded1 = degrader.apply_effects(MOCK_IMAGE)
degraded2 = degrader.apply_effects(MOCK_IMAGE)
degrader.apply_effects(MOCK_IMAGE)
degrader.apply_effects(MOCK_IMAGE)
# Make sure the degradation instructions are not altered
assert len(original_instruction) == len(degrader.effects_to_apply)
for i in range(len(original_instruction)):
@ -107,5 +128,5 @@ def test_degrader_instructions(degrader):
assert org_method_name == method_name
assert len(org_method_arg) == len(method_arg)
for key in org_method_arg.keys():
assert type(org_method_arg[key]) == type(method_arg[key])
assert org_method_arg[key] == method_arg[key]
assert isinstance(org_method_arg[key], type(method_arg[key]))
assert org_method_arg[key] == method_arg[key]

Просмотреть файл

@ -1,30 +1,34 @@
from genalog.degradation import effect
from unittest.mock import patch
import numpy as np
import pytest
from genalog.degradation import effect
NEW_IMG_SHAPE = (100, 100)
MOCK_IMG_SHAPE = (100, 120)
MOCK_IMG = np.ones(MOCK_IMG_SHAPE, dtype=np.uint8)
def test_blur():
dst = effect.blur(MOCK_IMG, radius=3)
assert dst.dtype == np.uint8 # preverse dtype
assert dst.shape == MOCK_IMG_SHAPE # preverse image size
assert dst.dtype == np.uint8 # preverse dtype
assert dst.shape == MOCK_IMG_SHAPE # preverse image size
def test_translation():
offset_x = offset_y = 1
# Test that border pixels are not white (<255)
assert all([col_pixel < 255 for col_pixel in MOCK_IMG[:, 0]])
assert all([row_pixel < 255 for row_pixel in MOCK_IMG[0 ,:]])
assert all([row_pixel < 255 for row_pixel in MOCK_IMG[0, :]])
dst = effect.translation(MOCK_IMG, offset_x, offset_y)
# Test that border pixels are white (255)
assert all([col_pixel == 255 for col_pixel in dst[:,0]])
assert all([col_pixel == 255 for col_pixel in dst[:, 0]])
assert all([row_pixel == 255 for row_pixel in dst[0, :]])
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_overlay_weighted():
src = MOCK_IMG.copy()
src[0][0] = 10
@ -34,6 +38,7 @@ def test_overlay_weighted():
assert dst.shape == MOCK_IMG_SHAPE
assert dst[0][0] == src[0][0] * alpha + src[0][0] * beta
def test_overlay():
src1 = MOCK_IMG.copy()
src2 = MOCK_IMG.copy()
@ -45,6 +50,7 @@ def test_overlay():
assert dst[0][0] == 0
assert dst[0][1] == 1
@patch("genalog.degradation.effect.translation")
def test_bleed_through_default(mock_translation):
mock_translation.return_value = MOCK_IMG
@ -53,11 +59,15 @@ def test_bleed_through_default(mock_translation):
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
@pytest.mark.parametrize("foreground, background, error_thrown", [
(MOCK_IMG, MOCK_IMG, None),
# Test unmatched shape
(MOCK_IMG, MOCK_IMG[:,:-1], Exception),
])
@pytest.mark.parametrize(
"foreground, background, error_thrown",
[
(MOCK_IMG, MOCK_IMG, None),
# Test unmatched shape
(MOCK_IMG, MOCK_IMG[:, :-1], Exception),
],
)
def test_bleed_through_kwargs(foreground, background, error_thrown):
if error_thrown:
assert foreground.shape != background.shape
@ -68,97 +78,125 @@ def test_bleed_through_kwargs(foreground, background, error_thrown):
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_pepper():
dst = effect.pepper(MOCK_IMG, amount=0.1)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_salt():
dst = effect.salt(MOCK_IMG, amount=0.1)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_salt_then_pepper():
dst = effect.salt_then_pepper(MOCK_IMG, 0.5, 0.001)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_pepper_then_salt():
dst = effect.pepper_then_salt(MOCK_IMG, 0.001, 0.5)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
@pytest.mark.parametrize("kernel_shape, kernel_type", [
((3,3), "NOT_VALID_TYPE"),
(1, "ones"),
((1,2,3), "ones")
])
@pytest.mark.parametrize(
"kernel_shape, kernel_type",
[((3, 3), "NOT_VALID_TYPE"), (1, "ones"), ((1, 2, 3), "ones")],
)
def test_create_2D_kernel_error(kernel_shape, kernel_type):
with pytest.raises(Exception):
effect.create_2D_kernel(kernel_shape, kernel_type)
@pytest.mark.parametrize("kernel_shape, kernel_type, expected_kernel", [
((2,2), "ones", np.array([[1,1],[1,1]])), # sq kernel
((1,2), "ones", np.array([[1,1]])), # horizontal
((2,1), "ones", np.array([[1],[1]])), # vertical
((2,2), "upper_triangle", np.array([[1,1],[0,1]])),
((2,2), "lower_triangle", np.array([[1,0],[1,1]])),
((2,2), "x", np.array([[1,1],[1,1]])),
((3,3), "x", np.array([[1,0,1],[0,1,0],[1,0,1]])),
((2,2), "plus", np.array([[0,1],[1,1]])),
((3,3), "plus", np.array([[0,1,0],[1,1,1],[0,1,0]])),
((3,3), "ellipse", np.array([[0,1,0],[1,1,1],[0,1,0]])),
((5,5), "ellipse",
np.array([
[0, 0, 1, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 0, 1, 0, 0]
])),
])
@pytest.mark.parametrize(
"kernel_shape, kernel_type, expected_kernel",
[
((2, 2), "ones", np.array([[1, 1], [1, 1]])), # sq kernel
((1, 2), "ones", np.array([[1, 1]])), # horizontal
((2, 1), "ones", np.array([[1], [1]])), # vertical
((2, 2), "upper_triangle", np.array([[1, 1], [0, 1]])),
((2, 2), "lower_triangle", np.array([[1, 0], [1, 1]])),
((2, 2), "x", np.array([[1, 1], [1, 1]])),
((3, 3), "x", np.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]])),
((2, 2), "plus", np.array([[0, 1], [1, 1]])),
((3, 3), "plus", np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])),
((3, 3), "ellipse", np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])),
(
(5, 5),
"ellipse",
np.array(
[
[0, 0, 1, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 0, 1, 0, 0],
]
),
),
],
)
def test_create_2D_kernel(kernel_shape, kernel_type, expected_kernel):
kernel = effect.create_2D_kernel(kernel_shape, kernel_type)
assert np.array_equal(kernel, expected_kernel)
def test_morphology_with_error():
INVALID_OPERATION = "NOT_A_OPERATION"
with pytest.raises(ValueError):
effect.morphology(MOCK_IMG, operation=INVALID_OPERATION)
@pytest.mark.parametrize("operation, kernel_shape, kernel_type", [
("open", (3,3), "ones"),
("close", (3,3), "ones"),
("dilate", (3,3), "ones"),
("erode", (3,3), "ones"),
])
@pytest.mark.parametrize(
"operation, kernel_shape, kernel_type",
[
("open", (3, 3), "ones"),
("close", (3, 3), "ones"),
("dilate", (3, 3), "ones"),
("erode", (3, 3), "ones"),
],
)
def test_morphology(operation, kernel_shape, kernel_type):
dst = effect.morphology(MOCK_IMG,
operation=operation, kernel_shape=kernel_shape,
kernel_type=kernel_type)
dst = effect.morphology(
MOCK_IMG,
operation=operation,
kernel_shape=kernel_shape,
kernel_type=kernel_type,
)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
@pytest.fixture(params=["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"])
@pytest.fixture(
params=["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"]
)
def kernel(request):
return effect.create_2D_kernel((5,5), request.param)
return effect.create_2D_kernel((5, 5), request.param)
def test_open(kernel):
dst = effect.open(MOCK_IMG, kernel)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_close(kernel):
dst = effect.close(MOCK_IMG, kernel)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_erode(kernel):
dst = effect.erode(MOCK_IMG, kernel)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
def test_dilate(kernel):
dst = effect.dilate(MOCK_IMG, kernel)
assert dst.dtype == np.uint8
assert dst.shape == MOCK_IMG_SHAPE
assert dst.shape == MOCK_IMG_SHAPE

Просмотреть файл

@ -1,15 +1,18 @@
from genalog.text import alignment, anchor, preprocess
import glob
import pytest
import difflib
import glob
import warnings
@pytest.mark.parametrize("gt_file, ocr_file",
import pytest
from genalog.text import alignment, anchor, preprocess
@pytest.mark.parametrize(
"gt_file, ocr_file",
zip(
sorted(glob.glob("tests/text/data/gt_*.txt")),
sorted(glob.glob("tests/text/data/ocr_*.txt"))
)
sorted(glob.glob("tests/text/data/ocr_*.txt")),
),
)
def test_align_w_anchor_and_align(gt_file, ocr_file):
gt_text = open(gt_file, "r").read()
@ -21,17 +24,22 @@ def test_align_w_anchor_and_align(gt_file, ocr_file):
aligned_anchor_gt = aligned_anchor_gt.split(".")
aligned_gt = aligned_gt.split(".")
str_diff = "\n".join(difflib.unified_diff(aligned_gt, aligned_anchor_gt))
warnings.warn(UserWarning(
"\n"+ f"{str_diff}" +
f"\n\n**** Inconsistent Alignment Results between align() and " +
f"align_w_anchor(). Ignore this if the delta is not significant. ****\n"))
warnings.warn(
UserWarning(
"\n"
+ f"{str_diff}"
+ "\n\n**** Inconsistent Alignment Results between align() and "
+ "align_w_anchor(). Ignore this if the delta is not significant. ****\n"
)
)
@pytest.mark.parametrize("gt_file, ocr_file",
@pytest.mark.parametrize(
"gt_file, ocr_file",
zip(
sorted(glob.glob("tests/text/data/gt_*.txt")),
sorted(glob.glob("tests/text/data/ocr_*.txt"))
)
sorted(glob.glob("tests/text/data/ocr_*.txt")),
),
)
@pytest.mark.parametrize("max_seg_length", [25, 50, 75, 100, 150])
def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
@ -39,7 +47,9 @@ def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
ocr_text = open(ocr_file, "r").read()
gt_tokens = preprocess.tokenize(gt_text)
ocr_tokens = preprocess.tokenize(ocr_text)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
)
for gt_anchor, ocr_anchor in zip(gt_anchors, ocr_anchors):
# Ensure that each anchor word is the same word in both text
assert gt_tokens[gt_anchor] == ocr_tokens[ocr_anchor]

Просмотреть файл

@ -1,47 +1,62 @@
from genalog.text import conll_format
from unittest import mock
import argparse
import pytest
import glob
import itertools
@pytest.mark.parametrize("required_args", [
(["tests/e2e/data/synthetic_dataset", "test_version"])
])
@pytest.mark.parametrize("optional_args", [
(["--train_subset"]),
(["--test_subset"]),
(["--gt_folder", "shared"]),
])
import pytest
from genalog.text import conll_format
@pytest.mark.parametrize(
"required_args", [(["tests/e2e/data/synthetic_dataset", "test_version"])]
)
@pytest.mark.parametrize(
"optional_args",
[
(["--train_subset"]),
(["--test_subset"]),
(["--gt_folder", "shared"]),
],
)
def test_conll_format(required_args, optional_args):
parser = conll_format.create_parser()
arg_list = required_args + optional_args
args = parser.parse_args(args=arg_list)
conll_format.main(args)
basepath = "tests/e2e/data/conll_formatter/"
@pytest.mark.parametrize("clean_label_filename, ocr_text_filename",
@pytest.mark.parametrize(
"clean_label_filename, ocr_text_filename",
zip(
sorted(glob.glob("tests/e2e/data/conll_formatter/clean_labels/*.txt")),
sorted(glob.glob("tests/e2e/data/conll_formatter/ocr_text/*.txt"))
)
sorted(glob.glob("tests/e2e/data/conll_formatter/ocr_text/*.txt")),
),
)
def test_propagate_labels_sentence_single_file(clean_label_filename, ocr_text_filename):
with open(clean_label_filename, 'r', encoding='utf-8') as clf:
with open(clean_label_filename, "r", encoding="utf-8") as clf:
tokens_labels_str = clf.readlines()
clean_tokens = [line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2]
clean_labels = [line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2]
clean_tokens = [
line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2
]
clean_labels = [
line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2
]
clean_sentences = conll_format.get_sentences_from_iob_format(tokens_labels_str)
# read ocr tokens
with open(ocr_text_filename, 'r', encoding='utf-8') as otf:
ocr_text_str = ' '.join(otf.readlines())
ocr_tokens = [token.strip() for token in ocr_text_str.split()] # already tokenized in data
with open(ocr_text_filename, "r", encoding="utf-8") as otf:
ocr_text_str = " ".join(otf.readlines())
ocr_tokens = [
token.strip() for token in ocr_text_str.split()
] # already tokenized in data
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
assert len(ocr_text_sentences) == len(clean_sentences)
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
assert len(ocr_sentences_flatten) == len(
ocr_tokens
) # ensure aligned ocr tokens == ocr tokens

Просмотреть файл

@ -1,10 +1,13 @@
from genalog.generation.document import DocumentGenerator
from genalog.generation.content import CompositeContent, ContentType
import pytest
import os
CONTENT = CompositeContent(["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH])
import pytest
from genalog.generation.content import CompositeContent, ContentType
from genalog.generation.document import DocumentGenerator
CONTENT = CompositeContent(
["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH]
)
UNSUPPORTED_CONTENT_FORMAT = ["foo bar"]
UNSUPPORTED_CONTENT_TYPE = CompositeContent(["foo"], [ContentType.TITLE])
@ -21,9 +24,10 @@ CUSTOM_STYLE = {
"font_family": ["Calibri", "Times"],
"font_size": ["10px"],
"text_align": ["right"],
"hyphenate": [True, False]
"hyphenate": [True, False],
}
def test_default_template_generation():
doc_gen = DocumentGenerator()
generator = doc_gen.create_generator(CONTENT, doc_gen.template_list)
@ -31,28 +35,36 @@ def test_default_template_generation():
html_str = doc.render_html()
assert "Unsupported Content Type:" not in html_str
assert "No content loaded" not in html_str
def test_default_template_generation_w_unsupported_content_format():
doc_gen = DocumentGenerator()
generator = doc_gen.create_generator(UNSUPPORTED_CONTENT_FORMAT, doc_gen.template_list)
generator = doc_gen.create_generator(
UNSUPPORTED_CONTENT_FORMAT, doc_gen.template_list
)
for doc in generator:
html_str = doc.render_html()
assert "No content loaded" in html_str
def test_default_template_generation_w_unsupported_content_type():
doc_gen = DocumentGenerator()
generator = doc_gen.create_generator(UNSUPPORTED_CONTENT_TYPE, ["text_block.html.jinja"])
generator = doc_gen.create_generator(
UNSUPPORTED_CONTENT_TYPE, ["text_block.html.jinja"]
)
for doc in generator:
html_str = doc.render_html()
assert "Unsupported Content Type: ContentType.TITLE" in html_str
def test_custom_template_generation():
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME])
doc = next(generator)
result = doc.render_html()
assert result == str(CONTENT)
def test_undefined_template_generation():
doc_gen = DocumentGenerator()
assert UNDEFINED_TEMPLATE_NAME not in doc_gen.template_list
@ -60,6 +72,7 @@ def test_undefined_template_generation():
with pytest.raises(FileNotFoundError):
next(generator)
def test_custom_style_template_generation():
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
assert len(doc_gen.styles_to_generate) == 1
@ -70,14 +83,16 @@ def test_custom_style_template_generation():
result = doc.render_html()
assert doc.styles["font_family"] == result
def test_render_pdf_and_png():
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME])
for doc in generator:
pdf_bytes = doc.render_pdf()
png_bytes = doc.render_png()
assert pdf_bytes != None
assert png_bytes != None
assert pdf_bytes is not None
assert png_bytes is not None
def test_save_document_as_png():
if not os.path.exists(TEST_OUTPUT_DIR):
@ -86,15 +101,16 @@ def test_save_document_as_png():
generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME])
for doc in generator:
doc.render_png(target=FILE_DESTINATION, resolution=100)
# Check if the document is saved in filepath
# Check if the document is saved in filepath
assert os.path.exists(FILE_DESTINATION)
def test_save_document_as_separate_png():
if not os.path.exists(TEST_OUTPUT_DIR):
os.mkdir(TEST_OUTPUT_DIR)
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
generator = doc_gen.create_generator(CONTENT, [MULTI_PAGE_TEMPLATE_NAME])
document = next(generator)
document.render_png(target=FILE_DESTINATION, split_pages=True, resolution=100)
# Check if the document is saved as separated .png files
@ -102,6 +118,7 @@ def test_save_document_as_separate_png():
printed_doc_name = FILE_DESTINATION.replace(".png", f"_pg_{page_num}.png")
assert os.path.exists(printed_doc_name)
def test_overwriting_style():
new_font = "NewFontFamily"
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
@ -111,4 +128,4 @@ def test_overwriting_style():
assert doc.styles["font_family"] != new_font
doc.update_style(font_family=new_font)
result = doc.render_html()
assert new_font == result
assert new_font == result

Просмотреть файл

@ -1,23 +1,25 @@
from genalog.generation.document import DocumentGenerator
from genalog.generation.content import CompositeContent, ContentType
from genalog.degradation.degrader import Degrader
from genalog.degradation import effect
from genalog.generation.content import CompositeContent, ContentType
from genalog.generation.document import DocumentGenerator
import numpy as np
import cv2
TEST_OUTPUT_DIR = "test_out/"
SAMPLE_TXT = "Everton 's Duncan Ferguson , who scored twice against Manchester United on Wednesday , was picked on Thursday for the Scottish squad after a 20-month exile ."
SAMPLE_TXT = """Everton 's Duncan Ferguson , who scored twice against Manchester United on Wednesday ,
was picked on Thursday for the Scottish squad after a 20-month exile ."""
DEFAULT_TEMPLATE = "text_block.html.jinja"
DEGRADATION_EFFECTS = [
("blur", {"radius": 5}),
("bleed_through", {"alpha": 0.8}),
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "plus"}),
(
"morphology",
{"operation": "open", "kernel_shape": (3, 3), "kernel_type": "plus"},
),
("morphology", {"operation": "close"}),
("morphology", {"operation": "dilate"}),
("morphology", {"operation": "erode"})
("morphology", {"operation": "erode"}),
]
def test_generation_and_degradation():
# Initiate content
content = CompositeContent([SAMPLE_TXT], [ContentType.PARAGRAPH])
@ -32,6 +34,4 @@ def test_generation_and_degradation():
# get the image in bytes in RGBA channels
src = doc.render_array(resolution=100, channel="GRAYSCALE")
# run each degradation effect
dst = degrader.apply_effects(src)
degrader.apply_effects(src)

Просмотреть файл

@ -1,41 +1,44 @@
from genalog.generation.document import DocumentGenerator
from genalog.generation.content import CompositeContent, ContentType
import pytest
import numpy as np
import cv2
import pytest
from genalog.generation.content import CompositeContent, ContentType
from genalog.generation.document import DocumentGenerator
TEMPLATE_PATH = "tests/e2e/templates"
TEST_OUT_FOLDER = "test_out/"
SAMPLE_TXT = "foo"
CONTENT = CompositeContent([SAMPLE_TXT], [ContentType.PARAGRAPH])
@pytest.fixture
def doc_generator():
return DocumentGenerator(template_path=TEMPLATE_PATH)
def test_red_channel(doc_generator):
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
for doc in generator:
doc.update_style(background_color="red")
doc.update_style(background_color="red")
img_array = doc.render_array(resolution=100, channel="BGRA")
# css "red" is rgb(255,0,0) or bgra(0,0,255,255)
assert tuple(img_array[0][0]) == (0, 0, 255, 255)
assert tuple(img_array[0][0]) == (0, 0, 255, 255)
cv2.imwrite(TEST_OUT_FOLDER + "red.png", img_array)
def test_green_channel(doc_generator):
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
for doc in generator:
doc.update_style(background_color="green")
doc.update_style(background_color="green")
img_array = doc.render_array(resolution=100, channel="BGRA")
# css "green" is rgb(0,128,0) or bgra(0,128,0,255)
assert tuple(img_array[0][0]) == (0, 128, 0, 255)
cv2.imwrite(TEST_OUT_FOLDER + "green.png", img_array)
def test_blue_channel(doc_generator):
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
for doc in generator:
doc.update_style(background_color="blue")
doc.update_style(background_color="blue")
img_array = doc.render_array(resolution=100, channel="BGRA")
# css "blue" is rgb(0,0,255) or bgra(255,0,0,255)
assert tuple(img_array[0][0]) == (255, 0, 0, 255)

Просмотреть файл

@ -1,48 +1,60 @@
from genalog.ocr.rest_client import GrokRestClient
import json
import pytest
from dotenv import load_dotenv
from genalog.ocr.blob_client import GrokBlobClient
from genalog.ocr.grok import Grok
import requests
import pytest
import time
import json
import os
from dotenv import load_dotenv
load_dotenv("tests/ocr/.env")
class TestBlobClient:
@pytest.mark.parametrize("use_async",[True, False])
@pytest.mark.parametrize("use_async", [True, False])
def test_upload_images(self, use_async):
blob_client = GrokBlobClient.create_from_env_var()
subfolder = "tests/ocr/data/img"
file_prefix = subfolder.replace("/", "_")
dst_folder, _ = blob_client.upload_images_to_blob(subfolder, use_async=use_async)
subfolder.replace("/", "_")
dst_folder, _ = blob_client.upload_images_to_blob(
subfolder, use_async=use_async
)
uploaded_items, _ = blob_client.list_blobs(dst_folder)
uploaded_items = sorted(list(uploaded_items), key = lambda x : x.name)
uploaded_items = sorted(list(uploaded_items), key=lambda x: x.name)
assert uploaded_items[0].name == f"{dst_folder}/0.png"
assert uploaded_items[1].name == f"{dst_folder}/1.png"
assert uploaded_items[2].name == f"{dst_folder}/11.png"
blob_client.delete_blobs_folder(dst_folder)
assert len(list(blob_client.list_blobs(dst_folder)[0])) == 0, f"folder {dst_folder} was not deleted"
assert (
len(list(blob_client.list_blobs(dst_folder)[0])) == 0
), f"folder {dst_folder} was not deleted"
dst_folder, _ = blob_client.upload_images_to_blob(subfolder, "test_images", use_async=use_async)
assert dst_folder == "test_images"
dst_folder, _ = blob_client.upload_images_to_blob(
subfolder, "test_images", use_async=use_async
)
assert dst_folder == "test_images"
uploaded_items, _ = blob_client.list_blobs(dst_folder)
uploaded_items = sorted(list(uploaded_items), key = lambda x : x.name)
uploaded_items = sorted(list(uploaded_items), key=lambda x: x.name)
assert uploaded_items[0].name == f"{dst_folder}/0.png"
assert uploaded_items[1].name == f"{dst_folder}/1.png"
assert uploaded_items[2].name == f"{dst_folder}/11.png"
blob_client.delete_blobs_folder(dst_folder)
assert len(list(blob_client.list_blobs(dst_folder)[0])) == 0, f"folder {dst_folder} was not deleted"
assert (
len(list(blob_client.list_blobs(dst_folder)[0])) == 0
), f"folder {dst_folder} was not deleted"
class TestGROKe2e:
@pytest.mark.parametrize("use_async",[False,True])
@pytest.mark.parametrize("use_async", [False, True])
def test_grok_e2e(self, tmpdir, use_async):
grok = Grok.create_from_env_var()
src_folder = "tests/ocr/data/img"
grok.run_grok(src_folder, tmpdir, blob_dest_folder="testimages", use_async=use_async, cleanup=True)
json_folder = "tests/ocr/data/json"
json_hash = "521c38122f783673598856cd81d91c21"
assert json.load(open(f"{tmpdir}/0.json", "r"))[0]["text"]
grok.run_grok(
src_folder,
tmpdir,
blob_dest_folder="testimages",
use_async=use_async,
cleanup=True,
)
assert json.load(open(f"{tmpdir}/0.json", "r"))[0]["text"]
assert json.load(open(f"{tmpdir}/1.json", "r"))[0]["text"]
assert json.load(open(f"{tmpdir}/11.json", "r"))[0]["text"]

Просмотреть файл

@ -1,32 +1,43 @@
from genalog import pipeline
import pytest
import glob
import pytest
from genalog import pipeline
EXAMPLE_TEXT_FILE = "tests/text/data/gt_1.txt"
@pytest.fixture
def default_analog_generator():
return pipeline.AnalogDocumentGeneration()
@pytest.fixture
def custom_analog_generator():
custom_styles = {"font_size": ["5px"]}
custom_degradation = [("blur", {"radius": 3})]
return pipeline.AnalogDocumentGeneration(
styles=custom_styles,
degradations=custom_degradation,
resolution=300)
styles=custom_styles, degradations=custom_degradation, resolution=300
)
def test_default_generate_img(default_analog_generator):
example_template = default_analog_generator.list_templates()[0]
img_array = default_analog_generator.generate_img(EXAMPLE_TEXT_FILE, example_template, target_folder=None)
default_analog_generator.generate_img(
EXAMPLE_TEXT_FILE, example_template, target_folder=None
)
def test_custom_generate_img(custom_analog_generator):
example_template = custom_analog_generator.list_templates()[0]
img_array = custom_analog_generator.generate_img(EXAMPLE_TEXT_FILE, example_template, target_folder=None)
custom_analog_generator.generate_img(
EXAMPLE_TEXT_FILE, example_template, target_folder=None
)
def test_generate_dataset_multiprocess():
INPUT_TEXT_FILENAMES = glob.glob("tests/text/data/gt_*.txt")
with pytest.deprecated_call():
pipeline.generate_dataset_multiprocess(INPUT_TEXT_FILENAMES, "test_out", {}, [], "text_block.html.jinja")
pipeline.generate_dataset_multiprocess(
INPUT_TEXT_FILENAMES, "test_out", {}, [], "text_block.html.jinja"
)

Просмотреть файл

@ -1,7 +1,8 @@
import os
import difflib
import os
from genalog.text.splitter import CONLL2003_DOC_SEPERATOR, generate_splits
from genalog.text.splitter import generate_splits, CONLL2003_DOC_SEPERATOR
def _compare_content(file1, file2):
txt1 = open(file1, "r").read()
@ -12,15 +13,32 @@ def _compare_content(file1, file2):
str_diff = "\n".join(difflib.unified_diff(sentences_txt1, sentences_txt2))
assert False, f"Delta between outputs: \n {str_diff}"
def test_splitter(tmpdir):
# tmpdir = "test_out"
os.makedirs(f"{tmpdir}/clean_labels")
os.makedirs(f"{tmpdir}/clean_text")
generate_splits("tests/e2e/data/splitter/example_conll2012.txt", tmpdir,
doc_seperator=CONLL2003_DOC_SEPERATOR, sentence_seperator="")
_compare_content("tests/e2e/data/splitter/example_splits/clean_text/0.txt", f"{tmpdir}/clean_text/0.txt")
_compare_content("tests/e2e/data/splitter/example_splits/clean_text/1.txt", f"{tmpdir}/clean_text/1.txt")
_compare_content("tests/e2e/data/splitter/example_splits/clean_labels/0.txt", f"{tmpdir}/clean_labels/0.txt")
_compare_content("tests/e2e/data/splitter/example_splits/clean_labels/1.txt", f"{tmpdir}/clean_labels/1.txt")
generate_splits(
"tests/e2e/data/splitter/example_conll2012.txt",
tmpdir,
doc_seperator=CONLL2003_DOC_SEPERATOR,
sentence_seperator="",
)
_compare_content(
"tests/e2e/data/splitter/example_splits/clean_text/0.txt",
f"{tmpdir}/clean_text/0.txt",
)
_compare_content(
"tests/e2e/data/splitter/example_splits/clean_text/1.txt",
f"{tmpdir}/clean_text/1.txt",
)
_compare_content(
"tests/e2e/data/splitter/example_splits/clean_labels/0.txt",
f"{tmpdir}/clean_labels/0.txt",
)
_compare_content(
"tests/e2e/data/splitter/example_splits/clean_labels/1.txt",
f"{tmpdir}/clean_labels/1.txt",
)

Просмотреть файл

@ -1,62 +1,76 @@
from genalog.generation.content import *
import pytest
from genalog.generation.content import CompositeContent, Content, ContentType
from genalog.generation.content import Paragraph, Title
CONTENT_LIST = ["foo", "bar"]
COMPOSITE_CONTENT_TYPE = [ContentType.TITLE, ContentType.PARAGRAPH]
TEXT = "foo bar"
@pytest.fixture
def content_base_class():
return Content()
@pytest.fixture
def paragraph():
return Paragraph(TEXT)
@pytest.fixture
def title():
return Title(TEXT)
@pytest.fixture
def section():
return CompositeContent(CONTENT_LIST, COMPOSITE_CONTENT_TYPE)
def test_content_set_content_type(content_base_class):
with pytest.raises(TypeError):
content_base_class.set_content_type("NOT VALID CONTENT TYPE")
content_base_class.set_content_type(ContentType.PARAGRAPH)
def test_paragraph_init(paragraph):
with pytest.raises(TypeError):
Paragraph([])
assert paragraph.content_type == ContentType.PARAGRAPH
def test_paragraph_print(paragraph):
assert paragraph.__str__()
def test_paragraph_iterable_indexable(paragraph):
for index, character in enumerate(paragraph):
assert character == paragraph[index]
def test_title_init(title):
with pytest.raises(TypeError):
Title([])
assert title.content_type == ContentType.TITLE
def test_title_iterable_indexable(title):
for index, character in enumerate(title):
assert character == title[index]
def test_composite_content_init(section):
with pytest.raises(TypeError):
CompositeContent((),[])
CompositeContent((), [])
assert section.content_type == ContentType.COMPOSITE
def test_composite_content_iterable(section):
for index, content in enumerate(section):
assert content.content_type == COMPOSITE_CONTENT_TYPE[index]
def test_composite_content_print(section):
assert "foo" in section.__str__()
assert "bar" in section.__str__()

Просмотреть файл

@ -1,8 +1,10 @@
from genalog.generation.document import Document, DocumentGenerator
from genalog.generation.document import DEFAULT_DOCUMENT_STYLE
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import MagicMock, patch
from genalog.generation.document import DEFAULT_DOCUMENT_STYLE
from genalog.generation.document import Document, DocumentGenerator
FRENCH = "fr"
CONTENT = ["some text"]
@ -21,84 +23,106 @@ DEFAULT_TEMPLATE_NAME = "text_block.html.jinja"
DEFAULT_PACKAGE_NAME = "genalog.generation"
DEFAULT_TEMPLATE_FOLDER = "templates"
@pytest.fixture
def default_document():
mock_jinja_template = MagicMock()
mock_jinja_template.render.return_value = MOCK_COMPILED_DOCUMENT
return Document(CONTENT, mock_jinja_template)
@pytest.fixture
def french_document():
mock_jinja_template = MagicMock()
mock_jinja_template.render.return_value = MOCK_COMPILED_DOCUMENT
return Document(CONTENT, mock_jinja_template, language=FRENCH)
def test_document_init(default_document):
assert default_document.styles == DEFAULT_DOCUMENT_STYLE
assert default_document._document != None
assert default_document.compiled_html != None
assert default_document._document is not None
assert default_document.compiled_html is not None
def test_document_init_with_kwargs(french_document):
assert french_document.styles["language"] == FRENCH
assert french_document._document != None
assert french_document.compiled_html != None
assert french_document._document is not None
assert french_document.compiled_html is not None
def test_document_render_html(french_document):
compiled_document = french_document.render_html()
assert compiled_document == MOCK_COMPILED_DOCUMENT
french_document.template.render.assert_called_with(content=CONTENT, **french_document.styles)
french_document.template.render.assert_called_with(
content=CONTENT, **french_document.styles
)
def test_document_render_pdf(default_document):
default_document._document = MagicMock()
# run tested function
default_document.render_pdf(target=FILE_DESTINATION_PDF, zoom=2)
default_document._document.write_pdf.assert_called_with(target=FILE_DESTINATION_PDF, zoom=2)
default_document._document.write_pdf.assert_called_with(
target=FILE_DESTINATION_PDF, zoom=2
)
def test_document_render_png(default_document):
default_document._document = MagicMock()
# run tested function
default_document.render_png(target=FILE_DESTINATION_PNG, resolution=100)
default_document._document.write_png.assert_called_with(target=FILE_DESTINATION_PNG, resolution=100)
default_document._document.write_png.assert_called_with(
target=FILE_DESTINATION_PNG, resolution=100
)
def test_document_render_png_split_pages(default_document):
default_document._document.copy = MagicMock()
# run tested function
default_document.render_png(target=FILE_DESTINATION_PNG, split_pages=True, resolution=100)
default_document.render_png(
target=FILE_DESTINATION_PNG, split_pages=True, resolution=100
)
result_destination = FILE_DESTINATION_PNG.replace(".png", "_pg_0.png")
# assertion
document_copy = default_document._document.copy.return_value
document_copy.write_png.assert_called_with(target=result_destination, resolution=100)
document_copy.write_png.assert_called_with(
target=result_destination, resolution=100
)
def test_document_render_array_valid_args(default_document):
# setup mock
mock_surface = MagicMock()
mock_surface.get_format.return_value = 0 # 0 == cairocffi.FORMAT_ARGB32
mock_surface.get_format.return_value = 0 # 0 == cairocffi.FORMAT_ARGB32
mock_surface.get_data = MagicMock(return_value=IMG_BYTES) # loading a 2x2 image
mock_write_image_surface = MagicMock(return_value=(mock_surface, 2, 2))
default_document._document.write_image_surface = mock_write_image_surface
channel_types = ["RGBA", "RGB", "GRAYSCALE", "BGRA", "BGR"]
expected_img_shape = [(2,2,4), (2,2,3), (2,2), (2,2,4), (2,2,3)]
expected_img_shape = [(2, 2, 4), (2, 2, 3), (2, 2), (2, 2, 4), (2, 2, 3)]
for channel_type, expected_img_shape in zip(channel_types, expected_img_shape):
img_array = default_document.render_array(resolution=100, channel=channel_type)
assert img_array.shape == expected_img_shape
def test_document_render_array_invalid_args(default_document):
invalid_channel_types = "INVALID"
with pytest.raises(ValueError):
default_document.render_array(resolution=100, channel=invalid_channel_types)
def test_document_render_array_invalid_format(default_document):
# setup mock
mock_surface = MagicMock()
mock_surface.get_format.return_value = 1 # 1 != cairocffi.FORMAT_ARGB32
mock_surface.get_format.return_value = 1 # 1 != cairocffi.FORMAT_ARGB32
mock_write_image_surface = MagicMock(return_value=(mock_surface, 2, 2))
default_document._document.write_image_surface = mock_write_image_surface
with pytest.raises(RuntimeError):
default_document.render_array(resolution=100)
def test_document_update_style(default_document):
new_style = {"language": FRENCH, "new_property": "some value"}
# Ensure that a new property is not already defined
@ -111,11 +135,13 @@ def test_document_update_style(default_document):
# Ensure that a new property is added
assert default_document.styles["new_property"] == new_style["new_property"]
@patch("genalog.generation.document.Environment")
@patch("genalog.generation.document.PackageLoader")
@patch("genalog.generation.document.FileSystemLoader")
def test_document_generator_init_default_setting(mock_file_system_loader,
mock_package_loader, mock_environment):
def test_document_generator_init_default_setting(
mock_file_system_loader, mock_package_loader, mock_environment
):
# setup mock template environment
mock_environment_instance = mock_environment.return_value
mock_environment_instance.list_templates.return_value = [DEFAULT_TEMPLATE_NAME]
@ -123,15 +149,19 @@ def test_document_generator_init_default_setting(mock_file_system_loader,
document_generator = DocumentGenerator()
# Ensure the right loader is called
mock_file_system_loader.assert_not_called()
mock_package_loader.assert_called_with(DEFAULT_PACKAGE_NAME, DEFAULT_TEMPLATE_FOLDER)
mock_package_loader.assert_called_with(
DEFAULT_PACKAGE_NAME, DEFAULT_TEMPLATE_FOLDER
)
# Ensure that the default template in the package is loaded
assert DEFAULT_TEMPLATE_NAME in document_generator.template_list
@patch("genalog.generation.document.Environment")
@patch("genalog.generation.document.PackageLoader")
@patch("genalog.generation.document.FileSystemLoader")
def test_document_generator_init_custom_template(mock_file_system_loader,
mock_package_loader, mock_environment):
def test_document_generator_init_custom_template(
mock_file_system_loader, mock_package_loader, mock_environment
):
# setup mock template environment
mock_environment_instance = mock_environment.return_value
mock_environment_instance.list_templates.return_value = [CUSTOM_TEMPLATE_NAME]
@ -143,67 +173,79 @@ def test_document_generator_init_custom_template(mock_file_system_loader,
# Ensure that the expected template is registered
assert CUSTOM_TEMPLATE_NAME in document_generator.template_list
@pytest.fixture
def default_document_generator():
with patch("genalog.generation.document.Environment") as MockEnvironment:
template_environment_instance = MockEnvironment.return_value
template_environment_instance.list_templates.return_value = [DEFAULT_TEMPLATE_NAME]
template_environment_instance.list_templates.return_value = [
DEFAULT_TEMPLATE_NAME
]
template_environment_instance.get_template.return_value = MOCK_TEMPLATE
doc_gen = DocumentGenerator()
return doc_gen
return doc_gen
def test_document_generator_create_generator(default_document_generator):
available_templates = default_document_generator.template_list
assert len(available_templates) < 2
generator = default_document_generator.create_generator(CONTENT, available_templates)
doc = next(generator)
generator = default_document_generator.create_generator(
CONTENT, available_templates
)
next(generator)
with pytest.raises(StopIteration):
next(generator)
def test_document_generator_create_generator_(default_document_generator):
# setup test case
available_templates = default_document_generator.template_list
undefined_template = "NOT A VALID TEMPLATE"
assert undefined_template not in available_templates
generator = default_document_generator.create_generator(CONTENT, [undefined_template])
generator = default_document_generator.create_generator(
CONTENT, [undefined_template]
)
with pytest.raises(FileNotFoundError):
doc = next(generator)
next(generator)
@pytest.mark.parametrize("template_name, expected_output", [
("base.html.jinja", False),
("text_block.html.jinja", True),
("text_block.css.jinja", False),
("macro/dimension.css.jinja", False)
])
@pytest.mark.parametrize(
"template_name, expected_output",
[
("base.html.jinja", False),
("text_block.html.jinja", True),
("text_block.css.jinja", False),
("macro/dimension.css.jinja", False),
],
)
def test__keep_templates(template_name, expected_output):
output = DocumentGenerator._keep_template(template_name)
assert output == expected_output
def test_set_styles_to_generate(default_document_generator):
assert len(default_document_generator.styles_to_generate) == 1
default_document_generator.set_styles_to_generate({"foo": ["bar", "bar"]})
assert len(default_document_generator.styles_to_generate) == 2
@pytest.mark.parametrize("styles, expected_output", [
({}, []), # empty case
({"size": ["10px"], "color": [] }, []), #empty value will result in null combinations
(
{"size": ["10px"], "color": ["red"] },
[{"size":"10px", "color":"red"}]
),
(
{"size": ["5px", "10px"]},
[{"size": "5px"}, {"size": "10px"}]
@pytest.mark.parametrize(
"styles, expected_output",
[
({}, []), # empty case
(
{"size": ["10px"], "color": []},
[],
), # empty value will result in null combinations
({"size": ["10px"], "color": ["red"]}, [{"size": "10px", "color": "red"}]),
({"size": ["5px", "10px"]}, [{"size": "5px"}, {"size": "10px"}]),
(
{"size": ["10px", "15px"], "color": ["blue"]},
[{"size": "10px", "color": "blue"}, {"size": "15px", "color": "blue"}],
),
(
{"size": ["10px", "15px"], "color": ["blue"] },
[
{"size":"10px", "color": "blue"},
{"size":"15px", "color": "blue"}
]
)
])
],
)
def test_document_generator_expand_style_combinations(styles, expected_output):
output = DocumentGenerator.expand_style_combinations(styles)
assert output == expected_output
assert output == expected_output

Просмотреть файл

@ -1,212 +1,611 @@
from genalog.ocr.metrics import get_align_stats, get_editops_stats, get_metrics, get_stats
from genalog.text.anchor import align_w_anchor
from genalog.text.alignment import GAP_CHAR, align
from genalog.text.ner_label import _find_gap_char_candidates
from pandas._testing import assert_frame_equal
import pytest
import genalog.ocr.metrics
import pandas as pd
import numpy as np
import json
import pickle
import os
import genalog.ocr.metrics
from genalog.ocr.metrics import get_align_stats, get_editops_stats, get_stats
from genalog.text.alignment import align, GAP_CHAR
from genalog.text.ner_label import _find_gap_char_candidates
genalog.ocr.metrics.LOG_LEVEL = 0
@pytest.mark.parametrize("src_string, target, expected_stats",
[
("a worn coat", "a wom coat",
{'edit_insert': 1, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
(" ", "a",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("a", " ",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("a", "a",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 0, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("ab", "ac",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("ac", "ab",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("New York is big.", "N ewYork kis big.",
{'edit_insert': 0, 'edit_delete': 1, 'edit_replace': 0, 'edit_insert_spacing': 1, 'edit_delete_spacing': 1}),
("B oston grea t", "Boston is great",
{'edit_insert': 0, 'edit_delete': 2, 'edit_replace': 0, 'edit_insert_spacing': 2, 'edit_delete_spacing': 1}),
("New York is big.", "N ewyork kis big",
{'edit_insert': 1, 'edit_delete': 1, 'edit_replace': 1, 'edit_insert_spacing': 1, 'edit_delete_spacing': 1}),
("dog", "d@g", # Test against default gap_char "@"
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
("some@one.com", "some@one.com",
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 0, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0})
])
def test_editops_stats(src_string, target, expected_stats):
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
alignment = align(target, src_string)
stats, actions = get_editops_stats(alignment, gap_char)
for k in expected_stats:
assert stats[k] == expected_stats[k], (k,stats[k], expected_stats[k])
@pytest.mark.parametrize("src_string, target, expected_stats, expected_substitutions",
[
(
"a worn coat", "a wom coat",
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 11, 'total_words': 3,
'matching_chars': 9, 'matching_words': 2,"matching_alnum_words" : 2, 'word_accuracy': 2/3, 'char_accuracy': 9/11},
{('rn', 'm'): 1}
),
(
"a c", "def",
{'insert': 0, 'delete': 0, 'replace': 1 , 'spacing': 0, 'total_chars': 3, 'total_words': 1,
'matching_chars': 0, 'matching_words': 0,"matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 0},
{('a c', 'def'): 1}
),
(
"a", "a b",
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 2,
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 0.5, 'char_accuracy': 1/3},
{}
),
(
"a b", "b",
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 1,
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1, 'char_accuracy': 1/3},
{}
),
(
"a b", "a",
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 1,
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1, 'char_accuracy': 1/3},
{}
),
(
"b ..", "a b ..",
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 1, 'total_chars': 6, 'total_words': 3, 'total_alnum_words': 2,
'matching_chars': 4, 'matching_words': 2, "matching_alnum_words" : 1, 'word_accuracy': 2/3, 'char_accuracy': 4/6},
{}
),
(
"taxi cab", "taxl c b",
{'insert': 0, 'delete': 1, 'replace': 1, 'spacing': 1, 'total_chars': 9, 'total_words': 3,
'matching_chars': 6, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 6/9},
{('i','l'):1}
),
(
"taxl c b ri de", "taxi cab ride",
{'insert': 1, 'delete': 0, 'replace': 1, 'spacing': 6, 'total_chars': 18, 'total_words': 3,
'matching_chars': 11, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 11/18},
{('l','i'):1}
),
(
"ab", "ac",
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 2, 'total_words': 1,
'matching_chars': 1, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0.0, 'char_accuracy': 0.5},
{}
),
(
"a", "a",
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 1, 'total_words': 1,
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1.0, 'char_accuracy': 1.0},
{}
),
(
"New York is big.", "N ewYork kis big.",
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 2, 'total_chars': 17, 'total_words': 4,
'matching_chars': 15, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1/4, 'char_accuracy': 15/17},
{}
),
(
"B oston grea t", "Boston is great",
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 3, 'total_chars': 15, 'total_words': 3,
'matching_chars': 12, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0.0, 'char_accuracy': 0.8},
{}
),
(
"New York is big.", "N ewyork kis big",
{'insert': 1, 'delete': 1, 'replace': 1, 'spacing': 2, 'total_chars': 16, 'total_words': 4,
'matching_chars': 13, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 13/16},
{('Y', 'y'): 1}
),
(
"dog", "d@g",
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
{('o', '@'): 1}
),
(
"some@one.com", "some@one.com",
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 12, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 12, 'matching_alnum_words': 1, 'matching_words': 1, 'alnum_word_accuracy': 1.0, 'word_accuracy': 1.0, 'char_accuracy': 1.0}, {}
@pytest.mark.parametrize(
"src_string, target, expected_stats",
[
(
"a worn coat",
"a wom coat",
{
"edit_insert": 1,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
" ",
"a",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"a",
" ",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"a",
"a",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 0,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"ab",
"ac",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"ac",
"ab",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"New York is big.",
"N ewYork kis big.",
{
"edit_insert": 0,
"edit_delete": 1,
"edit_replace": 0,
"edit_insert_spacing": 1,
"edit_delete_spacing": 1,
},
),
(
"B oston grea t",
"Boston is great",
{
"edit_insert": 0,
"edit_delete": 2,
"edit_replace": 0,
"edit_insert_spacing": 2,
"edit_delete_spacing": 1,
},
),
(
"New York is big.",
"N ewyork kis big",
{
"edit_insert": 1,
"edit_delete": 1,
"edit_replace": 1,
"edit_insert_spacing": 1,
"edit_delete_spacing": 1,
},
),
(
"dog",
"d@g", # Test against default gap_char "@"
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 1,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
(
"some@one.com",
"some@one.com",
{
"edit_insert": 0,
"edit_delete": 0,
"edit_replace": 0,
"edit_insert_spacing": 0,
"edit_delete_spacing": 0,
},
),
],
)
def test_editops_stats(src_string, target, expected_stats):
gap_char_candidates, input_char_set = _find_gap_char_candidates(
[src_string], [target]
)
])
gap_char = (
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
)
alignment = align(target, src_string)
stats, actions = get_editops_stats(alignment, gap_char)
for k in expected_stats:
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
@pytest.mark.parametrize(
"src_string, target, expected_stats, expected_substitutions",
[
(
"a worn coat",
"a wom coat",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 11,
"total_words": 3,
"matching_chars": 9,
"matching_words": 2,
"matching_alnum_words": 2,
"word_accuracy": 2 / 3,
"char_accuracy": 9 / 11,
},
{("rn", "m"): 1},
),
(
"a c",
"def",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 3,
"total_words": 1,
"matching_chars": 0,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0,
"char_accuracy": 0,
},
{("a c", "def"): 1},
),
(
"a",
"a b",
{
"insert": 1,
"delete": 0,
"replace": 0,
"spacing": 1,
"total_chars": 3,
"total_words": 2,
"matching_chars": 1,
"matching_words": 1,
"matching_alnum_words": 1,
"word_accuracy": 0.5,
"char_accuracy": 1 / 3,
},
{},
),
(
"a b",
"b",
{
"insert": 0,
"delete": 1,
"replace": 0,
"spacing": 1,
"total_chars": 3,
"total_words": 1,
"matching_chars": 1,
"matching_words": 1,
"matching_alnum_words": 1,
"word_accuracy": 1,
"char_accuracy": 1 / 3,
},
{},
),
(
"a b",
"a",
{
"insert": 0,
"delete": 1,
"replace": 0,
"spacing": 1,
"total_chars": 3,
"total_words": 1,
"matching_chars": 1,
"matching_words": 1,
"matching_alnum_words": 1,
"word_accuracy": 1,
"char_accuracy": 1 / 3,
},
{},
),
(
"b ..",
"a b ..",
{
"insert": 1,
"delete": 0,
"replace": 0,
"spacing": 1,
"total_chars": 6,
"total_words": 3,
"total_alnum_words": 2,
"matching_chars": 4,
"matching_words": 2,
"matching_alnum_words": 1,
"word_accuracy": 2 / 3,
"char_accuracy": 4 / 6,
},
{},
),
(
"taxi cab",
"taxl c b",
{
"insert": 0,
"delete": 1,
"replace": 1,
"spacing": 1,
"total_chars": 9,
"total_words": 3,
"matching_chars": 6,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0,
"char_accuracy": 6 / 9,
},
{("i", "l"): 1},
),
(
"taxl c b ri de",
"taxi cab ride",
{
"insert": 1,
"delete": 0,
"replace": 1,
"spacing": 6,
"total_chars": 18,
"total_words": 3,
"matching_chars": 11,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0,
"char_accuracy": 11 / 18,
},
{("l", "i"): 1},
),
(
"ab",
"ac",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 2,
"total_words": 1,
"matching_chars": 1,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0.0,
"char_accuracy": 0.5,
},
{},
),
(
"a",
"a",
{
"insert": 0,
"delete": 0,
"replace": 0,
"spacing": 0,
"total_chars": 1,
"total_words": 1,
"matching_chars": 1,
"matching_words": 1,
"matching_alnum_words": 1,
"word_accuracy": 1.0,
"char_accuracy": 1.0,
},
{},
),
(
"New York is big.",
"N ewYork kis big.",
{
"insert": 1,
"delete": 0,
"replace": 0,
"spacing": 2,
"total_chars": 17,
"total_words": 4,
"matching_chars": 15,
"matching_words": 1,
"matching_alnum_words": 1,
"word_accuracy": 1 / 4,
"char_accuracy": 15 / 17,
},
{},
),
(
"B oston grea t",
"Boston is great",
{
"insert": 1,
"delete": 0,
"replace": 0,
"spacing": 3,
"total_chars": 15,
"total_words": 3,
"matching_chars": 12,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0.0,
"char_accuracy": 0.8,
},
{},
),
(
"New York is big.",
"N ewyork kis big",
{
"insert": 1,
"delete": 1,
"replace": 1,
"spacing": 2,
"total_chars": 16,
"total_words": 4,
"matching_chars": 13,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0,
"char_accuracy": 13 / 16,
},
{("Y", "y"): 1},
),
(
"dog",
"d@g",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 3,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 2,
"matching_alnum_words": 0,
"matching_words": 0,
"alnum_word_accuracy": 0.0,
"word_accuracy": 0.0,
"char_accuracy": 2 / 3,
},
{("o", "@"): 1},
),
(
"some@one.com",
"some@one.com",
{
"insert": 0,
"delete": 0,
"replace": 0,
"spacing": 0,
"total_chars": 12,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 12,
"matching_alnum_words": 1,
"matching_words": 1,
"alnum_word_accuracy": 1.0,
"word_accuracy": 1.0,
"char_accuracy": 1.0,
},
{},
),
],
)
def test_align_stats(src_string, target, expected_stats, expected_substitutions):
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
gap_char_candidates, input_char_set = _find_gap_char_candidates(
[src_string], [target]
)
gap_char = (
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
)
alignment = align(src_string, target, gap_char=gap_char)
stats, substitution_dict = get_align_stats(alignment, src_string, target, gap_char)
for k in expected_stats:
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
for k in expected_substitutions:
assert substitution_dict[k] == expected_substitutions[k], (substitution_dict, expected_substitutions)
assert substitution_dict[k] == expected_substitutions[k], (
substitution_dict,
expected_substitutions,
)
@pytest.mark.parametrize("src_string, target, expected_stats, expected_substitutions, expected_actions", [
(
"ab", "a",
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 0, 'total_chars': 2, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 1, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 1/2},
{},
{1: 'D'}
),
(
"ab", "abb",
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
{},
{2: ('I', 'b')}
),
(
"ab", "ac",
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 2, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 1, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 1/2},
{('b', 'c'): 1},
{1: ('R', 'c')}
),
(
"New York is big.", "N ewyork kis big",
{'insert': 1, 'delete': 1, 'replace': 1, 'spacing': 2, 'total_chars': 16, 'total_words': 4,
'matching_chars': 13, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 13/16},
{('Y', 'y'): 1},
{1: ('I', ' '), 4: 'D', 5: ('R', 'y'), 10: ('I', 'k'), 17: 'D'}
),
(
"dog", "d@g",
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
{('o', '@'): 1},
{1: ('R', '@')}
),
(
"some@one.com", "some@one.com",
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 12, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 12, 'matching_alnum_words': 1, 'matching_words': 1, 'alnum_word_accuracy': 1.0, 'word_accuracy': 1.0, 'char_accuracy': 1.0},
{},
{}
)
])
def test_get_stats(src_string, target, expected_stats, expected_substitutions, expected_actions ):
@pytest.mark.parametrize(
"src_string, target, expected_stats, expected_substitutions, expected_actions",
[
(
"ab",
"a",
{
"insert": 0,
"delete": 1,
"replace": 0,
"spacing": 0,
"total_chars": 2,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 1,
"matching_alnum_words": 0,
"matching_words": 0,
"alnum_word_accuracy": 0.0,
"word_accuracy": 0.0,
"char_accuracy": 1 / 2,
},
{},
{1: "D"},
),
(
"ab",
"abb",
{
"insert": 1,
"delete": 0,
"replace": 0,
"spacing": 0,
"total_chars": 3,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 2,
"matching_alnum_words": 0,
"matching_words": 0,
"alnum_word_accuracy": 0.0,
"word_accuracy": 0.0,
"char_accuracy": 2 / 3,
},
{},
{2: ("I", "b")},
),
(
"ab",
"ac",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 2,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 1,
"matching_alnum_words": 0,
"matching_words": 0,
"alnum_word_accuracy": 0.0,
"word_accuracy": 0.0,
"char_accuracy": 1 / 2,
},
{("b", "c"): 1},
{1: ("R", "c")},
),
(
"New York is big.",
"N ewyork kis big",
{
"insert": 1,
"delete": 1,
"replace": 1,
"spacing": 2,
"total_chars": 16,
"total_words": 4,
"matching_chars": 13,
"matching_words": 0,
"matching_alnum_words": 0,
"word_accuracy": 0,
"char_accuracy": 13 / 16,
},
{("Y", "y"): 1},
{1: ("I", " "), 4: "D", 5: ("R", "y"), 10: ("I", "k"), 17: "D"},
),
(
"dog",
"d@g",
{
"insert": 0,
"delete": 0,
"replace": 1,
"spacing": 0,
"total_chars": 3,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 2,
"matching_alnum_words": 0,
"matching_words": 0,
"alnum_word_accuracy": 0.0,
"word_accuracy": 0.0,
"char_accuracy": 2 / 3,
},
{("o", "@"): 1},
{1: ("R", "@")},
),
(
"some@one.com",
"some@one.com",
{
"insert": 0,
"delete": 0,
"replace": 0,
"spacing": 0,
"total_chars": 12,
"total_words": 1,
"total_alnum_words": 1,
"matching_chars": 12,
"matching_alnum_words": 1,
"matching_words": 1,
"alnum_word_accuracy": 1.0,
"word_accuracy": 1.0,
"char_accuracy": 1.0,
},
{},
{},
),
],
)
def test_get_stats(
src_string, target, expected_stats, expected_substitutions, expected_actions
):
stats, substitution_dict, actions = get_stats(target, src_string)
for k in expected_stats:
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
for k in expected_substitutions:
assert substitution_dict[k] == expected_substitutions[k], (substitution_dict, expected_substitutions)
assert substitution_dict[k] == expected_substitutions[k], (
substitution_dict,
expected_substitutions,
)
for k in expected_actions:
assert actions[k] == expected_actions[k], (k, actions[k], expected_actions[k])
@pytest.mark.parametrize("src_string, target, expected_actions",
[
("dog and cat", "g and at",
{0: ('I', 'd'), 1: ('I', 'o'), 8:('I', 'c')}),
])
@pytest.mark.parametrize(
"src_string, target, expected_actions",
[
("dog and cat", "g and at", {0: ("I", "d"), 1: ("I", "o"), 8: ("I", "c")}),
],
)
def test_actions_stats(src_string, target, expected_actions):
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
gap_char_candidates, input_char_set = _find_gap_char_candidates(
[src_string], [target]
)
gap_char = (
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
)
alignment = align(target, src_string, gap_char=gap_char)
_ , actions = get_editops_stats(alignment, gap_char)
_, actions = get_editops_stats(alignment, gap_char)
print(actions)
for k in expected_actions:
assert actions[k] == expected_actions[k], (k,actions[k], expected_actions[k])
assert actions[k] == expected_actions[k], (k, actions[k], expected_actions[k])

Просмотреть файл

@ -1,13 +1,15 @@
from genalog.ocr.rest_client import GrokRestClient
from genalog.ocr.blob_client import GrokBlobClient
from genalog.ocr.grok import Grok
import requests
import pytest
import time
import json
import pytest
import requests
from dotenv import load_dotenv
from genalog.ocr.rest_client import GrokRestClient
load_dotenv("tests/ocr/.env")
@pytest.fixture(autouse=True)
def setup_monkeypatch(monkeypatch):
def mock_http(*args, **kwargs):
@ -21,7 +23,6 @@ def setup_monkeypatch(monkeypatch):
class MockedResponse:
def __init__(self, args, kwargs):
self.url = args[0]
self.text = "response"
@ -34,27 +35,39 @@ class MockedResponse:
if "search.windows.net/indexers/" in self.url:
if "status" in self.url:
return {
"lastResult": {"status": "success"},
"status": "finished"
}
return {"lastResult": {"status": "success"}, "status": "finished"}
return {}
if "search.windows.net/indexes/" in self.url:
if "docs/search" in self.url:
return {
"value" : [
{
"value": [
{
"metadata_storage_name": "521c38122f783673598856cd81d91c21_0.png",
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_0.png.json", "r"))
"layoutText": json.load(
open(
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_0.png.json",
"r",
)
),
},
{
"metadata_storage_name":"521c38122f783673598856cd81d91c21_1.png",
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_1.png.json", "r"))
{
"metadata_storage_name": "521c38122f783673598856cd81d91c21_1.png",
"layoutText": json.load(
open(
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_1.png.json",
"r",
)
),
},
{
"metadata_storage_name":"521c38122f783673598856cd81d91c21_11.png",
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_11.png.json", "r"))
{
"metadata_storage_name": "521c38122f783673598856cd81d91c21_11.png",
"layoutText": json.load(
open(
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_11.png.json",
"r",
)
),
},
]
}
@ -69,7 +82,6 @@ class MockedResponse:
class TestGROK:
def test_creating_indexing_pipeline(self):
grok_rest_client = GrokRestClient.create_from_env_var()
grok_rest_client.create_indexing_pipeline()
@ -78,11 +90,11 @@ class TestGROK:
def test_running_indexer(self):
grok_rest_client = GrokRestClient.create_from_env_var()
grok_rest_client.create_indexing_pipeline()
indexer_status = grok_rest_client.get_indexer_status()
if indexer_status["status"] == "error":
raise RuntimeError(f"indexer error: {indexer_status}")
# if not already running start the indexer
if indexer_status["lastResult"]["status"] != "inProgress":
grok_rest_client.run_indexer()
@ -90,5 +102,4 @@ class TestGROK:
grok_rest_client.run_indexer()
indexer_status = grok_rest_client.poll_indexer_till_complete()
assert indexer_status["lastResult"]["status"] == "success"
grok_rest_client.delete_indexer_pipeline()
grok_rest_client.delete_indexer_pipeline()

Просмотреть файл

@ -1,43 +1,59 @@
from genalog.text import alignment
from genalog.text.alignment import MATCH_REWARD, MISMATCH_PENALTY, GAP_PENALTY, GAP_EXT_PENALTY
from tests.cases.text_alignment import PARSE_ALIGNMENT_REGRESSION_TEST_CASES, ALIGNMENT_REGRESSION_TEST_CASES
import warnings
from random import randint
from unittest.mock import MagicMock
import pytest
import warnings
from genalog.text import alignment
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
from tests.cases.text_alignment import PARSE_ALIGNMENT_REGRESSION_TEST_CASES
RANDOM_INT = randint(1, 100)
MOCK_ALIGNMENT_RESULT = [("X", "X", 0, 0, 1)]
# Settup mock for third party library call
@pytest.fixture
def mock_pairwise2_align(monkeypatch):
mock = MagicMock()
def mock_globalcs(*args, **kwargs):
mock.globalcs(*args, **kwargs)
return MOCK_ALIGNMENT_RESULT
# replace target method reference with the mock method
monkeypatch.setattr("Bio.pairwise2.align.globalcs", mock_globalcs)
return mock
def test__align_seg(mock_pairwise2_align):
# setup method input
required_arg = ("A", "B")
optional_arg = (alignment.MATCH_REWARD, alignment.MISMATCH_PENALTY, alignment.GAP_PENALTY, alignment.GAP_EXT_PENALTY)
optional_kwarg = {"gap_char": alignment.GAP_CHAR, "one_alignment_only": alignment.ONE_ALIGNMENT_ONLY}
optional_arg = (
alignment.MATCH_REWARD,
alignment.MISMATCH_PENALTY,
alignment.GAP_PENALTY,
alignment.GAP_EXT_PENALTY,
)
optional_kwarg = {
"gap_char": alignment.GAP_CHAR,
"one_alignment_only": alignment.ONE_ALIGNMENT_ONLY,
}
# test method
result = alignment._align_seg(*required_arg + optional_arg, **optional_kwarg)
# assertion
mock_pairwise2_align.globalcs.assert_called()
assert result == MOCK_ALIGNMENT_RESULT
@pytest.mark.parametrize("alignments, target_num_tokens, raised_exception",
[
(MOCK_ALIGNMENT_RESULT, 1, None),
(MOCK_ALIGNMENT_RESULT, 2, ValueError),
([("X", "XY", 0, 0, 1)], 1, ValueError)
])
@pytest.mark.parametrize(
"alignments, target_num_tokens, raised_exception",
[
(MOCK_ALIGNMENT_RESULT, 1, None),
(MOCK_ALIGNMENT_RESULT, 2, ValueError),
([("X", "XY", 0, 0, 1)], 1, ValueError),
],
)
def test__select_alignment_candidates(alignments, target_num_tokens, raised_exception):
if raised_exception:
with pytest.raises(raised_exception):
@ -46,24 +62,27 @@ def test__select_alignment_candidates(alignments, target_num_tokens, raised_exce
result = alignment._select_alignment_candidates(alignments, target_num_tokens)
assert result == MOCK_ALIGNMENT_RESULT[0]
@pytest.mark.parametrize("s, index, desired_output, raised_exception",
[
# Test exceptions
("s", 2, None, IndexError),
("", -1, None, ValueError), # Empty case
# Index at start of string
(" token", 0, 2, None),
("\t\ntoken", 0, 2, None),
# Index reach end of string
("token ", 5, 5, None),
("token", 4, 4, None),
# Index in-between tokens
("token", 0, 0, None),
("t1 t2", 2, 7, None),
("t1 \t \n t2", 3, 7, None),
# Gap char
(" @", 0, 1, None),
])
@pytest.mark.parametrize(
"s, index, desired_output, raised_exception",
[
# Test exceptions
("s", 2, None, IndexError),
("", -1, None, ValueError), # Empty case
# Index at start of string
(" token", 0, 2, None),
("\t\ntoken", 0, 2, None),
# Index reach end of string
("token ", 5, 5, None),
("token", 4, 4, None),
# Index in-between tokens
("token", 0, 0, None),
("t1 t2", 2, 7, None),
("t1 \t \n t2", 3, 7, None),
# Gap char
(" @", 0, 1, None),
],
)
def test__find_token_start(s, index, desired_output, raised_exception):
if raised_exception:
with pytest.raises(raised_exception):
@ -72,25 +91,28 @@ def test__find_token_start(s, index, desired_output, raised_exception):
output = alignment._find_token_start(s, index)
assert output == desired_output
@pytest.mark.parametrize("s, index, desired_output, raised_exception",
[
# Test exceptions
("s", 2, None, IndexError),
("", -1, None, ValueError), # Empty case
# Index at start of string
(" ", 0, 0, None),
("\t\ntoken", 0, 0, None),
("token", 0, 4, None),
("token\t", 0, 5, None),
("token\n", 0, 5, None),
# Index reach end of string
("token ", 5, 5, None),
("token", 4, 4, None),
# Single Char
(".", 0, 0, None),
# Gap char
("@@ @", 0, 2, None),
])
@pytest.mark.parametrize(
"s, index, desired_output, raised_exception",
[
# Test exceptions
("s", 2, None, IndexError),
("", -1, None, ValueError), # Empty case
# Index at start of string
(" ", 0, 0, None),
("\t\ntoken", 0, 0, None),
("token", 0, 4, None),
("token\t", 0, 5, None),
("token\n", 0, 5, None),
# Index reach end of string
("token ", 5, 5, None),
("token", 4, 4, None),
# Single Char
(".", 0, 0, None),
# Gap char
("@@ @", 0, 2, None),
],
)
def test__find_token_end(s, index, desired_output, raised_exception):
if raised_exception:
with pytest.raises(raised_exception):
@ -99,61 +121,81 @@ def test__find_token_end(s, index, desired_output, raised_exception):
output = alignment._find_token_end(s, index)
assert output == desired_output
@pytest.mark.parametrize("s, start, desired_output",
[
("token", 0, (0,4)),
("token\t", 0, (0,5)),
("token \n", 0, (0,5)),
(" token ", 0, (1,6)),
# mix with GAP_CHAR
(" @@@@ ", 0, (1,5)),
("\n\t tok@n@@ \n\t", 0, (3,10)),
# single character string
("s", 0, (0,0)),
# punctuation
(" !,.: ", 0, (2,6))
])
@pytest.mark.parametrize(
"s, start, desired_output",
[
("token", 0, (0, 4)),
("token\t", 0, (0, 5)),
("token \n", 0, (0, 5)),
(" token ", 0, (1, 6)),
# mix with GAP_CHAR
(" @@@@ ", 0, (1, 5)),
("\n\t tok@n@@ \n\t", 0, (3, 10)),
# single character string
("s", 0, (0, 0)),
# punctuation
(" !,.: ", 0, (2, 6)),
],
)
def test__find_next_token(s, start, desired_output):
output = alignment._find_next_token(s, start)
assert output == desired_output
@pytest.mark.parametrize("token, desired_output",
[
# Valid tokens
("\n\t token.!,:\n\t ", True),
("token", True),
(" @@@t@@@ ", True),
("@@token@@", True),
(" @@token@@ ", True),
(f"t1{alignment.GAP_CHAR*RANDOM_INT}t2", True), #i.e. 't1@t2'
# Invalid tokens (i.e. multiples of the GAP_CHAR)
("", False),
(" ", False),
("@@", False),
(" @@ ", False),
("\t\n@", False),
(alignment.GAP_CHAR*1, False),
(alignment.GAP_CHAR*RANDOM_INT, False),
(f"\n\t {alignment.GAP_CHAR*RANDOM_INT} \n\t", False)
])
@pytest.mark.parametrize(
"token, desired_output",
[
# Valid tokens
("\n\t token.!,:\n\t ", True),
("token", True),
(" @@@t@@@ ", True),
("@@token@@", True),
(" @@token@@ ", True),
(f"t1{alignment.GAP_CHAR*RANDOM_INT}t2", True), # i.e. 't1@t2'
# Invalid tokens (i.e. multiples of the GAP_CHAR)
("", False),
(" ", False),
("@@", False),
(" @@ ", False),
("\t\n@", False),
(alignment.GAP_CHAR * 1, False),
(alignment.GAP_CHAR * RANDOM_INT, False),
(f"\n\t {alignment.GAP_CHAR*RANDOM_INT} \n\t", False),
],
)
def test__is_valid_token(token, desired_output):
result = alignment._is_valid_token(token)
assert result == desired_output
@pytest.mark.parametrize("aligned_gt, aligned_noise," +
"expected_gt_to_noise_map, expected_noise_to_gt_map",
PARSE_ALIGNMENT_REGRESSION_TEST_CASES)
def test_parse_alignment(aligned_gt, aligned_noise, expected_gt_to_noise_map, expected_noise_to_gt_map):
gt_to_noise_map, noise_to_gt_map = alignment.parse_alignment(aligned_gt, aligned_noise)
@pytest.mark.parametrize(
"aligned_gt, aligned_noise," + "expected_gt_to_noise_map, expected_noise_to_gt_map",
PARSE_ALIGNMENT_REGRESSION_TEST_CASES,
)
def test_parse_alignment(
aligned_gt, aligned_noise, expected_gt_to_noise_map, expected_noise_to_gt_map
):
gt_to_noise_map, noise_to_gt_map = alignment.parse_alignment(
aligned_gt, aligned_noise
)
assert gt_to_noise_map == expected_gt_to_noise_map
assert noise_to_gt_map == expected_noise_to_gt_map
@pytest.mark.parametrize("gt_txt, noisy_txt," +
"expected_aligned_gt, expected_aligned_noise",
ALIGNMENT_REGRESSION_TEST_CASES)
@pytest.mark.parametrize(
"gt_txt, noisy_txt," + "expected_aligned_gt, expected_aligned_noise",
ALIGNMENT_REGRESSION_TEST_CASES,
)
def test_align(gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
aligned_gt, aligned_noise = alignment.align(gt_txt, noisy_txt)
if aligned_gt != expected_aligned_gt:
expected_alignment = alignment._format_alignment(expected_aligned_gt, expected_aligned_noise)
expected_alignment = alignment._format_alignment(
expected_aligned_gt, expected_aligned_noise
)
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
)
)

Просмотреть файл

@ -1,43 +1,56 @@
import glob
import warnings
import pytest
from genalog.text import alignment, anchor, preprocess
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
import glob
import pytest
import warnings
@pytest.mark.parametrize("tokens, case_sensitive, desired_output", [
([], True, set()),
([], False, set()),
(["a", "A"], True, set(["a", "A"])),
(["a", "A"], False, set()),
(["An", "an", "ab"], True, set(["An", "an", "ab"])),
(["An", "an", "ab"], False, set(["ab"])),
])
@pytest.mark.parametrize(
"tokens, case_sensitive, desired_output",
[
([], True, set()),
([], False, set()),
(["a", "A"], True, set(["a", "A"])),
(["a", "A"], False, set()),
(["An", "an", "ab"], True, set(["An", "an", "ab"])),
(["An", "an", "ab"], False, set(["ab"])),
],
)
def test_get_unique_words(tokens, case_sensitive, desired_output):
output = anchor.get_unique_words(tokens, case_sensitive=case_sensitive)
assert desired_output == output
@pytest.mark.parametrize("tokens, desired_output", [
([], 0),
([""], 0),
(["a", "b"], 2),
(["abc.", "def!"], 8)
])
@pytest.mark.parametrize(
"tokens, desired_output",
[([], 0), ([""], 0), (["a", "b"], 2), (["abc.", "def!"], 8)],
)
def test_segment_len(tokens, desired_output):
output = anchor.segment_len(tokens)
assert desired_output == output
@pytest.mark.parametrize("unique_words, src_tokens, desired_output, raised_exception", [
(set(), [], [], None),
(set(), ["a"], [], None),
(set("a"), [], [], ValueError), # unique word not in src_tokens
(set("a"), ["b"], [], ValueError),
(set("a"), ["A"], [], ValueError), # case sensitive
(set("a"), ["an", "na", " a "], [], ValueError), # substring
(set("a"), ["a"], [("a", 0)], None), # valid input
(set("a"), ["c", "b", "a"], [("a", 2)], None), # multiple src_tokens
(set("ab"), ["c", "b", "a"], [("b", 1), ("a", 2)], None), # multiple matches ordered by index
])
@pytest.mark.parametrize(
"unique_words, src_tokens, desired_output, raised_exception",
[
(set(), [], [], None),
(set(), ["a"], [], None),
(set("a"), [], [], ValueError), # unique word not in src_tokens
(set("a"), ["b"], [], ValueError),
(set("a"), ["A"], [], ValueError), # case sensitive
(set("a"), ["an", "na", " a "], [], ValueError), # substring
(set("a"), ["a"], [("a", 0)], None), # valid input
(set("a"), ["c", "b", "a"], [("a", 2)], None), # multiple src_tokens
(
set("ab"),
["c", "b", "a"],
[("b", 1), ("a", 2)],
None,
), # multiple matches ordered by index
],
)
def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception):
if raised_exception:
with pytest.raises(raised_exception):
@ -46,85 +59,150 @@ def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception
output = anchor.get_word_map(unique_words, src_tokens)
assert desired_output == output
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_output", [
([], [], ([], [])), # empty
([""], [""], ([], [])),
(["a"], ["b"], ([], [])), # no common unique words
(["a", "a"], ["a"], ([], [])), # no unique words
(["a"], ["a", "a"], ([], [])),
(["a"], ["a"], ([("a", 0)], [("a", 0)])), # common unique word exist
(["a"], ["b", "a"], ([("a", 0)], [("a", 1)])),
(["a", "b", "c"], ["a", "b", "c"], # common unique words
([("a", 0), ("b", 1), ("c", 2)], [("a", 0), ("b", 1), ("c", 2)])),
(["a", "b", "c"], ["c", "b", "a"], # common unique words but not in same order
([("b", 1)], [("b", 1)])),
(["b", "a", "c"], ["c", "b", "a"], # LCS has multiple results
([("b", 0), ("a", 1)], [("b", 1), ("a", 2)])),
(["c", "a", "b"], ["c", "b", "a"],
([("c", 0), ("b", 2)], [("c", 0), ("b", 1)])),
(["c", "a", "b"], ["a", "c", "b"], # LCS has multiple results
([("a", 1), ("b", 2)], [("a", 0), ("b", 2)])),
])
@pytest.mark.parametrize(
"gt_tokens, ocr_tokens, desired_output",
[
([], [], ([], [])), # empty
([""], [""], ([], [])),
(["a"], ["b"], ([], [])), # no common unique words
(["a", "a"], ["a"], ([], [])), # no unique words
(["a"], ["a", "a"], ([], [])),
(["a"], ["a"], ([("a", 0)], [("a", 0)])), # common unique word exist
(["a"], ["b", "a"], ([("a", 0)], [("a", 1)])),
(
["a", "b", "c"],
["a", "b", "c"], # common unique words
([("a", 0), ("b", 1), ("c", 2)], [("a", 0), ("b", 1), ("c", 2)]),
),
(
["a", "b", "c"],
["c", "b", "a"], # common unique words but not in same order
([("b", 1)], [("b", 1)]),
),
(
["b", "a", "c"],
["c", "b", "a"], # LCS has multiple results
([("b", 0), ("a", 1)], [("b", 1), ("a", 2)]),
),
(
["c", "a", "b"],
["c", "b", "a"],
([("c", 0), ("b", 2)], [("c", 0), ("b", 1)]),
),
(
["c", "a", "b"],
["a", "c", "b"], # LCS has multiple results
([("a", 1), ("b", 2)], [("a", 0), ("b", 2)]),
),
],
)
def test_get_anchor_map(gt_tokens, ocr_tokens, desired_output):
desired_gt_map, desired_ocr_map = desired_output
gt_map, ocr_map = anchor.get_anchor_map(gt_tokens, ocr_tokens)
assert desired_gt_map == gt_map
assert desired_ocr_map == ocr_map
# max_seg_length does not change the following output
@pytest.mark.parametrize("max_seg_length", [0, 1, 2, 3, 5, 4, 6])
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_output", [
([], [], ([], [])), # empty
([""], [""], ([], [])),
(["a"], ["b"], ([], [])), # no anchors
(["a", "a"], ["a"], ([], [])),
(["a"], ["a", "a"], ([], [])),
(["a"], ["a"], ([0], [0])), # anchors exist
("a1 w w w".split(), "a1 w w w".split(), # no anchors in the subsequence [w w w]
([0], [0])),
("a1 w w w a2".split(), "a1 w w w a2".split(),
([0, 4], [0, 4])),
("a1 w w w2 a2".split(), "a1 w w w3 a2".split(),
([0, 4], [0, 4])),
("a1 a2 a3".split(), "a1 a2 a3".split(), # all words are anchors
([0, 1, 2], [0, 1, 2])),
("a1 a2 a3".split(), "A1 A2 A3".split(), # anchor words must be in the same casing
([], [])),
("a1 w w a2".split(), "a1 w W a2".split(), # unique words are case insensitive
([0, 3], [0, 3])),
("a1 w w a2".split(), "A1 w W A2".split(), # unique words are case insensitive, but anchor are case sensitive
([], [])),
])
def test_find_anchor_recur_various_seg_len(max_seg_length, gt_tokens, ocr_tokens, desired_output):
@pytest.mark.parametrize(
"gt_tokens, ocr_tokens, desired_output",
[
([], [], ([], [])), # empty
([""], [""], ([], [])),
(["a"], ["b"], ([], [])), # no anchors
(["a", "a"], ["a"], ([], [])),
(["a"], ["a", "a"], ([], [])),
(["a"], ["a"], ([0], [0])), # anchors exist
(
"a1 w w w".split(),
"a1 w w w".split(), # no anchors in the subsequence [w w w]
([0], [0]),
),
("a1 w w w a2".split(), "a1 w w w a2".split(), ([0, 4], [0, 4])),
("a1 w w w2 a2".split(), "a1 w w w3 a2".split(), ([0, 4], [0, 4])),
(
"a1 a2 a3".split(),
"a1 a2 a3".split(), # all words are anchors
([0, 1, 2], [0, 1, 2]),
),
(
"a1 a2 a3".split(),
"A1 A2 A3".split(), # anchor words must be in the same casing
([], []),
),
(
"a1 w w a2".split(),
"a1 w W a2".split(), # unique words are case insensitive
([0, 3], [0, 3]),
),
(
"a1 w w a2".split(),
"A1 w W A2".split(), # unique words are case insensitive, but anchor are case sensitive
([], []),
),
],
)
def test_find_anchor_recur_various_seg_len(
max_seg_length, gt_tokens, ocr_tokens, desired_output
):
desired_gt_anchors, desired_ocr_anchors = desired_output
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
)
assert desired_gt_anchors == gt_anchors
assert desired_ocr_anchors == ocr_anchors
# Test the recursion bahavior
@pytest.mark.parametrize("gt_tokens, ocr_tokens, max_seg_length, desired_output", [
("a1 w_ w_ a3".split(), "a1 w_ w_ a3".split(), 6,
([0, 3], [0, 3])),
("a1 w_ w_ a2 a3 a2".split(), "a1 w_ w_ a2 a3 a2".split(), 4, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
([0, 3, 4], [0, 3, 4])),
("a1 w_ w_ a2 a3 a2".split(), "a1 w_ w_ a2 a3 a2".split(), 2, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
([0, 2, 3, 4, 5], [0, 2, 3, 4, 5])),
("a1 w_ w_ a2 w_ w_ a3".split(), "a1 w_ a2 w_ a3".split(), 2, # missing ocr token
([0, 3, 6], [0, 2, 4])),
("a1 w_ w_ a2 w_ w_ a3".split(), "a1 w_ a2 W_ A3".split(), 2, # changing cases
([0, 3], [0, 2])),
])
def test_find_anchor_recur_fixed_seg_len(gt_tokens, ocr_tokens, max_seg_length, desired_output):
@pytest.mark.parametrize(
"gt_tokens, ocr_tokens, max_seg_length, desired_output",
[
("a1 w_ w_ a3".split(), "a1 w_ w_ a3".split(), 6, ([0, 3], [0, 3])),
(
"a1 w_ w_ a2 a3 a2".split(),
"a1 w_ w_ a2 a3 a2".split(),
4, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
([0, 3, 4], [0, 3, 4]),
),
(
"a1 w_ w_ a2 a3 a2".split(),
"a1 w_ w_ a2 a3 a2".split(),
2, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
([0, 2, 3, 4, 5], [0, 2, 3, 4, 5]),
),
(
"a1 w_ w_ a2 w_ w_ a3".split(),
"a1 w_ a2 w_ a3".split(),
2, # missing ocr token
([0, 3, 6], [0, 2, 4]),
),
(
"a1 w_ w_ a2 w_ w_ a3".split(),
"a1 w_ a2 W_ A3".split(),
2, # changing cases
([0, 3], [0, 2]),
),
],
)
def test_find_anchor_recur_fixed_seg_len(
gt_tokens, ocr_tokens, max_seg_length, desired_output
):
desired_gt_anchors, desired_ocr_anchors = desired_output
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
)
assert desired_gt_anchors == gt_anchors
assert desired_ocr_anchors == ocr_anchors
@pytest.mark.parametrize("gt_file, ocr_file",
@pytest.mark.parametrize(
"gt_file, ocr_file",
zip(
sorted(glob.glob("tests/text/data/gt_1.txt")),
sorted(glob.glob("tests/text/data/ocr_1.txt"))
)
sorted(glob.glob("tests/text/data/ocr_1.txt")),
),
)
@pytest.mark.parametrize("max_seg_length", [75])
def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
@ -132,15 +210,27 @@ def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
ocr_text = open(ocr_file, "r").read()
gt_tokens = preprocess.tokenize(gt_text)
ocr_tokens = preprocess.tokenize(ocr_text)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
)
for gt_anchor, ocr_anchor in zip(gt_anchors, ocr_anchors):
# Ensure that each anchor word is the same word in both text
assert gt_tokens[gt_anchor] == ocr_tokens[ocr_anchor]
@pytest.mark.parametrize("gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", ALIGNMENT_REGRESSION_TEST_CASES)
@pytest.mark.parametrize(
"gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise",
ALIGNMENT_REGRESSION_TEST_CASES,
)
def test_align_w_anchor(gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
aligned_gt, aligned_noise = anchor.align_w_anchor(gt_txt, noisy_txt)
if aligned_gt != expected_aligned_gt:
expected_alignment = alignment._format_alignment(expected_aligned_gt, expected_aligned_noise)
expected_alignment = alignment._format_alignment(
expected_aligned_gt, expected_aligned_noise
)
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
)
)

Просмотреть файл

@ -1,157 +1,286 @@
from genalog.text import conll_format
import itertools
import warnings
from unittest.mock import patch
import itertools
import pytest
import warnings
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception", [
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], None),
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], [], ValueError), # No alignment
(["w1", "w3"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], ValueError), # Unequal tokens
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w3"]], ["w", "w"], ValueError), # Unequal tokens
(["w1", "w3"], ["l1", "l2"], [["w1"]],["w", "w"], ValueError), # Unequal length
(["w1"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], ValueError), # Unequal length
])
def test_propagate_labels_sentences_error(clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception):
from genalog.text import conll_format
@pytest.mark.parametrize(
"clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception",
[
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], None),
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], [], ValueError), # No alignment
(
["w1", "w3"],
["l1", "l2"],
[["w1"], ["w2"]],
["w", "w"],
ValueError,
), # Unequal tokens
(
["w1", "w2"],
["l1", "l2"],
[["w1"], ["w3"]],
["w", "w"],
ValueError,
), # Unequal tokens
(
["w1", "w3"],
["l1", "l2"],
[["w1"]],
["w", "w"],
ValueError,
), # Unequal length
(
["w1"],
["l1", "l2"],
[["w1"], ["w2"]],
["w", "w"],
ValueError,
), # Unequal length
],
)
def test_propagate_labels_sentences_error(
clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception
):
if raised_exception:
with pytest.raises(raised_exception):
conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
conll_format.propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
else:
conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
conll_format.propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels", [
(
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]], # clean sentences
["a1", "b1", "a2", "b2"], # ocr token
[["a1", "b1"], ["a2","b2"]], [["l1", "l2"], ["l3", "l4"]] # desired output
),
(
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]], # clean sentences
["a1", "b1"], # Missing sentence 2
# Ideally we would expect [["a1", "b1"], []]
# But the limitation of text alignment, which yield
# "a1 b1 a2 b2"
# "a1 b1@@@@@@"
# It is difficult to decide the location of "b1"
# when all tokens "b1" "a2" "b2" are aligned to "b1@@@@@@"
# NOTE: this is a improper behavior but the best
# solution to this corner case by preserving the number of OCR tokens.
[["a1"], ["b1"]], [["l1"], ["l2"]]
),
(
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]],
["a", "a2", "b2"], # ocr token (missing b1 token at sentence boundary)
[["a"], ["a2", "b2"]], [["l1"], ["l3", "l4"]]
),
(
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]],
["a1", "b1", "a2"], # ocr token (missing b2 token at sentence boundary)
[["a1", "b1"], ["a2"]], [["l1", "l2"], ["l3"]]
),
(
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]],
["b1", "a2", "b2"], # ocr token (missing a1 token at sentence start)
[["b1", "a2"], ["b2"]], [["l2", "l3"], ["l4"]]
),
(
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
[["a1"], ["b1", "c1", "a2"], ["b2"]],
["a1", "b1", "a2", "b2"], # ocr token (missing c1 token at middle of sentence)
[["a1"], ["b1", "a2"], ["b2"]], [["l1"], ["l2", "l4"], ["l5"]]
),
@pytest.mark.parametrize(
"clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels",
[
(
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
[["a1", "b1"], ["c1", "a2", "b2"]],
["a1", "b1", "b2"], # ocr token (missing c1 a2 tokens)
[["a1"], ["b1", "b2"]], [["l1"], ["l2", "l5"]]
),
(
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
[["a1"], ["b1", "c1", "a2"], ["b2"]],
["a1", "c1", "a2", "b2"], # ocr token (missing b1 token at sentence start)
[[], ["a1", "c1", "a2"], ["b2"]], [[], ["l1", "l3", "l4"], ["l5"]]
),
(
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
[["a1", "b1", "c1"], ["a2","b2"]],
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
[["a1"], [ "b1", "b2"]], [["l1"], ["l2", "l5"]]
),
(
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
[["a1", "b1", "c1"], ["a2","b2"]],
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
[["a1"], [ "b1", "b2"]], [["l1"], ["l2", "l5"]]
),
])
def test_propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels):
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
"a1 b1 a2 b2".split(),
"l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]], # clean sentences
["a1", "b1", "a2", "b2"], # ocr token
[["a1", "b1"], ["a2", "b2"]],
[["l1", "l2"], ["l3", "l4"]], # desired output
),
(
"a1 b1 a2 b2".split(),
"l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]], # clean sentences
["a1", "b1"], # Missing sentence 2
# Ideally we would expect [["a1", "b1"], []]
# But the limitation of text alignment, which yield
# "a1 b1 a2 b2"
# "a1 b1@@@@@@"
# It is difficult to decide the location of "b1"
# when all tokens "b1" "a2" "b2" are aligned to "b1@@@@@@"
# NOTE: this is a improper behavior but the best
# solution to this corner case by preserving the number of OCR tokens.
[["a1"], ["b1"]],
[["l1"], ["l2"]],
),
(
"a1 b1 a2 b2".split(),
"l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]],
["a", "a2", "b2"], # ocr token (missing b1 token at sentence boundary)
[["a"], ["a2", "b2"]],
[["l1"], ["l3", "l4"]],
),
(
"a1 b1 a2 b2".split(),
"l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]],
["a1", "b1", "a2"], # ocr token (missing b2 token at sentence boundary)
[["a1", "b1"], ["a2"]],
[["l1", "l2"], ["l3"]],
),
(
"a1 b1 a2 b2".split(),
"l1 l2 l3 l4".split(),
[["a1", "b1"], ["a2", "b2"]],
["b1", "a2", "b2"], # ocr token (missing a1 token at sentence start)
[["b1", "a2"], ["b2"]],
[["l2", "l3"], ["l4"]],
),
(
"a1 b1 c1 a2 b2".split(),
"l1 l2 l3 l4 l5".split(),
[["a1"], ["b1", "c1", "a2"], ["b2"]],
[
"a1",
"b1",
"a2",
"b2",
], # ocr token (missing c1 token at middle of sentence)
[["a1"], ["b1", "a2"], ["b2"]],
[["l1"], ["l2", "l4"], ["l5"]],
),
(
"a1 b1 c1 a2 b2".split(),
"l1 l2 l3 l4 l5".split(),
[["a1", "b1"], ["c1", "a2", "b2"]],
["a1", "b1", "b2"], # ocr token (missing c1 a2 tokens)
[["a1"], ["b1", "b2"]],
[["l1"], ["l2", "l5"]],
),
(
"a1 b1 c1 a2 b2".split(),
"l1 l2 l3 l4 l5".split(),
[["a1"], ["b1", "c1", "a2"], ["b2"]],
["a1", "c1", "a2", "b2"], # ocr token (missing b1 token at sentence start)
[[], ["a1", "c1", "a2"], ["b2"]],
[[], ["l1", "l3", "l4"], ["l5"]],
),
(
"a1 b1 c1 a2 b2".split(),
"l1 l2 l3 l4 l5".split(),
[["a1", "b1", "c1"], ["a2", "b2"]],
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
[["a1"], ["b1", "b2"]],
[["l1"], ["l2", "l5"]],
),
(
"a1 b1 c1 a2 b2".split(),
"l1 l2 l3 l4 l5".split(),
[["a1", "b1", "c1"], ["a2", "b2"]],
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
[["a1"], ["b1", "b2"]],
[["l1"], ["l2", "l5"]],
),
],
)
def test_propagate_labels_sentences(
clean_tokens,
clean_labels,
clean_sentences,
ocr_tokens,
desired_sentences,
desired_labels,
):
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
assert len(ocr_text_sentences) == len(clean_sentences)
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
assert len(ocr_sentences_flatten) == len(
ocr_tokens
) # ensure aligned ocr tokens == ocr tokens
if desired_sentences != ocr_text_sentences:
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"
)
)
if desired_labels != ocr_labels_sentences:
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"
)
)
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens," +
"mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels", [
(
"a b c d".split(), "l1 l2 l3 l4".split(),
[["a", "b"], ["c", "d"]],
["a", "b"], # Sentence is empty
[[0], [1], [], []],
[[0], [1]],
[["a", "b"], []],
[["l1", "l2"], []]
),
(
"a b c d".split(), "l1 l2 l3 l4".split(),
[["a", "b",], ["c", "d"]],
["a", "b", "d"], # Missing sentence start
[[0], [1], [], [2]],
[[0], [1], [3]],
[["a", "b"], ["d"]],
[["l1", "l2"], ["l4"]]
),
(
"a b c d".split(), "l1 l2 l3 l4".split(),
[["a", "b",], ["c", "d"]],
["a", "c", "d"], # Missing sentence end
[[0], [], [1], [2]],
[[0], [2], [3]],
[["a"], ["c", "d"]],
[["l1"], ["l3", "l4"]]
),
])
def test_propagate_labels_sentences_text_alignment_corner_cases(clean_tokens, clean_labels, clean_sentences, ocr_tokens,
mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels):
@pytest.mark.parametrize(
"clean_tokens, clean_labels, clean_sentences, ocr_tokens,"
+ "mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels",
[
(
"a b c d".split(),
"l1 l2 l3 l4".split(),
[["a", "b"], ["c", "d"]],
["a", "b"], # Sentence is empty
[[0], [1], [], []],
[[0], [1]],
[["a", "b"], []],
[["l1", "l2"], []],
),
(
"a b c d".split(),
"l1 l2 l3 l4".split(),
[
[
"a",
"b",
],
["c", "d"],
],
["a", "b", "d"], # Missing sentence start
[[0], [1], [], [2]],
[[0], [1], [3]],
[["a", "b"], ["d"]],
[["l1", "l2"], ["l4"]],
),
(
"a b c d".split(),
"l1 l2 l3 l4".split(),
[
[
"a",
"b",
],
["c", "d"],
],
["a", "c", "d"], # Missing sentence end
[[0], [], [1], [2]],
[[0], [2], [3]],
[["a"], ["c", "d"]],
[["l1"], ["l3", "l4"]],
),
],
)
def test_propagate_labels_sentences_text_alignment_corner_cases(
clean_tokens,
clean_labels,
clean_sentences,
ocr_tokens,
mock_gt_to_ocr_mapping,
mock_ocr_to_gt_mapping,
desired_sentences,
desired_labels,
):
with patch("genalog.text.alignment.parse_alignment") as mock_alignment:
mock_alignment.return_value = (mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping)
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
(
ocr_text_sentences,
ocr_labels_sentences,
) = conll_format.propagate_labels_sentences(
clean_tokens, clean_labels, clean_sentences, ocr_tokens
)
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
assert len(ocr_text_sentences) == len(clean_sentences)
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
assert len(ocr_sentences_flatten) == len(
ocr_tokens
) # ensure aligned ocr tokens == ocr tokens
if desired_sentences != ocr_text_sentences:
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"
)
)
if desired_labels != ocr_labels_sentences:
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"
)
)
@pytest.mark.parametrize("s, desired_output", [
("", []),
("\n\n", []),
("a1\tb1\na2\tb2", [["a1", "a2"]]),
("a1\tb1\n\na2\tb2", [["a1"], ["a2"]]),
("\n\n\na1\tb1\n\na2\tb2\n\n\n", [["a1"], ["a2"]]),
])
@pytest.mark.parametrize(
"s, desired_output",
[
("", []),
("\n\n", []),
("a1\tb1\na2\tb2", [["a1", "a2"]]),
("a1\tb1\n\na2\tb2", [["a1"], ["a2"]]),
("\n\n\na1\tb1\n\na2\tb2\n\n\n", [["a1"], ["a2"]]),
],
)
def test_get_sentences_from_iob_format(s, desired_output):
output = conll_format.get_sentences_from_iob_format(s.splitlines(True))
assert desired_output == output
assert desired_output == output

Просмотреть файл

@ -1,37 +1,54 @@
from genalog.text.lcs import LCS
import pytest
@pytest.fixture(params=[
("", ""), # empty
("abcde", "ace"), # naive case
])
from genalog.text.lcs import LCS
@pytest.fixture(
params=[
("", ""), # empty
("abcde", "ace"), # naive case
]
)
def lcs(request):
str1, str2 = request.param
return LCS(str1, str2)
def test_lcs_init(lcs):
assert lcs._lcs_len is not None
assert lcs._lcs is not None
@pytest.mark.parametrize("str1, str2, expected_len, expected_lcs", [
("", "", 0, ""), # empty
("abc", "abc", 3, "abc"),
("abcde", "ace", 3, "ace"), # naive case
("a", "", 0, ""), # no results
("abc", "cba", 1, "c"), # multiple cases
("abcdgh", "aedfhr", 3, "adh"),
("abc.!\t\nd", "dxab", 2, "ab"), # with punctuations
("New York @", "New @ York", len("New York"), "New York"), # with space-separated, tokens
("Is A Big City", "A Big City Is", len("A Big City"), "A Big City"),
("Is A Big City", "City Big Is A", len(" Big "), " Big "), # reversed order
# mixed order with similar tokens
("Is A Big City IS", "IS Big A City Is", len("I Big City I"), "I Big City I"),
# casing
("Is A Big City IS a", "IS a Big City Is A", len("I Big City I "), "I Big City I "),
])
@pytest.mark.parametrize(
"str1, str2, expected_len, expected_lcs",
[
("", "", 0, ""), # empty
("abc", "abc", 3, "abc"),
("abcde", "ace", 3, "ace"), # naive case
("a", "", 0, ""), # no results
("abc", "cba", 1, "c"), # multiple cases
("abcdgh", "aedfhr", 3, "adh"),
("abc.!\t\nd", "dxab", 2, "ab"), # with punctuations
(
"New York @",
"New @ York",
len("New York"),
"New York",
), # with space-separated, tokens
("Is A Big City", "A Big City Is", len("A Big City"), "A Big City"),
("Is A Big City", "City Big Is A", len(" Big "), " Big "), # reversed order
# mixed order with similar tokens
("Is A Big City IS", "IS Big A City Is", len("I Big City I"), "I Big City I"),
# casing
(
"Is A Big City IS a",
"IS a Big City Is A",
len("I Big City I "),
"I Big City I ",
),
],
)
def test_lcs_e2e(str1, str2, expected_len, expected_lcs):
lcs = LCS(str1, str2)
assert expected_lcs == lcs.get_str()
assert expected_len == lcs.get_len()

Просмотреть файл

@ -1,171 +1,295 @@
import pytest
from genalog.text import ner_label
from genalog.text import alignment
from tests.cases.label_propagation import LABEL_PROPAGATION_REGRESSION_TEST_CASES
import pytest
import string
@pytest.mark.parametrize("label, desired_output", [
# Positive Cases
("B-org", True), (" B-org ", True), #whitespae tolerant
("\tB-ORG\n", True),
# Negative Cases
("I-ORG", False), ("O", False), ("other-B-label", False),
])
@pytest.mark.parametrize(
"label, desired_output",
[
# Positive Cases
("B-org", True),
(" B-org ", True), # whitespae tolerant
("\tB-ORG\n", True),
# Negative Cases
("I-ORG", False),
("O", False),
("other-B-label", False),
],
)
def test__is_begin_label(label, desired_output):
output = ner_label._is_begin_label(label)
assert output == desired_output
@pytest.mark.parametrize("label, desired_output", [
# Positive Cases
("I-ORG", True), (" \t I-ORG ", True),
# Negative Cases
("O", False), ("B-LOC", False),("B-ORG", False),
])
@pytest.mark.parametrize(
"label, desired_output",
[
# Positive Cases
("I-ORG", True),
(" \t I-ORG ", True),
# Negative Cases
("O", False),
("B-LOC", False),
("B-ORG", False),
],
)
def test__is_inside_label(label, desired_output):
output = ner_label._is_inside_label(label)
assert output == desired_output
@pytest.mark.parametrize("label, desired_output", [
# Positive Cases
("I-ORG", True), ("B-ORG", True),
# Negative Cases
("O", False)
])
@pytest.mark.parametrize(
"label, desired_output",
[
# Positive Cases
("I-ORG", True),
("B-ORG", True),
# Negative Cases
("O", False),
],
)
def test__is_multi_token_label(label, desired_output):
output = ner_label._is_multi_token_label(label)
assert output == desired_output
@pytest.mark.parametrize("label, desired_output", [
# Positive Cases
("I-Place", "B-Place"), (" \t I-place ", "B-place"),
# Negative Cases
("O", "O"), ("B-LOC", "B-LOC"), (" B-ORG ", " B-ORG ")
])
@pytest.mark.parametrize(
"label, desired_output",
[
# Positive Cases
("I-Place", "B-Place"),
(" \t I-place ", "B-place"),
# Negative Cases
("O", "O"),
("B-LOC", "B-LOC"),
(" B-ORG ", " B-ORG "),
],
)
def test__convert_to_begin_label(label, desired_output):
output = ner_label._convert_to_begin_label(label)
assert output == desired_output
@pytest.mark.parametrize("label, desired_output", [
# Positive Cases
("B-LOC", "I-LOC"),
(" B-ORG ", "I-ORG"),
# Negative Cases
("", ""), ("O", "O"), ("I-Place", "I-Place"),
(" \t I-place ", " \t I-place ")
])
@pytest.mark.parametrize(
"label, desired_output",
[
# Positive Cases
("B-LOC", "I-LOC"),
(" B-ORG ", "I-ORG"),
# Negative Cases
("", ""),
("O", "O"),
("I-Place", "I-Place"),
(" \t I-place ", " \t I-place "),
],
)
def test__convert_to_inside_label(label, desired_output):
output = ner_label._convert_to_inside_label(label)
assert output == desired_output
@pytest.mark.parametrize("begin_label, inside_label, desired_output", [
# Positive Cases
("", "I-LOC", True),
("B-LOC", "I-ORG", True),
("", "I-ORG", True),
# Negative Cases
("", "", False), ("O", "O", False), ("", "", False),
("B-LOC", "O", False),
("B-LOC", "B-ORG", False),
("B-LOC", "I-LOC", False),
(" B-ORG ", "I-ORG", False),
])
@pytest.mark.parametrize(
"begin_label, inside_label, desired_output",
[
# Positive Cases
("", "I-LOC", True),
("B-LOC", "I-ORG", True),
("", "I-ORG", True),
# Negative Cases
("", "", False),
("O", "O", False),
("", "", False),
("B-LOC", "O", False),
("B-LOC", "B-ORG", False),
("B-LOC", "I-LOC", False),
(" B-ORG ", "I-ORG", False),
],
)
def test__is_missing_begin_label(begin_label, inside_label, desired_output):
output = ner_label._is_missing_begin_label(begin_label, inside_label)
assert output == desired_output
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_input_char_set", [
(["a","b"], ["c", "d"], set("abcd")),
(["New", "York"], ["is", "big"], set("NewYorkisbig")),
(["word1", "word2"], ["word1", "word2"], set("word12")),
])
@pytest.mark.parametrize(
"gt_tokens, ocr_tokens, desired_input_char_set",
[
(["a", "b"], ["c", "d"], set("abcd")),
(["New", "York"], ["is", "big"], set("NewYorkisbig")),
(["word1", "word2"], ["word1", "word2"], set("word12")),
],
)
def test__find_gap_char_candidates(gt_tokens, ocr_tokens, desired_input_char_set):
gap_char_candidates, input_char_set = ner_label._find_gap_char_candidates(gt_tokens, ocr_tokens)
gap_char_candidates, input_char_set = ner_label._find_gap_char_candidates(
gt_tokens, ocr_tokens
)
assert input_char_set == desired_input_char_set
assert ner_label.GAP_CHAR_SET.difference(input_char_set) == gap_char_candidates
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, raised_exception",
[
(["o"], ["New York"], ["NewYork"], ValueError), # non-atomic gt_token
(["o"], ["NewYork"], ["New York"], ValueError), # non-atomic ocr_token
(["o"], [" @ New"], ["@ @"], ValueError), # non-atomic tokens with GAP_CHAR
(["o", "o"], ["New"], ["New"], ValueError), # num gt_labels != num gt_tokens
(["o"], ["@"], ["New"], ner_label.GapCharError), # invalid token with gap char only (gt_token)
(["o"], ["New"], ["@"], ner_label.GapCharError), # invalid token with gap char only (ocr_token)
(["o", "o"], ["New", "@"], ["New", "@"], ner_label.GapCharError), # invalid token (both)
(["o"], [" \n\t@@"], ["New"], ner_label.GapCharError), # invalid token with gap char and space chars (gt_token)
(["o"], ["New"], [" \n\t@"], ner_label.GapCharError), # invalid token with gap char and space chars (ocr_token)
(["o"], [""], ["New"], ValueError), # invalid token: empty string (gt_token)
(["o"], ["New"], [""], ValueError), # invalid token: empty string (ocr_token)
(["o"], [" \n\t"], ["New"], ValueError), # invalid token: space characters only (gt_token)
(["o"], ["New"], [" \n\t"], ValueError), # invalid token: space characters only (ocr_token)
(["o"], ["New"], ["New"], None), # positive case
(["o"], ["New@"], ["New"], None), # positive case with gap char
(["o"], ["New"], ["@@New"], None), # positive case with gap char
])
def test__propagate_label_to_ocr_error(gt_labels, gt_tokens, ocr_tokens, raised_exception):
@pytest.mark.parametrize(
"gt_labels, gt_tokens, ocr_tokens, raised_exception",
[
(["o"], ["New York"], ["NewYork"], ValueError), # non-atomic gt_token
(["o"], ["NewYork"], ["New York"], ValueError), # non-atomic ocr_token
(["o"], [" @ New"], ["@ @"], ValueError), # non-atomic tokens with GAP_CHAR
(["o", "o"], ["New"], ["New"], ValueError), # num gt_labels != num gt_tokens
(
["o"],
["@"],
["New"],
ner_label.GapCharError,
), # invalid token with gap char only (gt_token)
(
["o"],
["New"],
["@"],
ner_label.GapCharError,
), # invalid token with gap char only (ocr_token)
(
["o", "o"],
["New", "@"],
["New", "@"],
ner_label.GapCharError,
), # invalid token (both)
(
["o"],
[" \n\t@@"],
["New"],
ner_label.GapCharError,
), # invalid token with gap char and space chars (gt_token)
(
["o"],
["New"],
[" \n\t@"],
ner_label.GapCharError,
), # invalid token with gap char and space chars (ocr_token)
(["o"], [""], ["New"], ValueError), # invalid token: empty string (gt_token)
(["o"], ["New"], [""], ValueError), # invalid token: empty string (ocr_token)
(
["o"],
[" \n\t"],
["New"],
ValueError,
), # invalid token: space characters only (gt_token)
(
["o"],
["New"],
[" \n\t"],
ValueError,
), # invalid token: space characters only (ocr_token)
(["o"], ["New"], ["New"], None), # positive case
(["o"], ["New@"], ["New"], None), # positive case with gap char
(["o"], ["New"], ["@@New"], None), # positive case with gap char
],
)
def test__propagate_label_to_ocr_error(
gt_labels, gt_tokens, ocr_tokens, raised_exception
):
if raised_exception:
with pytest.raises(raised_exception):
ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char='@')
ner_label._propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens, gap_char="@"
)
else:
ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char='@')
ner_label._propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens, gap_char="@"
)
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
@pytest.mark.parametrize(
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
)
def test__propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
gap_char_candidates, _ = ner_label._find_gap_char_candidates(gt_tokens, ocr_tokens)
# run regression test for each GAP_CHAR candidate to make sure
# run regression test for each GAP_CHAR candidate to make sure
# label propagate is function correctly
for gap_char in gap_char_candidates:
ocr_labels, _, _, _ = ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char)
for gap_char in gap_char_candidates:
ocr_labels, _, _, _ = ner_label._propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char
)
assert ocr_labels == desired_ocr_labels
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, raised_exception", [
(["o"], ["New"], ["New"], None), # positive case
(["o"], ["New@"], ["New"], None), # positive case with gap char
(["o"], ["New"], ["@@New"], None), # positive case with gap char
(["o"], list(ner_label.GAP_CHAR_SET), [""], ner_label.GapCharError), # input char set == GAP_CHAR_SET
(["o"], [""], list(ner_label.GAP_CHAR_SET), ner_label.GapCharError), # input char set == GAP_CHAR_SET
# all possible gap chars set split between ocr and gt tokens
(["o"], list(ner_label.GAP_CHAR_SET)[:10], list(ner_label.GAP_CHAR_SET)[10:], ner_label.GapCharError),
])
def test_propagate_label_to_ocr_error(gt_labels, gt_tokens, ocr_tokens, raised_exception):
@pytest.mark.parametrize(
"gt_labels, gt_tokens, ocr_tokens, raised_exception",
[
(["o"], ["New"], ["New"], None), # positive case
(["o"], ["New@"], ["New"], None), # positive case with gap char
(["o"], ["New"], ["@@New"], None), # positive case with gap char
(
["o"],
list(ner_label.GAP_CHAR_SET),
[""],
ner_label.GapCharError,
), # input char set == GAP_CHAR_SET
(
["o"],
[""],
list(ner_label.GAP_CHAR_SET),
ner_label.GapCharError,
), # input char set == GAP_CHAR_SET
# all possible gap chars set split between ocr and gt tokens
(
["o"],
list(ner_label.GAP_CHAR_SET)[:10],
list(ner_label.GAP_CHAR_SET)[10:],
ner_label.GapCharError,
),
],
)
def test_propagate_label_to_ocr_error(
gt_labels, gt_tokens, ocr_tokens, raised_exception
):
if raised_exception:
with pytest.raises(raised_exception):
ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
else:
ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
@pytest.mark.parametrize(
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
)
def test_propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
ocr_labels, _, _, _ = ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
ocr_labels, _, _, _ = ner_label.propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens
)
assert ocr_labels == desired_ocr_labels
@pytest.mark.parametrize("tokens, labels, label_top, desired_output",
[
(
["New", "York", "is", "big"],
["B-place", "I-place", "o", "o"],
True,
"B-place I-place o o \n" +
"New York is big \n"
),
(
["New", "York", "is", "big"],
["B-place", "I-place", "o", "o"],
False,
"New York is big \n" +
"B-place I-place o o \n"
)
])
@pytest.mark.parametrize(
"tokens, labels, label_top, desired_output",
[
(
["New", "York", "is", "big"],
["B-place", "I-place", "o", "o"],
True,
"B-place I-place o o \n" + "New York is big \n",
),
(
["New", "York", "is", "big"],
["B-place", "I-place", "o", "o"],
False,
"New York is big \n" + "B-place I-place o o \n",
),
],
)
def test_format_label(tokens, labels, label_top, desired_output):
output = ner_label.format_labels(tokens, labels, label_top=label_top)
assert output == desired_output
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
@pytest.mark.parametrize(
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
)
def test_format_gt_ocr_w_labels(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
ocr_labels, aligned_gt, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
ner_label.format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, aligned_gt, aligned_ocr)
ocr_labels, aligned_gt, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(
gt_labels, gt_tokens, ocr_tokens
)
ner_label.format_label_propagation(
gt_tokens, gt_labels, ocr_tokens, ocr_labels, aligned_gt, aligned_ocr
)

Просмотреть файл

@ -1,124 +1,148 @@
from genalog.text import preprocess
from genalog.text.alignment import GAP_CHAR
import pytest
@pytest.mark.parametrize("token, replacement, desired_output",
[
("", "_", ""), # Do nothing to empty string
(" ", "_", " "), # Do nothing to whitespaces
(" \n\t", "_", " \n\t"),
("ascii", "_", "ascii"),
("a s\nc\tii", "_", "a s\nc\tii"),
("ascii·", "_", "ascii"), # Tokens with non-ASCII values
("·", "_", "_"), # Tokens with non-ASCII values
])
from genalog.text import preprocess
from genalog.text.alignment import GAP_CHAR
@pytest.mark.parametrize(
"token, replacement, desired_output",
[
("", "_", ""), # Do nothing to empty string
(" ", "_", " "), # Do nothing to whitespaces
(" \n\t", "_", " \n\t"),
("ascii", "_", "ascii"),
("a s\nc\tii", "_", "a s\nc\tii"),
("ascii·", "_", "ascii"), # Tokens with non-ASCII values
("·", "_", "_"), # Tokens with non-ASCII values
],
)
def test_remove_non_ascii(token, replacement, desired_output):
for code in range(128, 1000): # non-ASCII values
for code in range(128, 1000): # non-ASCII values
token.replace("·", chr(code))
output = preprocess.remove_non_ascii(token, replacement)
assert output == desired_output
@pytest.mark.parametrize("s, desired_output",
[
(
" New \t \n",
["New"]
),
@pytest.mark.parametrize(
"s, desired_output",
[
(" New \t \n", ["New"]),
# Mixed in gap char "@"
(
" @ @",
["@", "@"]
),
(
"New York is big",
["New", "York", "is", "big"]
),
(" @ @", ["@", "@"]),
("New York is big", ["New", "York", "is", "big"]),
# Mixed multiple spaces and tabs
(
" New York \t is \t big",
["New", "York", "is", "big"]
),
(" New York \t is \t big", ["New", "York", "is", "big"]),
# Mixed in punctuation
(
"New .York is, big !",
["New", ".York", "is,", "big", "!"]
),
("New .York is, big !", ["New", ".York", "is,", "big", "!"]),
# Mixed in gap char "@"
(
"@N@ew York@@@is,\t big@@@@@",
["@N@ew", "York@@@is,", "big@@@@@"]
)
])
("@N@ew York@@@is,\t big@@@@@", ["@N@ew", "York@@@is,", "big@@@@@"]),
],
)
def test_tokenize(s, desired_output):
output = preprocess.tokenize(s)
assert output == desired_output
@pytest.mark.parametrize("tokens, desired_output",
[
(
["New", "York", "is", "big"],
"New York is big",
),
@pytest.mark.parametrize(
"tokens, desired_output",
[
(
["New", "York", "is", "big"],
"New York is big",
),
# Mixed in punctuation
(
["New", ".York", "is,", "big", "!"],
"New .York is, big !",
),
(
["New", ".York", "is,", "big", "!"],
"New .York is, big !",
),
# Mixed in gap char "@"
(
["@N@ew", "York@@@is,", "big@@@@@"],
"@N@ew York@@@is, big@@@@@",
)
])
(
["@N@ew", "York@@@is,", "big@@@@@"],
"@N@ew York@@@is, big@@@@@",
),
],
)
def test_join_tokens(tokens, desired_output):
output = preprocess.join_tokens(tokens)
assert output == desired_output
@pytest.mark.parametrize("c, desired_output",
[
# Gap char
(GAP_CHAR, False),
# Alphabet char
('a', False), ('A', False),
# Punctuation
('.', False), ('!', False), (',', False), ('-', False),
# Token separators
(' ', True), ('\n', True), ('\t', True)
])
@pytest.mark.parametrize(
"c, desired_output",
[
# Gap char
(GAP_CHAR, False),
# Alphabet char
("a", False),
("A", False),
# Punctuation
(".", False),
("!", False),
(",", False),
("-", False),
# Token separators
(" ", True),
("\n", True),
("\t", True),
],
)
def test__is_spacing(c, desired_output):
assert desired_output == preprocess._is_spacing(c)
@pytest.mark.parametrize("text, desired_output", [
("", ""),
("w .", "w ."), ("w !", "w !"), ("w ?", "w ?"),
("w /.", "w /."), ("w /!", "w /!"), ("w /?", "w /?"),
("w1 , w2 .", "w1 , w2 ."),
("w1 . w2 .", "w1 . \nw2 ."), ("w1 /. w2 /.", "w1 /. \nw2 /."),
("w1 ! w2 .", "w1 ! \nw2 ."), ("w1 /! w2 /.", "w1 /! \nw2 /."),
("w1 ? w2 .", "w1 ? \nw2 ."), ("w1 /? w2 /.", "w1 /? \nw2 /."),
("U.S. . w2 .", "U.S. . \nw2 ."),
("w1 ??? w2 .", "w1 ??? w2 ."), # not splitting
("w1 !!! w2 .", "w1 !!! w2 ."),
("w1 ... . w2 .", "w1 ... . \nw2 ."),
("w1 ... /. w2 /.", "w1 ... /. \nw2 /."),
("w1 /. /. w2 .", "w1 /. /. \nw2 ."),
("w1 /. /.", "w1 /. \n/."),
("w1 /. /. ", "w1 /. /. \n"),
("w1 ? ? ? ? w2 .", "w1 ? ? ? ? \nw2 ."),
("w1 /? /? /? /? w2 /.", "w1 /? /? /? /? \nw2 /."),
("w1 ! ! ! ! w2 .", "w1 ! ! ! ! \nw2 ."),
("w1 /! /! /! /! w2 /.", "w1 /! /! /! /! \nw2 /."),
])
@pytest.mark.parametrize(
"text, desired_output",
[
("", ""),
("w .", "w ."),
("w !", "w !"),
("w ?", "w ?"),
("w /.", "w /."),
("w /!", "w /!"),
("w /?", "w /?"),
("w1 , w2 .", "w1 , w2 ."),
("w1 . w2 .", "w1 . \nw2 ."),
("w1 /. w2 /.", "w1 /. \nw2 /."),
("w1 ! w2 .", "w1 ! \nw2 ."),
("w1 /! w2 /.", "w1 /! \nw2 /."),
("w1 ? w2 .", "w1 ? \nw2 ."),
("w1 /? w2 /.", "w1 /? \nw2 /."),
("U.S. . w2 .", "U.S. . \nw2 ."),
("w1 ??? w2 .", "w1 ??? w2 ."), # not splitting
("w1 !!! w2 .", "w1 !!! w2 ."),
("w1 ... . w2 .", "w1 ... . \nw2 ."),
("w1 ... /. w2 /.", "w1 ... /. \nw2 /."),
("w1 /. /. w2 .", "w1 /. /. \nw2 ."),
("w1 /. /.", "w1 /. \n/."),
("w1 /. /. ", "w1 /. /. \n"),
("w1 ? ? ? ? w2 .", "w1 ? ? ? ? \nw2 ."),
("w1 /? /? /? /? w2 /.", "w1 /? /? /? /? \nw2 /."),
("w1 ! ! ! ! w2 .", "w1 ! ! ! ! \nw2 ."),
("w1 /! /! /! /! w2 /.", "w1 /! /! /! /! \nw2 /."),
],
)
def test_split_sentences(text, desired_output):
assert desired_output == preprocess.split_sentences(text)
@pytest.mark.parametrize("token, desired_output", [
("", False), (" ", False), ("\n", False), ("\t", False),
(" \n \t", False),
("...", False),
("???", False), ("!!!", False),
(".", True), ("!", True), ("?", True),
("/.", True), ("/!", True), ("/?", True),
])
@pytest.mark.parametrize(
"token, desired_output",
[
("", False),
(" ", False),
("\n", False),
("\t", False),
(" \n \t", False),
("...", False),
("???", False),
("!!!", False),
(".", True),
("!", True),
("?", True),
("/.", True),
("/!", True),
("/?", True),
],
)
def test_is_sentence_separator(token, desired_output):
assert desired_output == preprocess.is_sentence_separator(token)

Просмотреть файл

@ -1,14 +1,16 @@
import random
import pytest
import warnings
import pytest
from genalog.text import alignment
from genalog.text.alignment import GAP_CHAR
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
def random_utf8_char(byte_len=1):
if byte_len == 1:
return chr(random.randint(0,0x007F))
return chr(random.randint(0, 0x007F))
elif byte_len == 2:
return chr(random.randint(0x007F, 0x07FF))
elif byte_len == 3:
@ -16,33 +18,61 @@ def random_utf8_char(byte_len=1):
elif byte_len == 4:
return chr(random.randint(0xFFFF, 0x10FFFF))
else:
raise ValueError(f"Invalid byte length: {byte_len}." +
"utf-8 does not encode characters with more than 4 bytes in length")
raise ValueError(
f"Invalid byte length: {byte_len}."
+ "utf-8 does not encode characters with more than 4 bytes in length"
)
@pytest.mark.parametrize("num_utf_char_to_test", [100]) # Number of char per byte length
@pytest.mark.parametrize("byte_len", [1,2,3,4]) # UTF does not encode with more than 4 bytes
@pytest.mark.parametrize("gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", ALIGNMENT_REGRESSION_TEST_CASES)
def test_align(num_utf_char_to_test, byte_len, gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
invalid_char = set(gt_txt).union(set(GAP_CHAR)) # character to replace to cannot be in this set
@pytest.mark.parametrize(
"num_utf_char_to_test", [100]
) # Number of char per byte length
@pytest.mark.parametrize(
"byte_len", [1, 2, 3, 4]
) # UTF does not encode with more than 4 bytes
@pytest.mark.parametrize(
"gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise",
ALIGNMENT_REGRESSION_TEST_CASES,
)
def test_align(
num_utf_char_to_test,
byte_len,
gt_txt,
noisy_txt,
expected_aligned_gt,
expected_aligned_noise,
):
invalid_char = set(gt_txt).union(
set(GAP_CHAR)
) # character to replace to cannot be in this set
for _ in range(num_utf_char_to_test):
utf_char = random_utf8_char(byte_len)
while utf_char in invalid_char: # find a utf char not in the input string and not GAP_CHAR
while (
utf_char in invalid_char
): # find a utf char not in the input string and not GAP_CHAR
utf_char = random_utf8_char(byte_len)
char_to_replace = random.choice(list(invalid_char)) if gt_txt else ""
gt_txt_sub = gt_txt.replace(char_to_replace, utf_char)
noisy_txt_sub = noisy_txt.replace(char_to_replace, utf_char)
gt_txt.replace(char_to_replace, utf_char)
noisy_txt.replace(char_to_replace, utf_char)
expected_aligned_gt_sub = expected_aligned_gt.replace(char_to_replace, utf_char)
expected_aligned_noise_sub = expected_aligned_noise.replace(char_to_replace, utf_char)
expected_aligned_noise_sub = expected_aligned_noise.replace(
char_to_replace, utf_char
)
# Run alignment
aligned_gt, aligned_noise = alignment.align(gt_txt, noisy_txt)
aligned_gt = aligned_gt.replace(char_to_replace, utf_char)
aligned_noise = aligned_noise.replace(char_to_replace, utf_char)
if aligned_gt != expected_aligned_gt_sub:
expected_alignment = alignment._format_alignment(expected_aligned_gt_sub, expected_aligned_noise_sub)
expected_alignment = alignment._format_alignment(
expected_aligned_gt_sub, expected_aligned_noise_sub
)
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
warnings.warn(
RuntimeWarning(
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
)
)

36
tox.ini Normal file
Просмотреть файл

@ -0,0 +1,36 @@
; [tox]
; envlist = flake8, py36 # add other python versions if necessary
; [testenv]
; # Reading additional dependencies to run the test
; # https://tox.readthedocs.io/en/latest/example/basic.html#depending-on-requirements-txt-or-defining-constraints
; deps = -rdev-requirements.txt
; commands =
; pytest
; [testenv:flake8]
; deps = flake8
; skip_install = True
; commands = flake8 .
; # Configurations for running pytest
[pytest]
junit_family=xunit2
testpaths =
tests
addopts =
-rsx --cov=genalog --cov-report=html --cov-report=term-missing --cov-report=xml --junitxml=junit/test-results.xml
[flake8]
# Configs for flake8-import-order, see https://pypi.org/project/flake8-import-order/ for more info.
import-order-style=edited
application-import-names=genalog, tests
# Native flake8 configs
max-line-length = 140
exclude =
build, dist
.env*,.venv* # local virtual environments
.tox
; [mypy]
; ignore_missing_imports = True