зеркало из https://github.com/microsoft/genalog.git
Merge pull request #10 from microsoft/laserprec/use_linter
Use flake8 as linter and fix code format issues.
This commit is contained in:
Коммит
cda2c1e77d
|
@ -35,31 +35,20 @@ steps:
|
|||
displayName: 'Use Python $(python.version)'
|
||||
|
||||
- bash: |
|
||||
python -m venv .venv
|
||||
displayName: 'Create virtual environment'
|
||||
|
||||
- bash: |
|
||||
if [[ '$(Agent.OS)' == Windows* ]]
|
||||
then
|
||||
source .venv/Scripts/activate
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
pip install --upgrade pip
|
||||
pip install setuptools wheel
|
||||
pip install -r requirements.txt
|
||||
pip install pytest==5.3.5 pytest-cov==2.8.1
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install setuptools wheel
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install -r requirements-dev.txt
|
||||
workingDirectory: $(Build.SourcesDirectory)
|
||||
displayName: 'Install dependencies'
|
||||
|
||||
- bash: |
|
||||
if [[ '$(Agent.OS)' == Windows* ]]
|
||||
then
|
||||
source .venv/Scripts/activate
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
python -m pytest tests --cov=genalog --doctest-modules --junitxml=junit/test-results.xml --cov-report=xml --cov-report=html
|
||||
python -m flake8
|
||||
workingDirectory: $(Build.SourcesDirectory)
|
||||
displayName: 'Run Linter (flake8)'
|
||||
|
||||
- bash: |
|
||||
python -m pytest tests
|
||||
env:
|
||||
BLOB_KEY : $(BLOB_KEY)
|
||||
SEARCH_SERVICE_KEY: $(SEARCH_SERVICE_KEY)
|
||||
|
@ -86,12 +75,6 @@ steps:
|
|||
displayName: 'Publish test coverage'
|
||||
|
||||
- bash: |
|
||||
if [[ '$(Agent.OS)' == Windows* ]]
|
||||
then
|
||||
source .venv/Scripts/activate
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
python setup.py bdist_wheel --build-number $(Build.BuildNumber) --dist-dir dist
|
||||
workingDirectory: $(Build.SourcesDirectory)
|
||||
displayName: 'Building wheel package'
|
||||
|
|
|
@ -1,25 +1,29 @@
|
|||
from genalog.degradation import effect
|
||||
from enum import Enum
|
||||
import copy
|
||||
import inspect
|
||||
from enum import Enum
|
||||
|
||||
from genalog.degradation import effect
|
||||
|
||||
DEFAULT_METHOD_PARAM_TO_INCLUDE = "src"
|
||||
|
||||
|
||||
class ImageState(Enum):
|
||||
ORIGINAL_STATE = "ORIGINAL_STATE"
|
||||
CURRENT_STATE = "CURRENT_STATE"
|
||||
|
||||
class Degrader():
|
||||
|
||||
class Degrader:
|
||||
""" An object for applying multiple degradation effects onto an image"""
|
||||
|
||||
def __init__(self, effects):
|
||||
""" Initialize a Degrader object
|
||||
"""Initialize a Degrader object
|
||||
|
||||
Arguments:
|
||||
effects {list} -- a list of 2-element tuple that defines:
|
||||
|
||||
|
||||
(method_name, method_kwargs)
|
||||
|
||||
1. method_name: the name of the degradation method
|
||||
1. method_name: the name of the degradation method
|
||||
(method must be defined in 'genalog.degradation.effect')
|
||||
2. method_kwargs: the keyword arguments of the corresponding method
|
||||
|
||||
|
@ -28,10 +32,10 @@ class Degrader():
|
|||
[
|
||||
("blur", {"radius": 3}),
|
||||
("bleed_through", {"alpha": 0.8),
|
||||
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}),
|
||||
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}),
|
||||
]
|
||||
|
||||
The example above will apply degradation effects to the images
|
||||
The example above will apply degradation effects to the images
|
||||
in the following sequence:
|
||||
|
||||
blur -> bleed_through -> morphological operation (open)
|
||||
|
@ -42,14 +46,14 @@ class Degrader():
|
|||
|
||||
@staticmethod
|
||||
def validate_effects(effects):
|
||||
""" Validate the effects list
|
||||
"""Validate the effects list
|
||||
|
||||
Arguments:
|
||||
effects {list} -- a list of 2-element tuple that defines:
|
||||
|
||||
|
||||
(method_name, method_kwargs)
|
||||
|
||||
1. method_name: the name of the degradation method
|
||||
1. method_name: the name of the degradation method
|
||||
(method must be defined in 'genalog.degradation.effect')
|
||||
2. method_kwargs: the keyword arguments of the corresponding method
|
||||
|
||||
|
@ -58,11 +62,11 @@ class Degrader():
|
|||
[
|
||||
("blur", {"radius": "3"}),
|
||||
("bleed_through", {"alpha":"0.8"}),
|
||||
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}),
|
||||
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}),
|
||||
]
|
||||
|
||||
Raises:
|
||||
ValueError: raise this error when
|
||||
ValueError: raise this error when
|
||||
1. method_name not defined in "genalog.degradation.effect"
|
||||
2. method_kwargs is not a valid keyword arguments in the
|
||||
corresponding method
|
||||
|
@ -73,37 +77,46 @@ class Degrader():
|
|||
# Try to find corresponding degradation method in the module
|
||||
method = getattr(effect, method_name)
|
||||
except AttributeError:
|
||||
raise ValueError(f"Method '{method_name}' is not defined in 'genalog.degradation.effect'")
|
||||
raise ValueError(
|
||||
f"Method '{method_name}' is not defined in 'genalog.degradation.effect'"
|
||||
)
|
||||
# Get the method signatures
|
||||
method_sign = inspect.signature(method)
|
||||
# Check if method parameters are valid
|
||||
for param_name in method_kwargs.keys(): # i.e. ["operation", "kernel_shape", ...]
|
||||
if not param_name in method_sign.parameters:
|
||||
for (
|
||||
param_name
|
||||
) in method_kwargs.keys(): # i.e. ["operation", "kernel_shape", ...]
|
||||
if param_name not in method_sign.parameters:
|
||||
method_args = [param for param in method_sign.parameters]
|
||||
raise ValueError(f"Invalid parameter name '{param_name}' for method 'genalog.degradation.effect.{method_name}()'. Method parameter names are: {method_args}")
|
||||
raise ValueError(
|
||||
f"Invalid parameter name '{param_name}' for method 'genalog.degradation.effect.{method_name}()'. " +
|
||||
f"Method parameter names are: {method_args}"
|
||||
)
|
||||
|
||||
def _add_default_method_param(self):
|
||||
""" All methods in "genalog.degradation.effect" module have a required
|
||||
"""All methods in "genalog.degradation.effect" module have a required
|
||||
method parameter named "src". This parameter will be included if not provided
|
||||
by the input keyword argument dictionary.
|
||||
"""
|
||||
for effect_tuple in self.effects_to_apply:
|
||||
method_name, method_kwargs = effect_tuple
|
||||
if DEFAULT_METHOD_PARAM_TO_INCLUDE not in method_kwargs:
|
||||
method_kwargs[DEFAULT_METHOD_PARAM_TO_INCLUDE] = ImageState.CURRENT_STATE
|
||||
|
||||
method_kwargs[
|
||||
DEFAULT_METHOD_PARAM_TO_INCLUDE
|
||||
] = ImageState.CURRENT_STATE
|
||||
|
||||
def apply_effects(self, src):
|
||||
""" Apply degradation effects in sequence
|
||||
|
||||
"""Apply degradation effects in sequence
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effects
|
||||
"""
|
||||
self.original_state = src
|
||||
self.current_state = src
|
||||
# Preserve the original effect instructions
|
||||
# Preserve the original effect instructions
|
||||
effects_to_apply = copy.deepcopy(self.effects_to_apply)
|
||||
for effect_tuple in effects_to_apply:
|
||||
method_name, method_kwargs = effect_tuple
|
||||
|
@ -115,7 +128,7 @@ class Degrader():
|
|||
return self.current_state
|
||||
|
||||
def insert_image_state(self, kwargs):
|
||||
""" Replace the enumeration (ImageState) with the actual image in
|
||||
"""Replace the enumeration (ImageState) with the actual image in
|
||||
the keyword argument dictionary
|
||||
|
||||
Arguments:
|
||||
|
@ -124,12 +137,12 @@ class Degrader():
|
|||
Ex: {"src": ImageState.ORIGINAL_STATE, "radius": 5}
|
||||
|
||||
Returns:
|
||||
return keyword argument dictionary replaced with
|
||||
reference to the image
|
||||
return keyword argument dictionary replaced with
|
||||
reference to the image
|
||||
"""
|
||||
for keyword, argument in kwargs.items():
|
||||
if argument is ImageState.ORIGINAL_STATE:
|
||||
kwargs[keyword] = self.original_state.copy()
|
||||
if argument is ImageState.CURRENT_STATE:
|
||||
kwargs[keyword] = self.current_state.copy()
|
||||
return kwargs
|
||||
return kwargs
|
||||
|
|
|
@ -1,31 +1,34 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from math import floor
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def blur(src, radius=5):
|
||||
""" Wrapper function for cv2.GaussianBlur
|
||||
|
||||
"""Wrapper function for cv2.GaussianBlur
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
radius {int} -- size of the square kernel,
|
||||
radius {int} -- size of the square kernel,
|
||||
MUST be an odd integer (default: {5})
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect
|
||||
"""
|
||||
return cv2.GaussianBlur(src, (radius, radius), cv2.BORDER_DEFAULT)
|
||||
|
||||
|
||||
def overlay_weighted(src, background, alpha, beta, gamma=0):
|
||||
""" overlay two images together, pixels from each image is weighted as follow
|
||||
"""overlay two images together, pixels from each image is weighted as follow
|
||||
|
||||
dst[i] = alpha*src[i] + beta*background[i] + gamma
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
background {numpy.ndarray} -- background image. Must be in same shape are `src`
|
||||
alpha {float} -- transparent factor for the foreground
|
||||
alpha {float} -- transparent factor for the foreground
|
||||
beta {float} -- transparent factor for the background
|
||||
|
||||
Keyword Arguments:
|
||||
|
@ -36,8 +39,9 @@ def overlay_weighted(src, background, alpha, beta, gamma=0):
|
|||
"""
|
||||
return cv2.addWeighted(src, alpha, background, beta, gamma).astype(np.uint8)
|
||||
|
||||
|
||||
def overlay(src, background):
|
||||
""" Overlay two images together via bitwise-and:
|
||||
"""Overlay two images together via bitwise-and:
|
||||
|
||||
dst[i] = src[i] & background[i]
|
||||
|
||||
|
@ -50,33 +54,35 @@ def overlay(src, background):
|
|||
"""
|
||||
return cv2.bitwise_and(src, background).astype(np.uint8)
|
||||
|
||||
|
||||
def translation(src, offset_x, offset_y):
|
||||
""" Shift the image in x, y direction
|
||||
|
||||
"""Shift the image in x, y direction
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
offset_x {int} -- pixels in the x direction.
|
||||
offset_x {int} -- pixels in the x direction.
|
||||
Positive value shifts right and negative shifts right.
|
||||
offset_y {int} -- pixels in the y direction.
|
||||
Positive value shifts down and negative shifts up.
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect
|
||||
"""
|
||||
rows, cols = src.shape
|
||||
trans_matrix = np.float32([[1,0,offset_x], [0,1,offset_y]])
|
||||
trans_matrix = np.float32([[1, 0, offset_x], [0, 1, offset_y]])
|
||||
# size of the output image should be in the form of (width, height)
|
||||
dst = cv2.warpAffine(src, trans_matrix, (cols, rows), borderValue=255)
|
||||
return dst.astype(np.uint8)
|
||||
|
||||
|
||||
def bleed_through(src, background=None, alpha=0.8, gamma=0, offset_x=0, offset_y=5):
|
||||
""" Apply bleed through effect, background is flipped horizontally.
|
||||
|
||||
"""Apply bleed through effect, background is flipped horizontally.
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
background {numpy.ndarray} -- background image. Must be in same
|
||||
background {numpy.ndarray} -- background image. Must be in same
|
||||
shape as foreground (default: {None})
|
||||
alpha {float} -- transparent factor for the foreground (default: {0.8})
|
||||
gamma {int} -- luminance constant (default: {0})
|
||||
|
@ -84,30 +90,31 @@ def bleed_through(src, background=None, alpha=0.8, gamma=0, offset_x=0, offset_y
|
|||
Positive value shifts right and negative shifts right.
|
||||
offset_y {int} -- background translation offset (default: {5})
|
||||
Positive value shifts down and negative shifts up.
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect.
|
||||
Pixel value ranges [0, 255]
|
||||
"""
|
||||
if background is None:
|
||||
background = src.copy()
|
||||
background = cv2.flip(background, 1) # flipped horizontally
|
||||
background = cv2.flip(background, 1) # flipped horizontally
|
||||
background = translation(background, offset_x, offset_y)
|
||||
beta = 1 - alpha
|
||||
return overlay_weighted(src, background, alpha, beta, gamma)
|
||||
|
||||
|
||||
def pepper(src, amount=0.05):
|
||||
""" Randomly sprinkle dark pixels on src image.
|
||||
"""Randomly sprinkle dark pixels on src image.
|
||||
Wrapper function for skimage.util.noise.random_noise().
|
||||
See https://scikit-image.org/docs/stable/api/skimage.util.html#random-noise
|
||||
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
amount {float} -- proportion of pixels in range [0, 1] to apply the effect
|
||||
(default: {0.05})
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect.
|
||||
Pixel value ranges [0, 255] as uint8.
|
||||
|
@ -118,18 +125,19 @@ def pepper(src, amount=0.05):
|
|||
dst[noise < amount] = 0
|
||||
return dst.astype(np.uint8)
|
||||
|
||||
|
||||
def salt(src, amount=0.3):
|
||||
""" Randomly sprinkle white pixels on src image.
|
||||
"""Randomly sprinkle white pixels on src image.
|
||||
Wrapper function for skimage.util.noise.random_noise().
|
||||
See https://scikit-image.org/docs/stable/api/skimage.util.html#random-noise
|
||||
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
amount {float} -- proportion of pixels in range [0, 1] to apply the effect
|
||||
(default: {0.05})
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect.
|
||||
Pixel value ranges [0, 255]
|
||||
|
@ -140,18 +148,19 @@ def salt(src, amount=0.3):
|
|||
dst[noise < amount] = 255
|
||||
return dst.astype(np.uint8)
|
||||
|
||||
|
||||
def salt_then_pepper(src, salt_amount=0.1, pepper_amount=0.05):
|
||||
""" Randomly add salt then add pepper onto the image.
|
||||
|
||||
"""Randomly add salt then add pepper onto the image.
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
salt_amount {float} -- proportion of pixels in range [0, 1] to
|
||||
salt_amount {float} -- proportion of pixels in range [0, 1] to
|
||||
apply the salt effect
|
||||
(default: {0.1})
|
||||
pepper_amount {float} -- proportion of pixels in range [0, 1] to
|
||||
apply the pepper effect
|
||||
pepper_amount {float} -- proportion of pixels in range [0, 1] to
|
||||
apply the pepper effect
|
||||
(default: {0.05})
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect.
|
||||
Pixel value ranges [0, 255] as uint8.
|
||||
|
@ -159,18 +168,19 @@ def salt_then_pepper(src, salt_amount=0.1, pepper_amount=0.05):
|
|||
salted = salt(src, amount=salt_amount)
|
||||
return pepper(salted, amount=pepper_amount)
|
||||
|
||||
|
||||
def pepper_then_salt(src, pepper_amount=0.05, salt_amount=0.1):
|
||||
""" Randomly add pepper then salt onto the image.
|
||||
|
||||
"""Randomly add pepper then salt onto the image.
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
pepper_amount {float} -- proportion of pixels in range [0, 1] to
|
||||
pepper_amount {float} -- proportion of pixels in range [0, 1] to
|
||||
apply the pepper effect.
|
||||
(default: {0.05})
|
||||
salt_amount {float} -- proportion of pixels in range [0, 1] to
|
||||
(default: {0.05})
|
||||
salt_amount {float} -- proportion of pixels in range [0, 1] to
|
||||
apply the salt effect.
|
||||
(default: {0.1})
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect.
|
||||
Pixel value ranges [0, 255] as uint8.
|
||||
|
@ -178,16 +188,17 @@ def pepper_then_salt(src, pepper_amount=0.05, salt_amount=0.1):
|
|||
peppered = pepper(src, amount=pepper_amount)
|
||||
return salt(peppered, amount=salt_amount)
|
||||
|
||||
|
||||
def create_2D_kernel(kernel_shape, kernel_type="ones"):
|
||||
""" Create 2D kernel for morphological operations.
|
||||
|
||||
"""Create 2D kernel for morphological operations.
|
||||
|
||||
Arguments:
|
||||
kernel_shape {tuple} -- shape of the kernel (rows, cols)
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
kernel_type {str} -- type of kernel (default: {"ones"}).
|
||||
kernel_type {str} -- type of kernel (default: {"ones"}).
|
||||
All supported kernel types are below:
|
||||
|
||||
|
||||
"ones": kernel is filled with all 1s in shape (rows, cols)
|
||||
[[1,1,1],
|
||||
[1,1,1],
|
||||
|
@ -215,9 +226,9 @@ def create_2D_kernel(kernel_shape, kernel_type="ones"):
|
|||
[1, 1, 1, 1, 1],
|
||||
[0, 0, 1, 0, 0]]
|
||||
Raises:
|
||||
ValueError: if kernel is not a 2-element tuple or
|
||||
ValueError: if kernel is not a 2-element tuple or
|
||||
kernel_type is not one of the supported values
|
||||
|
||||
|
||||
Returns:
|
||||
a 2D array {numpy.ndarray} of shape `kernel_shape`.
|
||||
"""
|
||||
|
@ -233,37 +244,47 @@ def create_2D_kernel(kernel_shape, kernel_type="ones"):
|
|||
elif kernel_type == "x":
|
||||
diagonal = np.eye(kernel_rows, kernel_cols)
|
||||
kernel = np.add(diagonal, np.fliplr(diagonal))
|
||||
kernel[kernel>1] = 1
|
||||
kernel[kernel > 1] = 1
|
||||
elif kernel_type == "plus":
|
||||
kernel = np.zeros(kernel_shape)
|
||||
center_col = floor(kernel.shape[0]/2)
|
||||
center_row = floor(kernel.shape[1]/2)
|
||||
center_col = floor(kernel.shape[0] / 2)
|
||||
center_row = floor(kernel.shape[1] / 2)
|
||||
kernel[:, center_col] = 1
|
||||
kernel[center_row, :] = 1
|
||||
elif kernel_type == "ellipse":
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_shape)
|
||||
else:
|
||||
valid_kernel_types = {"ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"}
|
||||
raise ValueError(f"Invalid kernel_type: {kernel_type}. Valid types are {valid_kernel_types}")
|
||||
valid_kernel_types = {
|
||||
"ones",
|
||||
"upper_triangle",
|
||||
"lower_triangle",
|
||||
"x",
|
||||
"plus",
|
||||
"ellipse",
|
||||
}
|
||||
raise ValueError(
|
||||
f"Invalid kernel_type: {kernel_type}. Valid types are {valid_kernel_types}"
|
||||
)
|
||||
|
||||
return kernel.astype(np.uint8)
|
||||
|
||||
def morphology(src, operation="open", kernel_shape=(3,3), kernel_type="ones"):
|
||||
""" Dynamic calls different morphological operations
|
||||
|
||||
def morphology(src, operation="open", kernel_shape=(3, 3), kernel_type="ones"):
|
||||
"""Dynamic calls different morphological operations
|
||||
("open", "close", "dilate" and "erode") with the given parameters
|
||||
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
operation {str} -- name of a morphological operation:
|
||||
("open", "close", "dilate", "erode")
|
||||
(default: {"open"})
|
||||
kernel_shape {tuple} -- shape of the kernel (rows, cols)
|
||||
kernel_shape {tuple} -- shape of the kernel (rows, cols)
|
||||
(default: {(3,3)})
|
||||
kernel_type {str} -- type of kernel (default: {"ones"})
|
||||
Supported kernel_types are:
|
||||
["ones", "upper_triangle", "lower_triangle",
|
||||
["ones", "upper_triangle", "lower_triangle",
|
||||
"x", "plus", "ellipse"]
|
||||
|
||||
Returns:
|
||||
|
@ -280,7 +301,10 @@ def morphology(src, operation="open", kernel_shape=(3,3), kernel_type="ones"):
|
|||
return erode(src, kernel)
|
||||
else:
|
||||
valid_operations = ["open", "close", "dilate", "erode"]
|
||||
raise ValueError(f"Invalid morphology operation '{operation}'. Valid morphological operations are {valid_operations}")
|
||||
raise ValueError(
|
||||
f"Invalid morphology operation '{operation}'. Valid morphological operations are {valid_operations}"
|
||||
)
|
||||
|
||||
|
||||
def open(src, kernel):
|
||||
""" "open" morphological operation. Like morphological "erosion", it removes
|
||||
|
@ -289,59 +313,62 @@ def open(src, kernel):
|
|||
For more information see:
|
||||
1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html
|
||||
2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/open.htm
|
||||
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect
|
||||
"""
|
||||
return cv2.morphologyEx(src, cv2.MORPH_OPEN, kernel)
|
||||
|
||||
|
||||
def close(src, kernel):
|
||||
""" "close" morphological operation. Like morphological "dilation", it grows the
|
||||
""" "close" morphological operation. Like morphological "dilation", it grows the
|
||||
boundary of the foreground (white pixels), however, it is less destructive than
|
||||
dilation of the original boundary shape.
|
||||
|
||||
For more information see:
|
||||
1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html
|
||||
2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/close.htm
|
||||
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect
|
||||
"""
|
||||
return cv2.morphologyEx(src, cv2.MORPH_CLOSE, kernel)
|
||||
|
||||
|
||||
def erode(src, kernel):
|
||||
""" "erode" morphological operation. Erodes foreground pixels (white pixels).
|
||||
For more information see:
|
||||
1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html
|
||||
2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/erode.htm
|
||||
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect
|
||||
"""
|
||||
return cv2.erode(src, kernel)
|
||||
|
||||
|
||||
def dilate(src, kernel):
|
||||
""" "dilate" morphological operation. Grows foreground pixels (white pixels).
|
||||
""" "dilate" morphological operation. Grows foreground pixels (white pixels).
|
||||
For more information see:
|
||||
1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html
|
||||
2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/dilate.htm
|
||||
|
||||
|
||||
Arguments:
|
||||
src {numpy.ndarray} -- source image of shape (rows, cols)
|
||||
kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect
|
||||
|
||||
|
||||
Returns:
|
||||
a copy of the source image {numpy.ndarray} after apply the effect
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from enum import Enum, auto
|
||||
from enum import auto, Enum
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
PARAGRAPH = auto()
|
||||
|
@ -6,14 +7,17 @@ class ContentType(Enum):
|
|||
IMAGE = auto()
|
||||
COMPOSITE = auto()
|
||||
|
||||
class Content():
|
||||
|
||||
class Content:
|
||||
def __init__(self):
|
||||
self.iterable = True
|
||||
self._content = None
|
||||
|
||||
def set_content_type(self, content_type):
|
||||
if type(content_type) != ContentType:
|
||||
raise TypeError(f"Invalid content type: {content_type}, valid types are {list(ContentType)}")
|
||||
raise TypeError(
|
||||
f"Invalid content type: {content_type}, valid types are {list(ContentType)}"
|
||||
)
|
||||
self.content_type = content_type
|
||||
|
||||
def validate_content(self):
|
||||
|
@ -21,13 +25,14 @@ class Content():
|
|||
|
||||
def __str__(self):
|
||||
return self._content.__str__()
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
return self._content.__iter__()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._content.__getitem__(key)
|
||||
|
||||
|
||||
class Paragraph(Content):
|
||||
def __init__(self, content):
|
||||
self.set_content_type(ContentType.PARAGRAPH)
|
||||
|
@ -38,16 +43,18 @@ class Paragraph(Content):
|
|||
if not isinstance(content, str):
|
||||
raise TypeError(f"Expect a str, but got {type(content)}")
|
||||
|
||||
|
||||
class Title(Content):
|
||||
def __init__(self, content):
|
||||
self.set_content_type(ContentType.TITLE)
|
||||
self.validate_content(content)
|
||||
self._content = content
|
||||
|
||||
|
||||
def validate_content(self, content):
|
||||
if not isinstance(content, str):
|
||||
raise TypeError(f"Expect a str, but got {type(content)}")
|
||||
|
||||
|
||||
class CompositeContent(Content):
|
||||
def __init__(self, content_list, content_type_list):
|
||||
self.set_content_type(ContentType.COMPOSITE)
|
||||
|
@ -68,7 +75,7 @@ class CompositeContent(Content):
|
|||
self._content.append(Paragraph(content))
|
||||
else:
|
||||
raise NotImplementedError(f"{content_type} is not currently supported")
|
||||
|
||||
|
||||
def insert_content(self, new_content, index):
|
||||
NotImplementedError
|
||||
|
||||
|
@ -82,5 +89,5 @@ class CompositeContent(Content):
|
|||
"""get a string transparent of the nested object types"""
|
||||
transparent_str = "["
|
||||
for content in self._content:
|
||||
transparent_str += '"' + content.__str__() + "\", "
|
||||
transparent_str += '"' + content.__str__() + '", '
|
||||
return transparent_str + "]"
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
from jinja2 import PackageLoader, FileSystemLoader
|
||||
from jinja2 import Environment, select_autoescape
|
||||
from weasyprint import HTML
|
||||
from cairocffi import FORMAT_ARGB32
|
||||
import numpy as np
|
||||
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import cv2
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from cairocffi import FORMAT_ARGB32
|
||||
from jinja2 import Environment, select_autoescape
|
||||
from jinja2 import FileSystemLoader, PackageLoader
|
||||
from weasyprint import HTML
|
||||
|
||||
DEFAULT_DOCUMENT_STYLE = {
|
||||
"language": "en_US",
|
||||
|
@ -27,19 +25,21 @@ DEFAULT_STYLE_COMBINATION = {
|
|||
"hyphenate": [False],
|
||||
}
|
||||
|
||||
|
||||
class Document(object):
|
||||
""" A composite object that represents a document """
|
||||
|
||||
def __init__(self, content, template, **styles):
|
||||
"""Initialize a Document object with source template and content
|
||||
|
||||
|
||||
Arguments:
|
||||
content {CompositeContent} -- a iterable object whose elements
|
||||
template {Template} -- a jinja2.Template object
|
||||
|
||||
|
||||
Optional Argument:
|
||||
styles [dict] -- a kwargs dictionary (context) whose keys and values are
|
||||
styles [dict] -- a kwargs dictionary (context) whose keys and values are
|
||||
the template variable and their respective values
|
||||
|
||||
|
||||
Example:
|
||||
{
|
||||
"font_family": "Calibri",
|
||||
|
@ -48,24 +48,24 @@ class Document(object):
|
|||
}
|
||||
|
||||
Note that this assumes that "font_family", "font_size", "hyphenate" are valid
|
||||
variables declared in the loaded template. There will be **NO SIDE-EFFECT**
|
||||
variables declared in the loaded template. There will be **NO SIDE-EFFECT**
|
||||
providing an variable undefined in the template.
|
||||
|
||||
|
||||
You can also provide these key-value pairs via Python keyword arguments:
|
||||
|
||||
|
||||
Document(content, template, font_family="Calibri, font_size="10px", hyphenate=True)
|
||||
"""
|
||||
self.content = content
|
||||
self.template = template
|
||||
self.styles = DEFAULT_DOCUMENT_STYLE.copy()
|
||||
# This is a rendered document ready to be painted on a cairo surface
|
||||
self._document = None # weasyprint.document.Document object
|
||||
self._document = None # weasyprint.document.Document object
|
||||
self.compiled_html = None
|
||||
# Update the default styles and initialize self._document object
|
||||
self.update_style(**styles)
|
||||
|
||||
def render_html(self):
|
||||
""" Wrapper function for Jinjia2.Template.render(). Each template
|
||||
"""Wrapper function for Jinjia2.Template.render(). Each template
|
||||
declare its template variables. This method assigns each variable to
|
||||
its respective value and compiles the template.
|
||||
|
||||
|
@ -77,8 +77,8 @@ class Document(object):
|
|||
return self.template.render(content=self.content, **self.styles)
|
||||
|
||||
def render_pdf(self, target=None, zoom=1):
|
||||
""" Wrapper function for WeasyPrint.Document.write_pdf
|
||||
|
||||
"""Wrapper function for WeasyPrint.Document.write_pdf
|
||||
|
||||
Arguments:
|
||||
target -- a filename, file-like object, or None
|
||||
split_pages {bool} -- true if saving each document page as a separate file.
|
||||
|
@ -92,32 +92,37 @@ class Document(object):
|
|||
return self._document.write_pdf(target=target, zoom=zoom)
|
||||
|
||||
def render_png(self, target=None, split_pages=False, resolution=300):
|
||||
""" Wrapper function for WeasyPrint.Document.write_png
|
||||
|
||||
"""Wrapper function for WeasyPrint.Document.write_png
|
||||
|
||||
Arguments:
|
||||
target -- a filename, file-like object, or None
|
||||
split_pages {bool} -- true if save each document page as a separate file.
|
||||
resolution {int} -- the output resolution in PNG pixels per CSS inch. At 300 dpi (the default), PNG pixels match the CSS px unit.
|
||||
|
||||
resolution {int} -- the output resolution in PNG pixels per CSS inch. At 300 dpi (the default),
|
||||
PNG pixels match the CSS px unit.
|
||||
|
||||
Returns:
|
||||
The image as bytes if target is not provided or None, otherwise None (the PDF is written to target)
|
||||
"""
|
||||
if target != None and split_pages:
|
||||
if target is not None and split_pages:
|
||||
# get destination filename and extension
|
||||
filename, ext = os.path.splitext(target)
|
||||
for page_num, page in enumerate(self._document.pages):
|
||||
page_name = filename + f"_pg_{page_num}" + ext
|
||||
self._document.copy([page]).write_png(target=page_name, resolution=resolution)
|
||||
self._document.copy([page]).write_png(
|
||||
target=page_name, resolution=resolution
|
||||
)
|
||||
return None
|
||||
elif target == None:
|
||||
elif target is None:
|
||||
# return image bytes string if no target is specified
|
||||
png_bytes, png_width, png_height = self._document.write_png(target=target, resolution=resolution)
|
||||
png_bytes, png_width, png_height = self._document.write_png(
|
||||
target=target, resolution=resolution
|
||||
)
|
||||
return png_bytes
|
||||
else:
|
||||
return self._document.write_png(target=target, resolution=resolution)
|
||||
|
||||
def render_array(self, resolution=300, channel="GRAYSCALE"):
|
||||
""" Render document as a numpy.ndarray.
|
||||
"""Render document as a numpy.ndarray.
|
||||
|
||||
Keyword Arguments:
|
||||
resolution {int} -- in units dpi (default: {300})
|
||||
|
@ -125,24 +130,29 @@ class Document(object):
|
|||
available values are: "GRAYSCALE", "RGB", "RGBA", "BGRA", "BGR"
|
||||
|
||||
Note that "RGB" is 3-channel, "RGBA" is 4-channel and "GRAYSCALE" is single channel
|
||||
|
||||
|
||||
Returns:
|
||||
A numpy.ndarray representation of the document.
|
||||
"""
|
||||
# Method below returns a cairocffi.ImageSurface object
|
||||
# https://cairocffi.readthedocs.io/en/latest/api.html#cairocffi.ImageSurface
|
||||
surface, width, height = self._document.write_image_surface(resolution=resolution)
|
||||
surface, width, height = self._document.write_image_surface(
|
||||
resolution=resolution
|
||||
)
|
||||
img_format = surface.get_format()
|
||||
|
||||
|
||||
# This is BGRA channel in little endian (reverse)
|
||||
if img_format != FORMAT_ARGB32:
|
||||
raise RuntimeError(f"Expect surface format to be 'cairocffi.FORMAT_ARGB32', but got {img_format}. Please check the underlining implementation of 'weasyprint.document.Document.write_image_surface()'")
|
||||
|
||||
raise RuntimeError(
|
||||
f"Expect surface format to be 'cairocffi.FORMAT_ARGB32', but got {img_format}." +
|
||||
"Please check the underlining implementation of 'weasyprint.document.Document.write_image_surface()'"
|
||||
)
|
||||
|
||||
img_buffer = surface.get_data()
|
||||
# Returns image array in "BGRA" channel
|
||||
img_array = np.ndarray(shape=(height, width, 4),
|
||||
dtype=np.uint8,
|
||||
buffer=img_buffer)
|
||||
img_array = np.ndarray(
|
||||
shape=(height, width, 4), dtype=np.uint8, buffer=img_buffer
|
||||
)
|
||||
if channel == "GRAYSCALE":
|
||||
return cv2.cvtColor(img_array, cv2.COLOR_BGRA2GRAY)
|
||||
elif channel == "RGBA":
|
||||
|
@ -155,13 +165,15 @@ class Document(object):
|
|||
return cv2.cvtColor(img_array, cv2.COLOR_BGRA2BGR)
|
||||
else:
|
||||
valid_channels = ["GRAYSCALE", "RGB", "RGBA", "BGR", "BGRA"]
|
||||
raise ValueError(f"Invalid channel code {channel}. Valid values are: {valid_channels}.")
|
||||
raise ValueError(
|
||||
f"Invalid channel code {channel}. Valid values are: {valid_channels}."
|
||||
)
|
||||
|
||||
def update_style(self, **style):
|
||||
""" Update template variables that controls the document style and re-compile the document to reflect the style change.
|
||||
|
||||
"""Update template variables that controls the document style and re-compile the document to reflect the style change.
|
||||
|
||||
Optional Arguments:
|
||||
style {dict} -- a kwargs dictionary whose keys and values are
|
||||
style {dict} -- a kwargs dictionary whose keys and values are
|
||||
the template variable and their respective values
|
||||
|
||||
Example:
|
||||
|
@ -175,44 +187,49 @@ class Document(object):
|
|||
self.styles.update(style)
|
||||
# Recompile the html template and the document obj
|
||||
self.compiled_html = self.render_html()
|
||||
self._document = HTML(string=self.compiled_html).render() # weasyprinter.document.Document object
|
||||
self._document = HTML(
|
||||
string=self.compiled_html
|
||||
).render() # weasyprinter.document.Document object
|
||||
|
||||
|
||||
class DocumentGenerator():
|
||||
class DocumentGenerator:
|
||||
""" Document generator class """
|
||||
|
||||
def __init__(self, template_path=None):
|
||||
""" Initialize a DocumentGenerator class
|
||||
|
||||
"""Initialize a DocumentGenerator class
|
||||
|
||||
Keyword Arguments:
|
||||
template_path {str} -- filepath of custom templates (default: {None})
|
||||
*** Important *** if not set, will use the default templates from the
|
||||
*** Important *** if not set, will use the default templates from the
|
||||
package "genalog.generation.templates".
|
||||
"""
|
||||
if template_path:
|
||||
self.template_env = Environment(
|
||||
loader=FileSystemLoader(template_path),
|
||||
autoescape=select_autoescape(['html', 'xml'])
|
||||
)
|
||||
loader=FileSystemLoader(template_path),
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
)
|
||||
self.template_list = self.template_env.list_templates()
|
||||
else:
|
||||
# Loading built-in templates from the genalog package
|
||||
self.template_env = Environment(
|
||||
loader=PackageLoader("genalog.generation", "templates"),
|
||||
autoescape=select_autoescape(['html', 'xml'])
|
||||
)
|
||||
loader=PackageLoader("genalog.generation", "templates"),
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
)
|
||||
# Remove macros and css templates from rendering
|
||||
self.template_list = self.template_env.list_templates(filter_func=DocumentGenerator._keep_template)
|
||||
self.template_list = self.template_env.list_templates(
|
||||
filter_func=DocumentGenerator._keep_template
|
||||
)
|
||||
|
||||
self.set_styles_to_generate(DEFAULT_STYLE_COMBINATION)
|
||||
|
||||
@staticmethod
|
||||
def _keep_template(template_name):
|
||||
""" Auxiliary function for Jinja2.Environment.list_templates().
|
||||
"""Auxiliary function for Jinja2.Environment.list_templates().
|
||||
This function filters out non-html templates and base templates
|
||||
|
||||
|
||||
Arguments:
|
||||
template_name {str} -- target of the template
|
||||
|
||||
|
||||
Returns:
|
||||
[bool] -- True if keeping the template in the list. False otherwise.
|
||||
"""
|
||||
|
@ -220,16 +237,16 @@ class DocumentGenerator():
|
|||
if any(name in template_name for name in TEMPLATES_TO_REMOVE):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def set_styles_to_generate(self, style_combinations):
|
||||
"""
|
||||
Set new styles to generate.
|
||||
|
||||
|
||||
Arguments:
|
||||
style_combination {dict} -- a dictionary {str: list} enlisting the combinations
|
||||
of values to generate per style property
|
||||
style_combination {dict} -- a dictionary {str: list} enlisting the combinations
|
||||
of values to generate per style property
|
||||
(default: {None})
|
||||
|
||||
|
||||
Example:
|
||||
{
|
||||
"font_family": ["Calibri", "Times"],
|
||||
|
@ -248,14 +265,16 @@ class DocumentGenerator():
|
|||
variables declared in the loaded template. There will be NO side-effect providing
|
||||
an variable UNDEFINED in the template.
|
||||
|
||||
If this parameter is not provided, generator will use default document
|
||||
If this parameter is not provided, generator will use default document
|
||||
styles: DEFAULT_STYLE_COMBINATION
|
||||
"""
|
||||
self.styles_to_generate = DocumentGenerator.expand_style_combinations(style_combinations)
|
||||
self.styles_to_generate = DocumentGenerator.expand_style_combinations(
|
||||
style_combinations
|
||||
)
|
||||
|
||||
def create_generator(self, content, templates_to_render):
|
||||
""" Create a Document generator
|
||||
|
||||
"""Create a Document generator
|
||||
|
||||
Arguments:
|
||||
content {list} -- a list [str] of string to populate the template
|
||||
templates_to_render {list} -- a list [str] or templates to render
|
||||
|
@ -266,15 +285,17 @@ class DocumentGenerator():
|
|||
"""
|
||||
for template_name in templates_to_render:
|
||||
if template_name not in self.template_list:
|
||||
raise FileNotFoundError(f"File '{template_name}' not found. Available templates are {self.template_list}")
|
||||
raise FileNotFoundError(
|
||||
f"File '{template_name}' not found. Available templates are {self.template_list}"
|
||||
)
|
||||
template = self.template_env.get_template(template_name)
|
||||
for style in self.styles_to_generate:
|
||||
yield Document(content, template, **style)
|
||||
|
||||
@staticmethod
|
||||
def expand_style_combinations(styles):
|
||||
""" Expand the list of style values into all possible style combinations
|
||||
|
||||
"""Expand the list of style values into all possible style combinations
|
||||
|
||||
Example:
|
||||
styles =
|
||||
{
|
||||
|
@ -291,12 +312,12 @@ class DocumentGenerator():
|
|||
{"font_family": "Times", "font_size": "12px", "hyphenate":True }
|
||||
]
|
||||
|
||||
The result dictionaries are intended to be used as a kwargs to initialize a
|
||||
The result dictionaries are intended to be used as a kwargs to initialize a
|
||||
Document Object:
|
||||
|
||||
Example:
|
||||
Document(template, content, **{"font_family": "Calibri", "font_size": ...})
|
||||
|
||||
|
||||
Arguments:
|
||||
styles {dict} -- a dictionary {str: list} enlisting the combinations of values
|
||||
to generate per style property
|
||||
|
@ -308,10 +329,14 @@ class DocumentGenerator():
|
|||
if not styles:
|
||||
return []
|
||||
# Python 2.x+ guarantees that the order in keys() and values() is preserved
|
||||
style_properties = styles.keys() # ex) ["font_family", "font_size", "hyphenate"]
|
||||
property_values = styles.values() # ex) [["Calibri", "Times"], ["10px", "12px"], [True]]
|
||||
|
||||
# Generate all possible combinations:
|
||||
style_properties = (
|
||||
styles.keys()
|
||||
) # ex) ["font_family", "font_size", "hyphenate"]
|
||||
property_values = (
|
||||
styles.values()
|
||||
) # ex) [["Calibri", "Times"], ["10px", "12px"], [True]]
|
||||
|
||||
# Generate all possible combinations:
|
||||
# [("Calibri", "10px", True), ("Calibri", "12px", True), ...]
|
||||
property_value_combinations = itertools.product(*property_values)
|
||||
|
||||
|
@ -328,4 +353,4 @@ class DocumentGenerator():
|
|||
style_dict[style_property] = property_value
|
||||
style_combinations.append(style_dict)
|
||||
|
||||
return style_combinations
|
||||
return style_combinations
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
"""Uses the python sdk to make operation on Azure Blob storage.
|
||||
see: https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from multiprocessing import Pool
|
||||
from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient
|
||||
from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient
|
||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
||||
from tqdm import tqdm
|
||||
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
|
||||
import base64
|
||||
|
||||
import aiofiles
|
||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient
|
||||
from tqdm import tqdm
|
||||
|
||||
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
|
||||
|
||||
# maximum number of simultaneous requests
|
||||
REQUEST_SEMAPHORE = asyncio.Semaphore(50)
|
||||
|
@ -24,13 +24,21 @@ FILE_SEMAPHORE = asyncio.Semaphore(500)
|
|||
|
||||
MAX_RETRIES = 5
|
||||
|
||||
|
||||
class GrokBlobClient:
|
||||
"""This class is a client that is used to upload and delete files from Azure Blob storage
|
||||
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python
|
||||
"""
|
||||
def __init__(self, datasource_container_name, blob_account_name, blob_key, projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasource_container_name,
|
||||
blob_account_name,
|
||||
blob_key,
|
||||
projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME,
|
||||
):
|
||||
"""Creates the blob storage client given the key and storage account name
|
||||
|
||||
|
||||
Args:
|
||||
datasource_container_name (str): container name. This container does not need to be existing
|
||||
projections_container_name (str): projections container to store ocr projections.
|
||||
|
@ -42,38 +50,54 @@ class GrokBlobClient:
|
|||
self.PROJECTIONS_CONTAINER_NAME = projections_container_name
|
||||
self.BLOB_NAME = blob_account_name
|
||||
self.BLOB_KEY = blob_key
|
||||
self.BLOB_CONNECTION_STRING = f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" \
|
||||
self.BLOB_CONNECTION_STRING = (
|
||||
f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};"
|
||||
f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_from_env_var():
|
||||
"""Created the blob client using values in the environment variables
|
||||
|
||||
|
||||
Returns:
|
||||
GrokBlobClient: the new blob client
|
||||
"""
|
||||
DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"]
|
||||
BLOB_NAME = os.environ["BLOB_NAME"]
|
||||
BLOB_KEY = os.environ["BLOB_KEY"]
|
||||
PROJECTIONS_CONTAINER_NAME = os.environ.get("PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME)
|
||||
client = GrokBlobClient(DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY, projections_container_name=PROJECTIONS_CONTAINER_NAME)
|
||||
PROJECTIONS_CONTAINER_NAME = os.environ.get(
|
||||
"PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME
|
||||
)
|
||||
client = GrokBlobClient(
|
||||
DATASOURCE_CONTAINER_NAME,
|
||||
BLOB_NAME,
|
||||
BLOB_KEY,
|
||||
projections_container_name=PROJECTIONS_CONTAINER_NAME,
|
||||
)
|
||||
return client
|
||||
|
||||
def upload_images_to_blob(self, src_folder_path, dest_folder_name=None,check_existing_cache=True, use_async=True):
|
||||
def upload_images_to_blob(
|
||||
self,
|
||||
src_folder_path,
|
||||
dest_folder_name=None,
|
||||
check_existing_cache=True,
|
||||
use_async=True,
|
||||
):
|
||||
"""Uploads images from the src_folder_path to blob storage at the destination folder.
|
||||
The destination folder is created if it doesn't exist. If a destination folder is not
|
||||
given a folder is created named by the md5 hash of the files.
|
||||
|
||||
|
||||
Args:
|
||||
src_folder_path (src): path to local folder that has images
|
||||
dest_folder_name (str, optional): destination folder name. Defaults to None.
|
||||
|
||||
|
||||
Returns:
|
||||
str: the destination folder name
|
||||
"""
|
||||
self._create_container()
|
||||
blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
|
||||
if dest_folder_name is None:
|
||||
dest_folder_name = self.get_folder_hash(src_folder_path)
|
||||
|
@ -94,30 +118,50 @@ class GrokBlobClient:
|
|||
return (upload_file_path, blob_name)
|
||||
|
||||
if check_existing_cache:
|
||||
existing_blobs,_ = self.list_blobs(dest_folder_name or "")
|
||||
existing_blobs = list(map(lambda blob: blob["name"] , existing_blobs))
|
||||
file_blob_names = filter(lambda file_blob_names: not file_blob_names[1] in existing_blobs, zip(files_to_upload, blob_names))
|
||||
job_args = [get_job_args(file_path, blob_name) for file_path, blob_name in file_blob_names ]
|
||||
existing_blobs, _ = self.list_blobs(dest_folder_name or "")
|
||||
existing_blobs = list(map(lambda blob: blob["name"], existing_blobs))
|
||||
file_blob_names = filter(
|
||||
lambda file_blob_names: not file_blob_names[1] in existing_blobs,
|
||||
zip(files_to_upload, blob_names),
|
||||
)
|
||||
job_args = [
|
||||
get_job_args(file_path, blob_name)
|
||||
for file_path, blob_name in file_blob_names
|
||||
]
|
||||
else:
|
||||
job_args = [get_job_args(file_path, blob_name) for file_path, blob_name in zip(files_to_upload, blob_names)]
|
||||
job_args = [
|
||||
get_job_args(file_path, blob_name)
|
||||
for file_path, blob_name in zip(files_to_upload, blob_names)
|
||||
]
|
||||
|
||||
print("uploading ", len(job_args), "files")
|
||||
if not use_async:
|
||||
blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
blob_container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
blob_container_client = blob_service_client.get_container_client(
|
||||
self.DATASOURCE_CONTAINER_NAME
|
||||
)
|
||||
jobs = [(blob_container_client,) + x for x in job_args]
|
||||
for _ in tqdm(map(_upload_worker_sync, jobs), total=len(jobs) ):
|
||||
for _ in tqdm(map(_upload_worker_sync, jobs), total=len(jobs)):
|
||||
pass
|
||||
else:
|
||||
async_blob_service_client = asyncBlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
|
||||
async def async_upload():
|
||||
async with async_blob_service_client:
|
||||
async_blob_container_client = async_blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
|
||||
async_blob_container_client = (
|
||||
async_blob_service_client.get_container_client(
|
||||
self.DATASOURCE_CONTAINER_NAME
|
||||
)
|
||||
)
|
||||
jobs = [(async_blob_container_client,) + x for x in job_args]
|
||||
for f in tqdm(asyncio.as_completed(map(_upload_worker_async,jobs)), total=len(jobs)):
|
||||
for f in tqdm(
|
||||
asyncio.as_completed(map(_upload_worker_async, jobs)),
|
||||
total=len(jobs),
|
||||
):
|
||||
await f
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -150,7 +194,7 @@ class GrokBlobClient:
|
|||
|
||||
def delete_blobs_folder(self, folder_name):
|
||||
"""Deletes all blobs in a folder
|
||||
|
||||
|
||||
Args:
|
||||
folder_name (str): folder to delete
|
||||
"""
|
||||
|
@ -158,53 +202,83 @@ class GrokBlobClient:
|
|||
blobs_list, blob_service_client = self.list_blobs(folder_name)
|
||||
for blob in blobs_list:
|
||||
blob_client = blob_service_client.get_blob_client(
|
||||
container=self.DATASOURCE_CONTAINER_NAME, blob=blob)
|
||||
container=self.DATASOURCE_CONTAINER_NAME, blob=blob
|
||||
)
|
||||
blob_client.delete_blob()
|
||||
|
||||
def list_blobs(self, folder_name):
|
||||
blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
|
||||
return container_client.list_blobs(name_starts_with=folder_name), blob_service_client
|
||||
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
container_client = blob_service_client.get_container_client(
|
||||
self.DATASOURCE_CONTAINER_NAME
|
||||
)
|
||||
return (
|
||||
container_client.list_blobs(name_starts_with=folder_name),
|
||||
blob_service_client,
|
||||
)
|
||||
|
||||
def _create_container(self):
|
||||
"""Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist.
|
||||
"""
|
||||
"""Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist."""
|
||||
# Create the BlobServiceClient object which will be used to create a container client
|
||||
blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
|
||||
try:
|
||||
blob_service_client.create_container(
|
||||
self.DATASOURCE_CONTAINER_NAME)
|
||||
blob_service_client.create_container(self.DATASOURCE_CONTAINER_NAME)
|
||||
except ResourceExistsError:
|
||||
print("container already exists:", self.DATASOURCE_CONTAINER_NAME)
|
||||
|
||||
# create the container for storing ocr projections
|
||||
try:
|
||||
print("creating projections storage container:", self.PROJECTIONS_CONTAINER_NAME)
|
||||
blob_service_client.create_container(
|
||||
self.PROJECTIONS_CONTAINER_NAME)
|
||||
print(
|
||||
"creating projections storage container:",
|
||||
self.PROJECTIONS_CONTAINER_NAME,
|
||||
)
|
||||
blob_service_client.create_container(self.PROJECTIONS_CONTAINER_NAME)
|
||||
except ResourceExistsError:
|
||||
print("container already exists:", self.PROJECTIONS_CONTAINER_NAME)
|
||||
|
||||
def get_ocr_json(self, remote_path, output_folder, use_async=True):
|
||||
blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
container_client = blob_service_client.get_container_client(
|
||||
self.DATASOURCE_CONTAINER_NAME
|
||||
)
|
||||
blobs_list = list(container_client.list_blobs(name_starts_with=remote_path))
|
||||
container_uri = f"https://{self.BLOB_NAME}.blob.core.windows.net/{self.DATASOURCE_CONTAINER_NAME}"
|
||||
|
||||
if use_async:
|
||||
async_blob_service_client = asyncBlobServiceClient.from_connection_string(
|
||||
self.BLOB_CONNECTION_STRING)
|
||||
self.BLOB_CONNECTION_STRING
|
||||
)
|
||||
|
||||
async def async_download():
|
||||
async with async_blob_service_client:
|
||||
async_projection_container_client = async_blob_service_client.get_container_client(self.PROJECTIONS_CONTAINER_NAME)
|
||||
jobs = list(map(lambda blob :(blob, async_projection_container_client, container_uri, output_folder), blobs_list ))
|
||||
for f in tqdm(asyncio.as_completed(map(_download_worker_async,jobs)), total=len(jobs)):
|
||||
async_projection_container_client = (
|
||||
async_blob_service_client.get_container_client(
|
||||
self.PROJECTIONS_CONTAINER_NAME
|
||||
)
|
||||
)
|
||||
jobs = list(
|
||||
map(
|
||||
lambda blob: (
|
||||
blob,
|
||||
async_projection_container_client,
|
||||
container_uri,
|
||||
output_folder,
|
||||
),
|
||||
blobs_list,
|
||||
)
|
||||
)
|
||||
for f in tqdm(
|
||||
asyncio.as_completed(map(_download_worker_async, jobs)),
|
||||
total=len(jobs),
|
||||
):
|
||||
await f
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
result = loop.create_task(async_download())
|
||||
|
@ -212,11 +286,24 @@ class GrokBlobClient:
|
|||
result = loop.run_until_complete(async_download())
|
||||
return result
|
||||
else:
|
||||
projection_container_client = blob_service_client.get_container_client(self.PROJECTIONS_CONTAINER_NAME)
|
||||
jobs = list(map(lambda blob : (blob, projection_container_client, container_uri, output_folder), blobs_list))
|
||||
projection_container_client = blob_service_client.get_container_client(
|
||||
self.PROJECTIONS_CONTAINER_NAME
|
||||
)
|
||||
jobs = list(
|
||||
map(
|
||||
lambda blob: (
|
||||
blob,
|
||||
projection_container_client,
|
||||
container_uri,
|
||||
output_folder,
|
||||
),
|
||||
blobs_list,
|
||||
)
|
||||
)
|
||||
print("downloading", len(jobs), "files")
|
||||
for _ in tqdm(map( _download_worker_sync, jobs), total=len(jobs)):
|
||||
pass
|
||||
for _ in tqdm(map(_download_worker_sync, jobs), total=len(jobs)):
|
||||
pass
|
||||
|
||||
|
||||
def _get_projection_path(container_uri, blob):
|
||||
blob_uri = f"{container_uri}/{blob.name}"
|
||||
|
@ -226,13 +313,16 @@ def _get_projection_path(container_uri, blob):
|
|||
# hopefully this doesn't change soon otherwise we will have to do linear search over all docs to find
|
||||
# the projections we want
|
||||
projection_path = base64.b64encode(blob_uri.encode()).decode()
|
||||
projection_path = projection_path.replace("=","") + str(projection_path.count("="))
|
||||
projection_path = projection_path.replace("=", "") + str(projection_path.count("="))
|
||||
return projection_path
|
||||
|
||||
|
||||
def _download_worker_sync(args):
|
||||
blob, projection_container_client, container_uri, output_folder = args
|
||||
projection_path = _get_projection_path(container_uri,blob)
|
||||
blob_client = projection_container_client.get_blob_client(blob=f"{projection_path}/document.json")
|
||||
projection_path = _get_projection_path(container_uri, blob)
|
||||
blob_client = projection_container_client.get_blob_client(
|
||||
blob=f"{projection_path}/document.json"
|
||||
)
|
||||
doc = json.loads(blob_client.download_blob().readall())
|
||||
file_name = os.path.basename(blob.name)
|
||||
base_name, ext = os.path.splitext(file_name)
|
||||
|
@ -242,10 +332,13 @@ def _download_worker_sync(args):
|
|||
json.dump(text, open(output_file, "w", encoding="utf-8"), ensure_ascii=False)
|
||||
return output_file
|
||||
|
||||
|
||||
async def _download_worker_async(args):
|
||||
blob, async_projection_container_client, container_uri, output_folder = args
|
||||
projection_path = _get_projection_path(container_uri,blob)
|
||||
async_blob_client = async_projection_container_client.get_blob_client( blob=f"{projection_path}/document.json")
|
||||
projection_path = _get_projection_path(container_uri, blob)
|
||||
async_blob_client = async_projection_container_client.get_blob_client(
|
||||
blob=f"{projection_path}/document.json"
|
||||
)
|
||||
file_name = os.path.basename(blob.name)
|
||||
base_name, ext = os.path.splitext(file_name)
|
||||
for retry in range(MAX_RETRIES):
|
||||
|
@ -260,14 +353,15 @@ async def _download_worker_async(args):
|
|||
json.dump(text, open(output_file, "w"))
|
||||
return output_file
|
||||
except ResourceNotFoundError:
|
||||
print(f"blob doesn't exist in OCR projection. try rerunning OCR", blob.name)
|
||||
print(f"Blob '{blob.name}'' doesn't exist in OCR projection. try rerunning OCR")
|
||||
return
|
||||
except Exception as e:
|
||||
print("error getting blob OCR projection", blob.name, e)
|
||||
|
||||
|
||||
# sleep for a bit then retry
|
||||
asyncio.sleep(2 * random.random())
|
||||
|
||||
|
||||
async def _upload_worker_async(args):
|
||||
async_blob_container_client, upload_file_path, blob_name = args
|
||||
async with FILE_SEMAPHORE:
|
||||
|
@ -276,24 +370,33 @@ async def _upload_worker_async(args):
|
|||
for retry in range(MAX_RETRIES):
|
||||
async with REQUEST_SEMAPHORE:
|
||||
try:
|
||||
await async_blob_container_client.upload_blob(name=blob_name, max_concurrency=8, data=data)
|
||||
await async_blob_container_client.upload_blob(
|
||||
name=blob_name, max_concurrency=8, data=data
|
||||
)
|
||||
return blob_name
|
||||
except ResourceExistsError:
|
||||
print("blob already exists:", blob_name)
|
||||
return
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"blob upload error. retry count: {retry}/{MAX_RETRIES} :", blob_name, e)
|
||||
print(
|
||||
f"blob upload error. retry count: {retry}/{MAX_RETRIES} :",
|
||||
blob_name,
|
||||
e,
|
||||
)
|
||||
# sleep for a bit then retry
|
||||
asyncio.sleep(2 * random.random())
|
||||
return blob_name
|
||||
|
||||
|
||||
def _upload_worker_sync(args):
|
||||
blob_container_client, upload_file_path, blob_name = args
|
||||
with open(upload_file_path, "rb") as data:
|
||||
try:
|
||||
blob_container_client.upload_blob(name=blob_name, max_concurrency=8, data=data)
|
||||
blob_container_client.upload_blob(
|
||||
name=blob_name, max_concurrency=8, data=data
|
||||
)
|
||||
except ResourceExistsError:
|
||||
print("blob already exists:", blob_name)
|
||||
except Exception as e:
|
||||
print("blob upload error:", blob_name, e)
|
||||
return blob_name
|
||||
return blob_name
|
||||
|
|
|
@ -1 +1 @@
|
|||
DEFAULT_PROJECTIONS_CONTAINER_NAME = "ocrprojections"
|
||||
DEFAULT_PROJECTIONS_CONTAINER_NAME = "ocrprojections"
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
from .rest_client import GrokRestClient
|
||||
from .blob_client import GrokBlobClient
|
||||
import time
|
||||
|
||||
from .blob_client import GrokBlobClient
|
||||
from .rest_client import GrokRestClient
|
||||
|
||||
|
||||
class Grok:
|
||||
|
||||
@staticmethod
|
||||
def create_from_env_var():
|
||||
"""Initializes Grok based on keys in the environment variables.
|
||||
|
||||
|
||||
Returns:
|
||||
Grok: the Grok client
|
||||
"""
|
||||
|
@ -16,31 +16,41 @@ class Grok:
|
|||
grok_blob_client = GrokBlobClient.create_from_env_var()
|
||||
return Grok(grok_rest_client, grok_blob_client)
|
||||
|
||||
def __init__(self, grok_rest_client: GrokRestClient, grok_blob_client: GrokBlobClient):
|
||||
def __init__(
|
||||
self, grok_rest_client: GrokRestClient, grok_blob_client: GrokBlobClient
|
||||
):
|
||||
self.grok_rest_client = grok_rest_client
|
||||
self.grok_blob_client = grok_blob_client
|
||||
|
||||
def run_grok(self, src_folder_path, dest_folder_path, blob_dest_folder=None,cleanup=False,use_async=True):
|
||||
"""Uploads images in the source folder to blob, sets up an indexing pipeline to run
|
||||
def run_grok(
|
||||
self,
|
||||
src_folder_path,
|
||||
dest_folder_path,
|
||||
blob_dest_folder=None,
|
||||
cleanup=False,
|
||||
use_async=True,
|
||||
):
|
||||
"""Uploads images in the source folder to blob, sets up an indexing pipeline to run
|
||||
GROK OCR on this blob storage as a source, then dowloads the OCR output json to the destination
|
||||
folder. There resulting json files are of the same name as the original images except prefixed
|
||||
folder. There resulting json files are of the same name as the original images except prefixed
|
||||
with the name of their folder on the blob storages and suffixed with the .json extension.
|
||||
|
||||
|
||||
Args:
|
||||
src_folder_path (str): Path to folder holding the images. This folder must only contain png or jpg files
|
||||
dest_folder_path (str): Path to folder where OCR json files will be placed
|
||||
blob_dest_folder (str, optional): Folder tag to use on the blob storage. If set to None, a hash is generated
|
||||
based on the names of files in the src folder. Defaults to None.
|
||||
cleanup (bool, optional): If set to True, the indexing pipeline is deleted, and the files uploaded to the blob are
|
||||
cleanup (bool, optional): If set to True, the indexing pipeline is deleted, and the files uploaded to the blob are
|
||||
deleted from blob after running. Defaults to True.
|
||||
use_multiprocessing (boo, optional): If set to True, this will use multiprocessing to increase blob transfers speed.
|
||||
|
||||
|
||||
Returns:
|
||||
indexer_status json, blob folder name
|
||||
"""
|
||||
print("uploading images to blob")
|
||||
blob_folder_name,_ = self.grok_blob_client.upload_images_to_blob(
|
||||
src_folder_path, dest_folder_name=blob_dest_folder, use_async=use_async)
|
||||
blob_folder_name, _ = self.grok_blob_client.upload_images_to_blob(
|
||||
src_folder_path, dest_folder_name=blob_dest_folder, use_async=use_async
|
||||
)
|
||||
print(f"images upload under folder {blob_folder_name}")
|
||||
try:
|
||||
print("creating and running indexer")
|
||||
|
@ -50,10 +60,13 @@ class Grok:
|
|||
indexer_status = self.grok_rest_client.get_indexer_status()
|
||||
if indexer_status["status"] == "error":
|
||||
raise RuntimeError(f"indexer error: {indexer_status}")
|
||||
|
||||
|
||||
# if not already running start the indexer
|
||||
print("indexer_status", indexer_status)
|
||||
if indexer_status["lastResult"] == None or indexer_status["lastResult"]["status"] != "inProgress":
|
||||
if (
|
||||
indexer_status["lastResult"] is None
|
||||
or indexer_status["lastResult"]["status"] != "inProgress"
|
||||
):
|
||||
self.grok_rest_client.run_indexer()
|
||||
|
||||
time.sleep(1)
|
||||
|
@ -62,9 +75,13 @@ class Grok:
|
|||
if indexer_status["lastResult"]["status"] == "success":
|
||||
time.sleep(30)
|
||||
print("fetching ocr json results.")
|
||||
self.grok_blob_client.get_ocr_json(blob_folder_name, dest_folder_path, use_async=use_async)
|
||||
self.grok_blob_client.get_ocr_json(
|
||||
blob_folder_name, dest_folder_path, use_async=use_async
|
||||
)
|
||||
print(f"indexer status {indexer_status}")
|
||||
print(f"finished running indexer. json files saved to {dest_folder_path}")
|
||||
print(
|
||||
f"finished running indexer. json files saved to {dest_folder_path}"
|
||||
)
|
||||
else:
|
||||
print("GROK failed", indexer_status["status"])
|
||||
raise RuntimeError("GROK failed", indexer_status["status"])
|
||||
|
@ -77,7 +94,7 @@ class Grok:
|
|||
def cleanup(self, folder_name):
|
||||
"""Deletes the indexing pipeline (index, indexer, datasource, skillset) from the search service.
|
||||
Deletes uploaded files from the blob
|
||||
|
||||
|
||||
Args:
|
||||
folder_name (str): blob folder name tag to remove
|
||||
"""
|
||||
|
|
|
@ -2,54 +2,59 @@
|
|||
Utility functions to support getting OCR metrics
|
||||
|
||||
OCR Metrics
|
||||
1. word/character accuracy like in this paper https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6065412.
|
||||
1. word/character accuracy like in this paper https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6065412.
|
||||
Accuracy = Correct Words/Total Words (in target strings)
|
||||
|
||||
2. Count of edit distance ops:
|
||||
insert, delete, substitutions; like in the paper "Deep Statistical Analysis of OCR Errors for Effective Post-OCR Processing". This is based on Levenshtein edit distance.
|
||||
2. Count of edit distance ops:
|
||||
insert, delete, substitutions; like in the paper "Deep Statistical Analysis of OCR Errors for Effective Post-OCR Processing".
|
||||
This is based on Levenshtein edit distance.
|
||||
|
||||
3. By looking at the gaps in alignment we also generate substitution dicts:
|
||||
e.g: if we have text "a worn coat" and ocr is "a wom coat" , "rn" -> "m" will be captured as a substitution
|
||||
3. By looking at the gaps in alignment we also generate substitution dicts:
|
||||
e.g: if we have text "a worn coat" and ocr is "a wom coat" , "rn" -> "m" will be captured as a substitution
|
||||
since the rest of the segments align.The assumption here is that we do not expect to have very long gaps in alignment,
|
||||
hence collecting and counting these substitutions will be managable.
|
||||
|
||||
"""
|
||||
import string
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
from genalog.text.ner_label import _find_gap_char_candidates
|
||||
from genalog.text.anchor import align_w_anchor
|
||||
from genalog.text.ner_label import _find_gap_char_candidates
|
||||
|
||||
LOG_LEVEL = 0
|
||||
WORKERS_PER_CPU = 2
|
||||
|
||||
|
||||
def _log(*args, **kwargs):
|
||||
if LOG_LEVEL:
|
||||
print(args)
|
||||
|
||||
|
||||
def _trim_whitespace(src_string):
|
||||
return re.sub(r"\s+", " ", src_string.strip())
|
||||
|
||||
|
||||
def _update_align_stats(src, target, align_stats, substitution_dict, gap_char):
|
||||
"""Given two string that differ and have no alignment at all,
|
||||
update the alignment dict and fill in substitution if replacements are found.
|
||||
update alignment stats with counts of the edit operation to transform the source
|
||||
string to the targes
|
||||
|
||||
|
||||
Args:
|
||||
src (str): source string
|
||||
target (str): target string at the
|
||||
target (str): target string at the
|
||||
align_stats (dict): key-value dictionary that stores the counts of inserts, deletes,
|
||||
spacing and replacements
|
||||
substitution_dict (dict): store the counts of mapping from one substring to another of
|
||||
the replacement edit operation. e.g if 'rm' in source needs to map to 'm' in the target 2
|
||||
the replacement edit operation. e.g if 'rm' in source needs to map to 'm' in the target 2
|
||||
times this will be { ('rm','m'): 2}
|
||||
gap_char (str): gap character used in alignment
|
||||
"""
|
||||
|
@ -70,30 +75,40 @@ def _update_align_stats(src, target, align_stats, substitution_dict, gap_char):
|
|||
else:
|
||||
align_stats["replace"] += 1
|
||||
_log("replacing", source_substr, target_substr)
|
||||
substitution_dict[source_substr, target_substr] = substitution_dict.get(
|
||||
(source_substr, target_substr), 0) + 1
|
||||
substitution_dict[source_substr, target_substr] = (
|
||||
substitution_dict.get((source_substr, target_substr), 0) + 1
|
||||
)
|
||||
_log("spacing count", spacing_count)
|
||||
align_stats["spacing"] += spacing_count
|
||||
|
||||
def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matching_chars_count, \
|
||||
matching_words_count, matching_alnum_words_count):
|
||||
|
||||
def _update_word_stats(
|
||||
aligned_src,
|
||||
aligned_target,
|
||||
gap_char,
|
||||
start,
|
||||
end,
|
||||
matching_chars_count,
|
||||
matching_words_count,
|
||||
matching_alnum_words_count,
|
||||
):
|
||||
"""Given two string segments that align. update the counts of matching words and characters
|
||||
|
||||
|
||||
Args:
|
||||
aligned_src (str): full source string
|
||||
aligned_target (str): full target string
|
||||
gap_char (str): gap character used in alignment
|
||||
gap_char (str): gap character used in alignment
|
||||
start (int): start position of alignment
|
||||
end (int): end position of alignment
|
||||
matching_chars_count (int): current count of matching characters
|
||||
matching_words_count (int): current count of matching words
|
||||
matching_alnum_words_count (int): current count of alphanumeric matching words
|
||||
|
||||
|
||||
Returns:
|
||||
tuple(int,int,int): the updated matching_chars_count, matching_words_count, matching_alnum_words_count
|
||||
"""
|
||||
aligned_part = aligned_src[start:end]
|
||||
matching_chars_count += (end-start)
|
||||
matching_chars_count += end - start
|
||||
# aligned_part = seq.strip()
|
||||
_log("aligned", aligned_part, start, end)
|
||||
if len(aligned_src) != len(aligned_target):
|
||||
|
@ -112,10 +127,18 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
|
|||
# to be compared with the full string to see if they have space before or after
|
||||
|
||||
if i == 0:
|
||||
if start != 0 and (aligned_target[start] != " " or aligned_src[start] != " " ):
|
||||
if start != 0 and (
|
||||
aligned_target[start] != " " or aligned_src[start] != " "
|
||||
):
|
||||
# if this was the start of the string in the target or source
|
||||
if not(aligned_src[:start].replace(gap_char,"").replace(" ","") == "" and aligned_target[start-1] == " ") and \
|
||||
not(aligned_target[:start].replace(gap_char,"").replace(" ","") == "" and aligned_src[start-1] == " "):
|
||||
if not (
|
||||
aligned_src[:start].replace(gap_char, "").replace(" ", "") == ""
|
||||
and aligned_target[start - 1] == " "
|
||||
) and not (
|
||||
aligned_target[:start].replace(gap_char, "").replace(" ", "")
|
||||
== ""
|
||||
and aligned_src[start - 1] == " "
|
||||
):
|
||||
# beginning word not matching completely
|
||||
_log("removing first match word from count", word, aligned_part)
|
||||
matching_words_count -= 1
|
||||
|
@ -123,34 +146,43 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi
|
|||
matching_alnum_words_count -= 1
|
||||
continue
|
||||
|
||||
if i == len(words)-1:
|
||||
if end != len(aligned_target) and (aligned_target[end] != " " or aligned_src[end] != " " ):
|
||||
if i == len(words) - 1:
|
||||
if end != len(aligned_target) and (
|
||||
aligned_target[end] != " " or aligned_src[end] != " "
|
||||
):
|
||||
# this was not the end of the string in the src and not end of string in target
|
||||
if not(aligned_src[end:].replace(gap_char,"").replace(" ","") == "" and aligned_target[end] == " ") and \
|
||||
not(aligned_target[end:].replace(gap_char,"").replace(" ","") == ""and aligned_src[end] == " "):
|
||||
if not (
|
||||
aligned_src[end:].replace(gap_char, "").replace(" ", "") == ""
|
||||
and aligned_target[end] == " "
|
||||
) and not (
|
||||
aligned_target[end:].replace(gap_char, "").replace(" ", "")
|
||||
== ""
|
||||
and aligned_src[end] == " "
|
||||
):
|
||||
# last word not matching completely
|
||||
_log("removing last match word from count",word, aligned_part)
|
||||
_log("removing last match word from count", word, aligned_part)
|
||||
matching_words_count -= 1
|
||||
if re.search(r"\w", word):
|
||||
matching_alnum_words_count -= 1
|
||||
|
||||
_log("matched count", matching_words_count)
|
||||
_log("matched alnum count", matching_alnum_words_count)
|
||||
return matching_chars_count, matching_words_count, matching_alnum_words_count
|
||||
|
||||
return matching_chars_count, matching_words_count, matching_alnum_words_count
|
||||
|
||||
|
||||
def _get_align_stats(alignment, src_string, target, gap_char):
|
||||
"""Given an alignment, this function get the align stats and substitution mapping to
|
||||
"""Given an alignment, this function get the align stats and substitution mapping to
|
||||
transform the source string to the target string
|
||||
|
||||
|
||||
Args:
|
||||
alignment (tuple(str, str)): the result of calling align on the two strings
|
||||
src_source (str): the source string
|
||||
target (str) : the target string
|
||||
gap_char (str) : the gap character used in alignment
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: if any of the aligned string are empty
|
||||
|
||||
|
||||
Returns:
|
||||
tuple(dict, dict): align stats dict, substitution mappings dict
|
||||
"""
|
||||
|
@ -167,17 +199,24 @@ def _get_align_stats(alignment, src_string, target, gap_char):
|
|||
word_count = len(words)
|
||||
|
||||
# alphanumeric words are defined here as words with at least one alphanumeric character
|
||||
alnum_words_count = len(list(filter(lambda x: re.search(r"\w",x) , words )))
|
||||
|
||||
alnum_words_count = len(list(filter(lambda x: re.search(r"\w", x), words)))
|
||||
|
||||
char_count = max(len(target), len(src_string))
|
||||
matching_chars_count = 0
|
||||
matching_words_count = 0
|
||||
matching_alnum_words_count = 0
|
||||
|
||||
align_stats = {"insert": 0, "delete": 0, "replace": 0, "spacing": 0,
|
||||
"total_chars": char_count, "total_words": word_count, "total_alnum_words": alnum_words_count}
|
||||
align_stats = {
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": char_count,
|
||||
"total_words": word_count,
|
||||
"total_alnum_words": alnum_words_count,
|
||||
}
|
||||
start = 0
|
||||
|
||||
|
||||
_log("######### Alignment ############")
|
||||
_log(aligned_src)
|
||||
_log(aligned_target)
|
||||
|
@ -190,57 +229,105 @@ def _get_align_stats(alignment, src_string, target, gap_char):
|
|||
# since this substring aligns, simple count the number of matching words and chars in and update
|
||||
# the word stats
|
||||
end = i
|
||||
_log("sequences", aligned_src[start:end], aligned_target[start:end], start, end)
|
||||
_log(
|
||||
"sequences",
|
||||
aligned_src[start:end],
|
||||
aligned_target[start:end],
|
||||
start,
|
||||
end,
|
||||
)
|
||||
assert aligned_src[start:end] == aligned_target[start:end]
|
||||
matching_chars_count, matching_words_count, matching_alnum_words_count = _update_word_stats(aligned_src,
|
||||
aligned_target, gap_char, start, end, matching_chars_count,matching_words_count, matching_alnum_words_count)
|
||||
(
|
||||
matching_chars_count,
|
||||
matching_words_count,
|
||||
matching_alnum_words_count,
|
||||
) = _update_word_stats(
|
||||
aligned_src,
|
||||
aligned_target,
|
||||
gap_char,
|
||||
start,
|
||||
end,
|
||||
matching_chars_count,
|
||||
matching_words_count,
|
||||
matching_alnum_words_count,
|
||||
)
|
||||
start = end + 1
|
||||
if gap_start is None:
|
||||
gap_start = end
|
||||
else:
|
||||
gap_end = i
|
||||
if not gap_start is None:
|
||||
if gap_start is not None:
|
||||
# since characters now match gap_start:i contains a substring of the characters that didnt align before
|
||||
# handle this gap alignment by calling _update_align_stats
|
||||
_log("gap", aligned_src[gap_start:gap_end], aligned_target[gap_start:gap_end], gap_start, gap_end)
|
||||
_update_align_stats(aligned_src[gap_start:gap_end], aligned_target[gap_start:gap_end], align_stats, substitution_dict, gap_char)
|
||||
_log(
|
||||
"gap",
|
||||
aligned_src[gap_start:gap_end],
|
||||
aligned_target[gap_start:gap_end],
|
||||
gap_start,
|
||||
gap_end,
|
||||
)
|
||||
_update_align_stats(
|
||||
aligned_src[gap_start:gap_end],
|
||||
aligned_target[gap_start:gap_end],
|
||||
align_stats,
|
||||
substitution_dict,
|
||||
gap_char,
|
||||
)
|
||||
gap_start = None
|
||||
|
||||
# Now compare any left overs string segments from the for loop
|
||||
if gap_start is not None:
|
||||
# handle last alignment gap
|
||||
_log("last gap", aligned_src[gap_start:], aligned_target[gap_start:])
|
||||
_update_align_stats(aligned_src[gap_start:], aligned_target[gap_start:], align_stats, substitution_dict, gap_char)
|
||||
_update_align_stats(
|
||||
aligned_src[gap_start:],
|
||||
aligned_target[gap_start:],
|
||||
align_stats,
|
||||
substitution_dict,
|
||||
gap_char,
|
||||
)
|
||||
else:
|
||||
# handle last aligned substring
|
||||
seq = aligned_src[start:]
|
||||
aligned_part = seq.strip()
|
||||
end = len(aligned_src)
|
||||
_log("last aligned", aligned_part)
|
||||
matching_chars_count, matching_words_count, matching_alnum_words_count = _update_word_stats(aligned_src,
|
||||
aligned_target, gap_char, start, end, matching_chars_count,matching_words_count, matching_alnum_words_count)
|
||||
(
|
||||
matching_chars_count,
|
||||
matching_words_count,
|
||||
matching_alnum_words_count,
|
||||
) = _update_word_stats(
|
||||
aligned_src,
|
||||
aligned_target,
|
||||
gap_char,
|
||||
start,
|
||||
end,
|
||||
matching_chars_count,
|
||||
matching_words_count,
|
||||
matching_alnum_words_count,
|
||||
)
|
||||
|
||||
align_stats["matching_chars"] = matching_chars_count
|
||||
align_stats["matching_alnum_words"] = matching_alnum_words_count
|
||||
align_stats["matching_words"] = matching_words_count
|
||||
align_stats["alnum_word_accuracy"] = matching_alnum_words_count/alnum_words_count
|
||||
align_stats["word_accuracy"] = matching_words_count/word_count
|
||||
align_stats["char_accuracy"] = matching_chars_count/char_count
|
||||
align_stats["alnum_word_accuracy"] = matching_alnum_words_count / alnum_words_count
|
||||
align_stats["word_accuracy"] = matching_words_count / word_count
|
||||
align_stats["char_accuracy"] = matching_chars_count / char_count
|
||||
return align_stats, substitution_dict
|
||||
|
||||
|
||||
def get_editops_stats(alignment, gap_char):
|
||||
"""Get stats for character level edit operations that need to be done to
|
||||
"""Get stats for character level edit operations that need to be done to
|
||||
transform the source string to the target string. Inputs must not be empty
|
||||
and must be the result of calling the runing the align function.
|
||||
|
||||
|
||||
Args:
|
||||
alignment (tuple(str, str)): the results from the string alignment biopy function
|
||||
gap_char (str): gap character used in alignment
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the string in the alignment are empty
|
||||
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
|
@ -248,42 +335,49 @@ def get_editops_stats(alignment, gap_char):
|
|||
aligned_src, aligned_target = alignment
|
||||
if aligned_src == "" or aligned_target == "":
|
||||
raise ValueError("one of the input strings is empty")
|
||||
stats = {"edit_insert": 0, "edit_delete": 0, "edit_replace": 0,
|
||||
"edit_insert_spacing": 0, "edit_delete_spacing": 0}
|
||||
stats = {
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 0,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
}
|
||||
actions = {}
|
||||
for i, (char_1, char_2) in enumerate(zip(aligned_src, aligned_target)):
|
||||
if LOG_LEVEL > 1: _log(char_1, char_2)
|
||||
if LOG_LEVEL > 1:
|
||||
_log(char_1, char_2)
|
||||
if char_1 == gap_char:
|
||||
# insert
|
||||
if char_2 == " ":
|
||||
stats["edit_insert_spacing"] += 1
|
||||
else:
|
||||
stats["edit_insert"] += 1
|
||||
actions[i] = ("I",char_2)
|
||||
actions[i] = ("I", char_2)
|
||||
elif char_2 == gap_char:
|
||||
# delete
|
||||
if char_1 == " ":
|
||||
stats["edit_delete_spacing"] += 1
|
||||
else:
|
||||
stats["edit_delete"] += 1
|
||||
actions[i] = ("D")
|
||||
actions[i] = "D"
|
||||
elif char_2 != char_1:
|
||||
stats["edit_replace"] += 1
|
||||
actions[i] = ("R",char_2)
|
||||
actions[i] = ("R", char_2)
|
||||
return stats, actions
|
||||
|
||||
|
||||
def get_align_stats(alignment, src_string, target, gap_char):
|
||||
"""Get alignment stats
|
||||
|
||||
"""Get alignment stats
|
||||
|
||||
Args:
|
||||
alignment (tuple(str,str)): the result of calling the align function
|
||||
src_string (str): the original source string
|
||||
target (str): the original target string
|
||||
gap_char (str): the gap character used in alignment
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: if any of the strings are empty
|
||||
|
||||
|
||||
Returns:
|
||||
tuple(dict, dict): dict of the align starts and dict of the substitution mappings
|
||||
"""
|
||||
|
@ -293,52 +387,61 @@ def get_align_stats(alignment, src_string, target, gap_char):
|
|||
_log("alignment results")
|
||||
_log(alignment)
|
||||
align_stats, substitution_dict = _get_align_stats(
|
||||
alignment, src_string, target, gap_char)
|
||||
alignment, src_string, target, gap_char
|
||||
)
|
||||
return align_stats, substitution_dict
|
||||
|
||||
|
||||
def get_stats(target, src_string):
|
||||
"""Get align stats, edit stats, and substitution mappings for transforming the
|
||||
"""Get align stats, edit stats, and substitution mappings for transforming the
|
||||
source string to the target string. Edit stats refers to character level edit operation
|
||||
required to transform the source to target. Align stats referers to substring level operation
|
||||
required to transform the source to target. Align stats have keys insert,replace,delete and the special
|
||||
key spacing which counts spacing differences between the two strings. Edit stats have the keys edit_insert,
|
||||
edit_replace, edit_delete which count the character level edits.
|
||||
|
||||
|
||||
Args:
|
||||
src_string (str): the source string
|
||||
target (str): the target string
|
||||
|
||||
|
||||
Returns:
|
||||
tuple(str, str): One dict containing the edit and align stats, another dict containing the substitutions
|
||||
"""
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
|
||||
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(
|
||||
[src_string], [target]
|
||||
)
|
||||
gap_char = (
|
||||
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
)
|
||||
alignment = align_w_anchor(src_string, target, gap_char=gap_char)
|
||||
align_stats, substitution_dict = get_align_stats(alignment,src_string, target, gap_char)
|
||||
align_stats, substitution_dict = get_align_stats(
|
||||
alignment, src_string, target, gap_char
|
||||
)
|
||||
edit_stats, actions = get_editops_stats(alignment, gap_char)
|
||||
_log("alignment", align_stats)
|
||||
return {**edit_stats, **align_stats}, substitution_dict, actions
|
||||
|
||||
|
||||
def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocessing=True):
|
||||
"""Given a path to the folder containing the source text and a folder containing
|
||||
the output OCR json, this generates the metrics for all files in the source folder.
|
||||
def get_metrics(
|
||||
src_text_path, ocr_json_path, folder_hash=None, use_multiprocessing=True
|
||||
):
|
||||
"""Given a path to the folder containing the source text and a folder containing
|
||||
the output OCR json, this generates the metrics for all files in the source folder.
|
||||
This assumes that the files json folder are of the same name the text files except they
|
||||
are prefixed by the parameter folder_hash followed by underscore and suffixed by .png.json.
|
||||
|
||||
|
||||
Args:
|
||||
src_text_path (str): path to source txt files
|
||||
ocr_json_path (str): path to OCR json files
|
||||
folder_hash (str): prefix for OCR json files
|
||||
use_multiprocessing (bool): use multiprocessing
|
||||
|
||||
|
||||
Returns:
|
||||
tuple(pandas.DataFrame, dict): A pandas dataframe of the metrics with each file in a row,
|
||||
a dict containing the substitions mappings for each file. the key to the dict is the
|
||||
filename and the values are dicts of the substition mappings for that file.
|
||||
"""
|
||||
|
||||
|
||||
rows = []
|
||||
substitutions = {}
|
||||
actions_map = {}
|
||||
|
@ -347,16 +450,25 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess
|
|||
cpu_count = multiprocessing.cpu_count()
|
||||
n_workers = WORKERS_PER_CPU * cpu_count
|
||||
|
||||
job_args = list(map(lambda f: (f, src_text_path, ocr_json_path, folder_hash) , os.listdir(src_text_path)))
|
||||
job_args = list(
|
||||
map(
|
||||
lambda f: (f, src_text_path, ocr_json_path, folder_hash),
|
||||
os.listdir(src_text_path),
|
||||
)
|
||||
)
|
||||
|
||||
if use_multiprocessing:
|
||||
with Pool(n_workers) as pool:
|
||||
for f, stats, actions, subs in tqdm(pool.imap_unordered(_worker, job_args), total=len(job_args)):
|
||||
for f, stats, actions, subs in tqdm(
|
||||
pool.imap_unordered(_worker, job_args), total=len(job_args)
|
||||
):
|
||||
substitutions[f] = subs
|
||||
actions_map[f] = actions
|
||||
rows.append(stats)
|
||||
rows.append(stats)
|
||||
else:
|
||||
for f, stats, actions, subs in tqdm(map(_worker, job_args), total=len(job_args)):
|
||||
for f, stats, actions, subs in tqdm(
|
||||
map(_worker, job_args), total=len(job_args)
|
||||
):
|
||||
substitutions[f] = subs
|
||||
actions_map[f] = actions
|
||||
rows.append(stats)
|
||||
|
@ -364,16 +476,17 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess
|
|||
df = pd.DataFrame(rows)
|
||||
return df, substitutions, actions_map
|
||||
|
||||
|
||||
def get_file_metrics(f, src_text_path, ocr_json_path, folder_hash):
|
||||
src_filename = os.path.join(src_text_path, f)
|
||||
if folder_hash:
|
||||
ocr_filename = os.path.join(
|
||||
ocr_json_path, f"{folder_hash}_{f.split('txt')[0] + 'json'}")
|
||||
ocr_json_path, f"{folder_hash}_{f.split('txt')[0] + 'json'}"
|
||||
)
|
||||
else:
|
||||
ocr_filename = os.path.join(
|
||||
ocr_json_path, f"{f.split('txt')[0] + 'json'}")
|
||||
ocr_filename = os.path.join(ocr_json_path, f"{f.split('txt')[0] + 'json'}")
|
||||
try:
|
||||
src_string = open(src_filename, "r", errors='ignore', encoding="utf8").read()
|
||||
src_string = open(src_filename, "r", errors="ignore", encoding="utf8").read()
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {src_filename}, skipping this file.")
|
||||
return f, {}, {}, {}
|
||||
|
@ -395,34 +508,41 @@ def get_file_metrics(f, src_text_path, ocr_json_path, folder_hash):
|
|||
stats["filename"] = f
|
||||
return f, stats, actions, subs
|
||||
|
||||
|
||||
def _worker(args):
|
||||
(f, src_text_path, ocr_json_path, folder_hash) = args
|
||||
return get_file_metrics(f, src_text_path, ocr_json_path, folder_hash)
|
||||
|
||||
|
||||
def _get_sorted_text(ocr_json):
|
||||
if "lines" in ocr_json[0]:
|
||||
lines = ocr_json[0]["lines"]
|
||||
sorted_lines = sorted(lines, key=lambda line : line["boundingBox"][0]["y"])
|
||||
sorted_lines = sorted(lines, key=lambda line: line["boundingBox"][0]["y"])
|
||||
return " ".join([line["text"] for line in sorted_lines])
|
||||
else:
|
||||
return ocr_json[0]["text"]
|
||||
|
||||
|
||||
def substitution_dict_to_json(substitution_dict):
|
||||
"""Converts substitution dict to list of tuples of (source_substring, target_substring, count)
|
||||
|
||||
|
||||
Args:
|
||||
substitution_dict ([type]): [description]
|
||||
"""
|
||||
to_tuple = lambda x: [(k +(x[k],)) for k in x]
|
||||
to_tuple = lambda x: [(k + (x[k],)) for k in x] # noqa: E731
|
||||
out = {}
|
||||
for filename in substitution_dict:
|
||||
out[filename] = to_tuple(substitution_dict[filename])
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("src", help="path to folder with text files.")
|
||||
parser.add_argument("ocr", help="folder with ocr json. the filename must match the text filename prefixed by ocr_prefix.")
|
||||
parser.add_argument(
|
||||
"ocr",
|
||||
help="folder with ocr json. the filename must match the text filename prefixed by ocr_prefix.",
|
||||
)
|
||||
parser.add_argument("--ocr_prefix", help="the prefix of the ocr files")
|
||||
parser.add_argument("--output", help="output names of metrics files")
|
||||
|
||||
|
|
|
@ -1,33 +1,45 @@
|
|||
"""Uses the REST api to perform operations on the search service.
|
||||
see: https://docs.microsoft.com/en-us/rest/api/searchservice/
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
import pkgutil
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
import time
|
||||
import sys
|
||||
import time
|
||||
from itertools import cycle
|
||||
|
||||
import requests
|
||||
|
||||
from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME
|
||||
|
||||
API_VERSION = '?api-version=2019-05-06-Preview'
|
||||
API_VERSION = "?api-version=2019-05-06-Preview"
|
||||
|
||||
# 15 min schedule
|
||||
SCHEDULE_INTERVAL= "PT15M"
|
||||
SCHEDULE_INTERVAL = "PT15M"
|
||||
|
||||
|
||||
class GrokRestClient:
|
||||
"""This is a REST client. It is a wrapper around the REST api for the Azure Search Service
|
||||
see: https://docs.microsoft.com/en-us/rest/api/searchservice/
|
||||
|
||||
This class can be used to create an indexing pipeline and can be used to run and monitor
|
||||
This class can be used to create an indexing pipeline and can be used to run and monitor
|
||||
ongoing indexers. The indexing pipeline can allow you to run batch OCR enrichment of documents.
|
||||
"""
|
||||
|
||||
def __init__(self, cognitive_service_key, search_service_key, search_service_name, skillset_name,
|
||||
index_name, indexer_name, datasource_name, datasource_container_name, blob_account_name, blob_key,
|
||||
projections_container_name = DEFAULT_PROJECTIONS_CONTAINER_NAME):
|
||||
def __init__(
|
||||
self,
|
||||
cognitive_service_key,
|
||||
search_service_key,
|
||||
search_service_name,
|
||||
skillset_name,
|
||||
index_name,
|
||||
indexer_name,
|
||||
datasource_name,
|
||||
datasource_container_name,
|
||||
blob_account_name,
|
||||
blob_key,
|
||||
projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME,
|
||||
):
|
||||
"""Creates the REST client
|
||||
|
||||
Args:
|
||||
|
@ -36,7 +48,7 @@ class GrokRestClient:
|
|||
search_service_name (str): name of the search service account
|
||||
skillset_name (str): name of the skillset
|
||||
index_name (str): name of the index
|
||||
indexer_name (str): the name of indexer
|
||||
indexer_name (str): the name of indexer
|
||||
datasource_name (str): the name to give the the attached blob storage source
|
||||
datasource_container_name (str): the container in the blob storage that host the files
|
||||
blob_account_name (str): blob storage account name that will host the documents to push though the pipeline
|
||||
|
@ -70,8 +82,10 @@ class GrokRestClient:
|
|||
|
||||
self.API_VERSION = API_VERSION
|
||||
|
||||
self.BLOB_CONNECTION_STRING = f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" \
|
||||
self.BLOB_CONNECTION_STRING = (
|
||||
f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};"
|
||||
f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_from_env_var():
|
||||
|
@ -85,51 +99,68 @@ class GrokRestClient:
|
|||
DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"]
|
||||
BLOB_NAME = os.environ["BLOB_NAME"]
|
||||
BLOB_KEY = os.environ["BLOB_KEY"]
|
||||
PROJECTIONS_CONTAINER_NAME = os.environ.get("PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME)
|
||||
PROJECTIONS_CONTAINER_NAME = os.environ.get(
|
||||
"PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME
|
||||
)
|
||||
|
||||
client = GrokRestClient(COGNITIVE_SERVICE_KEY, SEARCH_SERVICE_KEY, SEARCH_SERVICE_NAME, SKILLSET_NAME, INDEX_NAME,
|
||||
INDEXER_NAME, DATASOURCE_NAME, DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY,projections_container_name=PROJECTIONS_CONTAINER_NAME)
|
||||
client = GrokRestClient(
|
||||
COGNITIVE_SERVICE_KEY,
|
||||
SEARCH_SERVICE_KEY,
|
||||
SEARCH_SERVICE_NAME,
|
||||
SKILLSET_NAME,
|
||||
INDEX_NAME,
|
||||
INDEXER_NAME,
|
||||
DATASOURCE_NAME,
|
||||
DATASOURCE_CONTAINER_NAME,
|
||||
BLOB_NAME,
|
||||
BLOB_KEY,
|
||||
projections_container_name=PROJECTIONS_CONTAINER_NAME,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def create_skillset(self):
|
||||
"""Adds a skillset that performs OCR on images
|
||||
"""
|
||||
"""Adds a skillset that performs OCR on images"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
skillset_json = json.loads(pkgutil.get_data(
|
||||
__name__, "templates/skillset.json"))
|
||||
skillset_json = json.loads(
|
||||
pkgutil.get_data(__name__, "templates/skillset.json")
|
||||
)
|
||||
|
||||
skillset_json["name"] = self.SKILLSET_NAME
|
||||
skillset_json["cognitiveServices"]["key"] = self.COGNITIVE_SERVICE_KEY
|
||||
|
||||
knowledge_store_json = json.loads(pkgutil.get_data(
|
||||
__name__, "templates/knowledge_store.json"))
|
||||
knowledge_store_json = json.loads(
|
||||
pkgutil.get_data(__name__, "templates/knowledge_store.json")
|
||||
)
|
||||
knowledge_store_json["storageConnectionString"] = self.BLOB_CONNECTION_STRING
|
||||
knowledge_store_json["projections"][0]["objects"][0]["storageContainer"] = self.PROJECTIONS_CONTAINER_NAME
|
||||
knowledge_store_json["projections"][0]["objects"][0][
|
||||
"storageContainer"
|
||||
] = self.PROJECTIONS_CONTAINER_NAME
|
||||
skillset_json["knowledgeStore"] = knowledge_store_json
|
||||
print(skillset_json)
|
||||
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}"
|
||||
|
||||
r = requests.put(endpoint + self.API_VERSION,
|
||||
json.dumps(skillset_json), headers=headers)
|
||||
r = requests.put(
|
||||
endpoint + self.API_VERSION, json.dumps(skillset_json), headers=headers
|
||||
)
|
||||
print("skillset response", r.text)
|
||||
r.raise_for_status()
|
||||
print("added skillset", self.SKILLSET_NAME, r)
|
||||
|
||||
def create_datasource(self):
|
||||
"""Attaches the blob data store to the search service as a source for image documents
|
||||
"""
|
||||
"""Attaches the blob data store to the search service as a source for image documents"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
|
||||
datasource_json = json.loads(pkgutil.get_data(
|
||||
__name__, "templates/datasource.json"))
|
||||
datasource_json = json.loads(
|
||||
pkgutil.get_data(__name__, "templates/datasource.json")
|
||||
)
|
||||
datasource_json["name"] = self.DATASOURCE_NAME
|
||||
datasource_json["credentials"]["connectionString"] = self.BLOB_CONNECTION_STRING
|
||||
datasource_json["type"] = "azureblob"
|
||||
|
@ -137,27 +168,27 @@ class GrokRestClient:
|
|||
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/datasources/{self.DATASOURCE_NAME}"
|
||||
|
||||
r = requests.put(endpoint + self.API_VERSION,
|
||||
json.dumps(datasource_json), headers=headers)
|
||||
r = requests.put(
|
||||
endpoint + self.API_VERSION, json.dumps(datasource_json), headers=headers
|
||||
)
|
||||
print("datasource response", r.text)
|
||||
r.raise_for_status()
|
||||
print("added datasource", self.DATASOURCE_NAME, r)
|
||||
|
||||
def create_index(self):
|
||||
"""Create an index with the layoutText column to store OCR output from the enrichment
|
||||
"""
|
||||
"""Create an index with the layoutText column to store OCR output from the enrichment"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
index_json = json.loads(pkgutil.get_data(
|
||||
__name__, "templates/index.json"))
|
||||
index_json = json.loads(pkgutil.get_data(__name__, "templates/index.json"))
|
||||
index_json["name"] = self.INDEX_NAME
|
||||
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexes/{self.INDEX_NAME}"
|
||||
|
||||
r = requests.put(endpoint + self.API_VERSION,
|
||||
json.dumps(index_json), headers=headers)
|
||||
r = requests.put(
|
||||
endpoint + self.API_VERSION, json.dumps(index_json), headers=headers
|
||||
)
|
||||
print("index response", r.text)
|
||||
r.raise_for_status()
|
||||
print("created index", self.INDEX_NAME, r)
|
||||
|
@ -167,24 +198,26 @@ class GrokRestClient:
|
|||
The enriched results are pushed to the index.
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
|
||||
indexer_json = json.loads(pkgutil.get_data(
|
||||
__name__, "templates/indexer.json"))
|
||||
indexer_json = json.loads(pkgutil.get_data(__name__, "templates/indexer.json"))
|
||||
|
||||
indexer_json["name"] = self.INDEXER_NAME
|
||||
indexer_json["skillsetName"] = self.SKILLSET_NAME
|
||||
indexer_json["targetIndexName"] = self.INDEX_NAME
|
||||
indexer_json["dataSourceName"] = self.DATASOURCE_NAME
|
||||
indexer_json["schedule"] = {"interval": SCHEDULE_INTERVAL}
|
||||
indexer_json["parameters"]["configuration"]["excludedFileNameExtensions"] = extension_to_exclude
|
||||
indexer_json["parameters"]["configuration"][
|
||||
"excludedFileNameExtensions"
|
||||
] = extension_to_exclude
|
||||
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}"
|
||||
|
||||
r = requests.put(endpoint + self.API_VERSION,
|
||||
json.dumps(indexer_json), headers=headers)
|
||||
r = requests.put(
|
||||
endpoint + self.API_VERSION, json.dumps(indexer_json), headers=headers
|
||||
)
|
||||
print("indexer response", r.text)
|
||||
r.raise_for_status()
|
||||
print("created indexer", self.INDEXER_NAME, r)
|
||||
|
@ -203,14 +236,14 @@ class GrokRestClient:
|
|||
created
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
endpoints = [
|
||||
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}",
|
||||
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexes/{self.INDEX_NAME}",
|
||||
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/datasources/{self.DATASOURCE_NAME}",
|
||||
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}"
|
||||
f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}",
|
||||
]
|
||||
|
||||
for endpoint in endpoints:
|
||||
|
@ -220,8 +253,8 @@ class GrokRestClient:
|
|||
|
||||
def run_indexer(self):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}/run"
|
||||
|
@ -235,23 +268,26 @@ class GrokRestClient:
|
|||
i = 0
|
||||
while True:
|
||||
# attempt a call every 100 steps
|
||||
if i % 100 == 0:
|
||||
if i % 100 == 0:
|
||||
request_json = self.get_indexer_status()
|
||||
if request_json["status"] == "error":
|
||||
raise RuntimeError("Indexer failed")
|
||||
if request_json["lastResult"] and not request_json["lastResult"]["status"] == "inProgress":
|
||||
if (
|
||||
request_json["lastResult"]
|
||||
and not request_json["lastResult"]["status"] == "inProgress"
|
||||
):
|
||||
print(request_json["lastResult"]["status"], self.INDEXER_NAME)
|
||||
return request_json
|
||||
|
||||
|
||||
sys.stdout.write(next(progress))
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.05)
|
||||
i = (1+i) % 1000 # to avoid overflow
|
||||
i = (1 + i) % 1000 # to avoid overflow
|
||||
|
||||
def get_indexer_status(self):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': self.SEARCH_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.SEARCH_SERVICE_KEY,
|
||||
}
|
||||
endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}/status"
|
||||
response = requests.get(endpoint + self.API_VERSION, headers=headers)
|
||||
|
@ -259,5 +295,5 @@ class GrokRestClient:
|
|||
return response.json()
|
||||
|
||||
def _checkArg(self, name, value):
|
||||
if not(value):
|
||||
if not (value):
|
||||
raise ValueError(f"argument {name} is not set")
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.generation.document import DEFAULT_STYLE_COMBINATION
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.degradation.degrader import Degrader, ImageState
|
||||
from json import JSONEncoder
|
||||
from tqdm import tqdm
|
||||
|
||||
import concurrent.futures
|
||||
import timeit
|
||||
import cv2
|
||||
import os
|
||||
import timeit
|
||||
from json import JSONEncoder
|
||||
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
|
||||
from genalog.degradation.degrader import Degrader, ImageState
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.generation.document import DEFAULT_STYLE_COMBINATION
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
|
||||
|
||||
class ImageStateEncoder(JSONEncoder):
|
||||
def default(self, obj):
|
||||
|
@ -16,8 +19,15 @@ class ImageStateEncoder(JSONEncoder):
|
|||
return obj.value
|
||||
return JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
class AnalogDocumentGeneration(object):
|
||||
def __init__(self, template_path=None, styles=DEFAULT_STYLE_COMBINATION, degradations=[], resolution=300):
|
||||
def __init__(
|
||||
self,
|
||||
template_path=None,
|
||||
styles=DEFAULT_STYLE_COMBINATION,
|
||||
degradations=[],
|
||||
resolution=300,
|
||||
):
|
||||
self.doc_generator = DocumentGenerator(template_path=template_path)
|
||||
self.doc_generator.set_styles_to_generate(styles)
|
||||
self.degrader = Degrader(degradations)
|
||||
|
@ -25,7 +35,7 @@ class AnalogDocumentGeneration(object):
|
|||
self.resolution = resolution
|
||||
|
||||
def list_templates(self):
|
||||
""" List available templates to generate documents from
|
||||
"""List available templates to generate documents from
|
||||
|
||||
Returns:
|
||||
list -- a list of template names
|
||||
|
@ -33,7 +43,7 @@ class AnalogDocumentGeneration(object):
|
|||
return self.doc_generator.template_list
|
||||
|
||||
def generate_img(self, full_text_path, template, target_folder=None):
|
||||
""" Generate synthetic images given the filepath of a text document
|
||||
"""Generate synthetic images given the filepath of a text document
|
||||
|
||||
Arguments:
|
||||
full_text_path {str} -- full filepath of a text document (i.e /dataset/doc.txt)
|
||||
|
@ -44,18 +54,18 @@ class AnalogDocumentGeneration(object):
|
|||
target_folder {str} -- folder path in which the generated images are stored
|
||||
(default: {None})
|
||||
resolution {int} -- resolution in dpi (default: {300})
|
||||
"""
|
||||
with open(full_text_path, "r",encoding="utf8") as f: # read file
|
||||
"""
|
||||
with open(full_text_path, "r", encoding="utf8") as f: # read file
|
||||
text = f.read()
|
||||
content = CompositeContent([text], [ContentType.PARAGRAPH])
|
||||
|
||||
|
||||
generator = self.doc_generator.create_generator(content, [template])
|
||||
# Generate the image
|
||||
doc = next(generator)
|
||||
src = doc.render_array(resolution=self.resolution, channel="GRAYSCALE")
|
||||
# Degrade the image
|
||||
dst = self.degrader.apply_effects(src)
|
||||
|
||||
|
||||
if not target_folder:
|
||||
# return the analog document as numpy.ndarray
|
||||
return dst
|
||||
|
@ -67,45 +77,71 @@ class AnalogDocumentGeneration(object):
|
|||
cv2.imwrite(img_dst_path, dst)
|
||||
return
|
||||
|
||||
|
||||
def _divide_batches(a, batch_size):
|
||||
for i in range(0, len(a), batch_size):
|
||||
yield a[i:i + batch_size]
|
||||
for i in range(0, len(a), batch_size):
|
||||
yield a[i: i + batch_size]
|
||||
|
||||
|
||||
def _setup_folder(output_folder):
|
||||
os.makedirs(os.path.join(output_folder, "img"), exist_ok=True)
|
||||
|
||||
|
||||
|
||||
def batch_img_generate(args):
|
||||
input_files, output_folder, styles, degradations, template, resolution = args
|
||||
generator = AnalogDocumentGeneration(styles=styles, degradations=degradations, resolution=resolution)
|
||||
generator = AnalogDocumentGeneration(
|
||||
styles=styles, degradations=degradations, resolution=resolution
|
||||
)
|
||||
for file in input_files:
|
||||
generator.generate_img(file, template, target_folder=output_folder)
|
||||
|
||||
def _set_batch_generate_args(file_batches, output_folder, styles, degradations, template, resolution):
|
||||
return list(map(
|
||||
lambda batch:
|
||||
(batch, output_folder, styles, degradations, template, resolution),
|
||||
file_batches
|
||||
))
|
||||
|
||||
def _set_batch_generate_args(
|
||||
file_batches, output_folder, styles, degradations, template, resolution
|
||||
):
|
||||
return list(
|
||||
map(
|
||||
lambda batch: (
|
||||
batch,
|
||||
output_folder,
|
||||
styles,
|
||||
degradations,
|
||||
template,
|
||||
resolution,
|
||||
),
|
||||
file_batches,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def generate_dataset_multiprocess(
|
||||
input_text_files, output_folder, styles, degradations, template,
|
||||
resolution=300, batch_size=25
|
||||
):
|
||||
input_text_files,
|
||||
output_folder,
|
||||
styles,
|
||||
degradations,
|
||||
template,
|
||||
resolution=300,
|
||||
batch_size=25,
|
||||
):
|
||||
_setup_folder(output_folder)
|
||||
print(f"Storing generated images in {output_folder}")
|
||||
print(f"Storing generated images in {output_folder}")
|
||||
|
||||
batches = list(_divide_batches(input_text_files, batch_size))
|
||||
print(f"Splitting {len(input_text_files)} documents into {len(batches)} batches with size {batch_size}")
|
||||
batches = list(_divide_batches(input_text_files, batch_size))
|
||||
print(
|
||||
f"Splitting {len(input_text_files)} documents into {len(batches)} batches with size {batch_size}"
|
||||
)
|
||||
|
||||
batch_img_generate_args = _set_batch_generate_args(batches, output_folder, styles, degradations, template, resolution)
|
||||
batch_img_generate_args = _set_batch_generate_args(
|
||||
batches, output_folder, styles, degradations, template, resolution
|
||||
)
|
||||
|
||||
# Default to the number of processors on the machine
|
||||
start_time = timeit.default_timer()
|
||||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||
batch_iterator = executor.map(batch_img_generate, batch_img_generate_args)
|
||||
for _ in tqdm(batch_iterator, total=len(batch_img_generate_args)): # wrapping tqdm for progress report
|
||||
for _ in tqdm(
|
||||
batch_iterator, total=len(batch_img_generate_args)
|
||||
): # wrapping tqdm for progress report
|
||||
pass
|
||||
elapsed = timeit.default_timer() - start_time
|
||||
print(f"Time to generate {len(input_text_files)} documents: {elapsed:.3f} sec")
|
||||
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
|
|
@ -1,42 +1,53 @@
|
|||
from genalog.text.preprocess import _is_spacing, tokenize
|
||||
from Bio import pairwise2
|
||||
import re
|
||||
|
||||
from Bio import pairwise2
|
||||
|
||||
from genalog.text.preprocess import _is_spacing, tokenize
|
||||
|
||||
# Configuration params for global sequence alignment algorithm (Needleman-Wunsch)
|
||||
MATCH_REWARD = 1
|
||||
GAP_PENALTY = -0.5
|
||||
GAP_EXT_PENALTY = -0.5
|
||||
MISMATCH_PENALTY = -0.5
|
||||
GAP_CHAR = '@'
|
||||
GAP_CHAR = "@"
|
||||
ONE_ALIGNMENT_ONLY = False
|
||||
SPACE_MISMATCH_PENALTY = .1
|
||||
SPACE_MISMATCH_PENALTY = 0.1
|
||||
|
||||
|
||||
def _join_char_list(alignment_tuple):
|
||||
""" Post-process alignment results for unicode support """
|
||||
gt_char_list, noise_char_list, score, start, end = alignment_tuple
|
||||
return "".join(gt_char_list), "".join(noise_char_list), score, start, end
|
||||
|
||||
def _align_seg(gt, noise,
|
||||
match_reward=MATCH_REWARD, mismatch_pen=MISMATCH_PENALTY,
|
||||
gap_pen=GAP_PENALTY, gap_ext_pen=GAP_EXT_PENALTY, space_mismatch_penalty=SPACE_MISMATCH_PENALTY,
|
||||
gap_char=GAP_CHAR, one_alignment_only=ONE_ALIGNMENT_ONLY):
|
||||
""" Wrapper function for Bio.pairwise2.align.globalms(), which
|
||||
|
||||
def _align_seg(
|
||||
gt,
|
||||
noise,
|
||||
match_reward=MATCH_REWARD,
|
||||
mismatch_pen=MISMATCH_PENALTY,
|
||||
gap_pen=GAP_PENALTY,
|
||||
gap_ext_pen=GAP_EXT_PENALTY,
|
||||
space_mismatch_penalty=SPACE_MISMATCH_PENALTY,
|
||||
gap_char=GAP_CHAR,
|
||||
one_alignment_only=ONE_ALIGNMENT_ONLY,
|
||||
):
|
||||
"""Wrapper function for Bio.pairwise2.align.globalms(), which
|
||||
calls the sequence alignment algorithm (Needleman-Wunsch)
|
||||
|
||||
Arguments:
|
||||
gt {str} -- a ground truth string
|
||||
noise {str} -- a string with ocr noise
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
match_reward {int} -- reward for matching characters (default: {MATCH_REWARD})
|
||||
mismatch_pen {int} -- penalty for mistmatching characters (default: {MISMATCH_PENALTY})
|
||||
gap_pen {int} -- penalty for creating a gap (default: {GAP_PENALTY})
|
||||
gap_ext_pen {int} -- penalty for extending a gap (default: {GAP_EXT_PENALTY})
|
||||
|
||||
|
||||
Returns:
|
||||
list -- a list of alignment tuples. Each alignment tuple
|
||||
is one possible alignment candidate.
|
||||
|
||||
is one possible alignment candidate.
|
||||
|
||||
A tuple (str, str, int, int, int) contains the following information:
|
||||
(aligned_gt, aligned_noise, alignment_score, alignment_start, alignment_end)
|
||||
|
||||
|
@ -47,22 +58,32 @@ def _align_seg(gt, noise,
|
|||
...
|
||||
]
|
||||
"""
|
||||
def match_reward_fn (x,y) :
|
||||
|
||||
def match_reward_fn(x, y):
|
||||
if x == y:
|
||||
return match_reward
|
||||
elif x == " " or y == " ":
|
||||
elif x == " " or y == " ":
|
||||
# mismatch of a character with a space get a stronger penalty
|
||||
return mismatch_pen - space_mismatch_penalty
|
||||
return mismatch_pen - space_mismatch_penalty
|
||||
else:
|
||||
return mismatch_pen
|
||||
# NOTE: Work-around to enable support full Unicode character set - passing string as a list of characters
|
||||
alignments = pairwise2.align.globalcs(list(gt), list(noise), match_reward_fn,
|
||||
gap_pen, gap_ext_pen, gap_char=[gap_char], one_alignment_only=ONE_ALIGNMENT_ONLY)
|
||||
|
||||
# NOTE: Work-around to enable support full Unicode character set - passing string as a list of characters
|
||||
alignments = pairwise2.align.globalcs(
|
||||
list(gt),
|
||||
list(noise),
|
||||
match_reward_fn,
|
||||
gap_pen,
|
||||
gap_ext_pen,
|
||||
gap_char=[gap_char],
|
||||
one_alignment_only=ONE_ALIGNMENT_ONLY,
|
||||
)
|
||||
# Alignment result is a list of char instead of string because of the work-around
|
||||
return list(map(_join_char_list, alignments))
|
||||
|
||||
|
||||
def _select_alignment_candidates(alignments, target_num_gt_tokens):
|
||||
""" Return an alignment that contains the desired number
|
||||
"""Return an alignment that contains the desired number
|
||||
of ground truth tokens from a list of possible alignments
|
||||
|
||||
Case Analysis:
|
||||
|
@ -70,20 +91,20 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens):
|
|||
be guaranteed by the nature of text alignment.
|
||||
Invariant 2: we should not expect alignment introducing
|
||||
additional ground truth tokens.
|
||||
However, in some cases, the alignment algorithm can
|
||||
introduce a group of GAP_CHARs as a separate token at the
|
||||
However, in some cases, the alignment algorithm can
|
||||
introduce a group of GAP_CHARs as a separate token at the
|
||||
end of string, especially if there are lingering whitespaces.
|
||||
E.g:
|
||||
gt: "Boston is big " (num_tokens = 3)
|
||||
noise: "B oston bi g"
|
||||
noise: "B oston bi g"
|
||||
aligned_gt: "B@oston is big @" (num_tokens = 4)
|
||||
aligned_noise: "B oston @@@bi@ g"
|
||||
|
||||
Remember, the example above is just one out of the many possible alignment
|
||||
Remember, the example above is just one out of the many possible alignment
|
||||
candidates, and we need to search for the one with the target number of gt_tokens
|
||||
E.g:
|
||||
gt: "Boston is big " (num_tokens = 3)
|
||||
noise: "B oston bi g"
|
||||
noise: "B oston bi g"
|
||||
aligned_gt: "B@oston is bi@g " (num_tokens = 3)
|
||||
aligned_noise: "B oston @@@bi g@"
|
||||
|
||||
|
@ -93,12 +114,12 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens):
|
|||
alignments {list} -- a list of alignment tuples as follows:
|
||||
[(str1, str2, alignment_score, alignment_start, alignment_end), (str1, str2, ...), ...]
|
||||
target_num_gt_tokens {int} -- the number of token in the aligned ground truth string should have
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: raises this error if
|
||||
ValueError: raises this error if
|
||||
1. all the alignment candidates does NOT have the target number of tokens OR
|
||||
2. the aligned strings (str1 and str2) in the selected candidate are NOT EQUAL in length
|
||||
|
||||
2. the aligned strings (str1 and str2) in the selected candidate are NOT EQUAL in length
|
||||
|
||||
Returns:
|
||||
an alignment tuple (str, str, int, int, int) with following information:
|
||||
(str1, str2, alignment_score, alignment_start, alignment_end)
|
||||
|
@ -111,29 +132,34 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens):
|
|||
if num_aligned_gt_tokens == target_num_gt_tokens:
|
||||
# Invariant 1
|
||||
if len(aligned_gt) != len(aligned_noise):
|
||||
raise ValueError(f"Aligned strings are not equal in length: \naligned_gt: '{aligned_gt}'\naligned_noise '{aligned_noise}'\n")
|
||||
raise ValueError(
|
||||
f"Aligned strings are not equal in length: \naligned_gt: '{aligned_gt}'\naligned_noise '{aligned_noise}'\n"
|
||||
)
|
||||
# Returns the FIRST candidate that satisfies the invariant
|
||||
return alignment
|
||||
|
||||
raise ValueError(f"No alignment candidates with {target_num_gt_tokens} tokens. Total candidates: {len(alignments)}")
|
||||
|
||||
raise ValueError(
|
||||
f"No alignment candidates with {target_num_gt_tokens} tokens. Total candidates: {len(alignments)}"
|
||||
)
|
||||
|
||||
|
||||
def align(gt, noise, gap_char=GAP_CHAR):
|
||||
"""Align two text segments via sequence alignment algorithm
|
||||
|
||||
NOTE: this algorithm is O(N^2) and is NOT efficient for longer text.
|
||||
NOTE: this algorithm is O(N^2) and is NOT efficient for longer text.
|
||||
Please refer to `genalog.text.anchor` for faster alignment on longer strings.
|
||||
|
||||
|
||||
Arguments:
|
||||
gt {str} -- ground true text (should not contain GAP_CHAR)
|
||||
noise {str} -- str with ocr noise (should not contain GAP_CHAR)
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
gap_char {char} -- gap char used in alignment algorithm (default: {GAP_CHAR})
|
||||
|
||||
|
||||
Returns:
|
||||
a tuple (str, str) of aligned ground truth and noise:
|
||||
(aligned_gt, aligned_noise)
|
||||
|
||||
|
||||
Invariants:
|
||||
The returned aligned strings will satisfy the following invariants:
|
||||
1. len(aligned_gt) == len(aligned_noise)
|
||||
|
@ -143,34 +169,39 @@ def align(gt, noise, gap_char=GAP_CHAR):
|
|||
aligned_gt: "N@ew @@York @is big@@" (num_tokens = 4)
|
||||
|
||||
"""
|
||||
if not gt and not noise: # Both inputs are empty string
|
||||
return '', ''
|
||||
elif not gt: # Either is empty
|
||||
return gap_char*len(noise), noise
|
||||
if not gt and not noise: # Both inputs are empty string
|
||||
return "", ""
|
||||
elif not gt: # Either is empty
|
||||
return gap_char * len(noise), noise
|
||||
elif not noise:
|
||||
return gt, gap_char*len(gt)
|
||||
return gt, gap_char * len(gt)
|
||||
else:
|
||||
num_gt_tokens = len(tokenize(gt))
|
||||
alignments = _align_seg(gt, noise, gap_char=gap_char)
|
||||
try:
|
||||
aligned_gt, aligned_noise, _, _, _ = _select_alignment_candidates(alignments, num_gt_tokens)
|
||||
aligned_gt, aligned_noise, _, _, _ = _select_alignment_candidates(
|
||||
alignments, num_gt_tokens
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error with input strings '{gt}' and '{noise}': \n{str(e)}")
|
||||
raise ValueError(
|
||||
f"Error with input strings '{gt}' and '{noise}': \n{str(e)}"
|
||||
)
|
||||
return aligned_gt, aligned_noise
|
||||
|
||||
|
||||
def _format_alignment(align1, align2):
|
||||
"""Wrapper function for Bio.pairwise2.format_alignment()
|
||||
|
||||
|
||||
Arguments:
|
||||
align1 {str} -- alignment str
|
||||
align2 {str} -- second str for alignment
|
||||
|
||||
Returns:
|
||||
a string with formatted alignment.
|
||||
a string with formatted alignment.
|
||||
'|' is for matching character
|
||||
'.' is for substition
|
||||
'-' indicates gap
|
||||
|
||||
|
||||
For example:
|
||||
"
|
||||
New York is big.
|
||||
|
@ -178,44 +209,50 @@ def _format_alignment(align1, align2):
|
|||
New Yerk@is big.
|
||||
"
|
||||
"""
|
||||
formatted_str = pairwise2.format_alignment(align1, align2, 0, 0, len(align1), full_sequences=True)
|
||||
formatted_str = pairwise2.format_alignment(
|
||||
align1, align2, 0, 0, len(align1), full_sequences=True
|
||||
)
|
||||
# Remove the "Score=0" from the str
|
||||
formatted_str_no_score = formatted_str.replace("\n Score=0", "")
|
||||
return formatted_str_no_score
|
||||
|
||||
|
||||
def _find_token_start(s, index):
|
||||
"""Find the position of the start of token
|
||||
|
||||
"""Find the position of the start of token
|
||||
|
||||
Arguments:
|
||||
s {str} -- string to search in
|
||||
index {int} -- index to begin search from
|
||||
|
||||
|
||||
Returns:
|
||||
- position {int} of the first non-whitespace character
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: if input s is an empty string
|
||||
IndexError: if is out-of-bound index
|
||||
"""
|
||||
max_index = len(s) - 1
|
||||
if len(s) == 0: raise ValueError("Cannot search in an empty string")
|
||||
if index > max_index: raise IndexError(f"Out-of-bound index: {index} in string: {s}")
|
||||
if len(s) == 0:
|
||||
raise ValueError("Cannot search in an empty string")
|
||||
if index > max_index:
|
||||
raise IndexError(f"Out-of-bound index: {index} in string: {s}")
|
||||
|
||||
while index < max_index and _is_spacing(s[index]):
|
||||
index += 1
|
||||
return index
|
||||
|
||||
|
||||
def _find_token_end(s, index):
|
||||
"""Find the position of the end of a token
|
||||
|
||||
|
||||
*** Important ***
|
||||
This method ALWAYS return index within the bound of the string.
|
||||
So, for single character string (eg. "c"), it will return 0.
|
||||
|
||||
So, for single character string (eg. "c"), it will return 0.
|
||||
|
||||
Arguments:
|
||||
s {str} -- string to search in
|
||||
index {int} -- index to begin search from
|
||||
|
||||
|
||||
Returns:
|
||||
- position {int} of the first non-whitespace character
|
||||
|
||||
|
@ -224,40 +261,44 @@ def _find_token_end(s, index):
|
|||
IndexError: if is out-of-bound index
|
||||
"""
|
||||
max_index = len(s) - 1
|
||||
if len(s) == 0: raise ValueError("Cannot search in an empty string")
|
||||
if index > max_index: raise IndexError(f"Out-of-bound index: {index} in string: {s}")
|
||||
if len(s) == 0:
|
||||
raise ValueError("Cannot search in an empty string")
|
||||
if index > max_index:
|
||||
raise IndexError(f"Out-of-bound index: {index} in string: {s}")
|
||||
|
||||
while index < max_index and not _is_spacing(s[index]):
|
||||
index += 1
|
||||
return index
|
||||
|
||||
|
||||
def _find_next_token(s, start):
|
||||
""" Return the start and end index of a token in a string
|
||||
|
||||
"""Return the start and end index of a token in a string
|
||||
|
||||
*** Important ***
|
||||
This method ALWAYS return indices within the bound of the string.
|
||||
So, for single character string (eg. "c"), it will return (0,0)
|
||||
|
||||
Arguments:
|
||||
s {str} -- the string to search token in
|
||||
start {int} -- the starting index to start search in
|
||||
|
||||
start {int} -- the starting index to start search in
|
||||
|
||||
Returns:
|
||||
a tuple of (int, int) responding to the start and end indices of
|
||||
a token in the given s.
|
||||
a token in the given s.
|
||||
"""
|
||||
token_start = _find_token_start(s, start)
|
||||
token_end = _find_token_end(s, token_start)
|
||||
return token_start, token_end
|
||||
|
||||
|
||||
def _is_valid_token(token, gap_char=GAP_CHAR):
|
||||
""" Returns true if token is valid (i.e. compose of non-gap characters)
|
||||
Invalid tokens are
|
||||
"""Returns true if token is valid (i.e. compose of non-gap characters)
|
||||
Invalid tokens are
|
||||
1. multiple occurrences of the GAP_CHAR (e.g. '@@@')
|
||||
2. empty string ("")
|
||||
3. string with spaces (" ")
|
||||
|
||||
**Important: this method expects one token and not multiple space-separated tokens
|
||||
**Important: this method expects one token and not multiple space-separated tokens
|
||||
|
||||
Arguments:
|
||||
token {str} -- input string token
|
||||
|
@ -269,12 +310,15 @@ def _is_valid_token(token, gap_char=GAP_CHAR):
|
|||
bool-- True if is a valid token, false otherwise
|
||||
"""
|
||||
# Matches multiples of 'gap_char' that are padded with whitespace characters on either end
|
||||
INVALID_TOKEN_REGEX = rf'^\s*{re.escape(gap_char)}*\s*$' # Escape special regex chars
|
||||
INVALID_TOKEN_REGEX = (
|
||||
rf"^\s*{re.escape(gap_char)}*\s*$" # Escape special regex chars
|
||||
)
|
||||
return not re.match(INVALID_TOKEN_REGEX, token)
|
||||
|
||||
|
||||
def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
||||
"""Parse alignment to pair ground truth tokens with noise tokens
|
||||
|
||||
|
||||
Case 1: Case 2: Case 3: Case 4: Case 5:
|
||||
one-to-many many-to-one many-to-many missing tokens one-to-one
|
||||
(Case 1&2 Composite)
|
||||
|
@ -295,10 +339,10 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
|
||||
Returns:
|
||||
a tuple (list, list) of two 2D int arrays as follows:
|
||||
|
||||
|
||||
(gt_to_noise_mapping, noise_to_gt_mapping)
|
||||
|
||||
where each array defines the mapping between aligned gt tokens
|
||||
|
||||
where each array defines the mapping between aligned gt tokens
|
||||
to noise tokens and vice versa.
|
||||
|
||||
For example:
|
||||
|
@ -316,44 +360,44 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
# tk_start_gt=12 tk_index_gt = 4 total_tokens = 4
|
||||
# | tk_end_gt=15 tk_index_noise = 3 total_tokens = 3
|
||||
# | |
|
||||
# "New York is big " gt_token:big gt_to_noise_mapping: [[0][0][][2]]
|
||||
# "New York is big " gt_token:big gt_to_noise_mapping: [[0][0][][2]]
|
||||
# "New@york @@ big " noise_token:big noise_to_gt_mapping: [[0][][3]]
|
||||
# | |
|
||||
# | tk_end_noise=15 INVALID TOKENS: @*
|
||||
# tk_start_noise=12
|
||||
|
||||
# 1. Initialization:
|
||||
#1. IMPORTANT: add whitespace padding (' ') to both end of aligned_gt and aligned_noise to avoid overflow
|
||||
#2. find the first gt_token and the first noise_token
|
||||
#3. tk_index_gt = tk_index_noise = 0
|
||||
# 1. IMPORTANT: add whitespace padding (' ') to both end of aligned_gt and aligned_noise to avoid overflow
|
||||
# 2. find the first gt_token and the first noise_token
|
||||
# 3. tk_index_gt = tk_index_noise = 0
|
||||
# 2. While tk_index_gt < total_tk_gt and tk_index_noise < total_tk_noise:
|
||||
#1. if tk_end_gt == tk_end_noise (1-1 case)
|
||||
#1. check if the two tokens are valid
|
||||
#1. if so, register tokens in mapping
|
||||
#2. find next gt_token token and next noise_token
|
||||
#3. tk_index_gt ++, tk_index_noise ++
|
||||
#3. if tk_end_gt < tk_end_noise (many-1 case)
|
||||
#1. while tk_end_gt < tk_end_noise
|
||||
#1. check if gt_token and noise_token are BOTH valid
|
||||
#1. if so register tokens in mapping
|
||||
#2. find next gt_token
|
||||
#3. tk_index_gt ++
|
||||
#4. if tk_end_gt > tk_end_noise (1-many case)
|
||||
#1. while tk_end_gt > tk_end_noise
|
||||
#1. check if gt_token and noise_token are BOTH valid
|
||||
#1. if so register tokens in mapping
|
||||
#2. find next noise token
|
||||
#3. tk_index_noise ++
|
||||
# 1. if tk_end_gt == tk_end_noise (1-1 case)
|
||||
# 1. check if the two tokens are valid
|
||||
# 1. if so, register tokens in mapping
|
||||
# 2. find next gt_token token and next noise_token
|
||||
# 3. tk_index_gt ++, tk_index_noise ++
|
||||
# 3. if tk_end_gt < tk_end_noise (many-1 case)
|
||||
# 1. while tk_end_gt < tk_end_noise
|
||||
# 1. check if gt_token and noise_token are BOTH valid
|
||||
# 1. if so register tokens in mapping
|
||||
# 2. find next gt_token
|
||||
# 3. tk_index_gt ++
|
||||
# 4. if tk_end_gt > tk_end_noise (1-many case)
|
||||
# 1. while tk_end_gt > tk_end_noise
|
||||
# 1. check if gt_token and noise_token are BOTH valid
|
||||
# 1. if so register tokens in mapping
|
||||
# 2. find next noise token
|
||||
# 3. tk_index_noise ++
|
||||
# sanity check
|
||||
if len(aligned_gt) != len(aligned_noise):
|
||||
raise ValueError("Aligned strings are not equal in length")
|
||||
|
||||
raise ValueError("Aligned strings are not equal in length")
|
||||
|
||||
total_gt_tokens = len(tokenize(aligned_gt))
|
||||
total_noise_tokens = len(tokenize(aligned_noise))
|
||||
|
||||
# Initialization
|
||||
aligned_gt += ' ' # add whitespace padding to prevent ptr overflow
|
||||
aligned_noise += ' ' # add whitespace padding to prevent ptr overflow
|
||||
aligned_gt += " " # add whitespace padding to prevent ptr overflow
|
||||
aligned_noise += " " # add whitespace padding to prevent ptr overflow
|
||||
tk_index_gt = tk_index_noise = 0
|
||||
tk_start_gt, tk_end_gt = _find_next_token(aligned_gt, 0)
|
||||
tk_start_noise, tk_end_noise = _find_next_token(aligned_noise, 0)
|
||||
|
@ -364,8 +408,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
# If both tokens are aligned (one-to-one case)
|
||||
if tk_end_gt == tk_end_noise:
|
||||
# if both gt_token and noise_token are valid (missing token case)
|
||||
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
|
||||
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
|
||||
if _is_valid_token(
|
||||
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
|
||||
) and _is_valid_token(
|
||||
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
|
||||
):
|
||||
# register the index of these tokens in the gt_to_noise_mapping
|
||||
index_row = gt_to_noise_mapping[tk_index_gt]
|
||||
index_row.append(tk_index_noise)
|
||||
|
@ -381,8 +428,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
elif tk_end_gt < tk_end_noise:
|
||||
while tk_end_gt < tk_end_noise:
|
||||
# if both gt_token and noise_token are valid (missing token case)
|
||||
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
|
||||
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
|
||||
if _is_valid_token(
|
||||
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
|
||||
) and _is_valid_token(
|
||||
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
|
||||
):
|
||||
# register the index of these tokens in the gt_to_noise_mapping
|
||||
index_row = gt_to_noise_mapping[tk_index_gt]
|
||||
index_row.append(tk_index_noise)
|
||||
|
@ -397,8 +447,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
else:
|
||||
while tk_end_gt > tk_end_noise:
|
||||
# if both gt_token and noise_token are valid (missing token case)
|
||||
if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \
|
||||
and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char):
|
||||
if _is_valid_token(
|
||||
aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char
|
||||
) and _is_valid_token(
|
||||
aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char
|
||||
):
|
||||
# register the index of these token in the gt_to_noise mapping
|
||||
index_row = gt_to_noise_mapping[tk_index_gt]
|
||||
index_row.append(tk_index_noise)
|
||||
|
@ -406,8 +459,10 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR):
|
|||
index_row = noise_to_gt_mapping[tk_index_noise]
|
||||
index_row.append(tk_index_gt)
|
||||
# Find the next gt_token
|
||||
tk_start_noise, tk_end_noise = _find_next_token(aligned_noise, tk_end_noise)
|
||||
tk_start_noise, tk_end_noise = _find_next_token(
|
||||
aligned_noise, tk_end_noise
|
||||
)
|
||||
# Increment index
|
||||
tk_index_noise += 1
|
||||
|
||||
return gt_to_noise_mapping, noise_to_gt_mapping
|
||||
return gt_to_noise_mapping, noise_to_gt_mapping
|
||||
|
|
|
@ -1,36 +1,38 @@
|
|||
"""
|
||||
Baseline alignment algorithm is slow on long documents.
|
||||
The idea is to break down the longer text into smaller fragments
|
||||
for quicker alignment on individual pieces. We refer "anchor words"
|
||||
Baseline alignment algorithm is slow on long documents.
|
||||
The idea is to break down the longer text into smaller fragments
|
||||
for quicker alignment on individual pieces. We refer "anchor words"
|
||||
as these points of breakage.
|
||||
|
||||
The bulk of this algorithm is to identify these "anchor words".
|
||||
|
||||
This is an re-implementation of the algorithm in this paper
|
||||
This is an re-implementation of the algorithm in this paper
|
||||
"A Fast Alignment Scheme for Automatic OCR Evaluation of Books"
|
||||
(https://ieeexplore.ieee.org/document/6065412)
|
||||
|
||||
|
||||
We rely on `genalog.text.alignment` to align the subsequences.
|
||||
"""
|
||||
import itertools
|
||||
from collections import Counter
|
||||
from genalog.text import preprocess, alignment
|
||||
from genalog.text.lcs import LCS
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
|
||||
# The recursively portion of the algorithm will run on
|
||||
# segments longer than this value to find anchor points in
|
||||
from genalog.text import alignment, preprocess
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
from genalog.text.lcs import LCS
|
||||
|
||||
# The recursively portion of the algorithm will run on
|
||||
# segments longer than this value to find anchor points in
|
||||
# the longer segment (to break it up further).
|
||||
MAX_ALIGN_SEGMENT_LENGTH = 100 # in characters length
|
||||
MAX_ALIGN_SEGMENT_LENGTH = 100 # in characters length
|
||||
|
||||
|
||||
def get_unique_words(tokens, case_sensitive=False):
|
||||
""" Get a set of unique words from a Counter dictionary of word occurrences
|
||||
|
||||
"""Get a set of unique words from a Counter dictionary of word occurrences
|
||||
|
||||
Arguments:
|
||||
d {dict} -- a Counter dictionary of word occurrences
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
case_sensitive {bool} -- whether unique words are case sensitive
|
||||
case_sensitive {bool} -- whether unique words are case sensitive
|
||||
(default: {False})
|
||||
|
||||
Returns:
|
||||
|
@ -38,14 +40,15 @@ def get_unique_words(tokens, case_sensitive=False):
|
|||
"""
|
||||
if case_sensitive:
|
||||
word_count = Counter(tokens)
|
||||
return {word for word, count in word_count.items() if count < 2 }
|
||||
return {word for word, count in word_count.items() if count < 2}
|
||||
else:
|
||||
tokens_lowercase = [tk.lower() for tk in tokens]
|
||||
word_count = Counter(tokens_lowercase)
|
||||
return {tk for tk in tokens if word_count[tk.lower()] < 2 }
|
||||
return {tk for tk in tokens if word_count[tk.lower()] < 2}
|
||||
|
||||
|
||||
def segment_len(tokens):
|
||||
""" Get length of the segment
|
||||
"""Get length of the segment
|
||||
|
||||
Arguments:
|
||||
segment {list} -- a list of tokens
|
||||
|
@ -54,8 +57,9 @@ def segment_len(tokens):
|
|||
"""
|
||||
return sum(map(len, tokens))
|
||||
|
||||
|
||||
def get_word_map(unique_words, src_tokens):
|
||||
""" Arrange the set of unique words by the order they original appear in the text
|
||||
"""Arrange the set of unique words by the order they original appear in the text
|
||||
|
||||
Arguments:
|
||||
unique_words {set} -- a set of unique words
|
||||
|
@ -70,18 +74,19 @@ def get_word_map(unique_words, src_tokens):
|
|||
# Find the indices of the unique words in the source text
|
||||
unique_word_indices = map(src_tokens.index, unique_words)
|
||||
word_map = list(zip(unique_words, unique_word_indices))
|
||||
word_map.sort(key = lambda x: x[1]) # Re-arrange order by the index
|
||||
word_map.sort(key=lambda x: x[1]) # Re-arrange order by the index
|
||||
return word_map
|
||||
|
||||
|
||||
def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
|
||||
""" Find the location of anchor words in both the gt and ocr text.
|
||||
Anchor words are location where we can split both the source gt
|
||||
"""Find the location of anchor words in both the gt and ocr text.
|
||||
Anchor words are location where we can split both the source gt
|
||||
and ocr text into smaller text fragment for faster alignment.
|
||||
|
||||
Arguments:
|
||||
gt_tokens {list} -- a list of ground truth tokens
|
||||
ocr_tokens {list} -- a list of tokens from OCR'ed document
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
min_anchor_len {int} -- minimum len of the anchor word
|
||||
(default: {2})
|
||||
|
@ -91,9 +96,9 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
|
|||
(anchor_map_gt, anchor_map_ocr)
|
||||
1. `anchor_map_gt` is a `word_map` that locates all the anchor words in the gt tokens
|
||||
2. `anchor_map_gt` is a `word_map` that locates all the anchor words in the ocr tokens
|
||||
|
||||
|
||||
For example:
|
||||
Input:
|
||||
Input:
|
||||
gt_tokens: ["b", "a", "c"]
|
||||
ocr_tokens: ["c", "b", "a"]
|
||||
Ourput:
|
||||
|
@ -113,15 +118,15 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
|
|||
unique_word_map_ocr = get_word_map(unique_words_common, ocr_tokens)
|
||||
# Unzip to get the ordered unique_words
|
||||
ordered_unique_words_gt, _ = zip(*unique_word_map_gt)
|
||||
ordered_unique_words_ocr, _ = zip(*unique_word_map_ocr)
|
||||
ordered_unique_words_ocr, _ = zip(*unique_word_map_ocr)
|
||||
# Join words into a space-separated string for finding LCS
|
||||
unique_words_gt_str = preprocess.join_tokens(ordered_unique_words_gt)
|
||||
unique_words_gt_str = preprocess.join_tokens(ordered_unique_words_gt)
|
||||
unique_words_ocr_str = preprocess.join_tokens(ordered_unique_words_ocr)
|
||||
|
||||
# 3. Find the LCS between the two ordered list of unique words
|
||||
lcs = LCS(unique_words_gt_str, unique_words_ocr_str)
|
||||
lcs_str = lcs.get_str()
|
||||
|
||||
|
||||
# 4. Break up the LCS string into tokens
|
||||
lcs_words = set(preprocess.tokenize(lcs_str))
|
||||
|
||||
|
@ -129,19 +134,30 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2):
|
|||
anchor_words = lcs_words.intersection(unique_words_common)
|
||||
|
||||
# 6. Filter the unique words to keep the anchor words ONLY
|
||||
anchor_map_gt = list(filter(
|
||||
# This is a list of (unique_word, unique_word_index)
|
||||
lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_gt
|
||||
))
|
||||
anchor_map_ocr = list(filter(
|
||||
lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_ocr
|
||||
))
|
||||
anchor_map_gt = list(
|
||||
filter(
|
||||
# This is a list of (unique_word, unique_word_index)
|
||||
lambda word_coordinate: word_coordinate[0] in anchor_words,
|
||||
unique_word_map_gt,
|
||||
)
|
||||
)
|
||||
anchor_map_ocr = list(
|
||||
filter(
|
||||
lambda word_coordinate: word_coordinate[0] in anchor_words,
|
||||
unique_word_map_ocr,
|
||||
)
|
||||
)
|
||||
return anchor_map_gt, anchor_map_ocr
|
||||
|
||||
def find_anchor_recur(gt_tokens, ocr_tokens,
|
||||
start_pos_gt=0, start_pos_ocr=0,
|
||||
max_seg_length=MAX_ALIGN_SEGMENT_LENGTH):
|
||||
""" Recursively find anchor positions in the gt and ocr text
|
||||
|
||||
def find_anchor_recur(
|
||||
gt_tokens,
|
||||
ocr_tokens,
|
||||
start_pos_gt=0,
|
||||
start_pos_ocr=0,
|
||||
max_seg_length=MAX_ALIGN_SEGMENT_LENGTH,
|
||||
):
|
||||
"""Recursively find anchor positions in the gt and ocr text
|
||||
|
||||
Arguments:
|
||||
gt_tokens {list} -- a list of ground truth tokens
|
||||
|
@ -150,12 +166,12 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
|
|||
Keyword Arguments:
|
||||
start_pos {int} -- a constant to add to all the resulting indices
|
||||
(default: {0})
|
||||
max_seg_length {int} -- trigger recursion if any text segment is larger than this
|
||||
max_seg_length {int} -- trigger recursion if any text segment is larger than this
|
||||
(default: {MAX_ALIGN_SEGMENT_LENGTH})
|
||||
|
||||
Raises:
|
||||
ValueError: when there different number of anchor points in gt and ocr.
|
||||
|
||||
|
||||
Returns:
|
||||
tuple -- two lists of token indices:
|
||||
(output_gt_anchors, output_ocr_anchors)
|
||||
|
@ -165,7 +181,7 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
|
|||
# 1. Try to find anchor words
|
||||
anchor_word_map_gt, anchor_word_map_ocr = get_anchor_map(gt_tokens, ocr_tokens)
|
||||
|
||||
# 2. Check invariant
|
||||
# 2. Check invariant
|
||||
if len(anchor_word_map_gt) != len(anchor_word_map_ocr):
|
||||
raise ValueError("Unequal number of anchor points across gt and ocr string")
|
||||
# Return empty if no anchor word found
|
||||
|
@ -182,17 +198,26 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
|
|||
seg_start_gt = list(itertools.chain([0], anchor_indices_gt))
|
||||
seg_start_ocr = list(itertools.chain([0], anchor_indices_ocr))
|
||||
start_n_end_gt = zip(seg_start_gt, itertools.chain(anchor_indices_gt, [None]))
|
||||
start_n_end_ocr = zip(seg_start_ocr, itertools.chain(anchor_indices_ocr, [None]))
|
||||
start_n_end_ocr = zip(seg_start_ocr, itertools.chain(anchor_indices_ocr, [None]))
|
||||
gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt]
|
||||
ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr]
|
||||
# 4. Loop through each segment
|
||||
for gt_seg, ocr_seg, gt_start, ocr_start in zip(gt_segments, ocr_segments, seg_start_gt, seg_start_ocr):
|
||||
if segment_len(gt_seg) > max_seg_length or segment_len(ocr_seg) > max_seg_length:
|
||||
# 4. Loop through each segment
|
||||
for gt_seg, ocr_seg, gt_start, ocr_start in zip(
|
||||
gt_segments, ocr_segments, seg_start_gt, seg_start_ocr
|
||||
):
|
||||
if (
|
||||
segment_len(gt_seg) > max_seg_length
|
||||
or segment_len(ocr_seg) > max_seg_length
|
||||
):
|
||||
# recur on the segment in between the two anchors.
|
||||
# We assume the first token in the segment is an anchor word
|
||||
gt_anchors, ocr_anchors = find_anchor_recur(gt_seg[1:], ocr_seg[1:],
|
||||
start_pos_gt=gt_start + 1, start_pos_ocr=ocr_start + 1,
|
||||
max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = find_anchor_recur(
|
||||
gt_seg[1:],
|
||||
ocr_seg[1:],
|
||||
start_pos_gt=gt_start + 1,
|
||||
start_pos_ocr=ocr_start + 1,
|
||||
max_seg_length=max_seg_length,
|
||||
)
|
||||
# shift the token indices
|
||||
# (these are indices of a subsequence and does not reflect true position in the source sequence)
|
||||
gt_anchors = set(map(lambda x: x + start_pos_gt, gt_anchors))
|
||||
|
@ -200,12 +225,13 @@ def find_anchor_recur(gt_tokens, ocr_tokens,
|
|||
# merge recursion results
|
||||
output_gt_anchors = output_gt_anchors.union(gt_anchors)
|
||||
output_ocr_anchors = output_ocr_anchors.union(ocr_anchors)
|
||||
|
||||
|
||||
return sorted(output_gt_anchors), sorted(output_ocr_anchors)
|
||||
|
||||
|
||||
def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_LENGTH):
|
||||
"""A faster alignment scheme of two text segments. This method first
|
||||
breaks the strings into smaller segments with anchor words.
|
||||
breaks the strings into smaller segments with anchor words.
|
||||
Then these smaller segments are aligned.
|
||||
|
||||
NOTE: this function shares the same contract as `genalog.text.alignment.align()`
|
||||
|
@ -222,7 +248,7 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
|
|||
|
||||
"The planet Mar, "
|
||||
"The plamet Maris, "
|
||||
|
||||
|
||||
"I scarcely need "
|
||||
"I scacely neee "
|
||||
|
||||
|
@ -230,7 +256,7 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
|
|||
"remind te reader,"
|
||||
|
||||
And run sequence alignment on each pair.
|
||||
|
||||
|
||||
Arguments:
|
||||
gt {str} -- ground truth text
|
||||
noise {str} -- text with ocr noise
|
||||
|
@ -248,11 +274,17 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
|
|||
ocr_tokens = preprocess.tokenize(ocr)
|
||||
|
||||
# 1. Find anchor positions
|
||||
gt_anchors, ocr_anchors = find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = find_anchor_recur(
|
||||
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
|
||||
)
|
||||
|
||||
# 2. Split into segments
|
||||
start_n_end_gt = zip(itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None]))
|
||||
start_n_end_ocr = zip(itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None]))
|
||||
start_n_end_gt = zip(
|
||||
itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None])
|
||||
)
|
||||
start_n_end_ocr = zip(
|
||||
itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None])
|
||||
)
|
||||
gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt]
|
||||
ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr]
|
||||
|
||||
|
@ -263,13 +295,15 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_
|
|||
gt_segment = preprocess.join_tokens(gt_segment)
|
||||
noisy_segment = preprocess.join_tokens(noisy_segment)
|
||||
# Run alignment algorithm
|
||||
aligned_seg_gt, aligned_seg_ocr = alignment.align(gt_segment, noisy_segment, gap_char=gap_char)
|
||||
if aligned_seg_gt and aligned_seg_ocr: # if not empty string ""
|
||||
aligned_seg_gt, aligned_seg_ocr = alignment.align(
|
||||
gt_segment, noisy_segment, gap_char=gap_char
|
||||
)
|
||||
if aligned_seg_gt and aligned_seg_ocr: # if not empty string ""
|
||||
aligned_segments_gt.append(aligned_seg_gt)
|
||||
aligned_segments_ocr.append(aligned_seg_ocr)
|
||||
|
||||
# Stitch all segments together
|
||||
aligned_gt = ' '.join(aligned_segments_gt)
|
||||
aligned_noise = ' '.join(aligned_segments_ocr)
|
||||
|
||||
return aligned_gt, aligned_noise
|
||||
# Stitch all segments together
|
||||
aligned_gt = " ".join(aligned_segments_gt)
|
||||
aligned_noise = " ".join(aligned_segments_ocr)
|
||||
|
||||
return aligned_gt, aligned_noise
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
usage: conll_format.py [-h] [--train_subset] [--test_subset]
|
||||
[--gt_folder GT_FOLDER]
|
||||
base_folder degraded_folder
|
||||
base_folder degraded_folder
|
||||
|
||||
positional argument:
|
||||
base_folder base directory containing the collection of dataset
|
||||
|
@ -18,41 +18,42 @@ optional arguments:
|
|||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
|
||||
example usage
|
||||
example usage
|
||||
(to run for specified degradation of the dataset on both train and test)
|
||||
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
|
||||
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
|
||||
|
||||
(to run for specified degradation of the dataset and ground truth)
|
||||
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
|
||||
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
|
||||
--gt_folder='shared'
|
||||
|
||||
(to run for specified degradation of the dataset on only test subset)
|
||||
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
|
||||
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
|
||||
--test_subset
|
||||
|
||||
(to run for specified degradation of the dataset on only train subset)
|
||||
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
|
||||
python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all'
|
||||
--train_subset
|
||||
"""
|
||||
import itertools
|
||||
import difflib
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import difflib
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import timeit
|
||||
import concurrent.futures
|
||||
|
||||
from tqdm import tqdm
|
||||
from genalog.text import ner_label, ner_label, alignment
|
||||
|
||||
from genalog.text import alignment, ner_label
|
||||
|
||||
EMPTY_SENTENCE_SENTINEL = "<<<<EMPTY_OCR_SENTENCE>>>>"
|
||||
EMPTY_SENTENCE_SENTINEL_NER_LABEL = "O"
|
||||
|
||||
|
||||
def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens):
|
||||
"""
|
||||
propagate_labels_sentences propagates clean labels for clean tokens to ocr tokens and splits ocr tokens into sentences
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clean_tokens : list
|
||||
|
@ -63,7 +64,7 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
|
|||
list of sentences (each sentence is a list of tokens)
|
||||
ocr_tokens : list
|
||||
list of tokens in ocr text
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
list, list
|
||||
|
@ -73,9 +74,18 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
|
|||
# Ensure equal number of tokens in both clean_tokens and clean_sentences
|
||||
merged_sentences = list(itertools.chain(*clean_sentences))
|
||||
if merged_sentences != clean_tokens:
|
||||
delta = "\n".join(difflib.unified_diff(merged_sentences, clean_tokens, fromfile='merged_clean_sentences', tofile="clean_tokens"))
|
||||
raise ValueError(f"Inconsistent tokens. " +
|
||||
f"Delta between clean_text and clean_labels:\n{delta}")
|
||||
delta = "\n".join(
|
||||
difflib.unified_diff(
|
||||
merged_sentences,
|
||||
clean_tokens,
|
||||
fromfile="merged_clean_sentences",
|
||||
tofile="clean_tokens",
|
||||
)
|
||||
)
|
||||
raise ValueError(
|
||||
"Inconsistent tokens. "
|
||||
+ f"Delta between clean_text and clean_labels:\n{delta}"
|
||||
)
|
||||
# Ensure that there's OCR result
|
||||
if len(ocr_tokens) == 0:
|
||||
raise ValueError("Empty OCR tokens.")
|
||||
|
@ -83,14 +93,18 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
|
|||
raise ValueError("Empty clean tokens.")
|
||||
|
||||
# 1. Propagate labels + alig
|
||||
ocr_labels, aligned_clean, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(clean_labels, clean_tokens, ocr_tokens)
|
||||
ocr_labels, aligned_clean, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(
|
||||
clean_labels, clean_tokens, ocr_tokens
|
||||
)
|
||||
|
||||
# 2. Parse alignment to get mapping
|
||||
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(aligned_clean, aligned_ocr, gap_char=gap_char)
|
||||
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(
|
||||
aligned_clean, aligned_ocr, gap_char=gap_char
|
||||
)
|
||||
|
||||
# 3. Find sentence breaks in clean text sentences
|
||||
gt_to_ocr_mapping_is_empty = [len(mapping) == 0 for mapping in gt_to_ocr_mapping]
|
||||
gt_to_ocr_mapping_is_empty_reverse = gt_to_ocr_mapping_is_empty[::-1]
|
||||
|
||||
sentence_index = []
|
||||
sentence_token_counts = 0
|
||||
for sentence in clean_sentences:
|
||||
|
@ -108,12 +122,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
|
|||
ocr_start = 0
|
||||
# if gt token at sentence break is not mapped to any ocr token
|
||||
elif len(gt_to_ocr_mapping[gt_start]) < 1:
|
||||
try: # finding next gt token that is mapped to an ocr token
|
||||
try: # finding next gt token that is mapped to an ocr token
|
||||
new_gt_start = gt_to_ocr_mapping_is_empty.index(False, gt_start)
|
||||
ocr_start = gt_to_ocr_mapping[new_gt_start][0]
|
||||
# If no valid token mapping in the remaining gt tokens
|
||||
except ValueError:
|
||||
ocr_start = len(ocr_tokens) # use the last ocr token
|
||||
ocr_start = len(ocr_tokens) # use the last ocr token
|
||||
else:
|
||||
ocr_start = gt_to_ocr_mapping[gt_start][0]
|
||||
|
||||
|
@ -121,12 +135,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
|
|||
if gt_end >= len(gt_to_ocr_mapping):
|
||||
ocr_end = len(ocr_tokens)
|
||||
elif len(gt_to_ocr_mapping[gt_end]) < 1:
|
||||
try: # finding next gt token that is mapped to an ocr token
|
||||
try: # finding next gt token that is mapped to an ocr token
|
||||
new_gt_end = gt_to_ocr_mapping_is_empty.index(False, gt_end)
|
||||
ocr_end = gt_to_ocr_mapping[new_gt_end][0]
|
||||
# If no valid token mapping in the remaining gt tokens
|
||||
except ValueError:
|
||||
ocr_end = len(ocr_tokens) # use the last ocr token
|
||||
ocr_end = len(ocr_tokens) # use the last ocr token
|
||||
else:
|
||||
ocr_end = gt_to_ocr_mapping[gt_end][0]
|
||||
ocr_sentence = ocr_tokens[ocr_start:ocr_end]
|
||||
|
@ -135,11 +149,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_
|
|||
ocr_labels_sentences.append(ocr_sentence_labels)
|
||||
return ocr_text_sentences, ocr_labels_sentences
|
||||
|
||||
|
||||
def get_sentences_from_iob_format(iob_format_str):
|
||||
sentences = []
|
||||
sentence = []
|
||||
for line in iob_format_str:
|
||||
if line.strip() == '': # if line is empty (sentence separator)
|
||||
if line.strip() == "": # if line is empty (sentence separator)
|
||||
sentences.append(sentence)
|
||||
sentence = []
|
||||
else:
|
||||
|
@ -149,47 +164,81 @@ def get_sentences_from_iob_format(iob_format_str):
|
|||
# filter any empty sentences
|
||||
return list(filter(lambda sentence: len(sentence) > 0, sentences))
|
||||
|
||||
def propagate_labels_sentence_single_file(arg):
|
||||
clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext, input_filename = arg
|
||||
|
||||
clean_labels_file = os.path.join(clean_labels_dir, input_filename).replace(clean_label_ext, ".txt")
|
||||
def propagate_labels_sentence_single_file(arg):
|
||||
(
|
||||
clean_labels_dir,
|
||||
output_text_dir,
|
||||
output_labels_dir,
|
||||
clean_label_ext,
|
||||
input_filename,
|
||||
) = arg
|
||||
|
||||
clean_labels_file = os.path.join(clean_labels_dir, input_filename).replace(
|
||||
clean_label_ext, ".txt"
|
||||
)
|
||||
ocr_text_file = os.path.join(output_text_dir, input_filename)
|
||||
ocr_labels_file = os.path.join(output_labels_dir, input_filename)
|
||||
if not os.path.exists(clean_labels_file):
|
||||
print(f"Warning: missing clean label file '{clean_labels_file}'. Please check file corruption. Skipping this file index...")
|
||||
print(
|
||||
f"Warning: missing clean label file '{clean_labels_file}'. Please check file corruption. Skipping this file index..."
|
||||
)
|
||||
return
|
||||
elif not os.path.exists(ocr_text_file):
|
||||
print(f"Warning: missing ocr text file '{ocr_text_file}'. Please check file corruption. Skipping this file index...")
|
||||
print(
|
||||
f"Warning: missing ocr text file '{ocr_text_file}'. Please check file corruption. Skipping this file index..."
|
||||
)
|
||||
return
|
||||
else:
|
||||
with open(clean_labels_file, 'r', encoding='utf-8') as clf:
|
||||
with open(clean_labels_file, "r", encoding="utf-8") as clf:
|
||||
tokens_labels_str = clf.readlines()
|
||||
clean_tokens = [line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2]
|
||||
clean_labels = [line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2]
|
||||
clean_tokens = [
|
||||
line.split()[0].strip()
|
||||
for line in tokens_labels_str
|
||||
if len(line.split()) == 2
|
||||
]
|
||||
clean_labels = [
|
||||
line.split()[1].strip()
|
||||
for line in tokens_labels_str
|
||||
if len(line.split()) == 2
|
||||
]
|
||||
clean_sentences = get_sentences_from_iob_format(tokens_labels_str)
|
||||
# read ocr tokens
|
||||
with open(ocr_text_file, 'r', encoding='utf-8') as otf:
|
||||
ocr_text_str = ' '.join(otf.readlines())
|
||||
ocr_tokens = [token.strip() for token in ocr_text_str.split()] # already tokenized in data
|
||||
with open(ocr_text_file, "r", encoding="utf-8") as otf:
|
||||
ocr_text_str = " ".join(otf.readlines())
|
||||
ocr_tokens = [
|
||||
token.strip() for token in ocr_text_str.split()
|
||||
] # already tokenized in data
|
||||
try:
|
||||
ocr_tokens_sentences, ocr_labels_sentences = propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
ocr_tokens_sentences, ocr_labels_sentences = propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\nWarning: error processing '{input_filename}': {str(e)}.\nSkipping this file...")
|
||||
print(
|
||||
f"\nWarning: error processing '{input_filename}': {str(e)}.\nSkipping this file..."
|
||||
)
|
||||
return
|
||||
# Write result to file
|
||||
with open(ocr_labels_file, 'w', encoding="utf-8") as olf:
|
||||
for ocr_tokens, ocr_labels in zip(ocr_tokens_sentences, ocr_labels_sentences):
|
||||
if len(ocr_tokens) == 0: # if empty OCR sentences
|
||||
olf.write(f'{EMPTY_SENTENCE_SENTINEL}\t{EMPTY_SENTENCE_SENTINEL_NER_LABEL}\n')
|
||||
else:
|
||||
for token, label in zip(ocr_tokens, ocr_labels):
|
||||
olf.write(f"{token}\t{label}\n")
|
||||
olf.write('\n')
|
||||
# Write result to file
|
||||
with open(ocr_labels_file, "w", encoding="utf-8") as olf:
|
||||
for ocr_tokens, ocr_labels in zip(
|
||||
ocr_tokens_sentences, ocr_labels_sentences
|
||||
):
|
||||
if len(ocr_tokens) == 0: # if empty OCR sentences
|
||||
olf.write(
|
||||
f"{EMPTY_SENTENCE_SENTINEL}\t{EMPTY_SENTENCE_SENTINEL_NER_LABEL}\n"
|
||||
)
|
||||
else:
|
||||
for token, label in zip(ocr_tokens, ocr_labels):
|
||||
olf.write(f"{token}\t{label}\n")
|
||||
olf.write("\n")
|
||||
|
||||
def propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext):
|
||||
|
||||
def propagate_labels_sentences_multiprocess(
|
||||
clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext
|
||||
):
|
||||
"""
|
||||
propagate_labels_sentences_all_files propagates labels and sentences for all files in dataset
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clean_labels_dir : str
|
||||
|
@ -204,20 +253,28 @@ def propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, o
|
|||
file extension of the clean_labels
|
||||
"""
|
||||
clean_label_files = os.listdir(clean_labels_dir)
|
||||
args = list(map(
|
||||
lambda clean_label_filename:
|
||||
(clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext, clean_label_filename),
|
||||
clean_label_files
|
||||
))
|
||||
args = list(
|
||||
map(
|
||||
lambda clean_label_filename: (
|
||||
clean_labels_dir,
|
||||
output_text_dir,
|
||||
output_labels_dir,
|
||||
clean_label_ext,
|
||||
clean_label_filename,
|
||||
),
|
||||
clean_label_files,
|
||||
)
|
||||
)
|
||||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||
iterator = executor.map(propagate_labels_sentence_single_file, args)
|
||||
for _ in tqdm(iterator, total=len(args)): # wrapping tqdm for progress report
|
||||
for _ in tqdm(iterator, total=len(args)): # wrapping tqdm for progress report
|
||||
pass
|
||||
|
||||
|
||||
def extract_ocr_text(input_file, output_file):
|
||||
"""
|
||||
extract_ocr_text from GROK json
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_file : str
|
||||
|
@ -227,16 +284,17 @@ def extract_ocr_text(input_file, output_file):
|
|||
"""
|
||||
out_dir = os.path.dirname(output_file)
|
||||
in_file_name = os.path.basename(input_file)
|
||||
file_pre = in_file_name.split('_')[-1].split('.')[0]
|
||||
output_file_name = '{}.txt'.format(file_pre)
|
||||
file_pre = in_file_name.split("_")[-1].split(".")[0]
|
||||
output_file_name = "{}.txt".format(file_pre)
|
||||
output_file = os.path.join(out_dir, output_file_name)
|
||||
with open(input_file, 'r', encoding='utf-8') as fin:
|
||||
with open(input_file, "r", encoding="utf-8") as fin:
|
||||
json_data = json.load(fin)
|
||||
json_dict = json_data[0]
|
||||
text = json_dict['text']
|
||||
with open(output_file, 'wb') as fout:
|
||||
text = json_dict["text"]
|
||||
with open(output_file, "wb") as fout:
|
||||
fout.write(text.encode("utf-8"))
|
||||
|
||||
|
||||
def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
|
||||
"""
|
||||
check_n_sentences prints file name if number of sentences is different in clean and OCR files
|
||||
|
@ -253,49 +311,54 @@ def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext):
|
|||
text_files = os.listdir(output_labels_dir)
|
||||
skip_files = []
|
||||
for text_filename in tqdm(text_files):
|
||||
clean_labels_file = os.path.join(clean_labels_dir, text_filename).replace(".txt", clean_label_ext)
|
||||
clean_labels_file = os.path.join(clean_labels_dir, text_filename).replace(
|
||||
".txt", clean_label_ext
|
||||
)
|
||||
ocr_labels_file = os.path.join(output_labels_dir, text_filename)
|
||||
remove_first_line(clean_labels_file, clean_labels_file)
|
||||
remove_first_line(ocr_labels_file, ocr_labels_file)
|
||||
remove_last_line(clean_labels_file, clean_labels_file)
|
||||
remove_last_line(ocr_labels_file, ocr_labels_file)
|
||||
with open(clean_labels_file, 'r', encoding='utf-8') as lf:
|
||||
with open(clean_labels_file, "r", encoding="utf-8") as lf:
|
||||
clean_tokens_labels = lf.readlines()
|
||||
with open(ocr_labels_file, 'r', encoding='utf-8') as of:
|
||||
with open(ocr_labels_file, "r", encoding="utf-8") as of:
|
||||
ocr_tokens_labels = of.readlines()
|
||||
error = False
|
||||
n_clean_sentences = 0
|
||||
nl = False
|
||||
for line in clean_tokens_labels:
|
||||
if line == '\n':
|
||||
if line == "\n":
|
||||
if nl is True:
|
||||
error = True
|
||||
else:
|
||||
nl=True
|
||||
nl = True
|
||||
n_clean_sentences += 1
|
||||
else:
|
||||
nl = False
|
||||
n_ocr_sentences = 0
|
||||
nl = False
|
||||
for line in ocr_tokens_labels:
|
||||
if line == '\n':
|
||||
if line == "\n":
|
||||
if nl is True:
|
||||
error = True
|
||||
else:
|
||||
nl=True
|
||||
nl = True
|
||||
n_ocr_sentences += 1
|
||||
else:
|
||||
nl = False
|
||||
if error or n_ocr_sentences != n_clean_sentences:
|
||||
print(f"Warning: Inconsistent numbers of sentences in '{text_filename}''." +
|
||||
f"clean_sentences to ocr_sentences: {n_clean_sentences}:{n_ocr_sentences}")
|
||||
print(
|
||||
f"Warning: Inconsistent numbers of sentences in '{text_filename}''."
|
||||
+ f"clean_sentences to ocr_sentences: {n_clean_sentences}:{n_ocr_sentences}"
|
||||
)
|
||||
skip_files.append(text_filename)
|
||||
return skip_files
|
||||
|
||||
|
||||
def remove_first_line(input_file, output_file):
|
||||
"""
|
||||
remove_first_line from files (some clean CoNLL files have an empty first line)
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_file : str
|
||||
|
@ -303,17 +366,18 @@ def remove_first_line(input_file, output_file):
|
|||
output_file : str
|
||||
output file path
|
||||
"""
|
||||
with open(input_file, 'r', encoding='utf-8') as in_f:
|
||||
with open(input_file, "r", encoding="utf-8") as in_f:
|
||||
lines = in_f.readlines()
|
||||
if len(lines) > 1 and lines[0].strip() == '':
|
||||
if len(lines) > 1 and lines[0].strip() == "":
|
||||
# the clean CoNLL formatted files had a newline as the first line
|
||||
with open(output_file, 'w', encoding='utf-8') as out_f:
|
||||
with open(output_file, "w", encoding="utf-8") as out_f:
|
||||
out_f.writelines(lines[1:])
|
||||
|
||||
|
||||
def remove_last_line(input_file, output_file):
|
||||
"""
|
||||
remove_last_line from files (some clean CoNLL files have an empty last line)
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_file : str
|
||||
|
@ -322,16 +386,17 @@ def remove_last_line(input_file, output_file):
|
|||
output file path
|
||||
"""
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as in_f:
|
||||
with open(input_file, "r", encoding="utf-8") as in_f:
|
||||
lines = in_f.readlines()
|
||||
if len(lines) > 1 and lines[-1].strip() == '':
|
||||
with open(output_file, 'w', encoding='utf-8') as out_f:
|
||||
if len(lines) > 1 and lines[-1].strip() == "":
|
||||
with open(output_file, "w", encoding="utf-8") as out_f:
|
||||
out_f.writelines(lines[:-1])
|
||||
|
||||
|
||||
def for_all_files(input_dir, output_dir, func):
|
||||
"""
|
||||
for_all_files will apply function to every file in a director
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dir : str
|
||||
|
@ -347,25 +412,34 @@ def for_all_files(input_dir, output_dir, func):
|
|||
output_file = os.path.join(output_dir, text_filename)
|
||||
func(input_file, output_file)
|
||||
|
||||
|
||||
def main(args):
|
||||
if not args.train_subset and not args.test_subset:
|
||||
subsets = ['train', 'test']
|
||||
subsets = ["train", "test"]
|
||||
else:
|
||||
subsets = []
|
||||
if args.train_subset:
|
||||
subsets.append('train')
|
||||
subsets.append("train")
|
||||
|
||||
if args.test_subset:
|
||||
subsets.append('test')
|
||||
subsets.append("test")
|
||||
|
||||
for subset in subsets:
|
||||
print("Processing {} subset...".format(subset))
|
||||
|
||||
clean_labels_dir = os.path.join(args.base_folder, args.gt_folder, subset,'clean_labels')
|
||||
ocr_json_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr')
|
||||
clean_labels_dir = os.path.join(
|
||||
args.base_folder, args.gt_folder, subset, "clean_labels"
|
||||
)
|
||||
ocr_json_dir = os.path.join(
|
||||
args.base_folder, args.degraded_folder, subset, "ocr"
|
||||
)
|
||||
|
||||
output_text_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr_text')
|
||||
output_labels_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr_labels')
|
||||
output_text_dir = os.path.join(
|
||||
args.base_folder, args.degraded_folder, subset, "ocr_text"
|
||||
)
|
||||
output_labels_dir = os.path.join(
|
||||
args.base_folder, args.degraded_folder, subset, "ocr_labels"
|
||||
)
|
||||
|
||||
# remove first empty line of labels file, if exists
|
||||
for_all_files(clean_labels_dir, clean_labels_dir, remove_first_line)
|
||||
|
@ -380,21 +454,50 @@ def main(args):
|
|||
os.mkdir(output_labels_dir)
|
||||
|
||||
# make ocr labels files by propagating clean labels to ocr_text and creating files in ocr_labels
|
||||
propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, output_labels_dir, args.clean_label_ext)
|
||||
propagate_labels_sentences_multiprocess(
|
||||
clean_labels_dir, output_text_dir, output_labels_dir, args.clean_label_ext
|
||||
)
|
||||
print("Validating number of sentences in gt and ocr labels")
|
||||
check_n_sentences(clean_labels_dir, output_labels_dir, args.clean_label_ext) # check number of sentences and make sure same; print anomaly files
|
||||
check_n_sentences(
|
||||
clean_labels_dir, output_labels_dir, args.clean_label_ext
|
||||
) # check number of sentences and make sure same; print anomaly files
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("base_folder", help="base directory containing the collection of dataset")
|
||||
parser.add_argument("degraded_folder", help="directory containing train and test subset for degradation")
|
||||
parser.add_argument("--gt_folder", type=str, default="shared", help="directory containing the ground truth")
|
||||
parser.add_argument("--clean_label_ext", type=str, default=".txt", help="file extension of the clean_labels files")
|
||||
parser.add_argument('--train_subset', help="include if only train folder should be processed", action='store_true')
|
||||
parser.add_argument('--test_subset', help="include if only test folder should be processed", action='store_true')
|
||||
parser.add_argument(
|
||||
"base_folder", help="base directory containing the collection of dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"degraded_folder",
|
||||
help="directory containing train and test subset for degradation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gt_folder",
|
||||
type=str,
|
||||
default="shared",
|
||||
help="directory containing the ground truth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean_label_ext",
|
||||
type=str,
|
||||
default=".txt",
|
||||
help="file extension of the clean_labels files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_subset",
|
||||
help="include if only train folder should be processed",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_subset",
|
||||
help="include if only test folder should be processed",
|
||||
action="store_true",
|
||||
)
|
||||
return parser
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
start = timeit.default_timer()
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
from genalog.text import preprocess
|
||||
|
||||
class LCS():
|
||||
class LCS:
|
||||
""" Compute the Longest Common Subsequence (LCS) of two given string."""
|
||||
|
||||
def __init__(self, str_m, str_n):
|
||||
self.str_m_len = len(str_m)
|
||||
self.str_n_len = len(str_n)
|
||||
dp_table = self._construct_dp_table(str_m, str_n)
|
||||
self._lcs_len = dp_table[self.str_m_len][self.str_n_len]
|
||||
self._lcs = self._find_lcs_str(str_m, str_n, dp_table)
|
||||
|
||||
|
||||
def _construct_dp_table(self, str_m, str_n):
|
||||
m = self.str_m_len
|
||||
n = self.str_n_len
|
||||
|
@ -16,14 +15,14 @@ class LCS():
|
|||
# Initialize DP table
|
||||
dp = [[0 for j in range(n + 1)] for i in range(m + 1)]
|
||||
|
||||
for i in range(1, m+1):
|
||||
for j in range(1, n+1):
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
# Case 1: if char1 == char2
|
||||
if str_m[i-1] == str_n[j-1]:
|
||||
dp[i][j] = 1 + dp[i-1][j-1]
|
||||
if str_m[i - 1] == str_n[j - 1]:
|
||||
dp[i][j] = 1 + dp[i - 1][j - 1]
|
||||
# Case 2: take the max of the values in the top and left cell
|
||||
else:
|
||||
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
|
||||
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
|
||||
return dp
|
||||
|
||||
def _find_lcs_str(self, str_m, str_n, dp_table):
|
||||
|
@ -32,13 +31,13 @@ class LCS():
|
|||
lcs = ""
|
||||
while m > 0 and n > 0:
|
||||
# same char
|
||||
if str_m[m-1] == str_n[n-1]:
|
||||
if str_m[m - 1] == str_n[n - 1]:
|
||||
# prepend the character
|
||||
lcs = str_m[m - 1] + lcs
|
||||
lcs = str_m[m - 1] + lcs
|
||||
m -= 1
|
||||
n -= 1
|
||||
# top cell > left cell
|
||||
elif dp_table[m-1][n] > dp_table[m][n-1]:
|
||||
elif dp_table[m - 1][n] > dp_table[m][n - 1]:
|
||||
m -= 1
|
||||
else:
|
||||
n -= 1
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from genalog.text import alignment, anchor
|
||||
from genalog.text import preprocess
|
||||
import itertools
|
||||
import re
|
||||
import string
|
||||
import itertools
|
||||
|
||||
from genalog.text import alignment, anchor
|
||||
from genalog.text import preprocess
|
||||
|
||||
# Both regex below has the following behavior:
|
||||
# 1. whitespace-tolerant at both ends of the string
|
||||
|
@ -10,71 +11,81 @@ import itertools
|
|||
# For example, given a label 'B-PLACE'
|
||||
# Group 1 (denoted by \1): Label Indicator (B-)
|
||||
# Group 2 (denoted by \2): Label Name (PLACE)
|
||||
MULTI_TOKEN_BEGIN_LABEL_REGEX = r'^\s*(B-)([a-z|A-Z]+)\s*$'
|
||||
MULTI_TOKEN_INSIDE_LABEL_REGEX = r'^\s*(I-)([a-z|A-Z]+)\s*$'
|
||||
MULTI_TOKEN_LABEL_REGEX = r'^\s*([B|I]-)([a-z|A-Z]+)\s*'
|
||||
MULTI_TOKEN_BEGIN_LABEL_REGEX = r"^\s*(B-)([a-z|A-Z]+)\s*$"
|
||||
MULTI_TOKEN_INSIDE_LABEL_REGEX = r"^\s*(I-)([a-z|A-Z]+)\s*$"
|
||||
MULTI_TOKEN_LABEL_REGEX = r"^\s*([B|I]-)([a-z|A-Z]+)\s*"
|
||||
|
||||
# To avoid confusion in the Python interpreter,
|
||||
# gap char should not be any of the following special characters
|
||||
SPECIAL_CHAR = set(" \t\n'\x0b''\x0c''\r'") # Notice space characters (' ', '\t', '\n') are in this set.
|
||||
# gap char should not be any of the following special characters
|
||||
SPECIAL_CHAR = set(
|
||||
" \t\n'\x0b''\x0c''\r'"
|
||||
) # Notice space characters (' ', '\t', '\n') are in this set.
|
||||
GAP_CHAR_SET = set(string.printable).difference(SPECIAL_CHAR)
|
||||
# GAP_CHAR_SET = '!"#$%&()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~'
|
||||
|
||||
|
||||
class GapCharError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _is_begin_label(label):
|
||||
""" Return true if the NER label is a begin label (eg. B-PLACE) """
|
||||
return re.match(MULTI_TOKEN_BEGIN_LABEL_REGEX, label) != None
|
||||
return re.match(MULTI_TOKEN_BEGIN_LABEL_REGEX, label) is not None
|
||||
|
||||
|
||||
def _is_inside_label(label):
|
||||
""" Return true if the NER label is an inside label (eg. I-PLACE) """
|
||||
return re.match(MULTI_TOKEN_INSIDE_LABEL_REGEX, label) != None
|
||||
return re.match(MULTI_TOKEN_INSIDE_LABEL_REGEX, label) is not None
|
||||
|
||||
|
||||
def _is_multi_token_label(label):
|
||||
""" Return true if the NER label is a multi token label (eg. B-PLACE, I-PLACE) """
|
||||
return re.match(MULTI_TOKEN_LABEL_REGEX, label) != None
|
||||
return re.match(MULTI_TOKEN_LABEL_REGEX, label) is not None
|
||||
|
||||
|
||||
def _clean_multi_token_label(label):
|
||||
""" Rid the multi-token-labels of whitespaces"""
|
||||
return re.sub(MULTI_TOKEN_LABEL_REGEX, r'\1\2', label)
|
||||
return re.sub(MULTI_TOKEN_LABEL_REGEX, r"\1\2", label)
|
||||
|
||||
|
||||
def _convert_to_begin_label(label):
|
||||
""" Convert an inside label, or I-label, (ex. I-PLACE) to a begin label, or B-Label, (ex. B-PLACE)
|
||||
|
||||
"""Convert an inside label, or I-label, (ex. I-PLACE) to a begin label, or B-Label, (ex. B-PLACE)
|
||||
|
||||
Arguments:
|
||||
label {str} -- an NER label
|
||||
|
||||
|
||||
Returns:
|
||||
an NER label. This method DOES NOT alter the label unless it is an inside label
|
||||
"""
|
||||
if _is_inside_label(label):
|
||||
# Replace the Label Indicator to 'B-'(\1) and keep the Label Name (\2)
|
||||
return re.sub(MULTI_TOKEN_INSIDE_LABEL_REGEX, r'B-\2', label)
|
||||
return re.sub(MULTI_TOKEN_INSIDE_LABEL_REGEX, r"B-\2", label)
|
||||
return label
|
||||
|
||||
|
||||
def _convert_to_inside_label(label):
|
||||
""" Convert a begin label, or B-label, (ex. B-PLACE) to an inside label, or I-Label, (ex. B-PLACE)
|
||||
|
||||
"""Convert a begin label, or B-label, (ex. B-PLACE) to an inside label, or I-Label, (ex. B-PLACE)
|
||||
|
||||
Arguments:
|
||||
label {str} -- an NER label
|
||||
|
||||
|
||||
Returns:
|
||||
an NER label. This method DOES NOT alter the label unless it is a begin label
|
||||
"""
|
||||
if _is_begin_label(label):
|
||||
# Replace the Label Indicator to 'I-'(\1) and keep the Label Name (\2)
|
||||
return re.sub(MULTI_TOKEN_BEGIN_LABEL_REGEX, r'I-\2', label)
|
||||
return re.sub(MULTI_TOKEN_BEGIN_LABEL_REGEX, r"I-\2", label)
|
||||
return label
|
||||
|
||||
|
||||
def _is_missing_begin_label(begin_label, inside_label):
|
||||
""" Validate a inside label given an begin label
|
||||
|
||||
"""Validate a inside label given an begin label
|
||||
|
||||
Arguments:
|
||||
begin_label {str} -- a begin NER label used to
|
||||
begin_label {str} -- a begin NER label used to
|
||||
check if the given label is part of a multi-token label
|
||||
inside_label {str} -- an inside label to check for its validity
|
||||
|
||||
|
||||
Returns:
|
||||
True if the inside label paired with the begin_label. False otherwise.
|
||||
Also False if input is not an inside label
|
||||
|
@ -87,20 +98,21 @@ def _is_missing_begin_label(begin_label, inside_label):
|
|||
inside_label = _clean_multi_token_label(inside_label)
|
||||
begin_label = _clean_multi_token_label(begin_label)
|
||||
# convert inside label to a begin label for string comparison
|
||||
# True if the two labels have different names
|
||||
# True if the two labels have different names
|
||||
# (e.g. B-LOC followed by I-ORG, and I-ORG is missing a begin label)
|
||||
return _convert_to_begin_label(inside_label) != begin_label
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def correct_ner_labels(labels):
|
||||
""" Correct the given list of labels for the following case:
|
||||
"""Correct the given list of labels for the following case:
|
||||
|
||||
1. Missing B-Label (i.e. I-PLACE I-PLACE -> B-PLACE I-PLACE)
|
||||
|
||||
|
||||
Arguments:
|
||||
labels {list} -- list of NER labels
|
||||
|
||||
|
||||
Returns:
|
||||
a list of NER labels
|
||||
"""
|
||||
|
@ -109,7 +121,7 @@ def correct_ner_labels(labels):
|
|||
if _is_multi_token_label(label):
|
||||
if _is_begin_label(label):
|
||||
cur_begin_label = label
|
||||
# else is an inside label, so we check if it's missing a begin label
|
||||
# else is an inside label, so we check if it's missing a begin label
|
||||
else:
|
||||
if _is_missing_begin_label(cur_begin_label, label):
|
||||
labels[i] = _convert_to_begin_label(label)
|
||||
|
@ -118,10 +130,11 @@ def correct_ner_labels(labels):
|
|||
else:
|
||||
cur_begin_label = ""
|
||||
return labels
|
||||
|
||||
|
||||
|
||||
def _select_from_multiple_ner_labels(label_indices):
|
||||
""" Private method to select a NER label from a list of candidate
|
||||
|
||||
"""Private method to select a NER label from a list of candidate
|
||||
|
||||
Note: this method is used to tackle the issue when multiple gt tokens
|
||||
are aligned to ONE ocr_token
|
||||
|
||||
|
@ -129,7 +142,7 @@ def _select_from_multiple_ner_labels(label_indices):
|
|||
|
||||
gt_labels: B-p I-p O O
|
||||
| | | |
|
||||
gt: New York is big
|
||||
gt: New York is big
|
||||
| \\ / |
|
||||
ocr: New Yorkis big
|
||||
| | |
|
||||
|
@ -140,15 +153,16 @@ def _select_from_multiple_ner_labels(label_indices):
|
|||
|
||||
Arguments:
|
||||
label_indices {list} -- a list of token indices
|
||||
|
||||
|
||||
Returns:
|
||||
a specific index
|
||||
"""
|
||||
# TODO: may need a more sophisticated way to select from multiple NER labels
|
||||
return label_indices[0]
|
||||
|
||||
|
||||
def _find_gap_char_candidates(gt_tokens, ocr_tokens):
|
||||
""" Find a set of suitable GAP_CHARs based not in the set of input characters
|
||||
"""Find a set of suitable GAP_CHARs based not in the set of input characters
|
||||
|
||||
Arguments:
|
||||
gt_tokens {list} -- a list of tokens
|
||||
|
@ -159,15 +173,18 @@ def _find_gap_char_candidates(gt_tokens, ocr_tokens):
|
|||
1. the set of suitable GAP_CHARs
|
||||
2. the set of input characters
|
||||
"""
|
||||
input_char_set = set(''.join(itertools.chain(gt_tokens, ocr_tokens))) # The set of input characters
|
||||
input_char_set = set(
|
||||
"".join(itertools.chain(gt_tokens, ocr_tokens))
|
||||
) # The set of input characters
|
||||
gap_char_set = GAP_CHAR_SET # The set of possible GAP_CHARs
|
||||
# Find a set of gap_char that is NOT in the set of input characters
|
||||
gap_char_candidates = gap_char_set.difference(input_char_set)
|
||||
return gap_char_candidates, input_char_set
|
||||
|
||||
|
||||
def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
|
||||
""" Propagate NER label for ground truth tokens to to ocr tokens.
|
||||
|
||||
"""Propagate NER label for ground truth tokens to to ocr tokens.
|
||||
|
||||
NOTE that `gt_tokens` and `ocr_tokens` MUST NOT contain invalid tokens.
|
||||
Invalid tokens are:
|
||||
1. non-atomic tokens, or space-separated string ("New York")
|
||||
|
@ -175,7 +192,7 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
|
|||
4. string with spaces (" ")
|
||||
|
||||
Arguments:
|
||||
gt_labels {list} -- a list of NER label for ground truth token
|
||||
gt_labels {list} -- a list of NER label for ground truth token
|
||||
gt_tokens {list} -- a list of ground truth string tokens
|
||||
ocr_tokens {list} -- a list of OCR'ed text tokens
|
||||
|
||||
|
@ -185,8 +202,8 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
|
|||
(default: {True})
|
||||
|
||||
Raises:
|
||||
GapCharError:
|
||||
when the set of input character is EQUAL
|
||||
GapCharError:
|
||||
when the set of input character is EQUAL
|
||||
to set of all possible gap characters (GAP_CHAR_SET)
|
||||
|
||||
Returns:
|
||||
|
@ -199,21 +216,30 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True):
|
|||
`gap_char` is the char used to alignment for inserting gaps
|
||||
"""
|
||||
# Find a set of suitable GAP_CHAR based not in the set of input characters
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(gt_tokens, ocr_tokens)
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(
|
||||
gt_tokens, ocr_tokens
|
||||
)
|
||||
if len(gap_char_candidates) == 0:
|
||||
raise GapCharError("Exhausted all possible GAP_CHAR candidates for alignment." +
|
||||
" Consider reducing cardinality of the input character set.\n" +
|
||||
f"The set of possible GAP_CHAR candidates is: '{''.join(sorted(GAP_CHAR_SET))}'\n" +
|
||||
f"The set of input character is: '{''.join(sorted(input_char_set))}'")
|
||||
raise GapCharError(
|
||||
"Exhausted all possible GAP_CHAR candidates for alignment."
|
||||
+ " Consider reducing cardinality of the input character set.\n"
|
||||
+ f"The set of possible GAP_CHAR candidates is: '{''.join(sorted(GAP_CHAR_SET))}'\n"
|
||||
+ f"The set of input character is: '{''.join(sorted(input_char_set))}'"
|
||||
)
|
||||
else:
|
||||
if alignment.GAP_CHAR in gap_char_candidates:
|
||||
gap_char = alignment.GAP_CHAR # prefer to use default GAP_CHAR
|
||||
gap_char = alignment.GAP_CHAR # prefer to use default GAP_CHAR
|
||||
else:
|
||||
gap_char = gap_char_candidates.pop()
|
||||
return _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char, use_anchor=use_anchor)
|
||||
return _propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char, use_anchor=use_anchor
|
||||
)
|
||||
|
||||
def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment.GAP_CHAR, use_anchor=True):
|
||||
""" Propagate NER label for ground truth tokens to to ocr tokens. Low level implementation
|
||||
|
||||
def _propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens, gap_char=alignment.GAP_CHAR, use_anchor=True
|
||||
):
|
||||
"""Propagate NER label for ground truth tokens to to ocr tokens. Low level implementation
|
||||
|
||||
NOTE: that `gt_tokens` and `ocr_tokens` MUST NOT contain invalid tokens.
|
||||
Invalid tokens are:
|
||||
|
@ -221,10 +247,10 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
|
|||
2. multiple occurrences of the GAP_CHAR ('@@@')
|
||||
3. empty string ("")
|
||||
4. string with spaces (" ")
|
||||
|
||||
|
||||
Case Analysis:
|
||||
******************************** MULTI-TOKEN-LABELS ********************************
|
||||
|
||||
|
||||
Case 1: Case 2: Case 3: Case 4: Case 5:
|
||||
one-to-many many-to-one many-to-many missing tokens missing tokens
|
||||
(Case 1&2 comb) (I-label) (B-label)
|
||||
|
@ -233,24 +259,24 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
|
|||
gt_token New York New York New York New York New York City
|
||||
/ \\ / \\ \\/ /\\ / | | |
|
||||
ocr_token N ew Yo rk NewYork N ew@York New York City
|
||||
| | | | | | | | | |
|
||||
| | | | | | | | | |
|
||||
ocr label B-p I-p I-p I-p B-p B-p I-p B-p B-p I-p
|
||||
|
||||
******************************** SINGLE-TOKEN-LABELS ********************************
|
||||
|
||||
Case 1: Case 2: Case 3: Case 4:
|
||||
one-to-many many-to-one many-to-many missing tokens
|
||||
(Case 1&2 comb)
|
||||
gt label O V O O V W O O
|
||||
| | | | | | | |
|
||||
gt_token something is big this is huge is big
|
||||
/ \\ \\ \\/ /\\ /\\/ |
|
||||
ocr_token so me thing isbig th isi shuge is
|
||||
| | | | | | | |
|
||||
ocr label o o o V O O V O
|
||||
|
||||
|
||||
Case 1: Case 2: Case 3: Case 4:
|
||||
one-to-many many-to-one many-to-many missing tokens
|
||||
(Case 1&2 comb)
|
||||
gt label O V O O V W O O
|
||||
| | | | | | | |
|
||||
gt_token something is big this is huge is big
|
||||
/ \\ \\ \\/ /\\ /\\/ |
|
||||
ocr_token so me thing isbig th isi shuge is
|
||||
| | | | | | | |
|
||||
ocr label o o o V O O V O
|
||||
|
||||
Arguments:
|
||||
gt_labels {list} -- a list of NER label for ground truth token
|
||||
gt_labels {list} -- a list of NER label for ground truth token
|
||||
gt_tokens {list} -- a list of ground truth string tokens
|
||||
ocr_tokens {list} -- a list of OCR'ed text tokens
|
||||
|
||||
|
@ -259,13 +285,13 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
|
|||
use_anchor {bool} -- use faster alignment method with anchors if set to True
|
||||
(default: {True})
|
||||
Raises:
|
||||
ValueError: when
|
||||
ValueError: when
|
||||
1. there is unequal number of gt_tokens and gt_labels
|
||||
2. there is a non-atomic token in gt_tokens or ocr_tokens
|
||||
3. there is an empty string in gt_tokens or ocr_tokens
|
||||
4. there is a token full of space characters only in gt_tokens or ocr_tokens
|
||||
5. gt_to_ocr_mapping has more tokens than gt_tokens
|
||||
GapCharError: when
|
||||
GapCharError: when
|
||||
1. there is a token consisted of GAP_CHAR only
|
||||
|
||||
|
||||
|
@ -278,24 +304,24 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
|
|||
`aligned_ocr` is the ocr text aligned with ground true
|
||||
`gap_char` is the char used to alignment for inserting gaps
|
||||
|
||||
For example,
|
||||
For example,
|
||||
given input:
|
||||
|
||||
gt_labels: ["B-place", "I-place", "o", "o"]
|
||||
gt_tokens: ["New", "York", "is", "big"]
|
||||
ocr_tokens: ["N", "ewYork", "big"]
|
||||
|
||||
|
||||
output:
|
||||
(
|
||||
["B-place", "I-place", "o"],
|
||||
"N@ew York is big",
|
||||
"N ew@York@@@ big"
|
||||
"N ew@York@@@ big"
|
||||
)
|
||||
"""
|
||||
# Pseudo-algorithm:
|
||||
|
||||
# ocr_to_gt_mapping = [
|
||||
# gt_labels: B-P I-P I-P O O B-P I-P [1, 2], ('YorkCity' maps to 'York' and 'City')
|
||||
# gt_labels: B-P I-P I-P O O B-P I-P [1, 2], ('YorkCity' maps to 'York' and 'City')
|
||||
# | | | | | | | [3], ('i' maps to 'is')
|
||||
# gt_txt: "New York City is in New York" [3, 4], ('sin' maps to 'is' and 'in')
|
||||
# \/ /\ | /\ [5], ('N' maps to 'New')
|
||||
|
@ -312,13 +338,13 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
|
|||
#
|
||||
|
||||
# gt_to_ocr_mapping = [
|
||||
# gt_labels: B-P I-P I-P O O B-P I-P [], ('New' does not map to any ocr token)
|
||||
# gt_labels: B-P I-P I-P O O B-P I-P [], ('New' does not map to any ocr token)
|
||||
# | | | | | | | [0], ('York' maps to 'YorkCity')
|
||||
# gt_txt: "New York City is in New York" [0], ('City' maps to 'YorkCity')
|
||||
# \/ /\ | /\ [1, 2], ('is' maps to 'i' and 'sin')
|
||||
# ocr_txt: "YorkCity i sin N ew" [2], ('in' maps to 'sin)
|
||||
# | | | | | [3,4], ('New' maps to 'N' and 'ew')
|
||||
# I-P O O B-P B-P [] ('York' does not map to any ocr token)
|
||||
# | | | | | [3,4], ('New' maps to 'N' and 'ew')
|
||||
# I-P O O B-P B-P [] ('York' does not map to any ocr token)
|
||||
# ]
|
||||
|
||||
# STEP 2, clean up corner cases from multi-token-labels
|
||||
|
@ -332,41 +358,53 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
|
|||
# YorkCity
|
||||
|
||||
# We can address MULTI-TOKEN-LABELS Case 1 with following pseudo-algorithm:
|
||||
# 1. For each gt_token in gt_to_ocr_mapping:
|
||||
# 1. If the gt_token is mapped to 2 or more ocr_tokens AND the gt_token has a B-label
|
||||
# 1. For all the ocr_tokens this gt_token mapped to
|
||||
# 1. Keep the B-label for the 1st ocr_token
|
||||
# 2. For the rest of the ocr_token, convert the B-label to an I-label
|
||||
|
||||
# 1. For each gt_token in gt_to_ocr_mapping:
|
||||
# 1. If the gt_token is mapped to 2 or more ocr_tokens AND the gt_token has a B-label
|
||||
# 1. For all the ocr_tokens this gt_token mapped to
|
||||
# 1. Keep the B-label for the 1st ocr_token
|
||||
# 2. For the rest of the ocr_token, convert the B-label to an I-label
|
||||
|
||||
# We can address the MULTI-TOKEN-LABELS Case 5 with the '_correct_ner_labels()' method
|
||||
|
||||
|
||||
# Sanity check:
|
||||
if len(gt_tokens) != len(gt_labels):
|
||||
raise ValueError(f"Unequal number of gt_tokens ({len(gt_tokens)})" +
|
||||
f"to that of gt_labels ({len(gt_labels)})")
|
||||
|
||||
for tk in (gt_tokens + ocr_tokens):
|
||||
raise ValueError(
|
||||
f"Unequal number of gt_tokens ({len(gt_tokens)})"
|
||||
+ f"to that of gt_labels ({len(gt_labels)})"
|
||||
)
|
||||
|
||||
for tk in gt_tokens + ocr_tokens:
|
||||
if len(preprocess.tokenize(tk)) > 1:
|
||||
raise ValueError(f"Invalid token '{tk}'. Tokens must be atomic.")
|
||||
if not alignment._is_valid_token(tk, gap_char=gap_char):
|
||||
if re.search(rf'{re.escape(gap_char)}+', tk): # Escape special regex chars
|
||||
raise GapCharError(f"Invalid token '{tk}'. Tokens cannot be a chain repetition of the GAP_CHAR '{gap_char}'")
|
||||
if re.search(rf"{re.escape(gap_char)}+", tk): # Escape special regex chars
|
||||
raise GapCharError(
|
||||
f"Invalid token '{tk}'. Tokens cannot be a chain repetition of the GAP_CHAR '{gap_char}'"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid token '{tk}'. Tokens cannot be an empty string or a mix of space characters (spaces, tabs, newlines)")
|
||||
raise ValueError(
|
||||
f"Invalid token '{tk}'. Tokens cannot be an empty string or a mix of space characters (spaces, tabs, newlines)"
|
||||
)
|
||||
|
||||
# Stitch tokens together into one string for alignment
|
||||
gt_txt = preprocess.join_tokens(gt_tokens)
|
||||
ocr_txt = preprocess.join_tokens(ocr_tokens)
|
||||
# Align the ground truth and ocr text first
|
||||
if use_anchor:
|
||||
aligned_gt, aligned_ocr = anchor.align_w_anchor(gt_txt, ocr_txt, gap_char=gap_char)
|
||||
aligned_gt, aligned_ocr = anchor.align_w_anchor(
|
||||
gt_txt, ocr_txt, gap_char=gap_char
|
||||
)
|
||||
else:
|
||||
aligned_gt, aligned_ocr = alignment.align(gt_txt, ocr_txt, gap_char=gap_char)
|
||||
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(aligned_gt, aligned_ocr, gap_char=gap_char)
|
||||
gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(
|
||||
aligned_gt, aligned_ocr, gap_char=gap_char
|
||||
)
|
||||
# Check invariant
|
||||
if len(gt_to_ocr_mapping) != len(gt_tokens):
|
||||
raise ValueError(f"Alignment modified number of gt_tokens. aligned_gt_tokens to gt_tokens: " +
|
||||
f"{len(gt_to_ocr_mapping)}:{len(gt_tokens)}. \nCheck alignment.parse_alignment().")
|
||||
raise ValueError(
|
||||
"Alignment modified number of gt_tokens. aligned_gt_tokens to gt_tokens: "
|
||||
+ f"{len(gt_to_ocr_mapping)}:{len(gt_tokens)}. \nCheck alignment.parse_alignment()."
|
||||
)
|
||||
|
||||
ocr_labels = []
|
||||
# STEP 1: naively propagate NER label based on text-alignment
|
||||
|
@ -374,7 +412,9 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
|
|||
# if is not mapping to missing a token (Case 4)
|
||||
if ocr_to_gt_token_relationship:
|
||||
# Find the corresponding gt_token it is aligned to
|
||||
ner_label_index = _select_from_multiple_ner_labels(ocr_to_gt_token_relationship)
|
||||
ner_label_index = _select_from_multiple_ner_labels(
|
||||
ocr_to_gt_token_relationship
|
||||
)
|
||||
# Get the NER label for that particular gt_token
|
||||
ocr_labels.append(gt_labels[ner_label_index])
|
||||
|
||||
|
@ -397,7 +437,7 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment
|
|||
|
||||
def format_labels(tokens, labels, label_top=True):
|
||||
"""Format tokens and their NER label for display
|
||||
|
||||
|
||||
Arguments:
|
||||
tokens {list} -- a list of word tokens
|
||||
labels {list} -- a list of NER labels
|
||||
|
@ -405,7 +445,7 @@ def format_labels(tokens, labels, label_top=True):
|
|||
Keyword Arguments:
|
||||
label_top {bool} -- True if label is place on top of the token
|
||||
(default: {True})
|
||||
|
||||
|
||||
Returns:
|
||||
a str with NER label align to the token it is labeling
|
||||
|
||||
|
@ -429,20 +469,28 @@ def format_labels(tokens, labels, label_top=True):
|
|||
len_diff = abs(len(label) - len(token))
|
||||
# Add padding spaces for whichever is shorter
|
||||
if len(label) > len(token):
|
||||
formatted_labels += label + ' '
|
||||
formatted_tokens += token + ' '*len_diff + ' '
|
||||
formatted_labels += label + " "
|
||||
formatted_tokens += token + " " * len_diff + " "
|
||||
else:
|
||||
formatted_labels += label + ' '*len_diff + ' '
|
||||
formatted_tokens += token + ' '
|
||||
formatted_labels += label + " " * len_diff + " "
|
||||
formatted_tokens += token + " "
|
||||
if label_top:
|
||||
return formatted_labels + '\n' + formatted_tokens + '\n'
|
||||
return formatted_labels + "\n" + formatted_tokens + "\n"
|
||||
else:
|
||||
return formatted_tokens + '\n' + formatted_labels + '\n'
|
||||
return formatted_tokens + "\n" + formatted_labels + "\n"
|
||||
|
||||
|
||||
def format_label_propagation(
|
||||
gt_tokens,
|
||||
gt_labels,
|
||||
ocr_tokens,
|
||||
ocr_labels,
|
||||
aligned_gt,
|
||||
aligned_ocr,
|
||||
show_alignment=True,
|
||||
):
|
||||
"""Format label propagation for display
|
||||
|
||||
def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \
|
||||
aligned_gt, aligned_ocr, show_alignment=True):
|
||||
""" Format label propagation for display
|
||||
|
||||
Arguments:
|
||||
gt_tokens {list} -- list of ground truth tokens
|
||||
gt_labels {list} -- list of NER labels for ground truth tokens
|
||||
|
@ -450,15 +498,15 @@ def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \
|
|||
ocr_labels {list} -- list of NER labels for the OCR'ed tokens
|
||||
aligned_gt {str} -- ground truth string aligned with the OCR'ed text
|
||||
aligned_ocr {str} -- OCR'ed text aligned with ground truth
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
show_alignment {bool} -- if true, show alignment result (default: {True})
|
||||
|
||||
|
||||
Returns:
|
||||
a string formatted for display as follows:
|
||||
|
||||
if show_alignment=TRUE
|
||||
"
|
||||
"
|
||||
B-PLACE I-PLACE V O [gt_labels]
|
||||
New York is big [gt_txt]
|
||||
New York is big [aligned_gt]
|
||||
|
@ -468,19 +516,18 @@ def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \
|
|||
B-PLACE V O [ocr_labels]
|
||||
"
|
||||
else
|
||||
"
|
||||
"
|
||||
B-PLACE I-PLACE V O [gt_labels]
|
||||
New York is big [gt_txt]
|
||||
New is big [ocr_txt]
|
||||
B-PLACE V O [ocr_labels]
|
||||
"
|
||||
"""
|
||||
|
||||
|
||||
gt_label_str = format_labels(gt_tokens, gt_labels)
|
||||
label_str = format_labels(ocr_tokens, ocr_labels, label_top=False)
|
||||
label_str = format_labels(ocr_tokens, ocr_labels, label_top=False)
|
||||
if show_alignment:
|
||||
alignment_str = alignment._format_alignment(aligned_gt, aligned_ocr)
|
||||
return gt_label_str + alignment_str + label_str
|
||||
else:
|
||||
return gt_label_str + label_str
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import re
|
||||
import re
|
||||
|
||||
END_OF_TOKEN = {' ', '\t', '\n'}
|
||||
END_OF_TOKEN = {" ", "\t", "\n"}
|
||||
NON_ASCII_REPLACEMENT = "_"
|
||||
|
||||
|
||||
def remove_non_ascii(token, replacement=NON_ASCII_REPLACEMENT):
|
||||
""" Remove non ascii characters in a token
|
||||
"""Remove non ascii characters in a token
|
||||
|
||||
Arguments:
|
||||
token {str} -- a word token
|
||||
|
@ -15,27 +16,29 @@ def remove_non_ascii(token, replacement=NON_ASCII_REPLACEMENT):
|
|||
str -- a word token with non-ASCII characters removed
|
||||
"""
|
||||
# Remove non-ASCII characters in the token
|
||||
ascii_token = str(token.encode('utf-8').decode('ascii', 'ignore'))
|
||||
# If token becomes an empty string as a result
|
||||
ascii_token = str(token.encode("utf-8").decode("ascii", "ignore"))
|
||||
# If token becomes an empty string as a result
|
||||
if len(ascii_token) == 0 and len(token) != 0:
|
||||
ascii_token = replacement # replace with a default character
|
||||
ascii_token = replacement # replace with a default character
|
||||
return ascii_token
|
||||
|
||||
|
||||
def tokenize(s):
|
||||
""" Tokenize string
|
||||
|
||||
"""Tokenize string
|
||||
|
||||
Arguments:
|
||||
s {str} -- aligned string
|
||||
|
||||
|
||||
Returns:
|
||||
a list of tokens
|
||||
"""
|
||||
# split alignment tokens by spaces, tabs and newline (and excluding them in the tokens)
|
||||
return s.split()
|
||||
|
||||
|
||||
def join_tokens(tokens):
|
||||
""" Join a list of tokens into a string
|
||||
|
||||
"""Join a list of tokens into a string
|
||||
|
||||
Arguments:
|
||||
tokens {list} -- a list of tokens
|
||||
|
||||
|
@ -44,14 +47,17 @@ def join_tokens(tokens):
|
|||
"""
|
||||
return " ".join(tokens)
|
||||
|
||||
|
||||
def _is_spacing(c):
|
||||
""" Determine if the character is ignorable """
|
||||
return True if c in END_OF_TOKEN else False
|
||||
|
||||
|
||||
def split_sentences(text, delimiter="\n"):
|
||||
""" Split a text into sentences with a delimiter"""
|
||||
return re.sub(r'(( /?[.!?])+ )', rf'\1{delimiter}', text)
|
||||
return re.sub(r"(( /?[.!?])+ )", rf"\1{delimiter}", text)
|
||||
|
||||
|
||||
def is_sentence_separator(token):
|
||||
""" Returns true if the token is a sentence splitter """
|
||||
return re.match(r'^/?[.!?]$', token) != None
|
||||
return re.match(r"^/?[.!?]$", token) is not None
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""This is a utility tool to split CoNLL formated files. It has the capability to pack sentences into generated
|
||||
pages more tightly.
|
||||
"""This is a utility tool to split CoNLL formated files.
|
||||
It has the capability to pack sentences into generated pages more tightly.
|
||||
|
||||
usage: splitter.py [-h] [--doc_sep DOC_SEP] [--line_sep LINE_SEP]
|
||||
[--force_doc_sep]
|
||||
|
@ -22,16 +22,17 @@ example usage:
|
|||
python -m genalog.text.splitter CoNLL-2012_train.txt conll2012_train
|
||||
|
||||
"""
|
||||
import re
|
||||
import os
|
||||
import multiprocessing
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
from genalog.text import preprocess
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
import multiprocessing
|
||||
import os
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.text import preprocess
|
||||
|
||||
# default buffer. Preferebly set this to something large
|
||||
# It holds the lines read from the CoNLL file
|
||||
BUFFER_SIZE = 50000
|
||||
|
@ -40,14 +41,15 @@ CONLL2012_DOC_SEPERATOR = ""
|
|||
CONLL2003_DOC_SEPERATOR = "-DOCSTART-"
|
||||
|
||||
SEPERATOR = ""
|
||||
STARTING_SPLIT_GUESS = 100 # starting estimate of point where to split text
|
||||
MAX_SIZE = 100 # max number of sentences to pack on a doc page
|
||||
STARTING_SPLIT_GUESS = 100 # starting estimate of point where to split text
|
||||
MAX_SIZE = 100 # max number of sentences to pack on a doc page
|
||||
|
||||
SPLIT_ITERS = 2 # number of iterations to run to find a good split
|
||||
SPLIT_ITERS = 2 # number of iterations to run to find a good split
|
||||
WORKERS_PER_CPU = 2
|
||||
|
||||
default_generator = DocumentGenerator()
|
||||
|
||||
|
||||
def unwrap(size, accumulator):
|
||||
words = []
|
||||
labels = []
|
||||
|
@ -55,11 +57,14 @@ def unwrap(size, accumulator):
|
|||
sentence = accumulator[i]
|
||||
for word, tok in sentence:
|
||||
words.append(word)
|
||||
labels.append((word,tok))
|
||||
labels.append((word, tok))
|
||||
return words, labels
|
||||
|
||||
def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='text_block.html.jinja'):
|
||||
"""Run a few iterations of binary search to find the best split point
|
||||
|
||||
def find_split_position(
|
||||
accumulator, start_pos, iters=SPLIT_ITERS, template_name="text_block.html.jinja"
|
||||
):
|
||||
"""Run a few iterations of binary search to find the best split point
|
||||
from the start to pack in sentences into a page without overflow.
|
||||
|
||||
Args:
|
||||
|
@ -71,38 +76,48 @@ def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='
|
|||
"""
|
||||
global STARTING_SPLIT_GUESS
|
||||
# use binary search to find page split point
|
||||
start, end = start_pos, min(len(accumulator), MAX_SIZE+start_pos)
|
||||
start, end = start_pos, min(len(accumulator), MAX_SIZE + start_pos)
|
||||
best = None
|
||||
count = 0
|
||||
while start <= end:
|
||||
if count==0 and (STARTING_SPLIT_GUESS+start_pos > start and STARTING_SPLIT_GUESS + start_pos < end):
|
||||
if count == 0 and (
|
||||
STARTING_SPLIT_GUESS + start_pos > start
|
||||
and STARTING_SPLIT_GUESS + start_pos < end
|
||||
):
|
||||
split_point = STARTING_SPLIT_GUESS
|
||||
else:
|
||||
split_point = (start + end)//2
|
||||
doc_buf = (start_pos,split_point)
|
||||
split_point = (start + end) // 2
|
||||
doc_buf = (start_pos, split_point)
|
||||
content_words, labels = unwrap(doc_buf, accumulator)
|
||||
content_types = [ContentType.PARAGRAPH]
|
||||
|
||||
text = " ".join(content_words)
|
||||
content = CompositeContent([text], content_types)
|
||||
doc_gen = default_generator.create_generator(content, [template_name])
|
||||
doc_gen = default_generator.create_generator(content, [template_name])
|
||||
|
||||
doc = next(doc_gen)
|
||||
|
||||
|
||||
if len(doc._document.pages) > 1:
|
||||
end = split_point-1
|
||||
end = split_point - 1
|
||||
else:
|
||||
start = split_point+1
|
||||
best = split_point, doc, labels,text
|
||||
start = split_point + 1
|
||||
best = split_point, doc, labels, text
|
||||
if count >= iters:
|
||||
break
|
||||
count += 1
|
||||
STARTING_SPLIT_GUESS = split_point
|
||||
return best
|
||||
|
||||
|
||||
def generate_splits(input_file, output_folder, sentence_seperator="", doc_seperator=None, pool=None,
|
||||
force_doc_sep=False, ext="txt"):
|
||||
|
||||
def generate_splits(
|
||||
input_file,
|
||||
output_folder,
|
||||
sentence_seperator="",
|
||||
doc_seperator=None,
|
||||
pool=None,
|
||||
force_doc_sep=False,
|
||||
ext="txt",
|
||||
):
|
||||
"""Processes the file line by line and add sentences to the buffer for processing.
|
||||
|
||||
Args:
|
||||
|
@ -120,19 +135,21 @@ def generate_splits(input_file, output_folder, sentence_seperator="", doc_sepera
|
|||
with open(input_file) as f:
|
||||
for line in f:
|
||||
if line.strip() == sentence_seperator or line.strip() == doc_seperator:
|
||||
|
||||
if len(sentence) > 0:
|
||||
|
||||
if len(sentence) > 0:
|
||||
accumulator.append(sentence)
|
||||
sentence = []
|
||||
|
||||
if line.strip() == doc_seperator and force_doc_sep:
|
||||
# progress to processing buffer immediately if force_doc_sep
|
||||
pass
|
||||
elif len(accumulator) < BUFFER_SIZE:
|
||||
elif len(accumulator) < BUFFER_SIZE:
|
||||
continue
|
||||
start_pos = 0
|
||||
while start_pos < len(accumulator):
|
||||
start_pos = next_doc(accumulator,doc_id, start_pos, output_folder,pool)
|
||||
start_pos = next_doc(
|
||||
accumulator, doc_id, start_pos, output_folder, pool
|
||||
)
|
||||
doc_id += 1
|
||||
progress_bar.update(1)
|
||||
accumulator = []
|
||||
|
@ -141,69 +158,93 @@ def generate_splits(input_file, output_folder, sentence_seperator="", doc_sepera
|
|||
word, tok = line.split("\t")
|
||||
if word.strip() == "":
|
||||
continue
|
||||
sentence.append((word,tok))
|
||||
sentence.append((word, tok))
|
||||
|
||||
# process any left over lines
|
||||
start_pos = 0
|
||||
if len(sentence) > 0 : accumulator.append(sentence)
|
||||
if len(sentence) > 0:
|
||||
accumulator.append(sentence)
|
||||
while start_pos < len(accumulator):
|
||||
start_pos = next_doc(accumulator,doc_id, start_pos, output_folder,pool)
|
||||
start_pos = next_doc(accumulator, doc_id, start_pos, output_folder, pool)
|
||||
doc_id += 1
|
||||
progress_bar.update(1)
|
||||
|
||||
def next_doc(accumulator,doc_id, start_pos, output_folder,pool,ext="txt"):
|
||||
split_pos,doc,labels,text = find_split_position(accumulator, start_pos)
|
||||
|
||||
def next_doc(accumulator, doc_id, start_pos, output_folder, pool, ext="txt"):
|
||||
split_pos, doc, labels, text = find_split_position(accumulator, start_pos)
|
||||
handle_doc(doc, labels, doc_id, text, output_folder, pool, ext)
|
||||
return split_pos
|
||||
|
||||
|
||||
|
||||
def write_doc(doc, doc_id, labels, text, output_folder, ext="txt", write_png=False):
|
||||
|
||||
if write_png:
|
||||
f = f"{output_folder}/img/img_{doc_id}.png"
|
||||
doc.render_png(target=f)
|
||||
|
||||
text += " " # adding a space at EOF
|
||||
text += " " # adding a space at EOF
|
||||
text = preprocess.split_sentences(text)
|
||||
|
||||
with open(f"{output_folder}/clean_labels/{doc_id}.{ext}", "w") as l:
|
||||
with open(f"{output_folder}/clean_labels/{doc_id}.{ext}", "w") as fp:
|
||||
for idx, (token, label) in enumerate(labels):
|
||||
l.write(token + "\t" + label)
|
||||
fp.write(token + "\t" + label)
|
||||
next_token, _ = labels[(idx + 1) % len(labels)]
|
||||
if preprocess.is_sentence_separator(token) and not \
|
||||
preprocess.is_sentence_separator(next_token):
|
||||
l.write("\n")
|
||||
if idx == len(labels): # Reach the end of the document
|
||||
l.write("\n")
|
||||
|
||||
if preprocess.is_sentence_separator(
|
||||
token
|
||||
) and not preprocess.is_sentence_separator(next_token):
|
||||
fp.write("\n")
|
||||
if idx == len(labels): # Reach the end of the document
|
||||
fp.write("\n")
|
||||
|
||||
with open(f"{output_folder}/clean_text/{doc_id}.txt", "w") as text_file:
|
||||
text_file.write(text)
|
||||
return f"wrote: doc id: {doc_id}"
|
||||
|
||||
|
||||
def _error_callback(err):
|
||||
raise RuntimeError(err)
|
||||
|
||||
|
||||
def handle_doc(doc, labels, doc_id, text, output_folder, pool, ext="txt"):
|
||||
if pool:
|
||||
pool.apply_async(write_doc, args=(doc, doc_id, labels, text, output_folder,ext), error_callback=_error_callback)
|
||||
pool.apply_async(
|
||||
write_doc,
|
||||
args=(doc, doc_id, labels, text, output_folder, ext),
|
||||
error_callback=_error_callback,
|
||||
)
|
||||
else:
|
||||
write_doc(doc, doc_id, labels, text, output_folder)
|
||||
|
||||
|
||||
def setup_folder(output_folder):
|
||||
os.makedirs(os.path.join(output_folder,"clean_text"), exist_ok=True)
|
||||
os.makedirs(os.path.join(output_folder,"clean_labels"), exist_ok=True)
|
||||
os.makedirs(os.path.join(output_folder, "clean_text"), exist_ok=True)
|
||||
os.makedirs(os.path.join(output_folder, "clean_labels"), exist_ok=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input_file", default="CoNLL-2012_train.txt", help="path to input CoNLL formated file.")
|
||||
parser.add_argument("output_folder", default="conll2012_train", help="folder to write results to.")
|
||||
parser.add_argument(
|
||||
"input_file",
|
||||
default="CoNLL-2012_train.txt",
|
||||
help="path to input CoNLL formated file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_folder", default="conll2012_train", help="folder to write results to."
|
||||
)
|
||||
parser.add_argument("--doc_sep", help="CoNLL doc seperator")
|
||||
parser.add_argument("--ext", help="file extension", default="txt")
|
||||
parser.add_argument("--line_sep", default=CONLL2012_DOC_SEPERATOR, help="CoNLL line seperator")
|
||||
parser.add_argument("--force_doc_sep", default=False, action="store_true",
|
||||
help="If set, documents are forced to be split by the doc seperator (recommended to turn this off)")
|
||||
parser.add_argument(
|
||||
"--line_sep", default=CONLL2012_DOC_SEPERATOR, help="CoNLL line seperator"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_doc_sep",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="If set, documents are forced to be split by the doc seperator (recommended to turn this off)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
unescape = lambda s: s.encode('utf-8').decode('unicode_escape') if s else None
|
||||
unescape = lambda s: s.encode("utf-8").decode("unicode_escape") if s else None # noqa: E731
|
||||
|
||||
input_file = args.input_file
|
||||
output_folder = args.output_folder
|
||||
|
@ -211,10 +252,18 @@ if __name__ == "__main__":
|
|||
|
||||
# allow special characters in seperators
|
||||
line_sep = unescape(args.line_sep) or ""
|
||||
doc_sep = unescape(args.doc_sep)
|
||||
doc_sep = unescape(args.doc_sep)
|
||||
|
||||
n_workers = WORKERS_PER_CPU * multiprocessing.cpu_count()
|
||||
with ThreadPool(processes=n_workers) as pool:
|
||||
generate_splits(input_file, output_folder, line_sep, doc_seperator=doc_sep, pool=pool, force_doc_sep=False, ext=args.ext)
|
||||
generate_splits(
|
||||
input_file,
|
||||
output_folder,
|
||||
line_sep,
|
||||
doc_seperator=doc_sep,
|
||||
pool=pool,
|
||||
force_doc_sep=False,
|
||||
ext=args.ext,
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
pool.join()
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
[pytest]
|
||||
junit_family=xunit1
|
|
@ -0,0 +1,4 @@
|
|||
pytest
|
||||
pytest-cov
|
||||
flake8
|
||||
flake8-import-order
|
5
setup.py
5
setup.py
|
@ -1,6 +1,7 @@
|
|||
import setuptools
|
||||
import os
|
||||
|
||||
import setuptools
|
||||
|
||||
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'VERSION.txt')) as version_file:
|
||||
BUILD_VERSION = version_file.read().strip()
|
||||
|
||||
|
@ -13,7 +14,7 @@ with open("README.md", "r", encoding="utf8") as fh:
|
|||
|
||||
setuptools.setup(
|
||||
name="genalog",
|
||||
install_requires=requirements,
|
||||
install_requires=requirements,
|
||||
version=BUILD_VERSION,
|
||||
author="Team Enki",
|
||||
author_email="ta_nerds@microsoft.com",
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Test cases for genalog.text.ner_label.propagate_label_to_ocr() method.
|
||||
# For READABILITY purpose, ground truth and noisy text are presented as
|
||||
# For READABILITY purpose, ground truth and noisy text are presented as
|
||||
# a whole string, not in their tokenized format.
|
||||
|
||||
# Notice the `propagate_label_to_ocr()` method has the contract of
|
||||
# (list, list, list) -> (list, list, list)
|
||||
# consuming both ground truth text and noisy text as lists of tokens.
|
||||
# (list, list, list) -> (list, list, list)
|
||||
# consuming both ground truth text and noisy text as lists of tokens.
|
||||
# We will use `genalog.text.preprocess.tokenize()` to tokenize these strings
|
||||
from genalog.text import preprocess
|
||||
|
||||
|
@ -106,12 +106,14 @@ desired_ocr_labels.append(["O", "B-FRUIT", "I-FRUIT", "O", "O"])
|
|||
ner_labels.append(["O", "O", "ENTERTAINMENT", "O"])
|
||||
gt_txt.append("@ new TV !")
|
||||
ns_txt.append("@ n ow T\\/ |")
|
||||
desired_ocr_labels.append(["O", "O", "O", "ENTERTAINMENT" ,"O"])
|
||||
desired_ocr_labels.append(["O", "O", "O", "ENTERTAINMENT", "O"])
|
||||
|
||||
# Tokenize ground truth and noisy text strings
|
||||
# Tokenize ground truth and noisy text strings
|
||||
gt_tokens = [preprocess.tokenize(txt) for txt in gt_txt]
|
||||
ns_tokens = [preprocess.tokenize(txt) for txt in ns_txt]
|
||||
|
||||
# test function expect params in tuple of
|
||||
# test function expect params in tuple of
|
||||
# (gt_label, gt_tokens, ocr_tokens, desired_ocr_labels)
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES = list(zip(ner_labels, gt_tokens, ns_tokens, desired_ocr_labels))
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES = list(
|
||||
zip(ner_labels, gt_tokens, ns_tokens, desired_ocr_labels)
|
||||
)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
# Initializing test cases
|
||||
# For extensibility, all parameters in a case are append to following arrays
|
||||
gt_txt = []
|
||||
|
@ -17,31 +16,35 @@ ns_txt.append("N ewYork kis big.")
|
|||
aligned_gt.append("N@ew York @is big.")
|
||||
aligned_ns.append("N ew@York kis big.")
|
||||
gt_to_noise_maps.append(
|
||||
[
|
||||
# This shows that the first token in gt "New" maps to the
|
||||
# first ("N") and second ("ewYork") token in the noise
|
||||
[0,1],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
[
|
||||
# This shows that the first token in gt "New" maps to the
|
||||
# first ("N") and second ("ewYork") token in the noise
|
||||
[0, 1],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
]
|
||||
)
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
# Similarly, the following shows that the second token in noise "ewYork" maps to the
|
||||
# first ("New") and second ("York") token in gt
|
||||
[0,1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
|
||||
noise_to_gt_maps.append(
|
||||
[
|
||||
[0],
|
||||
# Similarly, the following shows that the second token in noise "ewYork" maps to the
|
||||
# first ("New") and second ("York") token in gt
|
||||
[0, 1],
|
||||
[2],
|
||||
[3],
|
||||
]
|
||||
)
|
||||
|
||||
##############################################################################################
|
||||
|
||||
# SPECIAL CASE: noisy text does not contain sufficient whitespaces to account
|
||||
# SPECIAL CASE: noisy text does not contain sufficient whitespaces to account
|
||||
# for missing tokens
|
||||
# Notice there's only 1 whitespace b/w 'oston' and 'grea'
|
||||
# The ideal situation is that there are 2 whitespaces. Ex:
|
||||
# ("B oston grea t")
|
||||
# ("B oston grea t")
|
||||
ns_txt.append("B oston grea t")
|
||||
gt_txt.append("Boston is great")
|
||||
|
||||
|
@ -52,18 +55,9 @@ gt_txt.append("Boston is great")
|
|||
aligned_ns.append("B oston@@@ grea t")
|
||||
aligned_gt.append("B@oston is grea@t")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0,1],
|
||||
[1], # 'is' is also mapped to 'oston'
|
||||
[2,3]
|
||||
])
|
||||
gt_to_noise_maps.append([[0, 1], [1], [2, 3]]) # 'is' is also mapped to 'oston'
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[0,1], # 'oston' is to 'Boston' and 'is'
|
||||
[2],
|
||||
[2]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [0, 1], [2], [2]]) # 'oston' is to 'Boston' and 'is'
|
||||
############################################################################################
|
||||
|
||||
# Empty Cases:
|
||||
|
@ -95,18 +89,9 @@ ns_txt.append("B oston bi g")
|
|||
aligned_gt.append("B@oston is bi@g")
|
||||
aligned_ns.append("B oston@@@ bi g")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0,1],
|
||||
[1],
|
||||
[2,3]
|
||||
])
|
||||
gt_to_noise_maps.append([[0, 1], [1], [2, 3]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[0,1],
|
||||
[2],
|
||||
[2]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [0, 1], [2], [2]])
|
||||
|
||||
############################################################################################
|
||||
gt_txt.append("New York is big.")
|
||||
|
@ -115,17 +100,9 @@ ns_txt.append("NewYork big")
|
|||
aligned_gt.append("New York is big.")
|
||||
aligned_ns.append("New@York @@@big@")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0],
|
||||
[0],
|
||||
[1],
|
||||
[1]
|
||||
])
|
||||
gt_to_noise_maps.append([[0], [0], [1], [1]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0, 1],
|
||||
[2, 3]
|
||||
])
|
||||
noise_to_gt_maps.append([[0, 1], [2, 3]])
|
||||
|
||||
#############################################################################################
|
||||
gt_txt.append("politicians who lag superfluous on the")
|
||||
|
@ -134,23 +111,9 @@ ns_txt.append("politicians who kg superfluous on the")
|
|||
aligned_gt.append("politicians who lag superfluous on the")
|
||||
aligned_ns.append("politicians who @kg superfluous on the")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5]
|
||||
])
|
||||
gt_to_noise_maps.append([[0], [1], [2], [3], [4], [5]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [1], [2], [3], [4], [5]])
|
||||
|
||||
############################################################################################
|
||||
|
||||
|
@ -160,20 +123,9 @@ ns_txt.append("faithei uifoimtdon the subject")
|
|||
aligned_gt.append("farther @informed on the subject.")
|
||||
aligned_ns.append("faithei ui@foimtd@on the subject@")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
gt_to_noise_maps.append([[0], [1], [1], [2], [3]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[1,2],
|
||||
[3],
|
||||
[4]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [1, 2], [3], [4]])
|
||||
|
||||
############################################################################################
|
||||
|
||||
|
@ -183,20 +135,9 @@ ns_txt.append("New Yorkis big .")
|
|||
aligned_gt.append("New York is big .")
|
||||
aligned_ns.append("New York@is big .")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
gt_to_noise_maps.append([[0], [1], [1], [2], [3]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[1,2],
|
||||
[3],
|
||||
[4]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [1, 2], [3], [4]])
|
||||
|
||||
############################################################################################
|
||||
|
||||
|
@ -206,23 +147,14 @@ ns_txt.append("New Yo rk is big.")
|
|||
aligned_gt.append("New Yo@rk is big.")
|
||||
aligned_ns.append("New Yo rk is big.")
|
||||
|
||||
gt_to_noise_maps.append([
|
||||
[0],
|
||||
[1,2],
|
||||
[3],
|
||||
[4]
|
||||
])
|
||||
gt_to_noise_maps.append([[0], [1, 2], [3], [4]])
|
||||
|
||||
noise_to_gt_maps.append([
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
noise_to_gt_maps.append([[0], [1], [1], [2], [3]])
|
||||
|
||||
# Format tests for pytest
|
||||
# Each test expect in the following format
|
||||
# (aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps)
|
||||
PARSE_ALIGNMENT_REGRESSION_TEST_CASES = zip(aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps)
|
||||
ALIGNMENT_REGRESSION_TEST_CASES = list(zip(gt_txt, ns_txt, aligned_gt, aligned_ns))
|
||||
PARSE_ALIGNMENT_REGRESSION_TEST_CASES = zip(
|
||||
aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps
|
||||
)
|
||||
ALIGNMENT_REGRESSION_TEST_CASES = list(zip(gt_txt, ns_txt, aligned_gt, aligned_ns))
|
||||
|
|
|
@ -1,80 +1,98 @@
|
|||
from genalog.degradation.degrader import Degrader, ImageState
|
||||
from genalog.degradation.degrader import DEFAULT_METHOD_PARAM_TO_INCLUDE
|
||||
import copy
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import copy
|
||||
|
||||
MOCK_IMAGE_SHAPE = (4,3)
|
||||
from genalog.degradation.degrader import DEFAULT_METHOD_PARAM_TO_INCLUDE
|
||||
from genalog.degradation.degrader import Degrader, ImageState
|
||||
|
||||
MOCK_IMAGE_SHAPE = (4, 3)
|
||||
MOCK_IMAGE = np.arange(12, dtype=np.uint8).reshape(MOCK_IMAGE_SHAPE)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_degrader():
|
||||
effects = []
|
||||
return Degrader(effects)
|
||||
|
||||
@pytest.fixture(params=[
|
||||
[("blur", {"radius": 5})],
|
||||
[("blur", {"src": ImageState.ORIGINAL_STATE, "radius": 5})],
|
||||
[("blur", {"src": ImageState.CURRENT_STATE, "radius": 5})],
|
||||
[
|
||||
("morphology", {"src": ImageState.ORIGINAL_STATE,"operation": "open"}),
|
||||
("morphology", {"operation": "close"}),
|
||||
("morphology", {"src": ImageState.ORIGINAL_STATE,"operation": "dilate"}),
|
||||
("morphology", {"operation": "erode"}),
|
||||
],
|
||||
[
|
||||
("blur", {"radius": 5}),
|
||||
("bleed_through", {
|
||||
"src": ImageState.CURRENT_STATE,
|
||||
"alpha": 0.7,
|
||||
"background": ImageState.ORIGINAL_STATE,
|
||||
}),
|
||||
("morphology", {
|
||||
"operation": "open",
|
||||
"kernel_shape": (3,3),
|
||||
"kernel_type": "ones"
|
||||
}),
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
[("blur", {"radius": 5})],
|
||||
[("blur", {"src": ImageState.ORIGINAL_STATE, "radius": 5})],
|
||||
[("blur", {"src": ImageState.CURRENT_STATE, "radius": 5})],
|
||||
[
|
||||
("morphology", {"src": ImageState.ORIGINAL_STATE, "operation": "open"}),
|
||||
("morphology", {"operation": "close"}),
|
||||
("morphology", {"src": ImageState.ORIGINAL_STATE, "operation": "dilate"}),
|
||||
("morphology", {"operation": "erode"}),
|
||||
],
|
||||
[
|
||||
("blur", {"radius": 5}),
|
||||
(
|
||||
"bleed_through",
|
||||
{
|
||||
"src": ImageState.CURRENT_STATE,
|
||||
"alpha": 0.7,
|
||||
"background": ImageState.ORIGINAL_STATE,
|
||||
},
|
||||
),
|
||||
(
|
||||
"morphology",
|
||||
{"operation": "open", "kernel_shape": (3, 3), "kernel_type": "ones"},
|
||||
),
|
||||
],
|
||||
]
|
||||
])
|
||||
)
|
||||
def degrader(request):
|
||||
effects = request.param
|
||||
return Degrader(effects)
|
||||
|
||||
def test_degrader_init(empty_degrader):
|
||||
|
||||
def test_empty_degrader_init(empty_degrader):
|
||||
assert empty_degrader.effects_to_apply == []
|
||||
|
||||
|
||||
def test_degrader_init(degrader):
|
||||
assert degrader.effects_to_apply is not []
|
||||
for effect_tuple in degrader.effects_to_apply:
|
||||
method_name, method_kwargs = effect_tuple
|
||||
assert DEFAULT_METHOD_PARAM_TO_INCLUDE in method_kwargs
|
||||
param_value = method_kwargs[DEFAULT_METHOD_PARAM_TO_INCLUDE]
|
||||
assert param_value is ImageState.ORIGINAL_STATE or param_value is ImageState.CURRENT_STATE
|
||||
assert (
|
||||
param_value is ImageState.ORIGINAL_STATE
|
||||
or param_value is ImageState.CURRENT_STATE
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("effects, error_thrown", [
|
||||
([], None), #Empty effect
|
||||
(None, TypeError),
|
||||
([("blur", {"radius": 5})], None), # Validate input
|
||||
([("not_a_func", {"radius": 5})], ValueError), # Invalid method name
|
||||
([("blur", {"not_a_argument": 5})], ValueError), # Invalid kwargs
|
||||
([("blur")], ValueError), # Missing kwargs
|
||||
(
|
||||
[
|
||||
("blur", {"radius": 5}),
|
||||
("bleed_through", {"alpha":"0.8"}),
|
||||
("morphology", {"operation": "open"})
|
||||
], None
|
||||
), # Multiple effects
|
||||
(
|
||||
[
|
||||
("blur", {"radius": 5}),
|
||||
("bleed_through", {"not_argument":"0.8"}),
|
||||
("morphology", {"missing value"})
|
||||
], ValueError
|
||||
), # Multiple effects
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"effects, error_thrown",
|
||||
[
|
||||
([], None), # Empty effect
|
||||
(None, TypeError),
|
||||
([("blur", {"radius": 5})], None), # Validate input
|
||||
([("not_a_func", {"radius": 5})], ValueError), # Invalid method name
|
||||
([("blur", {"not_a_argument": 5})], ValueError), # Invalid kwargs
|
||||
([("blur")], ValueError), # Missing kwargs
|
||||
(
|
||||
[
|
||||
("blur", {"radius": 5}),
|
||||
("bleed_through", {"alpha": "0.8"}),
|
||||
("morphology", {"operation": "open"}),
|
||||
],
|
||||
None,
|
||||
), # Multiple effects
|
||||
(
|
||||
[
|
||||
("blur", {"radius": 5}),
|
||||
("bleed_through", {"not_argument": "0.8"}),
|
||||
("morphology", {"missing value"}),
|
||||
],
|
||||
ValueError,
|
||||
), # Multiple effects
|
||||
],
|
||||
)
|
||||
def test_degrader_validate_effects(effects, error_thrown):
|
||||
if error_thrown:
|
||||
with pytest.raises(error_thrown):
|
||||
|
@ -82,23 +100,26 @@ def test_degrader_validate_effects(effects, error_thrown):
|
|||
else:
|
||||
Degrader.validate_effects(effects)
|
||||
|
||||
|
||||
def test_degrader_apply_effects(degrader):
|
||||
method_names = [effect[0] for effect in degrader.effects_to_apply]
|
||||
with patch("genalog.degradation.effect") as mock_effect:
|
||||
degraded = degrader.apply_effects(MOCK_IMAGE)
|
||||
degrader.apply_effects(MOCK_IMAGE)
|
||||
for method in method_names:
|
||||
assert mock_effect[method].is_called()
|
||||
# assert degraded.shape == MOCK_IMAGE_SHAPE
|
||||
|
||||
|
||||
def test_degrader_apply_effects_e2e(degrader):
|
||||
degraded = degrader.apply_effects(MOCK_IMAGE)
|
||||
assert degraded.shape == MOCK_IMAGE_SHAPE
|
||||
assert degraded.dtype == np.uint8
|
||||
|
||||
|
||||
def test_degrader_instructions(degrader):
|
||||
original_instruction = copy.deepcopy(degrader.effects_to_apply)
|
||||
degraded1 = degrader.apply_effects(MOCK_IMAGE)
|
||||
degraded2 = degrader.apply_effects(MOCK_IMAGE)
|
||||
degrader.apply_effects(MOCK_IMAGE)
|
||||
degrader.apply_effects(MOCK_IMAGE)
|
||||
# Make sure the degradation instructions are not altered
|
||||
assert len(original_instruction) == len(degrader.effects_to_apply)
|
||||
for i in range(len(original_instruction)):
|
||||
|
@ -107,5 +128,5 @@ def test_degrader_instructions(degrader):
|
|||
assert org_method_name == method_name
|
||||
assert len(org_method_arg) == len(method_arg)
|
||||
for key in org_method_arg.keys():
|
||||
assert type(org_method_arg[key]) == type(method_arg[key])
|
||||
assert org_method_arg[key] == method_arg[key]
|
||||
assert isinstance(org_method_arg[key], type(method_arg[key]))
|
||||
assert org_method_arg[key] == method_arg[key]
|
||||
|
|
|
@ -1,30 +1,34 @@
|
|||
from genalog.degradation import effect
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from genalog.degradation import effect
|
||||
|
||||
NEW_IMG_SHAPE = (100, 100)
|
||||
MOCK_IMG_SHAPE = (100, 120)
|
||||
MOCK_IMG = np.ones(MOCK_IMG_SHAPE, dtype=np.uint8)
|
||||
|
||||
|
||||
def test_blur():
|
||||
dst = effect.blur(MOCK_IMG, radius=3)
|
||||
assert dst.dtype == np.uint8 # preverse dtype
|
||||
assert dst.shape == MOCK_IMG_SHAPE # preverse image size
|
||||
assert dst.dtype == np.uint8 # preverse dtype
|
||||
assert dst.shape == MOCK_IMG_SHAPE # preverse image size
|
||||
|
||||
|
||||
def test_translation():
|
||||
offset_x = offset_y = 1
|
||||
# Test that border pixels are not white (<255)
|
||||
assert all([col_pixel < 255 for col_pixel in MOCK_IMG[:, 0]])
|
||||
assert all([row_pixel < 255 for row_pixel in MOCK_IMG[0 ,:]])
|
||||
assert all([row_pixel < 255 for row_pixel in MOCK_IMG[0, :]])
|
||||
dst = effect.translation(MOCK_IMG, offset_x, offset_y)
|
||||
# Test that border pixels are white (255)
|
||||
assert all([col_pixel == 255 for col_pixel in dst[:,0]])
|
||||
assert all([col_pixel == 255 for col_pixel in dst[:, 0]])
|
||||
assert all([row_pixel == 255 for row_pixel in dst[0, :]])
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_overlay_weighted():
|
||||
src = MOCK_IMG.copy()
|
||||
src[0][0] = 10
|
||||
|
@ -34,6 +38,7 @@ def test_overlay_weighted():
|
|||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
assert dst[0][0] == src[0][0] * alpha + src[0][0] * beta
|
||||
|
||||
|
||||
def test_overlay():
|
||||
src1 = MOCK_IMG.copy()
|
||||
src2 = MOCK_IMG.copy()
|
||||
|
@ -45,6 +50,7 @@ def test_overlay():
|
|||
assert dst[0][0] == 0
|
||||
assert dst[0][1] == 1
|
||||
|
||||
|
||||
@patch("genalog.degradation.effect.translation")
|
||||
def test_bleed_through_default(mock_translation):
|
||||
mock_translation.return_value = MOCK_IMG
|
||||
|
@ -53,11 +59,15 @@ def test_bleed_through_default(mock_translation):
|
|||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
@pytest.mark.parametrize("foreground, background, error_thrown", [
|
||||
(MOCK_IMG, MOCK_IMG, None),
|
||||
# Test unmatched shape
|
||||
(MOCK_IMG, MOCK_IMG[:,:-1], Exception),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"foreground, background, error_thrown",
|
||||
[
|
||||
(MOCK_IMG, MOCK_IMG, None),
|
||||
# Test unmatched shape
|
||||
(MOCK_IMG, MOCK_IMG[:, :-1], Exception),
|
||||
],
|
||||
)
|
||||
def test_bleed_through_kwargs(foreground, background, error_thrown):
|
||||
if error_thrown:
|
||||
assert foreground.shape != background.shape
|
||||
|
@ -68,97 +78,125 @@ def test_bleed_through_kwargs(foreground, background, error_thrown):
|
|||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_pepper():
|
||||
dst = effect.pepper(MOCK_IMG, amount=0.1)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_salt():
|
||||
dst = effect.salt(MOCK_IMG, amount=0.1)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_salt_then_pepper():
|
||||
dst = effect.salt_then_pepper(MOCK_IMG, 0.5, 0.001)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_pepper_then_salt():
|
||||
dst = effect.pepper_then_salt(MOCK_IMG, 0.001, 0.5)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
@pytest.mark.parametrize("kernel_shape, kernel_type", [
|
||||
((3,3), "NOT_VALID_TYPE"),
|
||||
(1, "ones"),
|
||||
((1,2,3), "ones")
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_shape, kernel_type",
|
||||
[((3, 3), "NOT_VALID_TYPE"), (1, "ones"), ((1, 2, 3), "ones")],
|
||||
)
|
||||
def test_create_2D_kernel_error(kernel_shape, kernel_type):
|
||||
with pytest.raises(Exception):
|
||||
effect.create_2D_kernel(kernel_shape, kernel_type)
|
||||
|
||||
@pytest.mark.parametrize("kernel_shape, kernel_type, expected_kernel", [
|
||||
((2,2), "ones", np.array([[1,1],[1,1]])), # sq kernel
|
||||
((1,2), "ones", np.array([[1,1]])), # horizontal
|
||||
((2,1), "ones", np.array([[1],[1]])), # vertical
|
||||
((2,2), "upper_triangle", np.array([[1,1],[0,1]])),
|
||||
((2,2), "lower_triangle", np.array([[1,0],[1,1]])),
|
||||
((2,2), "x", np.array([[1,1],[1,1]])),
|
||||
((3,3), "x", np.array([[1,0,1],[0,1,0],[1,0,1]])),
|
||||
((2,2), "plus", np.array([[0,1],[1,1]])),
|
||||
((3,3), "plus", np.array([[0,1,0],[1,1,1],[0,1,0]])),
|
||||
((3,3), "ellipse", np.array([[0,1,0],[1,1,1],[0,1,0]])),
|
||||
((5,5), "ellipse",
|
||||
np.array([
|
||||
[0, 0, 1, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[0, 0, 1, 0, 0]
|
||||
])),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_shape, kernel_type, expected_kernel",
|
||||
[
|
||||
((2, 2), "ones", np.array([[1, 1], [1, 1]])), # sq kernel
|
||||
((1, 2), "ones", np.array([[1, 1]])), # horizontal
|
||||
((2, 1), "ones", np.array([[1], [1]])), # vertical
|
||||
((2, 2), "upper_triangle", np.array([[1, 1], [0, 1]])),
|
||||
((2, 2), "lower_triangle", np.array([[1, 0], [1, 1]])),
|
||||
((2, 2), "x", np.array([[1, 1], [1, 1]])),
|
||||
((3, 3), "x", np.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]])),
|
||||
((2, 2), "plus", np.array([[0, 1], [1, 1]])),
|
||||
((3, 3), "plus", np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])),
|
||||
((3, 3), "ellipse", np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])),
|
||||
(
|
||||
(5, 5),
|
||||
"ellipse",
|
||||
np.array(
|
||||
[
|
||||
[0, 0, 1, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[0, 0, 1, 0, 0],
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_create_2D_kernel(kernel_shape, kernel_type, expected_kernel):
|
||||
kernel = effect.create_2D_kernel(kernel_shape, kernel_type)
|
||||
assert np.array_equal(kernel, expected_kernel)
|
||||
|
||||
|
||||
def test_morphology_with_error():
|
||||
INVALID_OPERATION = "NOT_A_OPERATION"
|
||||
with pytest.raises(ValueError):
|
||||
effect.morphology(MOCK_IMG, operation=INVALID_OPERATION)
|
||||
|
||||
@pytest.mark.parametrize("operation, kernel_shape, kernel_type", [
|
||||
("open", (3,3), "ones"),
|
||||
("close", (3,3), "ones"),
|
||||
("dilate", (3,3), "ones"),
|
||||
("erode", (3,3), "ones"),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation, kernel_shape, kernel_type",
|
||||
[
|
||||
("open", (3, 3), "ones"),
|
||||
("close", (3, 3), "ones"),
|
||||
("dilate", (3, 3), "ones"),
|
||||
("erode", (3, 3), "ones"),
|
||||
],
|
||||
)
|
||||
def test_morphology(operation, kernel_shape, kernel_type):
|
||||
dst = effect.morphology(MOCK_IMG,
|
||||
operation=operation, kernel_shape=kernel_shape,
|
||||
kernel_type=kernel_type)
|
||||
dst = effect.morphology(
|
||||
MOCK_IMG,
|
||||
operation=operation,
|
||||
kernel_shape=kernel_shape,
|
||||
kernel_type=kernel_type,
|
||||
)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
@pytest.fixture(params=["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"])
|
||||
|
||||
@pytest.fixture(
|
||||
params=["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"]
|
||||
)
|
||||
def kernel(request):
|
||||
return effect.create_2D_kernel((5,5), request.param)
|
||||
return effect.create_2D_kernel((5, 5), request.param)
|
||||
|
||||
|
||||
def test_open(kernel):
|
||||
dst = effect.open(MOCK_IMG, kernel)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_close(kernel):
|
||||
dst = effect.close(MOCK_IMG, kernel)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_erode(kernel):
|
||||
dst = effect.erode(MOCK_IMG, kernel)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
||||
|
||||
def test_dilate(kernel):
|
||||
dst = effect.dilate(MOCK_IMG, kernel)
|
||||
assert dst.dtype == np.uint8
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
assert dst.shape == MOCK_IMG_SHAPE
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
from genalog.text import alignment, anchor, preprocess
|
||||
|
||||
import glob
|
||||
import pytest
|
||||
import difflib
|
||||
import glob
|
||||
import warnings
|
||||
|
||||
@pytest.mark.parametrize("gt_file, ocr_file",
|
||||
import pytest
|
||||
|
||||
from genalog.text import alignment, anchor, preprocess
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_file, ocr_file",
|
||||
zip(
|
||||
sorted(glob.glob("tests/text/data/gt_*.txt")),
|
||||
sorted(glob.glob("tests/text/data/ocr_*.txt"))
|
||||
)
|
||||
sorted(glob.glob("tests/text/data/ocr_*.txt")),
|
||||
),
|
||||
)
|
||||
def test_align_w_anchor_and_align(gt_file, ocr_file):
|
||||
gt_text = open(gt_file, "r").read()
|
||||
|
@ -21,17 +24,22 @@ def test_align_w_anchor_and_align(gt_file, ocr_file):
|
|||
aligned_anchor_gt = aligned_anchor_gt.split(".")
|
||||
aligned_gt = aligned_gt.split(".")
|
||||
str_diff = "\n".join(difflib.unified_diff(aligned_gt, aligned_anchor_gt))
|
||||
warnings.warn(UserWarning(
|
||||
"\n"+ f"{str_diff}" +
|
||||
f"\n\n**** Inconsistent Alignment Results between align() and " +
|
||||
f"align_w_anchor(). Ignore this if the delta is not significant. ****\n"))
|
||||
warnings.warn(
|
||||
UserWarning(
|
||||
"\n"
|
||||
+ f"{str_diff}"
|
||||
+ "\n\n**** Inconsistent Alignment Results between align() and "
|
||||
+ "align_w_anchor(). Ignore this if the delta is not significant. ****\n"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gt_file, ocr_file",
|
||||
@pytest.mark.parametrize(
|
||||
"gt_file, ocr_file",
|
||||
zip(
|
||||
sorted(glob.glob("tests/text/data/gt_*.txt")),
|
||||
sorted(glob.glob("tests/text/data/ocr_*.txt"))
|
||||
)
|
||||
sorted(glob.glob("tests/text/data/ocr_*.txt")),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("max_seg_length", [25, 50, 75, 100, 150])
|
||||
def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
|
||||
|
@ -39,7 +47,9 @@ def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
|
|||
ocr_text = open(ocr_file, "r").read()
|
||||
gt_tokens = preprocess.tokenize(gt_text)
|
||||
ocr_tokens = preprocess.tokenize(ocr_text)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
|
||||
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
|
||||
)
|
||||
for gt_anchor, ocr_anchor in zip(gt_anchors, ocr_anchors):
|
||||
# Ensure that each anchor word is the same word in both text
|
||||
assert gt_tokens[gt_anchor] == ocr_tokens[ocr_anchor]
|
||||
|
|
|
@ -1,47 +1,62 @@
|
|||
from genalog.text import conll_format
|
||||
from unittest import mock
|
||||
|
||||
import argparse
|
||||
import pytest
|
||||
import glob
|
||||
import itertools
|
||||
|
||||
@pytest.mark.parametrize("required_args", [
|
||||
(["tests/e2e/data/synthetic_dataset", "test_version"])
|
||||
])
|
||||
@pytest.mark.parametrize("optional_args", [
|
||||
(["--train_subset"]),
|
||||
(["--test_subset"]),
|
||||
(["--gt_folder", "shared"]),
|
||||
])
|
||||
import pytest
|
||||
|
||||
from genalog.text import conll_format
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"required_args", [(["tests/e2e/data/synthetic_dataset", "test_version"])]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"optional_args",
|
||||
[
|
||||
(["--train_subset"]),
|
||||
(["--test_subset"]),
|
||||
(["--gt_folder", "shared"]),
|
||||
],
|
||||
)
|
||||
def test_conll_format(required_args, optional_args):
|
||||
parser = conll_format.create_parser()
|
||||
arg_list = required_args + optional_args
|
||||
args = parser.parse_args(args=arg_list)
|
||||
conll_format.main(args)
|
||||
|
||||
|
||||
basepath = "tests/e2e/data/conll_formatter/"
|
||||
|
||||
@pytest.mark.parametrize("clean_label_filename, ocr_text_filename",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clean_label_filename, ocr_text_filename",
|
||||
zip(
|
||||
sorted(glob.glob("tests/e2e/data/conll_formatter/clean_labels/*.txt")),
|
||||
sorted(glob.glob("tests/e2e/data/conll_formatter/ocr_text/*.txt"))
|
||||
)
|
||||
sorted(glob.glob("tests/e2e/data/conll_formatter/ocr_text/*.txt")),
|
||||
),
|
||||
)
|
||||
def test_propagate_labels_sentence_single_file(clean_label_filename, ocr_text_filename):
|
||||
with open(clean_label_filename, 'r', encoding='utf-8') as clf:
|
||||
with open(clean_label_filename, "r", encoding="utf-8") as clf:
|
||||
tokens_labels_str = clf.readlines()
|
||||
clean_tokens = [line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2]
|
||||
clean_labels = [line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2]
|
||||
clean_tokens = [
|
||||
line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2
|
||||
]
|
||||
clean_labels = [
|
||||
line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2
|
||||
]
|
||||
clean_sentences = conll_format.get_sentences_from_iob_format(tokens_labels_str)
|
||||
# read ocr tokens
|
||||
with open(ocr_text_filename, 'r', encoding='utf-8') as otf:
|
||||
ocr_text_str = ' '.join(otf.readlines())
|
||||
ocr_tokens = [token.strip() for token in ocr_text_str.split()] # already tokenized in data
|
||||
with open(ocr_text_filename, "r", encoding="utf-8") as otf:
|
||||
ocr_text_str = " ".join(otf.readlines())
|
||||
ocr_tokens = [
|
||||
token.strip() for token in ocr_text_str.split()
|
||||
] # already tokenized in data
|
||||
|
||||
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
|
||||
assert len(ocr_text_sentences) == len(clean_sentences)
|
||||
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
|
||||
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
|
||||
|
||||
assert len(ocr_sentences_flatten) == len(
|
||||
ocr_tokens
|
||||
) # ensure aligned ocr tokens == ocr tokens
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
|
||||
import pytest
|
||||
import os
|
||||
|
||||
CONTENT = CompositeContent(["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH])
|
||||
import pytest
|
||||
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
|
||||
CONTENT = CompositeContent(
|
||||
["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH]
|
||||
)
|
||||
UNSUPPORTED_CONTENT_FORMAT = ["foo bar"]
|
||||
UNSUPPORTED_CONTENT_TYPE = CompositeContent(["foo"], [ContentType.TITLE])
|
||||
|
||||
|
@ -21,9 +24,10 @@ CUSTOM_STYLE = {
|
|||
"font_family": ["Calibri", "Times"],
|
||||
"font_size": ["10px"],
|
||||
"text_align": ["right"],
|
||||
"hyphenate": [True, False]
|
||||
"hyphenate": [True, False],
|
||||
}
|
||||
|
||||
|
||||
def test_default_template_generation():
|
||||
doc_gen = DocumentGenerator()
|
||||
generator = doc_gen.create_generator(CONTENT, doc_gen.template_list)
|
||||
|
@ -31,28 +35,36 @@ def test_default_template_generation():
|
|||
html_str = doc.render_html()
|
||||
assert "Unsupported Content Type:" not in html_str
|
||||
assert "No content loaded" not in html_str
|
||||
|
||||
|
||||
|
||||
def test_default_template_generation_w_unsupported_content_format():
|
||||
doc_gen = DocumentGenerator()
|
||||
generator = doc_gen.create_generator(UNSUPPORTED_CONTENT_FORMAT, doc_gen.template_list)
|
||||
generator = doc_gen.create_generator(
|
||||
UNSUPPORTED_CONTENT_FORMAT, doc_gen.template_list
|
||||
)
|
||||
for doc in generator:
|
||||
html_str = doc.render_html()
|
||||
assert "No content loaded" in html_str
|
||||
|
||||
|
||||
|
||||
def test_default_template_generation_w_unsupported_content_type():
|
||||
doc_gen = DocumentGenerator()
|
||||
generator = doc_gen.create_generator(UNSUPPORTED_CONTENT_TYPE, ["text_block.html.jinja"])
|
||||
generator = doc_gen.create_generator(
|
||||
UNSUPPORTED_CONTENT_TYPE, ["text_block.html.jinja"]
|
||||
)
|
||||
for doc in generator:
|
||||
html_str = doc.render_html()
|
||||
assert "Unsupported Content Type: ContentType.TITLE" in html_str
|
||||
|
||||
|
||||
def test_custom_template_generation():
|
||||
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
|
||||
generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME])
|
||||
doc = next(generator)
|
||||
result = doc.render_html()
|
||||
assert result == str(CONTENT)
|
||||
|
||||
|
||||
|
||||
def test_undefined_template_generation():
|
||||
doc_gen = DocumentGenerator()
|
||||
assert UNDEFINED_TEMPLATE_NAME not in doc_gen.template_list
|
||||
|
@ -60,6 +72,7 @@ def test_undefined_template_generation():
|
|||
with pytest.raises(FileNotFoundError):
|
||||
next(generator)
|
||||
|
||||
|
||||
def test_custom_style_template_generation():
|
||||
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
|
||||
assert len(doc_gen.styles_to_generate) == 1
|
||||
|
@ -70,14 +83,16 @@ def test_custom_style_template_generation():
|
|||
result = doc.render_html()
|
||||
assert doc.styles["font_family"] == result
|
||||
|
||||
|
||||
def test_render_pdf_and_png():
|
||||
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
|
||||
generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME])
|
||||
for doc in generator:
|
||||
pdf_bytes = doc.render_pdf()
|
||||
png_bytes = doc.render_png()
|
||||
assert pdf_bytes != None
|
||||
assert png_bytes != None
|
||||
assert pdf_bytes is not None
|
||||
assert png_bytes is not None
|
||||
|
||||
|
||||
def test_save_document_as_png():
|
||||
if not os.path.exists(TEST_OUTPUT_DIR):
|
||||
|
@ -86,15 +101,16 @@ def test_save_document_as_png():
|
|||
generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME])
|
||||
for doc in generator:
|
||||
doc.render_png(target=FILE_DESTINATION, resolution=100)
|
||||
# Check if the document is saved in filepath
|
||||
# Check if the document is saved in filepath
|
||||
assert os.path.exists(FILE_DESTINATION)
|
||||
|
||||
|
||||
def test_save_document_as_separate_png():
|
||||
if not os.path.exists(TEST_OUTPUT_DIR):
|
||||
os.mkdir(TEST_OUTPUT_DIR)
|
||||
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
|
||||
generator = doc_gen.create_generator(CONTENT, [MULTI_PAGE_TEMPLATE_NAME])
|
||||
|
||||
|
||||
document = next(generator)
|
||||
document.render_png(target=FILE_DESTINATION, split_pages=True, resolution=100)
|
||||
# Check if the document is saved as separated .png files
|
||||
|
@ -102,6 +118,7 @@ def test_save_document_as_separate_png():
|
|||
printed_doc_name = FILE_DESTINATION.replace(".png", f"_pg_{page_num}.png")
|
||||
assert os.path.exists(printed_doc_name)
|
||||
|
||||
|
||||
def test_overwriting_style():
|
||||
new_font = "NewFontFamily"
|
||||
doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH)
|
||||
|
@ -111,4 +128,4 @@ def test_overwriting_style():
|
|||
assert doc.styles["font_family"] != new_font
|
||||
doc.update_style(font_family=new_font)
|
||||
result = doc.render_html()
|
||||
assert new_font == result
|
||||
assert new_font == result
|
||||
|
|
|
@ -1,23 +1,25 @@
|
|||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.degradation.degrader import Degrader
|
||||
from genalog.degradation import effect
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
TEST_OUTPUT_DIR = "test_out/"
|
||||
SAMPLE_TXT = "Everton 's Duncan Ferguson , who scored twice against Manchester United on Wednesday , was picked on Thursday for the Scottish squad after a 20-month exile ."
|
||||
SAMPLE_TXT = """Everton 's Duncan Ferguson , who scored twice against Manchester United on Wednesday ,
|
||||
was picked on Thursday for the Scottish squad after a 20-month exile ."""
|
||||
DEFAULT_TEMPLATE = "text_block.html.jinja"
|
||||
DEGRADATION_EFFECTS = [
|
||||
("blur", {"radius": 5}),
|
||||
("bleed_through", {"alpha": 0.8}),
|
||||
("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "plus"}),
|
||||
(
|
||||
"morphology",
|
||||
{"operation": "open", "kernel_shape": (3, 3), "kernel_type": "plus"},
|
||||
),
|
||||
("morphology", {"operation": "close"}),
|
||||
("morphology", {"operation": "dilate"}),
|
||||
("morphology", {"operation": "erode"})
|
||||
("morphology", {"operation": "erode"}),
|
||||
]
|
||||
|
||||
|
||||
def test_generation_and_degradation():
|
||||
# Initiate content
|
||||
content = CompositeContent([SAMPLE_TXT], [ContentType.PARAGRAPH])
|
||||
|
@ -32,6 +34,4 @@ def test_generation_and_degradation():
|
|||
# get the image in bytes in RGBA channels
|
||||
src = doc.render_array(resolution=100, channel="GRAYSCALE")
|
||||
# run each degradation effect
|
||||
dst = degrader.apply_effects(src)
|
||||
|
||||
|
||||
degrader.apply_effects(src)
|
||||
|
|
|
@ -1,41 +1,44 @@
|
|||
from genalog.generation.document import DocumentGenerator
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import cv2
|
||||
import pytest
|
||||
|
||||
from genalog.generation.content import CompositeContent, ContentType
|
||||
from genalog.generation.document import DocumentGenerator
|
||||
|
||||
TEMPLATE_PATH = "tests/e2e/templates"
|
||||
TEST_OUT_FOLDER = "test_out/"
|
||||
SAMPLE_TXT = "foo"
|
||||
CONTENT = CompositeContent([SAMPLE_TXT], [ContentType.PARAGRAPH])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def doc_generator():
|
||||
return DocumentGenerator(template_path=TEMPLATE_PATH)
|
||||
|
||||
|
||||
def test_red_channel(doc_generator):
|
||||
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
|
||||
for doc in generator:
|
||||
doc.update_style(background_color="red")
|
||||
doc.update_style(background_color="red")
|
||||
img_array = doc.render_array(resolution=100, channel="BGRA")
|
||||
# css "red" is rgb(255,0,0) or bgra(0,0,255,255)
|
||||
assert tuple(img_array[0][0]) == (0, 0, 255, 255)
|
||||
assert tuple(img_array[0][0]) == (0, 0, 255, 255)
|
||||
cv2.imwrite(TEST_OUT_FOLDER + "red.png", img_array)
|
||||
|
||||
|
||||
def test_green_channel(doc_generator):
|
||||
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
|
||||
for doc in generator:
|
||||
doc.update_style(background_color="green")
|
||||
doc.update_style(background_color="green")
|
||||
img_array = doc.render_array(resolution=100, channel="BGRA")
|
||||
# css "green" is rgb(0,128,0) or bgra(0,128,0,255)
|
||||
assert tuple(img_array[0][0]) == (0, 128, 0, 255)
|
||||
cv2.imwrite(TEST_OUT_FOLDER + "green.png", img_array)
|
||||
|
||||
|
||||
def test_blue_channel(doc_generator):
|
||||
generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"])
|
||||
for doc in generator:
|
||||
doc.update_style(background_color="blue")
|
||||
doc.update_style(background_color="blue")
|
||||
img_array = doc.render_array(resolution=100, channel="BGRA")
|
||||
# css "blue" is rgb(0,0,255) or bgra(255,0,0,255)
|
||||
assert tuple(img_array[0][0]) == (255, 0, 0, 255)
|
||||
|
|
|
@ -1,48 +1,60 @@
|
|||
from genalog.ocr.rest_client import GrokRestClient
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from genalog.ocr.blob_client import GrokBlobClient
|
||||
from genalog.ocr.grok import Grok
|
||||
import requests
|
||||
import pytest
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv("tests/ocr/.env")
|
||||
|
||||
|
||||
class TestBlobClient:
|
||||
@pytest.mark.parametrize("use_async",[True, False])
|
||||
@pytest.mark.parametrize("use_async", [True, False])
|
||||
def test_upload_images(self, use_async):
|
||||
blob_client = GrokBlobClient.create_from_env_var()
|
||||
subfolder = "tests/ocr/data/img"
|
||||
file_prefix = subfolder.replace("/", "_")
|
||||
dst_folder, _ = blob_client.upload_images_to_blob(subfolder, use_async=use_async)
|
||||
subfolder.replace("/", "_")
|
||||
dst_folder, _ = blob_client.upload_images_to_blob(
|
||||
subfolder, use_async=use_async
|
||||
)
|
||||
uploaded_items, _ = blob_client.list_blobs(dst_folder)
|
||||
uploaded_items = sorted(list(uploaded_items), key = lambda x : x.name)
|
||||
uploaded_items = sorted(list(uploaded_items), key=lambda x: x.name)
|
||||
assert uploaded_items[0].name == f"{dst_folder}/0.png"
|
||||
assert uploaded_items[1].name == f"{dst_folder}/1.png"
|
||||
assert uploaded_items[2].name == f"{dst_folder}/11.png"
|
||||
blob_client.delete_blobs_folder(dst_folder)
|
||||
assert len(list(blob_client.list_blobs(dst_folder)[0])) == 0, f"folder {dst_folder} was not deleted"
|
||||
assert (
|
||||
len(list(blob_client.list_blobs(dst_folder)[0])) == 0
|
||||
), f"folder {dst_folder} was not deleted"
|
||||
|
||||
dst_folder, _ = blob_client.upload_images_to_blob(subfolder, "test_images", use_async=use_async)
|
||||
assert dst_folder == "test_images"
|
||||
dst_folder, _ = blob_client.upload_images_to_blob(
|
||||
subfolder, "test_images", use_async=use_async
|
||||
)
|
||||
assert dst_folder == "test_images"
|
||||
uploaded_items, _ = blob_client.list_blobs(dst_folder)
|
||||
uploaded_items = sorted(list(uploaded_items), key = lambda x : x.name)
|
||||
uploaded_items = sorted(list(uploaded_items), key=lambda x: x.name)
|
||||
assert uploaded_items[0].name == f"{dst_folder}/0.png"
|
||||
assert uploaded_items[1].name == f"{dst_folder}/1.png"
|
||||
assert uploaded_items[2].name == f"{dst_folder}/11.png"
|
||||
blob_client.delete_blobs_folder(dst_folder)
|
||||
assert len(list(blob_client.list_blobs(dst_folder)[0])) == 0, f"folder {dst_folder} was not deleted"
|
||||
assert (
|
||||
len(list(blob_client.list_blobs(dst_folder)[0])) == 0
|
||||
), f"folder {dst_folder} was not deleted"
|
||||
|
||||
|
||||
class TestGROKe2e:
|
||||
@pytest.mark.parametrize("use_async",[False,True])
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
def test_grok_e2e(self, tmpdir, use_async):
|
||||
grok = Grok.create_from_env_var()
|
||||
src_folder = "tests/ocr/data/img"
|
||||
grok.run_grok(src_folder, tmpdir, blob_dest_folder="testimages", use_async=use_async, cleanup=True)
|
||||
json_folder = "tests/ocr/data/json"
|
||||
json_hash = "521c38122f783673598856cd81d91c21"
|
||||
assert json.load(open(f"{tmpdir}/0.json", "r"))[0]["text"]
|
||||
grok.run_grok(
|
||||
src_folder,
|
||||
tmpdir,
|
||||
blob_dest_folder="testimages",
|
||||
use_async=use_async,
|
||||
cleanup=True,
|
||||
)
|
||||
assert json.load(open(f"{tmpdir}/0.json", "r"))[0]["text"]
|
||||
assert json.load(open(f"{tmpdir}/1.json", "r"))[0]["text"]
|
||||
assert json.load(open(f"{tmpdir}/11.json", "r"))[0]["text"]
|
||||
|
||||
|
|
|
@ -1,32 +1,43 @@
|
|||
from genalog import pipeline
|
||||
|
||||
import pytest
|
||||
import glob
|
||||
|
||||
import pytest
|
||||
|
||||
from genalog import pipeline
|
||||
|
||||
EXAMPLE_TEXT_FILE = "tests/text/data/gt_1.txt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_analog_generator():
|
||||
return pipeline.AnalogDocumentGeneration()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_analog_generator():
|
||||
custom_styles = {"font_size": ["5px"]}
|
||||
custom_degradation = [("blur", {"radius": 3})]
|
||||
return pipeline.AnalogDocumentGeneration(
|
||||
styles=custom_styles,
|
||||
degradations=custom_degradation,
|
||||
resolution=300)
|
||||
styles=custom_styles, degradations=custom_degradation, resolution=300
|
||||
)
|
||||
|
||||
|
||||
def test_default_generate_img(default_analog_generator):
|
||||
example_template = default_analog_generator.list_templates()[0]
|
||||
img_array = default_analog_generator.generate_img(EXAMPLE_TEXT_FILE, example_template, target_folder=None)
|
||||
default_analog_generator.generate_img(
|
||||
EXAMPLE_TEXT_FILE, example_template, target_folder=None
|
||||
)
|
||||
|
||||
|
||||
def test_custom_generate_img(custom_analog_generator):
|
||||
example_template = custom_analog_generator.list_templates()[0]
|
||||
img_array = custom_analog_generator.generate_img(EXAMPLE_TEXT_FILE, example_template, target_folder=None)
|
||||
custom_analog_generator.generate_img(
|
||||
EXAMPLE_TEXT_FILE, example_template, target_folder=None
|
||||
)
|
||||
|
||||
|
||||
def test_generate_dataset_multiprocess():
|
||||
INPUT_TEXT_FILENAMES = glob.glob("tests/text/data/gt_*.txt")
|
||||
with pytest.deprecated_call():
|
||||
pipeline.generate_dataset_multiprocess(INPUT_TEXT_FILENAMES, "test_out", {}, [], "text_block.html.jinja")
|
||||
pipeline.generate_dataset_multiprocess(
|
||||
INPUT_TEXT_FILENAMES, "test_out", {}, [], "text_block.html.jinja"
|
||||
)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import os
|
||||
import difflib
|
||||
import os
|
||||
|
||||
from genalog.text.splitter import CONLL2003_DOC_SEPERATOR, generate_splits
|
||||
|
||||
from genalog.text.splitter import generate_splits, CONLL2003_DOC_SEPERATOR
|
||||
|
||||
def _compare_content(file1, file2):
|
||||
txt1 = open(file1, "r").read()
|
||||
|
@ -12,15 +13,32 @@ def _compare_content(file1, file2):
|
|||
str_diff = "\n".join(difflib.unified_diff(sentences_txt1, sentences_txt2))
|
||||
assert False, f"Delta between outputs: \n {str_diff}"
|
||||
|
||||
|
||||
def test_splitter(tmpdir):
|
||||
# tmpdir = "test_out"
|
||||
os.makedirs(f"{tmpdir}/clean_labels")
|
||||
os.makedirs(f"{tmpdir}/clean_text")
|
||||
|
||||
generate_splits("tests/e2e/data/splitter/example_conll2012.txt", tmpdir,
|
||||
doc_seperator=CONLL2003_DOC_SEPERATOR, sentence_seperator="")
|
||||
|
||||
_compare_content("tests/e2e/data/splitter/example_splits/clean_text/0.txt", f"{tmpdir}/clean_text/0.txt")
|
||||
_compare_content("tests/e2e/data/splitter/example_splits/clean_text/1.txt", f"{tmpdir}/clean_text/1.txt")
|
||||
_compare_content("tests/e2e/data/splitter/example_splits/clean_labels/0.txt", f"{tmpdir}/clean_labels/0.txt")
|
||||
_compare_content("tests/e2e/data/splitter/example_splits/clean_labels/1.txt", f"{tmpdir}/clean_labels/1.txt")
|
||||
generate_splits(
|
||||
"tests/e2e/data/splitter/example_conll2012.txt",
|
||||
tmpdir,
|
||||
doc_seperator=CONLL2003_DOC_SEPERATOR,
|
||||
sentence_seperator="",
|
||||
)
|
||||
|
||||
_compare_content(
|
||||
"tests/e2e/data/splitter/example_splits/clean_text/0.txt",
|
||||
f"{tmpdir}/clean_text/0.txt",
|
||||
)
|
||||
_compare_content(
|
||||
"tests/e2e/data/splitter/example_splits/clean_text/1.txt",
|
||||
f"{tmpdir}/clean_text/1.txt",
|
||||
)
|
||||
_compare_content(
|
||||
"tests/e2e/data/splitter/example_splits/clean_labels/0.txt",
|
||||
f"{tmpdir}/clean_labels/0.txt",
|
||||
)
|
||||
_compare_content(
|
||||
"tests/e2e/data/splitter/example_splits/clean_labels/1.txt",
|
||||
f"{tmpdir}/clean_labels/1.txt",
|
||||
)
|
||||
|
|
|
@ -1,62 +1,76 @@
|
|||
from genalog.generation.content import *
|
||||
|
||||
import pytest
|
||||
|
||||
from genalog.generation.content import CompositeContent, Content, ContentType
|
||||
from genalog.generation.content import Paragraph, Title
|
||||
|
||||
CONTENT_LIST = ["foo", "bar"]
|
||||
COMPOSITE_CONTENT_TYPE = [ContentType.TITLE, ContentType.PARAGRAPH]
|
||||
TEXT = "foo bar"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def content_base_class():
|
||||
return Content()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def paragraph():
|
||||
return Paragraph(TEXT)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def title():
|
||||
return Title(TEXT)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def section():
|
||||
return CompositeContent(CONTENT_LIST, COMPOSITE_CONTENT_TYPE)
|
||||
|
||||
|
||||
|
||||
def test_content_set_content_type(content_base_class):
|
||||
with pytest.raises(TypeError):
|
||||
content_base_class.set_content_type("NOT VALID CONTENT TYPE")
|
||||
content_base_class.set_content_type(ContentType.PARAGRAPH)
|
||||
|
||||
|
||||
def test_paragraph_init(paragraph):
|
||||
with pytest.raises(TypeError):
|
||||
Paragraph([])
|
||||
assert paragraph.content_type == ContentType.PARAGRAPH
|
||||
|
||||
|
||||
def test_paragraph_print(paragraph):
|
||||
assert paragraph.__str__()
|
||||
|
||||
|
||||
def test_paragraph_iterable_indexable(paragraph):
|
||||
for index, character in enumerate(paragraph):
|
||||
assert character == paragraph[index]
|
||||
|
||||
|
||||
def test_title_init(title):
|
||||
with pytest.raises(TypeError):
|
||||
Title([])
|
||||
assert title.content_type == ContentType.TITLE
|
||||
|
||||
|
||||
def test_title_iterable_indexable(title):
|
||||
for index, character in enumerate(title):
|
||||
assert character == title[index]
|
||||
|
||||
|
||||
def test_composite_content_init(section):
|
||||
with pytest.raises(TypeError):
|
||||
CompositeContent((),[])
|
||||
CompositeContent((), [])
|
||||
assert section.content_type == ContentType.COMPOSITE
|
||||
|
||||
|
||||
def test_composite_content_iterable(section):
|
||||
for index, content in enumerate(section):
|
||||
assert content.content_type == COMPOSITE_CONTENT_TYPE[index]
|
||||
|
||||
|
||||
|
||||
def test_composite_content_print(section):
|
||||
assert "foo" in section.__str__()
|
||||
assert "bar" in section.__str__()
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from genalog.generation.document import Document, DocumentGenerator
|
||||
from genalog.generation.document import DEFAULT_DOCUMENT_STYLE
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from genalog.generation.document import DEFAULT_DOCUMENT_STYLE
|
||||
from genalog.generation.document import Document, DocumentGenerator
|
||||
|
||||
|
||||
FRENCH = "fr"
|
||||
CONTENT = ["some text"]
|
||||
|
@ -21,84 +23,106 @@ DEFAULT_TEMPLATE_NAME = "text_block.html.jinja"
|
|||
DEFAULT_PACKAGE_NAME = "genalog.generation"
|
||||
DEFAULT_TEMPLATE_FOLDER = "templates"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_document():
|
||||
mock_jinja_template = MagicMock()
|
||||
mock_jinja_template.render.return_value = MOCK_COMPILED_DOCUMENT
|
||||
return Document(CONTENT, mock_jinja_template)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def french_document():
|
||||
mock_jinja_template = MagicMock()
|
||||
mock_jinja_template.render.return_value = MOCK_COMPILED_DOCUMENT
|
||||
return Document(CONTENT, mock_jinja_template, language=FRENCH)
|
||||
|
||||
|
||||
def test_document_init(default_document):
|
||||
assert default_document.styles == DEFAULT_DOCUMENT_STYLE
|
||||
assert default_document._document != None
|
||||
assert default_document.compiled_html != None
|
||||
|
||||
assert default_document._document is not None
|
||||
assert default_document.compiled_html is not None
|
||||
|
||||
|
||||
def test_document_init_with_kwargs(french_document):
|
||||
assert french_document.styles["language"] == FRENCH
|
||||
assert french_document._document != None
|
||||
assert french_document.compiled_html != None
|
||||
assert french_document._document is not None
|
||||
assert french_document.compiled_html is not None
|
||||
|
||||
|
||||
def test_document_render_html(french_document):
|
||||
compiled_document = french_document.render_html()
|
||||
assert compiled_document == MOCK_COMPILED_DOCUMENT
|
||||
french_document.template.render.assert_called_with(content=CONTENT, **french_document.styles)
|
||||
french_document.template.render.assert_called_with(
|
||||
content=CONTENT, **french_document.styles
|
||||
)
|
||||
|
||||
|
||||
def test_document_render_pdf(default_document):
|
||||
default_document._document = MagicMock()
|
||||
# run tested function
|
||||
default_document.render_pdf(target=FILE_DESTINATION_PDF, zoom=2)
|
||||
default_document._document.write_pdf.assert_called_with(target=FILE_DESTINATION_PDF, zoom=2)
|
||||
default_document._document.write_pdf.assert_called_with(
|
||||
target=FILE_DESTINATION_PDF, zoom=2
|
||||
)
|
||||
|
||||
|
||||
def test_document_render_png(default_document):
|
||||
default_document._document = MagicMock()
|
||||
# run tested function
|
||||
default_document.render_png(target=FILE_DESTINATION_PNG, resolution=100)
|
||||
default_document._document.write_png.assert_called_with(target=FILE_DESTINATION_PNG, resolution=100)
|
||||
default_document._document.write_png.assert_called_with(
|
||||
target=FILE_DESTINATION_PNG, resolution=100
|
||||
)
|
||||
|
||||
|
||||
def test_document_render_png_split_pages(default_document):
|
||||
default_document._document.copy = MagicMock()
|
||||
# run tested function
|
||||
default_document.render_png(target=FILE_DESTINATION_PNG, split_pages=True, resolution=100)
|
||||
default_document.render_png(
|
||||
target=FILE_DESTINATION_PNG, split_pages=True, resolution=100
|
||||
)
|
||||
result_destination = FILE_DESTINATION_PNG.replace(".png", "_pg_0.png")
|
||||
# assertion
|
||||
document_copy = default_document._document.copy.return_value
|
||||
document_copy.write_png.assert_called_with(target=result_destination, resolution=100)
|
||||
document_copy.write_png.assert_called_with(
|
||||
target=result_destination, resolution=100
|
||||
)
|
||||
|
||||
|
||||
def test_document_render_array_valid_args(default_document):
|
||||
# setup mock
|
||||
mock_surface = MagicMock()
|
||||
mock_surface.get_format.return_value = 0 # 0 == cairocffi.FORMAT_ARGB32
|
||||
mock_surface.get_format.return_value = 0 # 0 == cairocffi.FORMAT_ARGB32
|
||||
mock_surface.get_data = MagicMock(return_value=IMG_BYTES) # loading a 2x2 image
|
||||
mock_write_image_surface = MagicMock(return_value=(mock_surface, 2, 2))
|
||||
default_document._document.write_image_surface = mock_write_image_surface
|
||||
|
||||
channel_types = ["RGBA", "RGB", "GRAYSCALE", "BGRA", "BGR"]
|
||||
expected_img_shape = [(2,2,4), (2,2,3), (2,2), (2,2,4), (2,2,3)]
|
||||
|
||||
expected_img_shape = [(2, 2, 4), (2, 2, 3), (2, 2), (2, 2, 4), (2, 2, 3)]
|
||||
|
||||
for channel_type, expected_img_shape in zip(channel_types, expected_img_shape):
|
||||
img_array = default_document.render_array(resolution=100, channel=channel_type)
|
||||
assert img_array.shape == expected_img_shape
|
||||
|
||||
|
||||
def test_document_render_array_invalid_args(default_document):
|
||||
invalid_channel_types = "INVALID"
|
||||
with pytest.raises(ValueError):
|
||||
default_document.render_array(resolution=100, channel=invalid_channel_types)
|
||||
|
||||
|
||||
def test_document_render_array_invalid_format(default_document):
|
||||
# setup mock
|
||||
mock_surface = MagicMock()
|
||||
mock_surface.get_format.return_value = 1 # 1 != cairocffi.FORMAT_ARGB32
|
||||
mock_surface.get_format.return_value = 1 # 1 != cairocffi.FORMAT_ARGB32
|
||||
mock_write_image_surface = MagicMock(return_value=(mock_surface, 2, 2))
|
||||
default_document._document.write_image_surface = mock_write_image_surface
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
default_document.render_array(resolution=100)
|
||||
|
||||
|
||||
def test_document_update_style(default_document):
|
||||
new_style = {"language": FRENCH, "new_property": "some value"}
|
||||
# Ensure that a new property is not already defined
|
||||
|
@ -111,11 +135,13 @@ def test_document_update_style(default_document):
|
|||
# Ensure that a new property is added
|
||||
assert default_document.styles["new_property"] == new_style["new_property"]
|
||||
|
||||
|
||||
@patch("genalog.generation.document.Environment")
|
||||
@patch("genalog.generation.document.PackageLoader")
|
||||
@patch("genalog.generation.document.FileSystemLoader")
|
||||
def test_document_generator_init_default_setting(mock_file_system_loader,
|
||||
mock_package_loader, mock_environment):
|
||||
def test_document_generator_init_default_setting(
|
||||
mock_file_system_loader, mock_package_loader, mock_environment
|
||||
):
|
||||
# setup mock template environment
|
||||
mock_environment_instance = mock_environment.return_value
|
||||
mock_environment_instance.list_templates.return_value = [DEFAULT_TEMPLATE_NAME]
|
||||
|
@ -123,15 +149,19 @@ def test_document_generator_init_default_setting(mock_file_system_loader,
|
|||
document_generator = DocumentGenerator()
|
||||
# Ensure the right loader is called
|
||||
mock_file_system_loader.assert_not_called()
|
||||
mock_package_loader.assert_called_with(DEFAULT_PACKAGE_NAME, DEFAULT_TEMPLATE_FOLDER)
|
||||
mock_package_loader.assert_called_with(
|
||||
DEFAULT_PACKAGE_NAME, DEFAULT_TEMPLATE_FOLDER
|
||||
)
|
||||
# Ensure that the default template in the package is loaded
|
||||
assert DEFAULT_TEMPLATE_NAME in document_generator.template_list
|
||||
|
||||
|
||||
@patch("genalog.generation.document.Environment")
|
||||
@patch("genalog.generation.document.PackageLoader")
|
||||
@patch("genalog.generation.document.FileSystemLoader")
|
||||
def test_document_generator_init_custom_template(mock_file_system_loader,
|
||||
mock_package_loader, mock_environment):
|
||||
def test_document_generator_init_custom_template(
|
||||
mock_file_system_loader, mock_package_loader, mock_environment
|
||||
):
|
||||
# setup mock template environment
|
||||
mock_environment_instance = mock_environment.return_value
|
||||
mock_environment_instance.list_templates.return_value = [CUSTOM_TEMPLATE_NAME]
|
||||
|
@ -143,67 +173,79 @@ def test_document_generator_init_custom_template(mock_file_system_loader,
|
|||
# Ensure that the expected template is registered
|
||||
assert CUSTOM_TEMPLATE_NAME in document_generator.template_list
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_document_generator():
|
||||
with patch("genalog.generation.document.Environment") as MockEnvironment:
|
||||
template_environment_instance = MockEnvironment.return_value
|
||||
template_environment_instance.list_templates.return_value = [DEFAULT_TEMPLATE_NAME]
|
||||
template_environment_instance.list_templates.return_value = [
|
||||
DEFAULT_TEMPLATE_NAME
|
||||
]
|
||||
template_environment_instance.get_template.return_value = MOCK_TEMPLATE
|
||||
doc_gen = DocumentGenerator()
|
||||
return doc_gen
|
||||
return doc_gen
|
||||
|
||||
|
||||
def test_document_generator_create_generator(default_document_generator):
|
||||
available_templates = default_document_generator.template_list
|
||||
assert len(available_templates) < 2
|
||||
generator = default_document_generator.create_generator(CONTENT, available_templates)
|
||||
doc = next(generator)
|
||||
generator = default_document_generator.create_generator(
|
||||
CONTENT, available_templates
|
||||
)
|
||||
next(generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
|
||||
|
||||
def test_document_generator_create_generator_(default_document_generator):
|
||||
# setup test case
|
||||
available_templates = default_document_generator.template_list
|
||||
undefined_template = "NOT A VALID TEMPLATE"
|
||||
assert undefined_template not in available_templates
|
||||
|
||||
generator = default_document_generator.create_generator(CONTENT, [undefined_template])
|
||||
generator = default_document_generator.create_generator(
|
||||
CONTENT, [undefined_template]
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
doc = next(generator)
|
||||
next(generator)
|
||||
|
||||
@pytest.mark.parametrize("template_name, expected_output", [
|
||||
("base.html.jinja", False),
|
||||
("text_block.html.jinja", True),
|
||||
("text_block.css.jinja", False),
|
||||
("macro/dimension.css.jinja", False)
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"template_name, expected_output",
|
||||
[
|
||||
("base.html.jinja", False),
|
||||
("text_block.html.jinja", True),
|
||||
("text_block.css.jinja", False),
|
||||
("macro/dimension.css.jinja", False),
|
||||
],
|
||||
)
|
||||
def test__keep_templates(template_name, expected_output):
|
||||
output = DocumentGenerator._keep_template(template_name)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_set_styles_to_generate(default_document_generator):
|
||||
assert len(default_document_generator.styles_to_generate) == 1
|
||||
default_document_generator.set_styles_to_generate({"foo": ["bar", "bar"]})
|
||||
assert len(default_document_generator.styles_to_generate) == 2
|
||||
|
||||
@pytest.mark.parametrize("styles, expected_output", [
|
||||
({}, []), # empty case
|
||||
({"size": ["10px"], "color": [] }, []), #empty value will result in null combinations
|
||||
(
|
||||
{"size": ["10px"], "color": ["red"] },
|
||||
[{"size":"10px", "color":"red"}]
|
||||
),
|
||||
(
|
||||
{"size": ["5px", "10px"]},
|
||||
[{"size": "5px"}, {"size": "10px"}]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"styles, expected_output",
|
||||
[
|
||||
({}, []), # empty case
|
||||
(
|
||||
{"size": ["10px"], "color": []},
|
||||
[],
|
||||
), # empty value will result in null combinations
|
||||
({"size": ["10px"], "color": ["red"]}, [{"size": "10px", "color": "red"}]),
|
||||
({"size": ["5px", "10px"]}, [{"size": "5px"}, {"size": "10px"}]),
|
||||
(
|
||||
{"size": ["10px", "15px"], "color": ["blue"]},
|
||||
[{"size": "10px", "color": "blue"}, {"size": "15px", "color": "blue"}],
|
||||
),
|
||||
(
|
||||
{"size": ["10px", "15px"], "color": ["blue"] },
|
||||
[
|
||||
{"size":"10px", "color": "blue"},
|
||||
{"size":"15px", "color": "blue"}
|
||||
]
|
||||
)
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_document_generator_expand_style_combinations(styles, expected_output):
|
||||
output = DocumentGenerator.expand_style_combinations(styles)
|
||||
assert output == expected_output
|
||||
assert output == expected_output
|
||||
|
|
|
@ -1,212 +1,611 @@
|
|||
from genalog.ocr.metrics import get_align_stats, get_editops_stats, get_metrics, get_stats
|
||||
from genalog.text.anchor import align_w_anchor
|
||||
from genalog.text.alignment import GAP_CHAR, align
|
||||
from genalog.text.ner_label import _find_gap_char_candidates
|
||||
from pandas._testing import assert_frame_equal
|
||||
|
||||
import pytest
|
||||
import genalog.ocr.metrics
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import json
|
||||
import pickle
|
||||
import os
|
||||
|
||||
import genalog.ocr.metrics
|
||||
from genalog.ocr.metrics import get_align_stats, get_editops_stats, get_stats
|
||||
from genalog.text.alignment import align, GAP_CHAR
|
||||
from genalog.text.ner_label import _find_gap_char_candidates
|
||||
|
||||
|
||||
genalog.ocr.metrics.LOG_LEVEL = 0
|
||||
|
||||
@pytest.mark.parametrize("src_string, target, expected_stats",
|
||||
[
|
||||
("a worn coat", "a wom coat",
|
||||
{'edit_insert': 1, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
(" ", "a",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("a", " ",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("a", "a",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 0, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("ab", "ac",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("ac", "ab",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("New York is big.", "N ewYork kis big.",
|
||||
{'edit_insert': 0, 'edit_delete': 1, 'edit_replace': 0, 'edit_insert_spacing': 1, 'edit_delete_spacing': 1}),
|
||||
("B oston grea t", "Boston is great",
|
||||
{'edit_insert': 0, 'edit_delete': 2, 'edit_replace': 0, 'edit_insert_spacing': 2, 'edit_delete_spacing': 1}),
|
||||
("New York is big.", "N ewyork kis big",
|
||||
{'edit_insert': 1, 'edit_delete': 1, 'edit_replace': 1, 'edit_insert_spacing': 1, 'edit_delete_spacing': 1}),
|
||||
("dog", "d@g", # Test against default gap_char "@"
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}),
|
||||
("some@one.com", "some@one.com",
|
||||
{'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 0, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0})
|
||||
])
|
||||
def test_editops_stats(src_string, target, expected_stats):
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
|
||||
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
alignment = align(target, src_string)
|
||||
stats, actions = get_editops_stats(alignment, gap_char)
|
||||
for k in expected_stats:
|
||||
assert stats[k] == expected_stats[k], (k,stats[k], expected_stats[k])
|
||||
|
||||
@pytest.mark.parametrize("src_string, target, expected_stats, expected_substitutions",
|
||||
[
|
||||
(
|
||||
"a worn coat", "a wom coat",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 11, 'total_words': 3,
|
||||
'matching_chars': 9, 'matching_words': 2,"matching_alnum_words" : 2, 'word_accuracy': 2/3, 'char_accuracy': 9/11},
|
||||
{('rn', 'm'): 1}
|
||||
),
|
||||
(
|
||||
"a c", "def",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1 , 'spacing': 0, 'total_chars': 3, 'total_words': 1,
|
||||
'matching_chars': 0, 'matching_words': 0,"matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 0},
|
||||
{('a c', 'def'): 1}
|
||||
),
|
||||
(
|
||||
"a", "a b",
|
||||
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 2,
|
||||
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 0.5, 'char_accuracy': 1/3},
|
||||
{}
|
||||
),
|
||||
(
|
||||
"a b", "b",
|
||||
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 1,
|
||||
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1, 'char_accuracy': 1/3},
|
||||
{}
|
||||
),
|
||||
(
|
||||
"a b", "a",
|
||||
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 1,
|
||||
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1, 'char_accuracy': 1/3},
|
||||
{}
|
||||
),
|
||||
(
|
||||
"b ..", "a b ..",
|
||||
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 1, 'total_chars': 6, 'total_words': 3, 'total_alnum_words': 2,
|
||||
'matching_chars': 4, 'matching_words': 2, "matching_alnum_words" : 1, 'word_accuracy': 2/3, 'char_accuracy': 4/6},
|
||||
{}
|
||||
),
|
||||
(
|
||||
"taxi cab", "taxl c b",
|
||||
{'insert': 0, 'delete': 1, 'replace': 1, 'spacing': 1, 'total_chars': 9, 'total_words': 3,
|
||||
'matching_chars': 6, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 6/9},
|
||||
{('i','l'):1}
|
||||
),
|
||||
(
|
||||
"taxl c b ri de", "taxi cab ride",
|
||||
{'insert': 1, 'delete': 0, 'replace': 1, 'spacing': 6, 'total_chars': 18, 'total_words': 3,
|
||||
'matching_chars': 11, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 11/18},
|
||||
{('l','i'):1}
|
||||
),
|
||||
(
|
||||
"ab", "ac",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 2, 'total_words': 1,
|
||||
'matching_chars': 1, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0.0, 'char_accuracy': 0.5},
|
||||
{}
|
||||
),
|
||||
(
|
||||
"a", "a",
|
||||
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 1, 'total_words': 1,
|
||||
'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1.0, 'char_accuracy': 1.0},
|
||||
{}
|
||||
),
|
||||
(
|
||||
"New York is big.", "N ewYork kis big.",
|
||||
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 2, 'total_chars': 17, 'total_words': 4,
|
||||
'matching_chars': 15, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1/4, 'char_accuracy': 15/17},
|
||||
{}
|
||||
),
|
||||
(
|
||||
"B oston grea t", "Boston is great",
|
||||
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 3, 'total_chars': 15, 'total_words': 3,
|
||||
'matching_chars': 12, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0.0, 'char_accuracy': 0.8},
|
||||
{}
|
||||
),
|
||||
(
|
||||
"New York is big.", "N ewyork kis big",
|
||||
{'insert': 1, 'delete': 1, 'replace': 1, 'spacing': 2, 'total_chars': 16, 'total_words': 4,
|
||||
'matching_chars': 13, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 13/16},
|
||||
{('Y', 'y'): 1}
|
||||
),
|
||||
(
|
||||
"dog", "d@g",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
|
||||
{('o', '@'): 1}
|
||||
),
|
||||
(
|
||||
"some@one.com", "some@one.com",
|
||||
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 12, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 12, 'matching_alnum_words': 1, 'matching_words': 1, 'alnum_word_accuracy': 1.0, 'word_accuracy': 1.0, 'char_accuracy': 1.0}, {}
|
||||
@pytest.mark.parametrize(
|
||||
"src_string, target, expected_stats",
|
||||
[
|
||||
(
|
||||
"a worn coat",
|
||||
"a wom coat",
|
||||
{
|
||||
"edit_insert": 1,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
" ",
|
||||
"a",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"a",
|
||||
" ",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"a",
|
||||
"a",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 0,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"ab",
|
||||
"ac",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"ac",
|
||||
"ab",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"New York is big.",
|
||||
"N ewYork kis big.",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 1,
|
||||
"edit_replace": 0,
|
||||
"edit_insert_spacing": 1,
|
||||
"edit_delete_spacing": 1,
|
||||
},
|
||||
),
|
||||
(
|
||||
"B oston grea t",
|
||||
"Boston is great",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 2,
|
||||
"edit_replace": 0,
|
||||
"edit_insert_spacing": 2,
|
||||
"edit_delete_spacing": 1,
|
||||
},
|
||||
),
|
||||
(
|
||||
"New York is big.",
|
||||
"N ewyork kis big",
|
||||
{
|
||||
"edit_insert": 1,
|
||||
"edit_delete": 1,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 1,
|
||||
"edit_delete_spacing": 1,
|
||||
},
|
||||
),
|
||||
(
|
||||
"dog",
|
||||
"d@g", # Test against default gap_char "@"
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 1,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"some@one.com",
|
||||
"some@one.com",
|
||||
{
|
||||
"edit_insert": 0,
|
||||
"edit_delete": 0,
|
||||
"edit_replace": 0,
|
||||
"edit_insert_spacing": 0,
|
||||
"edit_delete_spacing": 0,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_editops_stats(src_string, target, expected_stats):
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(
|
||||
[src_string], [target]
|
||||
)
|
||||
])
|
||||
gap_char = (
|
||||
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
)
|
||||
alignment = align(target, src_string)
|
||||
stats, actions = get_editops_stats(alignment, gap_char)
|
||||
for k in expected_stats:
|
||||
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src_string, target, expected_stats, expected_substitutions",
|
||||
[
|
||||
(
|
||||
"a worn coat",
|
||||
"a wom coat",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 11,
|
||||
"total_words": 3,
|
||||
"matching_chars": 9,
|
||||
"matching_words": 2,
|
||||
"matching_alnum_words": 2,
|
||||
"word_accuracy": 2 / 3,
|
||||
"char_accuracy": 9 / 11,
|
||||
},
|
||||
{("rn", "m"): 1},
|
||||
),
|
||||
(
|
||||
"a c",
|
||||
"def",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"matching_chars": 0,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0,
|
||||
"char_accuracy": 0,
|
||||
},
|
||||
{("a c", "def"): 1},
|
||||
),
|
||||
(
|
||||
"a",
|
||||
"a b",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 1,
|
||||
"total_chars": 3,
|
||||
"total_words": 2,
|
||||
"matching_chars": 1,
|
||||
"matching_words": 1,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 0.5,
|
||||
"char_accuracy": 1 / 3,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"a b",
|
||||
"b",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 1,
|
||||
"replace": 0,
|
||||
"spacing": 1,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_words": 1,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 1,
|
||||
"char_accuracy": 1 / 3,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"a b",
|
||||
"a",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 1,
|
||||
"replace": 0,
|
||||
"spacing": 1,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_words": 1,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 1,
|
||||
"char_accuracy": 1 / 3,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"b ..",
|
||||
"a b ..",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 1,
|
||||
"total_chars": 6,
|
||||
"total_words": 3,
|
||||
"total_alnum_words": 2,
|
||||
"matching_chars": 4,
|
||||
"matching_words": 2,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 2 / 3,
|
||||
"char_accuracy": 4 / 6,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"taxi cab",
|
||||
"taxl c b",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 1,
|
||||
"replace": 1,
|
||||
"spacing": 1,
|
||||
"total_chars": 9,
|
||||
"total_words": 3,
|
||||
"matching_chars": 6,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0,
|
||||
"char_accuracy": 6 / 9,
|
||||
},
|
||||
{("i", "l"): 1},
|
||||
),
|
||||
(
|
||||
"taxl c b ri de",
|
||||
"taxi cab ride",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 6,
|
||||
"total_chars": 18,
|
||||
"total_words": 3,
|
||||
"matching_chars": 11,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0,
|
||||
"char_accuracy": 11 / 18,
|
||||
},
|
||||
{("l", "i"): 1},
|
||||
),
|
||||
(
|
||||
"ab",
|
||||
"ac",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 2,
|
||||
"total_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 0.5,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"a",
|
||||
"a",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": 1,
|
||||
"total_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_words": 1,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 1.0,
|
||||
"char_accuracy": 1.0,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"New York is big.",
|
||||
"N ewYork kis big.",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 2,
|
||||
"total_chars": 17,
|
||||
"total_words": 4,
|
||||
"matching_chars": 15,
|
||||
"matching_words": 1,
|
||||
"matching_alnum_words": 1,
|
||||
"word_accuracy": 1 / 4,
|
||||
"char_accuracy": 15 / 17,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"B oston grea t",
|
||||
"Boston is great",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 3,
|
||||
"total_chars": 15,
|
||||
"total_words": 3,
|
||||
"matching_chars": 12,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 0.8,
|
||||
},
|
||||
{},
|
||||
),
|
||||
(
|
||||
"New York is big.",
|
||||
"N ewyork kis big",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 1,
|
||||
"replace": 1,
|
||||
"spacing": 2,
|
||||
"total_chars": 16,
|
||||
"total_words": 4,
|
||||
"matching_chars": 13,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0,
|
||||
"char_accuracy": 13 / 16,
|
||||
},
|
||||
{("Y", "y"): 1},
|
||||
),
|
||||
(
|
||||
"dog",
|
||||
"d@g",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 2,
|
||||
"matching_alnum_words": 0,
|
||||
"matching_words": 0,
|
||||
"alnum_word_accuracy": 0.0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 2 / 3,
|
||||
},
|
||||
{("o", "@"): 1},
|
||||
),
|
||||
(
|
||||
"some@one.com",
|
||||
"some@one.com",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": 12,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 12,
|
||||
"matching_alnum_words": 1,
|
||||
"matching_words": 1,
|
||||
"alnum_word_accuracy": 1.0,
|
||||
"word_accuracy": 1.0,
|
||||
"char_accuracy": 1.0,
|
||||
},
|
||||
{},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_align_stats(src_string, target, expected_stats, expected_substitutions):
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
|
||||
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(
|
||||
[src_string], [target]
|
||||
)
|
||||
gap_char = (
|
||||
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
)
|
||||
alignment = align(src_string, target, gap_char=gap_char)
|
||||
stats, substitution_dict = get_align_stats(alignment, src_string, target, gap_char)
|
||||
for k in expected_stats:
|
||||
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
|
||||
for k in expected_substitutions:
|
||||
assert substitution_dict[k] == expected_substitutions[k], (substitution_dict, expected_substitutions)
|
||||
assert substitution_dict[k] == expected_substitutions[k], (
|
||||
substitution_dict,
|
||||
expected_substitutions,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("src_string, target, expected_stats, expected_substitutions, expected_actions", [
|
||||
(
|
||||
"ab", "a",
|
||||
{'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 0, 'total_chars': 2, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 1, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 1/2},
|
||||
{},
|
||||
{1: 'D'}
|
||||
),
|
||||
(
|
||||
"ab", "abb",
|
||||
{'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
|
||||
{},
|
||||
{2: ('I', 'b')}
|
||||
),
|
||||
(
|
||||
"ab", "ac",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 2, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 1, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 1/2},
|
||||
{('b', 'c'): 1},
|
||||
{1: ('R', 'c')}
|
||||
),
|
||||
(
|
||||
"New York is big.", "N ewyork kis big",
|
||||
{'insert': 1, 'delete': 1, 'replace': 1, 'spacing': 2, 'total_chars': 16, 'total_words': 4,
|
||||
'matching_chars': 13, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 13/16},
|
||||
{('Y', 'y'): 1},
|
||||
{1: ('I', ' '), 4: 'D', 5: ('R', 'y'), 10: ('I', 'k'), 17: 'D'}
|
||||
),
|
||||
(
|
||||
"dog", "d@g",
|
||||
{'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3},
|
||||
{('o', '@'): 1},
|
||||
{1: ('R', '@')}
|
||||
),
|
||||
(
|
||||
"some@one.com", "some@one.com",
|
||||
{'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 12, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 12, 'matching_alnum_words': 1, 'matching_words': 1, 'alnum_word_accuracy': 1.0, 'word_accuracy': 1.0, 'char_accuracy': 1.0},
|
||||
{},
|
||||
{}
|
||||
)
|
||||
])
|
||||
def test_get_stats(src_string, target, expected_stats, expected_substitutions, expected_actions ):
|
||||
@pytest.mark.parametrize(
|
||||
"src_string, target, expected_stats, expected_substitutions, expected_actions",
|
||||
[
|
||||
(
|
||||
"ab",
|
||||
"a",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 1,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": 2,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_alnum_words": 0,
|
||||
"matching_words": 0,
|
||||
"alnum_word_accuracy": 0.0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 1 / 2,
|
||||
},
|
||||
{},
|
||||
{1: "D"},
|
||||
),
|
||||
(
|
||||
"ab",
|
||||
"abb",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 2,
|
||||
"matching_alnum_words": 0,
|
||||
"matching_words": 0,
|
||||
"alnum_word_accuracy": 0.0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 2 / 3,
|
||||
},
|
||||
{},
|
||||
{2: ("I", "b")},
|
||||
),
|
||||
(
|
||||
"ab",
|
||||
"ac",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 2,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 1,
|
||||
"matching_alnum_words": 0,
|
||||
"matching_words": 0,
|
||||
"alnum_word_accuracy": 0.0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 1 / 2,
|
||||
},
|
||||
{("b", "c"): 1},
|
||||
{1: ("R", "c")},
|
||||
),
|
||||
(
|
||||
"New York is big.",
|
||||
"N ewyork kis big",
|
||||
{
|
||||
"insert": 1,
|
||||
"delete": 1,
|
||||
"replace": 1,
|
||||
"spacing": 2,
|
||||
"total_chars": 16,
|
||||
"total_words": 4,
|
||||
"matching_chars": 13,
|
||||
"matching_words": 0,
|
||||
"matching_alnum_words": 0,
|
||||
"word_accuracy": 0,
|
||||
"char_accuracy": 13 / 16,
|
||||
},
|
||||
{("Y", "y"): 1},
|
||||
{1: ("I", " "), 4: "D", 5: ("R", "y"), 10: ("I", "k"), 17: "D"},
|
||||
),
|
||||
(
|
||||
"dog",
|
||||
"d@g",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 1,
|
||||
"spacing": 0,
|
||||
"total_chars": 3,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 2,
|
||||
"matching_alnum_words": 0,
|
||||
"matching_words": 0,
|
||||
"alnum_word_accuracy": 0.0,
|
||||
"word_accuracy": 0.0,
|
||||
"char_accuracy": 2 / 3,
|
||||
},
|
||||
{("o", "@"): 1},
|
||||
{1: ("R", "@")},
|
||||
),
|
||||
(
|
||||
"some@one.com",
|
||||
"some@one.com",
|
||||
{
|
||||
"insert": 0,
|
||||
"delete": 0,
|
||||
"replace": 0,
|
||||
"spacing": 0,
|
||||
"total_chars": 12,
|
||||
"total_words": 1,
|
||||
"total_alnum_words": 1,
|
||||
"matching_chars": 12,
|
||||
"matching_alnum_words": 1,
|
||||
"matching_words": 1,
|
||||
"alnum_word_accuracy": 1.0,
|
||||
"word_accuracy": 1.0,
|
||||
"char_accuracy": 1.0,
|
||||
},
|
||||
{},
|
||||
{},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_stats(
|
||||
src_string, target, expected_stats, expected_substitutions, expected_actions
|
||||
):
|
||||
stats, substitution_dict, actions = get_stats(target, src_string)
|
||||
for k in expected_stats:
|
||||
assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k])
|
||||
for k in expected_substitutions:
|
||||
assert substitution_dict[k] == expected_substitutions[k], (substitution_dict, expected_substitutions)
|
||||
assert substitution_dict[k] == expected_substitutions[k], (
|
||||
substitution_dict,
|
||||
expected_substitutions,
|
||||
)
|
||||
for k in expected_actions:
|
||||
assert actions[k] == expected_actions[k], (k, actions[k], expected_actions[k])
|
||||
|
||||
@pytest.mark.parametrize("src_string, target, expected_actions",
|
||||
[
|
||||
("dog and cat", "g and at",
|
||||
{0: ('I', 'd'), 1: ('I', 'o'), 8:('I', 'c')}),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src_string, target, expected_actions",
|
||||
[
|
||||
("dog and cat", "g and at", {0: ("I", "d"), 1: ("I", "o"), 8: ("I", "c")}),
|
||||
],
|
||||
)
|
||||
def test_actions_stats(src_string, target, expected_actions):
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target])
|
||||
gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
gap_char_candidates, input_char_set = _find_gap_char_candidates(
|
||||
[src_string], [target]
|
||||
)
|
||||
gap_char = (
|
||||
GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop()
|
||||
)
|
||||
alignment = align(target, src_string, gap_char=gap_char)
|
||||
_ , actions = get_editops_stats(alignment, gap_char)
|
||||
_, actions = get_editops_stats(alignment, gap_char)
|
||||
print(actions)
|
||||
|
||||
for k in expected_actions:
|
||||
assert actions[k] == expected_actions[k], (k,actions[k], expected_actions[k])
|
||||
assert actions[k] == expected_actions[k], (k, actions[k], expected_actions[k])
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
from genalog.ocr.rest_client import GrokRestClient
|
||||
from genalog.ocr.blob_client import GrokBlobClient
|
||||
from genalog.ocr.grok import Grok
|
||||
import requests
|
||||
import pytest
|
||||
import time
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from genalog.ocr.rest_client import GrokRestClient
|
||||
|
||||
|
||||
load_dotenv("tests/ocr/.env")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_monkeypatch(monkeypatch):
|
||||
def mock_http(*args, **kwargs):
|
||||
|
@ -21,7 +23,6 @@ def setup_monkeypatch(monkeypatch):
|
|||
|
||||
|
||||
class MockedResponse:
|
||||
|
||||
def __init__(self, args, kwargs):
|
||||
self.url = args[0]
|
||||
self.text = "response"
|
||||
|
@ -34,27 +35,39 @@ class MockedResponse:
|
|||
|
||||
if "search.windows.net/indexers/" in self.url:
|
||||
if "status" in self.url:
|
||||
return {
|
||||
"lastResult": {"status": "success"},
|
||||
"status": "finished"
|
||||
}
|
||||
return {"lastResult": {"status": "success"}, "status": "finished"}
|
||||
return {}
|
||||
|
||||
if "search.windows.net/indexes/" in self.url:
|
||||
if "docs/search" in self.url:
|
||||
return {
|
||||
"value" : [
|
||||
{
|
||||
"value": [
|
||||
{
|
||||
"metadata_storage_name": "521c38122f783673598856cd81d91c21_0.png",
|
||||
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_0.png.json", "r"))
|
||||
"layoutText": json.load(
|
||||
open(
|
||||
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_0.png.json",
|
||||
"r",
|
||||
)
|
||||
),
|
||||
},
|
||||
{
|
||||
"metadata_storage_name":"521c38122f783673598856cd81d91c21_1.png",
|
||||
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_1.png.json", "r"))
|
||||
{
|
||||
"metadata_storage_name": "521c38122f783673598856cd81d91c21_1.png",
|
||||
"layoutText": json.load(
|
||||
open(
|
||||
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_1.png.json",
|
||||
"r",
|
||||
)
|
||||
),
|
||||
},
|
||||
{
|
||||
"metadata_storage_name":"521c38122f783673598856cd81d91c21_11.png",
|
||||
"layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_11.png.json", "r"))
|
||||
{
|
||||
"metadata_storage_name": "521c38122f783673598856cd81d91c21_11.png",
|
||||
"layoutText": json.load(
|
||||
open(
|
||||
"tests/ocr/data/json/521c38122f783673598856cd81d91c21_11.png.json",
|
||||
"r",
|
||||
)
|
||||
),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
@ -69,7 +82,6 @@ class MockedResponse:
|
|||
|
||||
|
||||
class TestGROK:
|
||||
|
||||
def test_creating_indexing_pipeline(self):
|
||||
grok_rest_client = GrokRestClient.create_from_env_var()
|
||||
grok_rest_client.create_indexing_pipeline()
|
||||
|
@ -78,11 +90,11 @@ class TestGROK:
|
|||
def test_running_indexer(self):
|
||||
grok_rest_client = GrokRestClient.create_from_env_var()
|
||||
grok_rest_client.create_indexing_pipeline()
|
||||
|
||||
|
||||
indexer_status = grok_rest_client.get_indexer_status()
|
||||
if indexer_status["status"] == "error":
|
||||
raise RuntimeError(f"indexer error: {indexer_status}")
|
||||
|
||||
|
||||
# if not already running start the indexer
|
||||
if indexer_status["lastResult"]["status"] != "inProgress":
|
||||
grok_rest_client.run_indexer()
|
||||
|
@ -90,5 +102,4 @@ class TestGROK:
|
|||
grok_rest_client.run_indexer()
|
||||
indexer_status = grok_rest_client.poll_indexer_till_complete()
|
||||
assert indexer_status["lastResult"]["status"] == "success"
|
||||
grok_rest_client.delete_indexer_pipeline()
|
||||
|
||||
grok_rest_client.delete_indexer_pipeline()
|
||||
|
|
|
@ -1,43 +1,59 @@
|
|||
from genalog.text import alignment
|
||||
from genalog.text.alignment import MATCH_REWARD, MISMATCH_PENALTY, GAP_PENALTY, GAP_EXT_PENALTY
|
||||
from tests.cases.text_alignment import PARSE_ALIGNMENT_REGRESSION_TEST_CASES, ALIGNMENT_REGRESSION_TEST_CASES
|
||||
|
||||
import warnings
|
||||
from random import randint
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
from genalog.text import alignment
|
||||
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
|
||||
from tests.cases.text_alignment import PARSE_ALIGNMENT_REGRESSION_TEST_CASES
|
||||
|
||||
RANDOM_INT = randint(1, 100)
|
||||
MOCK_ALIGNMENT_RESULT = [("X", "X", 0, 0, 1)]
|
||||
|
||||
|
||||
# Settup mock for third party library call
|
||||
@pytest.fixture
|
||||
def mock_pairwise2_align(monkeypatch):
|
||||
mock = MagicMock()
|
||||
|
||||
def mock_globalcs(*args, **kwargs):
|
||||
mock.globalcs(*args, **kwargs)
|
||||
return MOCK_ALIGNMENT_RESULT
|
||||
|
||||
# replace target method reference with the mock method
|
||||
monkeypatch.setattr("Bio.pairwise2.align.globalcs", mock_globalcs)
|
||||
return mock
|
||||
|
||||
|
||||
def test__align_seg(mock_pairwise2_align):
|
||||
# setup method input
|
||||
required_arg = ("A", "B")
|
||||
optional_arg = (alignment.MATCH_REWARD, alignment.MISMATCH_PENALTY, alignment.GAP_PENALTY, alignment.GAP_EXT_PENALTY)
|
||||
optional_kwarg = {"gap_char": alignment.GAP_CHAR, "one_alignment_only": alignment.ONE_ALIGNMENT_ONLY}
|
||||
optional_arg = (
|
||||
alignment.MATCH_REWARD,
|
||||
alignment.MISMATCH_PENALTY,
|
||||
alignment.GAP_PENALTY,
|
||||
alignment.GAP_EXT_PENALTY,
|
||||
)
|
||||
optional_kwarg = {
|
||||
"gap_char": alignment.GAP_CHAR,
|
||||
"one_alignment_only": alignment.ONE_ALIGNMENT_ONLY,
|
||||
}
|
||||
# test method
|
||||
result = alignment._align_seg(*required_arg + optional_arg, **optional_kwarg)
|
||||
# assertion
|
||||
mock_pairwise2_align.globalcs.assert_called()
|
||||
assert result == MOCK_ALIGNMENT_RESULT
|
||||
|
||||
@pytest.mark.parametrize("alignments, target_num_tokens, raised_exception",
|
||||
[
|
||||
(MOCK_ALIGNMENT_RESULT, 1, None),
|
||||
(MOCK_ALIGNMENT_RESULT, 2, ValueError),
|
||||
([("X", "XY", 0, 0, 1)], 1, ValueError)
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"alignments, target_num_tokens, raised_exception",
|
||||
[
|
||||
(MOCK_ALIGNMENT_RESULT, 1, None),
|
||||
(MOCK_ALIGNMENT_RESULT, 2, ValueError),
|
||||
([("X", "XY", 0, 0, 1)], 1, ValueError),
|
||||
],
|
||||
)
|
||||
def test__select_alignment_candidates(alignments, target_num_tokens, raised_exception):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
|
@ -46,24 +62,27 @@ def test__select_alignment_candidates(alignments, target_num_tokens, raised_exce
|
|||
result = alignment._select_alignment_candidates(alignments, target_num_tokens)
|
||||
assert result == MOCK_ALIGNMENT_RESULT[0]
|
||||
|
||||
@pytest.mark.parametrize("s, index, desired_output, raised_exception",
|
||||
[
|
||||
# Test exceptions
|
||||
("s", 2, None, IndexError),
|
||||
("", -1, None, ValueError), # Empty case
|
||||
# Index at start of string
|
||||
(" token", 0, 2, None),
|
||||
("\t\ntoken", 0, 2, None),
|
||||
# Index reach end of string
|
||||
("token ", 5, 5, None),
|
||||
("token", 4, 4, None),
|
||||
# Index in-between tokens
|
||||
("token", 0, 0, None),
|
||||
("t1 t2", 2, 7, None),
|
||||
("t1 \t \n t2", 3, 7, None),
|
||||
# Gap char
|
||||
(" @", 0, 1, None),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s, index, desired_output, raised_exception",
|
||||
[
|
||||
# Test exceptions
|
||||
("s", 2, None, IndexError),
|
||||
("", -1, None, ValueError), # Empty case
|
||||
# Index at start of string
|
||||
(" token", 0, 2, None),
|
||||
("\t\ntoken", 0, 2, None),
|
||||
# Index reach end of string
|
||||
("token ", 5, 5, None),
|
||||
("token", 4, 4, None),
|
||||
# Index in-between tokens
|
||||
("token", 0, 0, None),
|
||||
("t1 t2", 2, 7, None),
|
||||
("t1 \t \n t2", 3, 7, None),
|
||||
# Gap char
|
||||
(" @", 0, 1, None),
|
||||
],
|
||||
)
|
||||
def test__find_token_start(s, index, desired_output, raised_exception):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
|
@ -72,25 +91,28 @@ def test__find_token_start(s, index, desired_output, raised_exception):
|
|||
output = alignment._find_token_start(s, index)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("s, index, desired_output, raised_exception",
|
||||
[
|
||||
# Test exceptions
|
||||
("s", 2, None, IndexError),
|
||||
("", -1, None, ValueError), # Empty case
|
||||
# Index at start of string
|
||||
(" ", 0, 0, None),
|
||||
("\t\ntoken", 0, 0, None),
|
||||
("token", 0, 4, None),
|
||||
("token\t", 0, 5, None),
|
||||
("token\n", 0, 5, None),
|
||||
# Index reach end of string
|
||||
("token ", 5, 5, None),
|
||||
("token", 4, 4, None),
|
||||
# Single Char
|
||||
(".", 0, 0, None),
|
||||
# Gap char
|
||||
("@@ @", 0, 2, None),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s, index, desired_output, raised_exception",
|
||||
[
|
||||
# Test exceptions
|
||||
("s", 2, None, IndexError),
|
||||
("", -1, None, ValueError), # Empty case
|
||||
# Index at start of string
|
||||
(" ", 0, 0, None),
|
||||
("\t\ntoken", 0, 0, None),
|
||||
("token", 0, 4, None),
|
||||
("token\t", 0, 5, None),
|
||||
("token\n", 0, 5, None),
|
||||
# Index reach end of string
|
||||
("token ", 5, 5, None),
|
||||
("token", 4, 4, None),
|
||||
# Single Char
|
||||
(".", 0, 0, None),
|
||||
# Gap char
|
||||
("@@ @", 0, 2, None),
|
||||
],
|
||||
)
|
||||
def test__find_token_end(s, index, desired_output, raised_exception):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
|
@ -99,61 +121,81 @@ def test__find_token_end(s, index, desired_output, raised_exception):
|
|||
output = alignment._find_token_end(s, index)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("s, start, desired_output",
|
||||
[
|
||||
("token", 0, (0,4)),
|
||||
("token\t", 0, (0,5)),
|
||||
("token \n", 0, (0,5)),
|
||||
(" token ", 0, (1,6)),
|
||||
# mix with GAP_CHAR
|
||||
(" @@@@ ", 0, (1,5)),
|
||||
("\n\t tok@n@@ \n\t", 0, (3,10)),
|
||||
# single character string
|
||||
("s", 0, (0,0)),
|
||||
# punctuation
|
||||
(" !,.: ", 0, (2,6))
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s, start, desired_output",
|
||||
[
|
||||
("token", 0, (0, 4)),
|
||||
("token\t", 0, (0, 5)),
|
||||
("token \n", 0, (0, 5)),
|
||||
(" token ", 0, (1, 6)),
|
||||
# mix with GAP_CHAR
|
||||
(" @@@@ ", 0, (1, 5)),
|
||||
("\n\t tok@n@@ \n\t", 0, (3, 10)),
|
||||
# single character string
|
||||
("s", 0, (0, 0)),
|
||||
# punctuation
|
||||
(" !,.: ", 0, (2, 6)),
|
||||
],
|
||||
)
|
||||
def test__find_next_token(s, start, desired_output):
|
||||
output = alignment._find_next_token(s, start)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("token, desired_output",
|
||||
[
|
||||
# Valid tokens
|
||||
("\n\t token.!,:\n\t ", True),
|
||||
("token", True),
|
||||
(" @@@t@@@ ", True),
|
||||
("@@token@@", True),
|
||||
(" @@token@@ ", True),
|
||||
(f"t1{alignment.GAP_CHAR*RANDOM_INT}t2", True), #i.e. 't1@t2'
|
||||
# Invalid tokens (i.e. multiples of the GAP_CHAR)
|
||||
("", False),
|
||||
(" ", False),
|
||||
("@@", False),
|
||||
(" @@ ", False),
|
||||
("\t\n@", False),
|
||||
(alignment.GAP_CHAR*1, False),
|
||||
(alignment.GAP_CHAR*RANDOM_INT, False),
|
||||
(f"\n\t {alignment.GAP_CHAR*RANDOM_INT} \n\t", False)
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"token, desired_output",
|
||||
[
|
||||
# Valid tokens
|
||||
("\n\t token.!,:\n\t ", True),
|
||||
("token", True),
|
||||
(" @@@t@@@ ", True),
|
||||
("@@token@@", True),
|
||||
(" @@token@@ ", True),
|
||||
(f"t1{alignment.GAP_CHAR*RANDOM_INT}t2", True), # i.e. 't1@t2'
|
||||
# Invalid tokens (i.e. multiples of the GAP_CHAR)
|
||||
("", False),
|
||||
(" ", False),
|
||||
("@@", False),
|
||||
(" @@ ", False),
|
||||
("\t\n@", False),
|
||||
(alignment.GAP_CHAR * 1, False),
|
||||
(alignment.GAP_CHAR * RANDOM_INT, False),
|
||||
(f"\n\t {alignment.GAP_CHAR*RANDOM_INT} \n\t", False),
|
||||
],
|
||||
)
|
||||
def test__is_valid_token(token, desired_output):
|
||||
result = alignment._is_valid_token(token)
|
||||
assert result == desired_output
|
||||
|
||||
@pytest.mark.parametrize("aligned_gt, aligned_noise," +
|
||||
"expected_gt_to_noise_map, expected_noise_to_gt_map",
|
||||
PARSE_ALIGNMENT_REGRESSION_TEST_CASES)
|
||||
def test_parse_alignment(aligned_gt, aligned_noise, expected_gt_to_noise_map, expected_noise_to_gt_map):
|
||||
gt_to_noise_map, noise_to_gt_map = alignment.parse_alignment(aligned_gt, aligned_noise)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"aligned_gt, aligned_noise," + "expected_gt_to_noise_map, expected_noise_to_gt_map",
|
||||
PARSE_ALIGNMENT_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_parse_alignment(
|
||||
aligned_gt, aligned_noise, expected_gt_to_noise_map, expected_noise_to_gt_map
|
||||
):
|
||||
gt_to_noise_map, noise_to_gt_map = alignment.parse_alignment(
|
||||
aligned_gt, aligned_noise
|
||||
)
|
||||
assert gt_to_noise_map == expected_gt_to_noise_map
|
||||
assert noise_to_gt_map == expected_noise_to_gt_map
|
||||
|
||||
@pytest.mark.parametrize("gt_txt, noisy_txt," +
|
||||
"expected_aligned_gt, expected_aligned_noise",
|
||||
ALIGNMENT_REGRESSION_TEST_CASES)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_txt, noisy_txt," + "expected_aligned_gt, expected_aligned_noise",
|
||||
ALIGNMENT_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_align(gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
|
||||
aligned_gt, aligned_noise = alignment.align(gt_txt, noisy_txt)
|
||||
if aligned_gt != expected_aligned_gt:
|
||||
expected_alignment = alignment._format_alignment(expected_aligned_gt, expected_aligned_noise)
|
||||
expected_alignment = alignment._format_alignment(
|
||||
expected_aligned_gt, expected_aligned_noise
|
||||
)
|
||||
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,43 +1,56 @@
|
|||
import glob
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from genalog.text import alignment, anchor, preprocess
|
||||
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
|
||||
|
||||
import glob
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
@pytest.mark.parametrize("tokens, case_sensitive, desired_output", [
|
||||
([], True, set()),
|
||||
([], False, set()),
|
||||
(["a", "A"], True, set(["a", "A"])),
|
||||
(["a", "A"], False, set()),
|
||||
(["An", "an", "ab"], True, set(["An", "an", "ab"])),
|
||||
(["An", "an", "ab"], False, set(["ab"])),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, case_sensitive, desired_output",
|
||||
[
|
||||
([], True, set()),
|
||||
([], False, set()),
|
||||
(["a", "A"], True, set(["a", "A"])),
|
||||
(["a", "A"], False, set()),
|
||||
(["An", "an", "ab"], True, set(["An", "an", "ab"])),
|
||||
(["An", "an", "ab"], False, set(["ab"])),
|
||||
],
|
||||
)
|
||||
def test_get_unique_words(tokens, case_sensitive, desired_output):
|
||||
output = anchor.get_unique_words(tokens, case_sensitive=case_sensitive)
|
||||
assert desired_output == output
|
||||
|
||||
@pytest.mark.parametrize("tokens, desired_output", [
|
||||
([], 0),
|
||||
([""], 0),
|
||||
(["a", "b"], 2),
|
||||
(["abc.", "def!"], 8)
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, desired_output",
|
||||
[([], 0), ([""], 0), (["a", "b"], 2), (["abc.", "def!"], 8)],
|
||||
)
|
||||
def test_segment_len(tokens, desired_output):
|
||||
output = anchor.segment_len(tokens)
|
||||
assert desired_output == output
|
||||
|
||||
@pytest.mark.parametrize("unique_words, src_tokens, desired_output, raised_exception", [
|
||||
(set(), [], [], None),
|
||||
(set(), ["a"], [], None),
|
||||
(set("a"), [], [], ValueError), # unique word not in src_tokens
|
||||
(set("a"), ["b"], [], ValueError),
|
||||
(set("a"), ["A"], [], ValueError), # case sensitive
|
||||
(set("a"), ["an", "na", " a "], [], ValueError), # substring
|
||||
(set("a"), ["a"], [("a", 0)], None), # valid input
|
||||
(set("a"), ["c", "b", "a"], [("a", 2)], None), # multiple src_tokens
|
||||
(set("ab"), ["c", "b", "a"], [("b", 1), ("a", 2)], None), # multiple matches ordered by index
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"unique_words, src_tokens, desired_output, raised_exception",
|
||||
[
|
||||
(set(), [], [], None),
|
||||
(set(), ["a"], [], None),
|
||||
(set("a"), [], [], ValueError), # unique word not in src_tokens
|
||||
(set("a"), ["b"], [], ValueError),
|
||||
(set("a"), ["A"], [], ValueError), # case sensitive
|
||||
(set("a"), ["an", "na", " a "], [], ValueError), # substring
|
||||
(set("a"), ["a"], [("a", 0)], None), # valid input
|
||||
(set("a"), ["c", "b", "a"], [("a", 2)], None), # multiple src_tokens
|
||||
(
|
||||
set("ab"),
|
||||
["c", "b", "a"],
|
||||
[("b", 1), ("a", 2)],
|
||||
None,
|
||||
), # multiple matches ordered by index
|
||||
],
|
||||
)
|
||||
def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
|
@ -46,85 +59,150 @@ def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception
|
|||
output = anchor.get_word_map(unique_words, src_tokens)
|
||||
assert desired_output == output
|
||||
|
||||
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_output", [
|
||||
([], [], ([], [])), # empty
|
||||
([""], [""], ([], [])),
|
||||
(["a"], ["b"], ([], [])), # no common unique words
|
||||
(["a", "a"], ["a"], ([], [])), # no unique words
|
||||
(["a"], ["a", "a"], ([], [])),
|
||||
(["a"], ["a"], ([("a", 0)], [("a", 0)])), # common unique word exist
|
||||
(["a"], ["b", "a"], ([("a", 0)], [("a", 1)])),
|
||||
(["a", "b", "c"], ["a", "b", "c"], # common unique words
|
||||
([("a", 0), ("b", 1), ("c", 2)], [("a", 0), ("b", 1), ("c", 2)])),
|
||||
(["a", "b", "c"], ["c", "b", "a"], # common unique words but not in same order
|
||||
([("b", 1)], [("b", 1)])),
|
||||
(["b", "a", "c"], ["c", "b", "a"], # LCS has multiple results
|
||||
([("b", 0), ("a", 1)], [("b", 1), ("a", 2)])),
|
||||
(["c", "a", "b"], ["c", "b", "a"],
|
||||
([("c", 0), ("b", 2)], [("c", 0), ("b", 1)])),
|
||||
(["c", "a", "b"], ["a", "c", "b"], # LCS has multiple results
|
||||
([("a", 1), ("b", 2)], [("a", 0), ("b", 2)])),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_tokens, ocr_tokens, desired_output",
|
||||
[
|
||||
([], [], ([], [])), # empty
|
||||
([""], [""], ([], [])),
|
||||
(["a"], ["b"], ([], [])), # no common unique words
|
||||
(["a", "a"], ["a"], ([], [])), # no unique words
|
||||
(["a"], ["a", "a"], ([], [])),
|
||||
(["a"], ["a"], ([("a", 0)], [("a", 0)])), # common unique word exist
|
||||
(["a"], ["b", "a"], ([("a", 0)], [("a", 1)])),
|
||||
(
|
||||
["a", "b", "c"],
|
||||
["a", "b", "c"], # common unique words
|
||||
([("a", 0), ("b", 1), ("c", 2)], [("a", 0), ("b", 1), ("c", 2)]),
|
||||
),
|
||||
(
|
||||
["a", "b", "c"],
|
||||
["c", "b", "a"], # common unique words but not in same order
|
||||
([("b", 1)], [("b", 1)]),
|
||||
),
|
||||
(
|
||||
["b", "a", "c"],
|
||||
["c", "b", "a"], # LCS has multiple results
|
||||
([("b", 0), ("a", 1)], [("b", 1), ("a", 2)]),
|
||||
),
|
||||
(
|
||||
["c", "a", "b"],
|
||||
["c", "b", "a"],
|
||||
([("c", 0), ("b", 2)], [("c", 0), ("b", 1)]),
|
||||
),
|
||||
(
|
||||
["c", "a", "b"],
|
||||
["a", "c", "b"], # LCS has multiple results
|
||||
([("a", 1), ("b", 2)], [("a", 0), ("b", 2)]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_anchor_map(gt_tokens, ocr_tokens, desired_output):
|
||||
desired_gt_map, desired_ocr_map = desired_output
|
||||
gt_map, ocr_map = anchor.get_anchor_map(gt_tokens, ocr_tokens)
|
||||
assert desired_gt_map == gt_map
|
||||
assert desired_ocr_map == ocr_map
|
||||
|
||||
|
||||
# max_seg_length does not change the following output
|
||||
@pytest.mark.parametrize("max_seg_length", [0, 1, 2, 3, 5, 4, 6])
|
||||
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_output", [
|
||||
([], [], ([], [])), # empty
|
||||
([""], [""], ([], [])),
|
||||
(["a"], ["b"], ([], [])), # no anchors
|
||||
(["a", "a"], ["a"], ([], [])),
|
||||
(["a"], ["a", "a"], ([], [])),
|
||||
(["a"], ["a"], ([0], [0])), # anchors exist
|
||||
("a1 w w w".split(), "a1 w w w".split(), # no anchors in the subsequence [w w w]
|
||||
([0], [0])),
|
||||
("a1 w w w a2".split(), "a1 w w w a2".split(),
|
||||
([0, 4], [0, 4])),
|
||||
("a1 w w w2 a2".split(), "a1 w w w3 a2".split(),
|
||||
([0, 4], [0, 4])),
|
||||
("a1 a2 a3".split(), "a1 a2 a3".split(), # all words are anchors
|
||||
([0, 1, 2], [0, 1, 2])),
|
||||
("a1 a2 a3".split(), "A1 A2 A3".split(), # anchor words must be in the same casing
|
||||
([], [])),
|
||||
("a1 w w a2".split(), "a1 w W a2".split(), # unique words are case insensitive
|
||||
([0, 3], [0, 3])),
|
||||
("a1 w w a2".split(), "A1 w W A2".split(), # unique words are case insensitive, but anchor are case sensitive
|
||||
([], [])),
|
||||
])
|
||||
def test_find_anchor_recur_various_seg_len(max_seg_length, gt_tokens, ocr_tokens, desired_output):
|
||||
@pytest.mark.parametrize(
|
||||
"gt_tokens, ocr_tokens, desired_output",
|
||||
[
|
||||
([], [], ([], [])), # empty
|
||||
([""], [""], ([], [])),
|
||||
(["a"], ["b"], ([], [])), # no anchors
|
||||
(["a", "a"], ["a"], ([], [])),
|
||||
(["a"], ["a", "a"], ([], [])),
|
||||
(["a"], ["a"], ([0], [0])), # anchors exist
|
||||
(
|
||||
"a1 w w w".split(),
|
||||
"a1 w w w".split(), # no anchors in the subsequence [w w w]
|
||||
([0], [0]),
|
||||
),
|
||||
("a1 w w w a2".split(), "a1 w w w a2".split(), ([0, 4], [0, 4])),
|
||||
("a1 w w w2 a2".split(), "a1 w w w3 a2".split(), ([0, 4], [0, 4])),
|
||||
(
|
||||
"a1 a2 a3".split(),
|
||||
"a1 a2 a3".split(), # all words are anchors
|
||||
([0, 1, 2], [0, 1, 2]),
|
||||
),
|
||||
(
|
||||
"a1 a2 a3".split(),
|
||||
"A1 A2 A3".split(), # anchor words must be in the same casing
|
||||
([], []),
|
||||
),
|
||||
(
|
||||
"a1 w w a2".split(),
|
||||
"a1 w W a2".split(), # unique words are case insensitive
|
||||
([0, 3], [0, 3]),
|
||||
),
|
||||
(
|
||||
"a1 w w a2".split(),
|
||||
"A1 w W A2".split(), # unique words are case insensitive, but anchor are case sensitive
|
||||
([], []),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_find_anchor_recur_various_seg_len(
|
||||
max_seg_length, gt_tokens, ocr_tokens, desired_output
|
||||
):
|
||||
desired_gt_anchors, desired_ocr_anchors = desired_output
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
|
||||
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
|
||||
)
|
||||
assert desired_gt_anchors == gt_anchors
|
||||
assert desired_ocr_anchors == ocr_anchors
|
||||
|
||||
|
||||
# Test the recursion bahavior
|
||||
@pytest.mark.parametrize("gt_tokens, ocr_tokens, max_seg_length, desired_output", [
|
||||
("a1 w_ w_ a3".split(), "a1 w_ w_ a3".split(), 6,
|
||||
([0, 3], [0, 3])),
|
||||
("a1 w_ w_ a2 a3 a2".split(), "a1 w_ w_ a2 a3 a2".split(), 4, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
|
||||
([0, 3, 4], [0, 3, 4])),
|
||||
("a1 w_ w_ a2 a3 a2".split(), "a1 w_ w_ a2 a3 a2".split(), 2, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
|
||||
([0, 2, 3, 4, 5], [0, 2, 3, 4, 5])),
|
||||
("a1 w_ w_ a2 w_ w_ a3".split(), "a1 w_ a2 w_ a3".split(), 2, # missing ocr token
|
||||
([0, 3, 6], [0, 2, 4])),
|
||||
("a1 w_ w_ a2 w_ w_ a3".split(), "a1 w_ a2 W_ A3".split(), 2, # changing cases
|
||||
([0, 3], [0, 2])),
|
||||
])
|
||||
def test_find_anchor_recur_fixed_seg_len(gt_tokens, ocr_tokens, max_seg_length, desired_output):
|
||||
@pytest.mark.parametrize(
|
||||
"gt_tokens, ocr_tokens, max_seg_length, desired_output",
|
||||
[
|
||||
("a1 w_ w_ a3".split(), "a1 w_ w_ a3".split(), 6, ([0, 3], [0, 3])),
|
||||
(
|
||||
"a1 w_ w_ a2 a3 a2".split(),
|
||||
"a1 w_ w_ a2 a3 a2".split(),
|
||||
4, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
|
||||
([0, 3, 4], [0, 3, 4]),
|
||||
),
|
||||
(
|
||||
"a1 w_ w_ a2 a3 a2".split(),
|
||||
"a1 w_ w_ a2 a3 a2".split(),
|
||||
2, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3]
|
||||
([0, 2, 3, 4, 5], [0, 2, 3, 4, 5]),
|
||||
),
|
||||
(
|
||||
"a1 w_ w_ a2 w_ w_ a3".split(),
|
||||
"a1 w_ a2 w_ a3".split(),
|
||||
2, # missing ocr token
|
||||
([0, 3, 6], [0, 2, 4]),
|
||||
),
|
||||
(
|
||||
"a1 w_ w_ a2 w_ w_ a3".split(),
|
||||
"a1 w_ a2 W_ A3".split(),
|
||||
2, # changing cases
|
||||
([0, 3], [0, 2]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_find_anchor_recur_fixed_seg_len(
|
||||
gt_tokens, ocr_tokens, max_seg_length, desired_output
|
||||
):
|
||||
desired_gt_anchors, desired_ocr_anchors = desired_output
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
|
||||
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
|
||||
)
|
||||
assert desired_gt_anchors == gt_anchors
|
||||
assert desired_ocr_anchors == ocr_anchors
|
||||
|
||||
@pytest.mark.parametrize("gt_file, ocr_file",
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_file, ocr_file",
|
||||
zip(
|
||||
sorted(glob.glob("tests/text/data/gt_1.txt")),
|
||||
sorted(glob.glob("tests/text/data/ocr_1.txt"))
|
||||
)
|
||||
sorted(glob.glob("tests/text/data/ocr_1.txt")),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("max_seg_length", [75])
|
||||
def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
|
||||
|
@ -132,15 +210,27 @@ def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length):
|
|||
ocr_text = open(ocr_file, "r").read()
|
||||
gt_tokens = preprocess.tokenize(gt_text)
|
||||
ocr_tokens = preprocess.tokenize(ocr_text)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length)
|
||||
gt_anchors, ocr_anchors = anchor.find_anchor_recur(
|
||||
gt_tokens, ocr_tokens, max_seg_length=max_seg_length
|
||||
)
|
||||
for gt_anchor, ocr_anchor in zip(gt_anchors, ocr_anchors):
|
||||
# Ensure that each anchor word is the same word in both text
|
||||
assert gt_tokens[gt_anchor] == ocr_tokens[ocr_anchor]
|
||||
|
||||
@pytest.mark.parametrize("gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", ALIGNMENT_REGRESSION_TEST_CASES)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise",
|
||||
ALIGNMENT_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_align_w_anchor(gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
|
||||
aligned_gt, aligned_noise = anchor.align_w_anchor(gt_txt, noisy_txt)
|
||||
if aligned_gt != expected_aligned_gt:
|
||||
expected_alignment = alignment._format_alignment(expected_aligned_gt, expected_aligned_noise)
|
||||
expected_alignment = alignment._format_alignment(
|
||||
expected_aligned_gt, expected_aligned_noise
|
||||
)
|
||||
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,157 +1,286 @@
|
|||
from genalog.text import conll_format
|
||||
import itertools
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import itertools
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception", [
|
||||
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], None),
|
||||
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], [], ValueError), # No alignment
|
||||
(["w1", "w3"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], ValueError), # Unequal tokens
|
||||
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w3"]], ["w", "w"], ValueError), # Unequal tokens
|
||||
(["w1", "w3"], ["l1", "l2"], [["w1"]],["w", "w"], ValueError), # Unequal length
|
||||
(["w1"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], ValueError), # Unequal length
|
||||
])
|
||||
def test_propagate_labels_sentences_error(clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception):
|
||||
from genalog.text import conll_format
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception",
|
||||
[
|
||||
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], None),
|
||||
(["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], [], ValueError), # No alignment
|
||||
(
|
||||
["w1", "w3"],
|
||||
["l1", "l2"],
|
||||
[["w1"], ["w2"]],
|
||||
["w", "w"],
|
||||
ValueError,
|
||||
), # Unequal tokens
|
||||
(
|
||||
["w1", "w2"],
|
||||
["l1", "l2"],
|
||||
[["w1"], ["w3"]],
|
||||
["w", "w"],
|
||||
ValueError,
|
||||
), # Unequal tokens
|
||||
(
|
||||
["w1", "w3"],
|
||||
["l1", "l2"],
|
||||
[["w1"]],
|
||||
["w", "w"],
|
||||
ValueError,
|
||||
), # Unequal length
|
||||
(
|
||||
["w1"],
|
||||
["l1", "l2"],
|
||||
[["w1"], ["w2"]],
|
||||
["w", "w"],
|
||||
ValueError,
|
||||
), # Unequal length
|
||||
],
|
||||
)
|
||||
def test_propagate_labels_sentences_error(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception
|
||||
):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
conll_format.propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
else:
|
||||
conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
conll_format.propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels", [
|
||||
(
|
||||
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]], # clean sentences
|
||||
["a1", "b1", "a2", "b2"], # ocr token
|
||||
[["a1", "b1"], ["a2","b2"]], [["l1", "l2"], ["l3", "l4"]] # desired output
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]], # clean sentences
|
||||
["a1", "b1"], # Missing sentence 2
|
||||
# Ideally we would expect [["a1", "b1"], []]
|
||||
# But the limitation of text alignment, which yield
|
||||
# "a1 b1 a2 b2"
|
||||
# "a1 b1@@@@@@"
|
||||
# It is difficult to decide the location of "b1"
|
||||
# when all tokens "b1" "a2" "b2" are aligned to "b1@@@@@@"
|
||||
# NOTE: this is a improper behavior but the best
|
||||
# solution to this corner case by preserving the number of OCR tokens.
|
||||
[["a1"], ["b1"]], [["l1"], ["l2"]]
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
["a", "a2", "b2"], # ocr token (missing b1 token at sentence boundary)
|
||||
[["a"], ["a2", "b2"]], [["l1"], ["l3", "l4"]]
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
["a1", "b1", "a2"], # ocr token (missing b2 token at sentence boundary)
|
||||
[["a1", "b1"], ["a2"]], [["l1", "l2"], ["l3"]]
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
["b1", "a2", "b2"], # ocr token (missing a1 token at sentence start)
|
||||
[["b1", "a2"], ["b2"]], [["l2", "l3"], ["l4"]]
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
|
||||
[["a1"], ["b1", "c1", "a2"], ["b2"]],
|
||||
["a1", "b1", "a2", "b2"], # ocr token (missing c1 token at middle of sentence)
|
||||
[["a1"], ["b1", "a2"], ["b2"]], [["l1"], ["l2", "l4"], ["l5"]]
|
||||
),
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels",
|
||||
[
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
|
||||
[["a1", "b1"], ["c1", "a2", "b2"]],
|
||||
["a1", "b1", "b2"], # ocr token (missing c1 a2 tokens)
|
||||
[["a1"], ["b1", "b2"]], [["l1"], ["l2", "l5"]]
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
|
||||
[["a1"], ["b1", "c1", "a2"], ["b2"]],
|
||||
["a1", "c1", "a2", "b2"], # ocr token (missing b1 token at sentence start)
|
||||
[[], ["a1", "c1", "a2"], ["b2"]], [[], ["l1", "l3", "l4"], ["l5"]]
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
|
||||
[["a1", "b1", "c1"], ["a2","b2"]],
|
||||
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
|
||||
[["a1"], [ "b1", "b2"]], [["l1"], ["l2", "l5"]]
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(),
|
||||
[["a1", "b1", "c1"], ["a2","b2"]],
|
||||
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
|
||||
[["a1"], [ "b1", "b2"]], [["l1"], ["l2", "l5"]]
|
||||
),
|
||||
])
|
||||
def test_propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels):
|
||||
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
"a1 b1 a2 b2".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]], # clean sentences
|
||||
["a1", "b1", "a2", "b2"], # ocr token
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
[["l1", "l2"], ["l3", "l4"]], # desired output
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]], # clean sentences
|
||||
["a1", "b1"], # Missing sentence 2
|
||||
# Ideally we would expect [["a1", "b1"], []]
|
||||
# But the limitation of text alignment, which yield
|
||||
# "a1 b1 a2 b2"
|
||||
# "a1 b1@@@@@@"
|
||||
# It is difficult to decide the location of "b1"
|
||||
# when all tokens "b1" "a2" "b2" are aligned to "b1@@@@@@"
|
||||
# NOTE: this is a improper behavior but the best
|
||||
# solution to this corner case by preserving the number of OCR tokens.
|
||||
[["a1"], ["b1"]],
|
||||
[["l1"], ["l2"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
["a", "a2", "b2"], # ocr token (missing b1 token at sentence boundary)
|
||||
[["a"], ["a2", "b2"]],
|
||||
[["l1"], ["l3", "l4"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
["a1", "b1", "a2"], # ocr token (missing b2 token at sentence boundary)
|
||||
[["a1", "b1"], ["a2"]],
|
||||
[["l1", "l2"], ["l3"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 a2 b2".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a1", "b1"], ["a2", "b2"]],
|
||||
["b1", "a2", "b2"], # ocr token (missing a1 token at sentence start)
|
||||
[["b1", "a2"], ["b2"]],
|
||||
[["l2", "l3"], ["l4"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(),
|
||||
"l1 l2 l3 l4 l5".split(),
|
||||
[["a1"], ["b1", "c1", "a2"], ["b2"]],
|
||||
[
|
||||
"a1",
|
||||
"b1",
|
||||
"a2",
|
||||
"b2",
|
||||
], # ocr token (missing c1 token at middle of sentence)
|
||||
[["a1"], ["b1", "a2"], ["b2"]],
|
||||
[["l1"], ["l2", "l4"], ["l5"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(),
|
||||
"l1 l2 l3 l4 l5".split(),
|
||||
[["a1", "b1"], ["c1", "a2", "b2"]],
|
||||
["a1", "b1", "b2"], # ocr token (missing c1 a2 tokens)
|
||||
[["a1"], ["b1", "b2"]],
|
||||
[["l1"], ["l2", "l5"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(),
|
||||
"l1 l2 l3 l4 l5".split(),
|
||||
[["a1"], ["b1", "c1", "a2"], ["b2"]],
|
||||
["a1", "c1", "a2", "b2"], # ocr token (missing b1 token at sentence start)
|
||||
[[], ["a1", "c1", "a2"], ["b2"]],
|
||||
[[], ["l1", "l3", "l4"], ["l5"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(),
|
||||
"l1 l2 l3 l4 l5".split(),
|
||||
[["a1", "b1", "c1"], ["a2", "b2"]],
|
||||
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
|
||||
[["a1"], ["b1", "b2"]],
|
||||
[["l1"], ["l2", "l5"]],
|
||||
),
|
||||
(
|
||||
"a1 b1 c1 a2 b2".split(),
|
||||
"l1 l2 l3 l4 l5".split(),
|
||||
[["a1", "b1", "c1"], ["a2", "b2"]],
|
||||
["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end)
|
||||
[["a1"], ["b1", "b2"]],
|
||||
[["l1"], ["l2", "l5"]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_propagate_labels_sentences(
|
||||
clean_tokens,
|
||||
clean_labels,
|
||||
clean_sentences,
|
||||
ocr_tokens,
|
||||
desired_sentences,
|
||||
desired_labels,
|
||||
):
|
||||
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
|
||||
assert len(ocr_text_sentences) == len(clean_sentences)
|
||||
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
|
||||
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
|
||||
assert len(ocr_sentences_flatten) == len(
|
||||
ocr_tokens
|
||||
) # ensure aligned ocr tokens == ocr tokens
|
||||
if desired_sentences != ocr_text_sentences:
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"
|
||||
)
|
||||
)
|
||||
if desired_labels != ocr_labels_sentences:
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens," +
|
||||
"mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels", [
|
||||
(
|
||||
"a b c d".split(), "l1 l2 l3 l4".split(),
|
||||
[["a", "b"], ["c", "d"]],
|
||||
["a", "b"], # Sentence is empty
|
||||
[[0], [1], [], []],
|
||||
[[0], [1]],
|
||||
[["a", "b"], []],
|
||||
[["l1", "l2"], []]
|
||||
),
|
||||
(
|
||||
"a b c d".split(), "l1 l2 l3 l4".split(),
|
||||
[["a", "b",], ["c", "d"]],
|
||||
["a", "b", "d"], # Missing sentence start
|
||||
[[0], [1], [], [2]],
|
||||
[[0], [1], [3]],
|
||||
[["a", "b"], ["d"]],
|
||||
[["l1", "l2"], ["l4"]]
|
||||
),
|
||||
(
|
||||
"a b c d".split(), "l1 l2 l3 l4".split(),
|
||||
[["a", "b",], ["c", "d"]],
|
||||
["a", "c", "d"], # Missing sentence end
|
||||
[[0], [], [1], [2]],
|
||||
[[0], [2], [3]],
|
||||
[["a"], ["c", "d"]],
|
||||
[["l1"], ["l3", "l4"]]
|
||||
),
|
||||
])
|
||||
def test_propagate_labels_sentences_text_alignment_corner_cases(clean_tokens, clean_labels, clean_sentences, ocr_tokens,
|
||||
mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clean_tokens, clean_labels, clean_sentences, ocr_tokens,"
|
||||
+ "mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels",
|
||||
[
|
||||
(
|
||||
"a b c d".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[["a", "b"], ["c", "d"]],
|
||||
["a", "b"], # Sentence is empty
|
||||
[[0], [1], [], []],
|
||||
[[0], [1]],
|
||||
[["a", "b"], []],
|
||||
[["l1", "l2"], []],
|
||||
),
|
||||
(
|
||||
"a b c d".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[
|
||||
[
|
||||
"a",
|
||||
"b",
|
||||
],
|
||||
["c", "d"],
|
||||
],
|
||||
["a", "b", "d"], # Missing sentence start
|
||||
[[0], [1], [], [2]],
|
||||
[[0], [1], [3]],
|
||||
[["a", "b"], ["d"]],
|
||||
[["l1", "l2"], ["l4"]],
|
||||
),
|
||||
(
|
||||
"a b c d".split(),
|
||||
"l1 l2 l3 l4".split(),
|
||||
[
|
||||
[
|
||||
"a",
|
||||
"b",
|
||||
],
|
||||
["c", "d"],
|
||||
],
|
||||
["a", "c", "d"], # Missing sentence end
|
||||
[[0], [], [1], [2]],
|
||||
[[0], [2], [3]],
|
||||
[["a"], ["c", "d"]],
|
||||
[["l1"], ["l3", "l4"]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_propagate_labels_sentences_text_alignment_corner_cases(
|
||||
clean_tokens,
|
||||
clean_labels,
|
||||
clean_sentences,
|
||||
ocr_tokens,
|
||||
mock_gt_to_ocr_mapping,
|
||||
mock_ocr_to_gt_mapping,
|
||||
desired_sentences,
|
||||
desired_labels,
|
||||
):
|
||||
with patch("genalog.text.alignment.parse_alignment") as mock_alignment:
|
||||
mock_alignment.return_value = (mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping)
|
||||
ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens)
|
||||
(
|
||||
ocr_text_sentences,
|
||||
ocr_labels_sentences,
|
||||
) = conll_format.propagate_labels_sentences(
|
||||
clean_tokens, clean_labels, clean_sentences, ocr_tokens
|
||||
)
|
||||
ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences))
|
||||
assert len(ocr_text_sentences) == len(clean_sentences)
|
||||
assert len(ocr_text_sentences) == len(ocr_labels_sentences)
|
||||
assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens
|
||||
assert len(ocr_sentences_flatten) == len(
|
||||
ocr_tokens
|
||||
) # ensure aligned ocr tokens == ocr tokens
|
||||
if desired_sentences != ocr_text_sentences:
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}"
|
||||
)
|
||||
)
|
||||
if desired_labels != ocr_labels_sentences:
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"))
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}"
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("s, desired_output", [
|
||||
("", []),
|
||||
("\n\n", []),
|
||||
("a1\tb1\na2\tb2", [["a1", "a2"]]),
|
||||
("a1\tb1\n\na2\tb2", [["a1"], ["a2"]]),
|
||||
("\n\n\na1\tb1\n\na2\tb2\n\n\n", [["a1"], ["a2"]]),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s, desired_output",
|
||||
[
|
||||
("", []),
|
||||
("\n\n", []),
|
||||
("a1\tb1\na2\tb2", [["a1", "a2"]]),
|
||||
("a1\tb1\n\na2\tb2", [["a1"], ["a2"]]),
|
||||
("\n\n\na1\tb1\n\na2\tb2\n\n\n", [["a1"], ["a2"]]),
|
||||
],
|
||||
)
|
||||
def test_get_sentences_from_iob_format(s, desired_output):
|
||||
output = conll_format.get_sentences_from_iob_format(s.splitlines(True))
|
||||
assert desired_output == output
|
||||
assert desired_output == output
|
||||
|
|
|
@ -1,37 +1,54 @@
|
|||
from genalog.text.lcs import LCS
|
||||
|
||||
import pytest
|
||||
|
||||
@pytest.fixture(params=[
|
||||
("", ""), # empty
|
||||
("abcde", "ace"), # naive case
|
||||
])
|
||||
from genalog.text.lcs import LCS
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("", ""), # empty
|
||||
("abcde", "ace"), # naive case
|
||||
]
|
||||
)
|
||||
def lcs(request):
|
||||
str1, str2 = request.param
|
||||
return LCS(str1, str2)
|
||||
|
||||
|
||||
def test_lcs_init(lcs):
|
||||
assert lcs._lcs_len is not None
|
||||
assert lcs._lcs is not None
|
||||
|
||||
@pytest.mark.parametrize("str1, str2, expected_len, expected_lcs", [
|
||||
("", "", 0, ""), # empty
|
||||
("abc", "abc", 3, "abc"),
|
||||
("abcde", "ace", 3, "ace"), # naive case
|
||||
("a", "", 0, ""), # no results
|
||||
("abc", "cba", 1, "c"), # multiple cases
|
||||
("abcdgh", "aedfhr", 3, "adh"),
|
||||
("abc.!\t\nd", "dxab", 2, "ab"), # with punctuations
|
||||
("New York @", "New @ York", len("New York"), "New York"), # with space-separated, tokens
|
||||
("Is A Big City", "A Big City Is", len("A Big City"), "A Big City"),
|
||||
("Is A Big City", "City Big Is A", len(" Big "), " Big "), # reversed order
|
||||
# mixed order with similar tokens
|
||||
("Is A Big City IS", "IS Big A City Is", len("I Big City I"), "I Big City I"),
|
||||
# casing
|
||||
("Is A Big City IS a", "IS a Big City Is A", len("I Big City I "), "I Big City I "),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"str1, str2, expected_len, expected_lcs",
|
||||
[
|
||||
("", "", 0, ""), # empty
|
||||
("abc", "abc", 3, "abc"),
|
||||
("abcde", "ace", 3, "ace"), # naive case
|
||||
("a", "", 0, ""), # no results
|
||||
("abc", "cba", 1, "c"), # multiple cases
|
||||
("abcdgh", "aedfhr", 3, "adh"),
|
||||
("abc.!\t\nd", "dxab", 2, "ab"), # with punctuations
|
||||
(
|
||||
"New York @",
|
||||
"New @ York",
|
||||
len("New York"),
|
||||
"New York",
|
||||
), # with space-separated, tokens
|
||||
("Is A Big City", "A Big City Is", len("A Big City"), "A Big City"),
|
||||
("Is A Big City", "City Big Is A", len(" Big "), " Big "), # reversed order
|
||||
# mixed order with similar tokens
|
||||
("Is A Big City IS", "IS Big A City Is", len("I Big City I"), "I Big City I"),
|
||||
# casing
|
||||
(
|
||||
"Is A Big City IS a",
|
||||
"IS a Big City Is A",
|
||||
len("I Big City I "),
|
||||
"I Big City I ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_lcs_e2e(str1, str2, expected_len, expected_lcs):
|
||||
lcs = LCS(str1, str2)
|
||||
assert expected_lcs == lcs.get_str()
|
||||
assert expected_len == lcs.get_len()
|
||||
|
|
@ -1,171 +1,295 @@
|
|||
import pytest
|
||||
|
||||
from genalog.text import ner_label
|
||||
from genalog.text import alignment
|
||||
from tests.cases.label_propagation import LABEL_PROPAGATION_REGRESSION_TEST_CASES
|
||||
|
||||
import pytest
|
||||
import string
|
||||
|
||||
@pytest.mark.parametrize("label, desired_output", [
|
||||
# Positive Cases
|
||||
("B-org", True), (" B-org ", True), #whitespae tolerant
|
||||
("\tB-ORG\n", True),
|
||||
# Negative Cases
|
||||
("I-ORG", False), ("O", False), ("other-B-label", False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("B-org", True),
|
||||
(" B-org ", True), # whitespae tolerant
|
||||
("\tB-ORG\n", True),
|
||||
# Negative Cases
|
||||
("I-ORG", False),
|
||||
("O", False),
|
||||
("other-B-label", False),
|
||||
],
|
||||
)
|
||||
def test__is_begin_label(label, desired_output):
|
||||
output = ner_label._is_begin_label(label)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("label, desired_output", [
|
||||
# Positive Cases
|
||||
("I-ORG", True), (" \t I-ORG ", True),
|
||||
# Negative Cases
|
||||
("O", False), ("B-LOC", False),("B-ORG", False),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("I-ORG", True),
|
||||
(" \t I-ORG ", True),
|
||||
# Negative Cases
|
||||
("O", False),
|
||||
("B-LOC", False),
|
||||
("B-ORG", False),
|
||||
],
|
||||
)
|
||||
def test__is_inside_label(label, desired_output):
|
||||
output = ner_label._is_inside_label(label)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("label, desired_output", [
|
||||
# Positive Cases
|
||||
("I-ORG", True), ("B-ORG", True),
|
||||
# Negative Cases
|
||||
("O", False)
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("I-ORG", True),
|
||||
("B-ORG", True),
|
||||
# Negative Cases
|
||||
("O", False),
|
||||
],
|
||||
)
|
||||
def test__is_multi_token_label(label, desired_output):
|
||||
output = ner_label._is_multi_token_label(label)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("label, desired_output", [
|
||||
# Positive Cases
|
||||
("I-Place", "B-Place"), (" \t I-place ", "B-place"),
|
||||
# Negative Cases
|
||||
("O", "O"), ("B-LOC", "B-LOC"), (" B-ORG ", " B-ORG ")
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("I-Place", "B-Place"),
|
||||
(" \t I-place ", "B-place"),
|
||||
# Negative Cases
|
||||
("O", "O"),
|
||||
("B-LOC", "B-LOC"),
|
||||
(" B-ORG ", " B-ORG "),
|
||||
],
|
||||
)
|
||||
def test__convert_to_begin_label(label, desired_output):
|
||||
output = ner_label._convert_to_begin_label(label)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("label, desired_output", [
|
||||
# Positive Cases
|
||||
("B-LOC", "I-LOC"),
|
||||
(" B-ORG ", "I-ORG"),
|
||||
# Negative Cases
|
||||
("", ""), ("O", "O"), ("I-Place", "I-Place"),
|
||||
(" \t I-place ", " \t I-place ")
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("B-LOC", "I-LOC"),
|
||||
(" B-ORG ", "I-ORG"),
|
||||
# Negative Cases
|
||||
("", ""),
|
||||
("O", "O"),
|
||||
("I-Place", "I-Place"),
|
||||
(" \t I-place ", " \t I-place "),
|
||||
],
|
||||
)
|
||||
def test__convert_to_inside_label(label, desired_output):
|
||||
output = ner_label._convert_to_inside_label(label)
|
||||
assert output == desired_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("begin_label, inside_label, desired_output", [
|
||||
# Positive Cases
|
||||
("", "I-LOC", True),
|
||||
("B-LOC", "I-ORG", True),
|
||||
("", "I-ORG", True),
|
||||
# Negative Cases
|
||||
("", "", False), ("O", "O", False), ("", "", False),
|
||||
("B-LOC", "O", False),
|
||||
("B-LOC", "B-ORG", False),
|
||||
("B-LOC", "I-LOC", False),
|
||||
(" B-ORG ", "I-ORG", False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"begin_label, inside_label, desired_output",
|
||||
[
|
||||
# Positive Cases
|
||||
("", "I-LOC", True),
|
||||
("B-LOC", "I-ORG", True),
|
||||
("", "I-ORG", True),
|
||||
# Negative Cases
|
||||
("", "", False),
|
||||
("O", "O", False),
|
||||
("", "", False),
|
||||
("B-LOC", "O", False),
|
||||
("B-LOC", "B-ORG", False),
|
||||
("B-LOC", "I-LOC", False),
|
||||
(" B-ORG ", "I-ORG", False),
|
||||
],
|
||||
)
|
||||
def test__is_missing_begin_label(begin_label, inside_label, desired_output):
|
||||
output = ner_label._is_missing_begin_label(begin_label, inside_label)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_input_char_set", [
|
||||
(["a","b"], ["c", "d"], set("abcd")),
|
||||
(["New", "York"], ["is", "big"], set("NewYorkisbig")),
|
||||
(["word1", "word2"], ["word1", "word2"], set("word12")),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_tokens, ocr_tokens, desired_input_char_set",
|
||||
[
|
||||
(["a", "b"], ["c", "d"], set("abcd")),
|
||||
(["New", "York"], ["is", "big"], set("NewYorkisbig")),
|
||||
(["word1", "word2"], ["word1", "word2"], set("word12")),
|
||||
],
|
||||
)
|
||||
def test__find_gap_char_candidates(gt_tokens, ocr_tokens, desired_input_char_set):
|
||||
gap_char_candidates, input_char_set = ner_label._find_gap_char_candidates(gt_tokens, ocr_tokens)
|
||||
gap_char_candidates, input_char_set = ner_label._find_gap_char_candidates(
|
||||
gt_tokens, ocr_tokens
|
||||
)
|
||||
assert input_char_set == desired_input_char_set
|
||||
assert ner_label.GAP_CHAR_SET.difference(input_char_set) == gap_char_candidates
|
||||
|
||||
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, raised_exception",
|
||||
[
|
||||
(["o"], ["New York"], ["NewYork"], ValueError), # non-atomic gt_token
|
||||
(["o"], ["NewYork"], ["New York"], ValueError), # non-atomic ocr_token
|
||||
(["o"], [" @ New"], ["@ @"], ValueError), # non-atomic tokens with GAP_CHAR
|
||||
(["o", "o"], ["New"], ["New"], ValueError), # num gt_labels != num gt_tokens
|
||||
(["o"], ["@"], ["New"], ner_label.GapCharError), # invalid token with gap char only (gt_token)
|
||||
(["o"], ["New"], ["@"], ner_label.GapCharError), # invalid token with gap char only (ocr_token)
|
||||
(["o", "o"], ["New", "@"], ["New", "@"], ner_label.GapCharError), # invalid token (both)
|
||||
(["o"], [" \n\t@@"], ["New"], ner_label.GapCharError), # invalid token with gap char and space chars (gt_token)
|
||||
(["o"], ["New"], [" \n\t@"], ner_label.GapCharError), # invalid token with gap char and space chars (ocr_token)
|
||||
(["o"], [""], ["New"], ValueError), # invalid token: empty string (gt_token)
|
||||
(["o"], ["New"], [""], ValueError), # invalid token: empty string (ocr_token)
|
||||
(["o"], [" \n\t"], ["New"], ValueError), # invalid token: space characters only (gt_token)
|
||||
(["o"], ["New"], [" \n\t"], ValueError), # invalid token: space characters only (ocr_token)
|
||||
(["o"], ["New"], ["New"], None), # positive case
|
||||
(["o"], ["New@"], ["New"], None), # positive case with gap char
|
||||
(["o"], ["New"], ["@@New"], None), # positive case with gap char
|
||||
])
|
||||
def test__propagate_label_to_ocr_error(gt_labels, gt_tokens, ocr_tokens, raised_exception):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_labels, gt_tokens, ocr_tokens, raised_exception",
|
||||
[
|
||||
(["o"], ["New York"], ["NewYork"], ValueError), # non-atomic gt_token
|
||||
(["o"], ["NewYork"], ["New York"], ValueError), # non-atomic ocr_token
|
||||
(["o"], [" @ New"], ["@ @"], ValueError), # non-atomic tokens with GAP_CHAR
|
||||
(["o", "o"], ["New"], ["New"], ValueError), # num gt_labels != num gt_tokens
|
||||
(
|
||||
["o"],
|
||||
["@"],
|
||||
["New"],
|
||||
ner_label.GapCharError,
|
||||
), # invalid token with gap char only (gt_token)
|
||||
(
|
||||
["o"],
|
||||
["New"],
|
||||
["@"],
|
||||
ner_label.GapCharError,
|
||||
), # invalid token with gap char only (ocr_token)
|
||||
(
|
||||
["o", "o"],
|
||||
["New", "@"],
|
||||
["New", "@"],
|
||||
ner_label.GapCharError,
|
||||
), # invalid token (both)
|
||||
(
|
||||
["o"],
|
||||
[" \n\t@@"],
|
||||
["New"],
|
||||
ner_label.GapCharError,
|
||||
), # invalid token with gap char and space chars (gt_token)
|
||||
(
|
||||
["o"],
|
||||
["New"],
|
||||
[" \n\t@"],
|
||||
ner_label.GapCharError,
|
||||
), # invalid token with gap char and space chars (ocr_token)
|
||||
(["o"], [""], ["New"], ValueError), # invalid token: empty string (gt_token)
|
||||
(["o"], ["New"], [""], ValueError), # invalid token: empty string (ocr_token)
|
||||
(
|
||||
["o"],
|
||||
[" \n\t"],
|
||||
["New"],
|
||||
ValueError,
|
||||
), # invalid token: space characters only (gt_token)
|
||||
(
|
||||
["o"],
|
||||
["New"],
|
||||
[" \n\t"],
|
||||
ValueError,
|
||||
), # invalid token: space characters only (ocr_token)
|
||||
(["o"], ["New"], ["New"], None), # positive case
|
||||
(["o"], ["New@"], ["New"], None), # positive case with gap char
|
||||
(["o"], ["New"], ["@@New"], None), # positive case with gap char
|
||||
],
|
||||
)
|
||||
def test__propagate_label_to_ocr_error(
|
||||
gt_labels, gt_tokens, ocr_tokens, raised_exception
|
||||
):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char='@')
|
||||
ner_label._propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens, gap_char="@"
|
||||
)
|
||||
else:
|
||||
ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char='@')
|
||||
ner_label._propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens, gap_char="@"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test__propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
|
||||
gap_char_candidates, _ = ner_label._find_gap_char_candidates(gt_tokens, ocr_tokens)
|
||||
# run regression test for each GAP_CHAR candidate to make sure
|
||||
# run regression test for each GAP_CHAR candidate to make sure
|
||||
# label propagate is function correctly
|
||||
for gap_char in gap_char_candidates:
|
||||
ocr_labels, _, _, _ = ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char)
|
||||
for gap_char in gap_char_candidates:
|
||||
ocr_labels, _, _, _ = ner_label._propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char
|
||||
)
|
||||
assert ocr_labels == desired_ocr_labels
|
||||
|
||||
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, raised_exception", [
|
||||
(["o"], ["New"], ["New"], None), # positive case
|
||||
(["o"], ["New@"], ["New"], None), # positive case with gap char
|
||||
(["o"], ["New"], ["@@New"], None), # positive case with gap char
|
||||
(["o"], list(ner_label.GAP_CHAR_SET), [""], ner_label.GapCharError), # input char set == GAP_CHAR_SET
|
||||
(["o"], [""], list(ner_label.GAP_CHAR_SET), ner_label.GapCharError), # input char set == GAP_CHAR_SET
|
||||
# all possible gap chars set split between ocr and gt tokens
|
||||
(["o"], list(ner_label.GAP_CHAR_SET)[:10], list(ner_label.GAP_CHAR_SET)[10:], ner_label.GapCharError),
|
||||
])
|
||||
def test_propagate_label_to_ocr_error(gt_labels, gt_tokens, ocr_tokens, raised_exception):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_labels, gt_tokens, ocr_tokens, raised_exception",
|
||||
[
|
||||
(["o"], ["New"], ["New"], None), # positive case
|
||||
(["o"], ["New@"], ["New"], None), # positive case with gap char
|
||||
(["o"], ["New"], ["@@New"], None), # positive case with gap char
|
||||
(
|
||||
["o"],
|
||||
list(ner_label.GAP_CHAR_SET),
|
||||
[""],
|
||||
ner_label.GapCharError,
|
||||
), # input char set == GAP_CHAR_SET
|
||||
(
|
||||
["o"],
|
||||
[""],
|
||||
list(ner_label.GAP_CHAR_SET),
|
||||
ner_label.GapCharError,
|
||||
), # input char set == GAP_CHAR_SET
|
||||
# all possible gap chars set split between ocr and gt tokens
|
||||
(
|
||||
["o"],
|
||||
list(ner_label.GAP_CHAR_SET)[:10],
|
||||
list(ner_label.GAP_CHAR_SET)[10:],
|
||||
ner_label.GapCharError,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_propagate_label_to_ocr_error(
|
||||
gt_labels, gt_tokens, ocr_tokens, raised_exception
|
||||
):
|
||||
if raised_exception:
|
||||
with pytest.raises(raised_exception):
|
||||
ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
|
||||
else:
|
||||
ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
|
||||
|
||||
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
|
||||
ocr_labels, _, _, _ = ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
|
||||
ocr_labels, _, _, _ = ner_label.propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens
|
||||
)
|
||||
assert ocr_labels == desired_ocr_labels
|
||||
|
||||
@pytest.mark.parametrize("tokens, labels, label_top, desired_output",
|
||||
[
|
||||
(
|
||||
["New", "York", "is", "big"],
|
||||
["B-place", "I-place", "o", "o"],
|
||||
True,
|
||||
"B-place I-place o o \n" +
|
||||
"New York is big \n"
|
||||
),
|
||||
(
|
||||
["New", "York", "is", "big"],
|
||||
["B-place", "I-place", "o", "o"],
|
||||
False,
|
||||
"New York is big \n" +
|
||||
"B-place I-place o o \n"
|
||||
)
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, labels, label_top, desired_output",
|
||||
[
|
||||
(
|
||||
["New", "York", "is", "big"],
|
||||
["B-place", "I-place", "o", "o"],
|
||||
True,
|
||||
"B-place I-place o o \n" + "New York is big \n",
|
||||
),
|
||||
(
|
||||
["New", "York", "is", "big"],
|
||||
["B-place", "I-place", "o", "o"],
|
||||
False,
|
||||
"New York is big \n" + "B-place I-place o o \n",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_format_label(tokens, labels, label_top, desired_output):
|
||||
output = ner_label.format_labels(tokens, labels, label_top=label_top)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels",
|
||||
LABEL_PROPAGATION_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_format_gt_ocr_w_labels(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels):
|
||||
ocr_labels, aligned_gt, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens)
|
||||
ner_label.format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, aligned_gt, aligned_ocr)
|
||||
ocr_labels, aligned_gt, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(
|
||||
gt_labels, gt_tokens, ocr_tokens
|
||||
)
|
||||
ner_label.format_label_propagation(
|
||||
gt_tokens, gt_labels, ocr_tokens, ocr_labels, aligned_gt, aligned_ocr
|
||||
)
|
||||
|
|
|
@ -1,124 +1,148 @@
|
|||
from genalog.text import preprocess
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize("token, replacement, desired_output",
|
||||
[
|
||||
("", "_", ""), # Do nothing to empty string
|
||||
(" ", "_", " "), # Do nothing to whitespaces
|
||||
(" \n\t", "_", " \n\t"),
|
||||
("ascii", "_", "ascii"),
|
||||
("a s\nc\tii", "_", "a s\nc\tii"),
|
||||
("ascii·", "_", "ascii"), # Tokens with non-ASCII values
|
||||
("·", "_", "_"), # Tokens with non-ASCII values
|
||||
])
|
||||
from genalog.text import preprocess
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"token, replacement, desired_output",
|
||||
[
|
||||
("", "_", ""), # Do nothing to empty string
|
||||
(" ", "_", " "), # Do nothing to whitespaces
|
||||
(" \n\t", "_", " \n\t"),
|
||||
("ascii", "_", "ascii"),
|
||||
("a s\nc\tii", "_", "a s\nc\tii"),
|
||||
("ascii·", "_", "ascii"), # Tokens with non-ASCII values
|
||||
("·", "_", "_"), # Tokens with non-ASCII values
|
||||
],
|
||||
)
|
||||
def test_remove_non_ascii(token, replacement, desired_output):
|
||||
for code in range(128, 1000): # non-ASCII values
|
||||
for code in range(128, 1000): # non-ASCII values
|
||||
token.replace("·", chr(code))
|
||||
output = preprocess.remove_non_ascii(token, replacement)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("s, desired_output",
|
||||
[
|
||||
(
|
||||
" New \t \n",
|
||||
["New"]
|
||||
),
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"s, desired_output",
|
||||
[
|
||||
(" New \t \n", ["New"]),
|
||||
# Mixed in gap char "@"
|
||||
(
|
||||
" @ @",
|
||||
["@", "@"]
|
||||
),
|
||||
(
|
||||
"New York is big",
|
||||
["New", "York", "is", "big"]
|
||||
),
|
||||
(" @ @", ["@", "@"]),
|
||||
("New York is big", ["New", "York", "is", "big"]),
|
||||
# Mixed multiple spaces and tabs
|
||||
(
|
||||
" New York \t is \t big",
|
||||
["New", "York", "is", "big"]
|
||||
),
|
||||
(" New York \t is \t big", ["New", "York", "is", "big"]),
|
||||
# Mixed in punctuation
|
||||
(
|
||||
"New .York is, big !",
|
||||
["New", ".York", "is,", "big", "!"]
|
||||
),
|
||||
("New .York is, big !", ["New", ".York", "is,", "big", "!"]),
|
||||
# Mixed in gap char "@"
|
||||
(
|
||||
"@N@ew York@@@is,\t big@@@@@",
|
||||
["@N@ew", "York@@@is,", "big@@@@@"]
|
||||
)
|
||||
])
|
||||
("@N@ew York@@@is,\t big@@@@@", ["@N@ew", "York@@@is,", "big@@@@@"]),
|
||||
],
|
||||
)
|
||||
def test_tokenize(s, desired_output):
|
||||
output = preprocess.tokenize(s)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("tokens, desired_output",
|
||||
[
|
||||
(
|
||||
["New", "York", "is", "big"],
|
||||
"New York is big",
|
||||
),
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, desired_output",
|
||||
[
|
||||
(
|
||||
["New", "York", "is", "big"],
|
||||
"New York is big",
|
||||
),
|
||||
# Mixed in punctuation
|
||||
(
|
||||
["New", ".York", "is,", "big", "!"],
|
||||
"New .York is, big !",
|
||||
),
|
||||
(
|
||||
["New", ".York", "is,", "big", "!"],
|
||||
"New .York is, big !",
|
||||
),
|
||||
# Mixed in gap char "@"
|
||||
(
|
||||
["@N@ew", "York@@@is,", "big@@@@@"],
|
||||
"@N@ew York@@@is, big@@@@@",
|
||||
)
|
||||
])
|
||||
(
|
||||
["@N@ew", "York@@@is,", "big@@@@@"],
|
||||
"@N@ew York@@@is, big@@@@@",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_join_tokens(tokens, desired_output):
|
||||
output = preprocess.join_tokens(tokens)
|
||||
assert output == desired_output
|
||||
|
||||
@pytest.mark.parametrize("c, desired_output",
|
||||
[
|
||||
# Gap char
|
||||
(GAP_CHAR, False),
|
||||
# Alphabet char
|
||||
('a', False), ('A', False),
|
||||
# Punctuation
|
||||
('.', False), ('!', False), (',', False), ('-', False),
|
||||
# Token separators
|
||||
(' ', True), ('\n', True), ('\t', True)
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"c, desired_output",
|
||||
[
|
||||
# Gap char
|
||||
(GAP_CHAR, False),
|
||||
# Alphabet char
|
||||
("a", False),
|
||||
("A", False),
|
||||
# Punctuation
|
||||
(".", False),
|
||||
("!", False),
|
||||
(",", False),
|
||||
("-", False),
|
||||
# Token separators
|
||||
(" ", True),
|
||||
("\n", True),
|
||||
("\t", True),
|
||||
],
|
||||
)
|
||||
def test__is_spacing(c, desired_output):
|
||||
assert desired_output == preprocess._is_spacing(c)
|
||||
|
||||
@pytest.mark.parametrize("text, desired_output", [
|
||||
("", ""),
|
||||
("w .", "w ."), ("w !", "w !"), ("w ?", "w ?"),
|
||||
("w /.", "w /."), ("w /!", "w /!"), ("w /?", "w /?"),
|
||||
("w1 , w2 .", "w1 , w2 ."),
|
||||
("w1 . w2 .", "w1 . \nw2 ."), ("w1 /. w2 /.", "w1 /. \nw2 /."),
|
||||
("w1 ! w2 .", "w1 ! \nw2 ."), ("w1 /! w2 /.", "w1 /! \nw2 /."),
|
||||
("w1 ? w2 .", "w1 ? \nw2 ."), ("w1 /? w2 /.", "w1 /? \nw2 /."),
|
||||
("U.S. . w2 .", "U.S. . \nw2 ."),
|
||||
("w1 ??? w2 .", "w1 ??? w2 ."), # not splitting
|
||||
("w1 !!! w2 .", "w1 !!! w2 ."),
|
||||
("w1 ... . w2 .", "w1 ... . \nw2 ."),
|
||||
("w1 ... /. w2 /.", "w1 ... /. \nw2 /."),
|
||||
("w1 /. /. w2 .", "w1 /. /. \nw2 ."),
|
||||
("w1 /. /.", "w1 /. \n/."),
|
||||
("w1 /. /. ", "w1 /. /. \n"),
|
||||
("w1 ? ? ? ? w2 .", "w1 ? ? ? ? \nw2 ."),
|
||||
("w1 /? /? /? /? w2 /.", "w1 /? /? /? /? \nw2 /."),
|
||||
("w1 ! ! ! ! w2 .", "w1 ! ! ! ! \nw2 ."),
|
||||
("w1 /! /! /! /! w2 /.", "w1 /! /! /! /! \nw2 /."),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, desired_output",
|
||||
[
|
||||
("", ""),
|
||||
("w .", "w ."),
|
||||
("w !", "w !"),
|
||||
("w ?", "w ?"),
|
||||
("w /.", "w /."),
|
||||
("w /!", "w /!"),
|
||||
("w /?", "w /?"),
|
||||
("w1 , w2 .", "w1 , w2 ."),
|
||||
("w1 . w2 .", "w1 . \nw2 ."),
|
||||
("w1 /. w2 /.", "w1 /. \nw2 /."),
|
||||
("w1 ! w2 .", "w1 ! \nw2 ."),
|
||||
("w1 /! w2 /.", "w1 /! \nw2 /."),
|
||||
("w1 ? w2 .", "w1 ? \nw2 ."),
|
||||
("w1 /? w2 /.", "w1 /? \nw2 /."),
|
||||
("U.S. . w2 .", "U.S. . \nw2 ."),
|
||||
("w1 ??? w2 .", "w1 ??? w2 ."), # not splitting
|
||||
("w1 !!! w2 .", "w1 !!! w2 ."),
|
||||
("w1 ... . w2 .", "w1 ... . \nw2 ."),
|
||||
("w1 ... /. w2 /.", "w1 ... /. \nw2 /."),
|
||||
("w1 /. /. w2 .", "w1 /. /. \nw2 ."),
|
||||
("w1 /. /.", "w1 /. \n/."),
|
||||
("w1 /. /. ", "w1 /. /. \n"),
|
||||
("w1 ? ? ? ? w2 .", "w1 ? ? ? ? \nw2 ."),
|
||||
("w1 /? /? /? /? w2 /.", "w1 /? /? /? /? \nw2 /."),
|
||||
("w1 ! ! ! ! w2 .", "w1 ! ! ! ! \nw2 ."),
|
||||
("w1 /! /! /! /! w2 /.", "w1 /! /! /! /! \nw2 /."),
|
||||
],
|
||||
)
|
||||
def test_split_sentences(text, desired_output):
|
||||
assert desired_output == preprocess.split_sentences(text)
|
||||
|
||||
@pytest.mark.parametrize("token, desired_output", [
|
||||
("", False), (" ", False), ("\n", False), ("\t", False),
|
||||
(" \n \t", False),
|
||||
("...", False),
|
||||
("???", False), ("!!!", False),
|
||||
(".", True), ("!", True), ("?", True),
|
||||
("/.", True), ("/!", True), ("/?", True),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"token, desired_output",
|
||||
[
|
||||
("", False),
|
||||
(" ", False),
|
||||
("\n", False),
|
||||
("\t", False),
|
||||
(" \n \t", False),
|
||||
("...", False),
|
||||
("???", False),
|
||||
("!!!", False),
|
||||
(".", True),
|
||||
("!", True),
|
||||
("?", True),
|
||||
("/.", True),
|
||||
("/!", True),
|
||||
("/?", True),
|
||||
],
|
||||
)
|
||||
def test_is_sentence_separator(token, desired_output):
|
||||
assert desired_output == preprocess.is_sentence_separator(token)
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
import random
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from genalog.text import alignment
|
||||
from genalog.text.alignment import GAP_CHAR
|
||||
from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES
|
||||
|
||||
|
||||
def random_utf8_char(byte_len=1):
|
||||
if byte_len == 1:
|
||||
return chr(random.randint(0,0x007F))
|
||||
return chr(random.randint(0, 0x007F))
|
||||
elif byte_len == 2:
|
||||
return chr(random.randint(0x007F, 0x07FF))
|
||||
elif byte_len == 3:
|
||||
|
@ -16,33 +18,61 @@ def random_utf8_char(byte_len=1):
|
|||
elif byte_len == 4:
|
||||
return chr(random.randint(0xFFFF, 0x10FFFF))
|
||||
else:
|
||||
raise ValueError(f"Invalid byte length: {byte_len}." +
|
||||
"utf-8 does not encode characters with more than 4 bytes in length")
|
||||
raise ValueError(
|
||||
f"Invalid byte length: {byte_len}."
|
||||
+ "utf-8 does not encode characters with more than 4 bytes in length"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("num_utf_char_to_test", [100]) # Number of char per byte length
|
||||
@pytest.mark.parametrize("byte_len", [1,2,3,4]) # UTF does not encode with more than 4 bytes
|
||||
@pytest.mark.parametrize("gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", ALIGNMENT_REGRESSION_TEST_CASES)
|
||||
def test_align(num_utf_char_to_test, byte_len, gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise):
|
||||
|
||||
invalid_char = set(gt_txt).union(set(GAP_CHAR)) # character to replace to cannot be in this set
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_utf_char_to_test", [100]
|
||||
) # Number of char per byte length
|
||||
@pytest.mark.parametrize(
|
||||
"byte_len", [1, 2, 3, 4]
|
||||
) # UTF does not encode with more than 4 bytes
|
||||
@pytest.mark.parametrize(
|
||||
"gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise",
|
||||
ALIGNMENT_REGRESSION_TEST_CASES,
|
||||
)
|
||||
def test_align(
|
||||
num_utf_char_to_test,
|
||||
byte_len,
|
||||
gt_txt,
|
||||
noisy_txt,
|
||||
expected_aligned_gt,
|
||||
expected_aligned_noise,
|
||||
):
|
||||
|
||||
invalid_char = set(gt_txt).union(
|
||||
set(GAP_CHAR)
|
||||
) # character to replace to cannot be in this set
|
||||
for _ in range(num_utf_char_to_test):
|
||||
utf_char = random_utf8_char(byte_len)
|
||||
while utf_char in invalid_char: # find a utf char not in the input string and not GAP_CHAR
|
||||
while (
|
||||
utf_char in invalid_char
|
||||
): # find a utf char not in the input string and not GAP_CHAR
|
||||
utf_char = random_utf8_char(byte_len)
|
||||
char_to_replace = random.choice(list(invalid_char)) if gt_txt else ""
|
||||
|
||||
gt_txt_sub = gt_txt.replace(char_to_replace, utf_char)
|
||||
noisy_txt_sub = noisy_txt.replace(char_to_replace, utf_char)
|
||||
gt_txt.replace(char_to_replace, utf_char)
|
||||
noisy_txt.replace(char_to_replace, utf_char)
|
||||
expected_aligned_gt_sub = expected_aligned_gt.replace(char_to_replace, utf_char)
|
||||
expected_aligned_noise_sub = expected_aligned_noise.replace(char_to_replace, utf_char)
|
||||
|
||||
expected_aligned_noise_sub = expected_aligned_noise.replace(
|
||||
char_to_replace, utf_char
|
||||
)
|
||||
|
||||
# Run alignment
|
||||
aligned_gt, aligned_noise = alignment.align(gt_txt, noisy_txt)
|
||||
|
||||
aligned_gt = aligned_gt.replace(char_to_replace, utf_char)
|
||||
aligned_noise = aligned_noise.replace(char_to_replace, utf_char)
|
||||
if aligned_gt != expected_aligned_gt_sub:
|
||||
expected_alignment = alignment._format_alignment(expected_aligned_gt_sub, expected_aligned_noise_sub)
|
||||
expected_alignment = alignment._format_alignment(
|
||||
expected_aligned_gt_sub, expected_aligned_noise_sub
|
||||
)
|
||||
result_alignment = alignment._format_alignment(aligned_gt, aligned_noise)
|
||||
warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"))
|
||||
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}"
|
||||
)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
; [tox]
|
||||
; envlist = flake8, py36 # add other python versions if necessary
|
||||
|
||||
; [testenv]
|
||||
; # Reading additional dependencies to run the test
|
||||
; # https://tox.readthedocs.io/en/latest/example/basic.html#depending-on-requirements-txt-or-defining-constraints
|
||||
; deps = -rdev-requirements.txt
|
||||
; commands =
|
||||
; pytest
|
||||
|
||||
; [testenv:flake8]
|
||||
; deps = flake8
|
||||
; skip_install = True
|
||||
; commands = flake8 .
|
||||
|
||||
; # Configurations for running pytest
|
||||
[pytest]
|
||||
junit_family=xunit2
|
||||
testpaths =
|
||||
tests
|
||||
addopts =
|
||||
-rsx --cov=genalog --cov-report=html --cov-report=term-missing --cov-report=xml --junitxml=junit/test-results.xml
|
||||
|
||||
[flake8]
|
||||
# Configs for flake8-import-order, see https://pypi.org/project/flake8-import-order/ for more info.
|
||||
import-order-style=edited
|
||||
application-import-names=genalog, tests
|
||||
# Native flake8 configs
|
||||
max-line-length = 140
|
||||
exclude =
|
||||
build, dist
|
||||
.env*,.venv* # local virtual environments
|
||||
.tox
|
||||
|
||||
; [mypy]
|
||||
; ignore_missing_imports = True
|
Загрузка…
Ссылка в новой задаче