From 96620b4c434bf18b5608d870d9d00da0150547c1 Mon Sep 17 00:00:00 2001 From: "Jianjie Liu (MAIDAP)" Date: Mon, 25 Jan 2021 13:56:39 -0500 Subject: [PATCH 1/5] Fix linting issues in tests --- tests/cases/label_propagation.py | 16 +- tests/cases/text_alignment.py | 148 ++--- tests/degradation/test_degrader.py | 126 ++-- tests/degradation/test_effect.py | 129 ++-- tests/e2e/test_anchor_e2e.py | 31 +- tests/e2e/test_conll_format_e2e.py | 62 +- tests/e2e/test_document_generation.py | 46 +- tests/e2e/test_generaton_n_degradation.py | 18 +- tests/e2e/test_image_channel.py | 19 +- tests/e2e/test_ocr_e2e.py | 48 +- tests/e2e/test_pipeline.py | 22 +- tests/e2e/test_splitter.py | 32 +- tests/generation/test_content.py | 22 +- tests/generation/test_document.py | 142 ++-- tests/ocr/test_metrics.py | 768 ++++++++++++++++------ tests/ocr/test_ocr.py | 59 +- tests/text/test_alignment.py | 224 ++++--- tests/text/test_anchor.py | 267 +++++--- tests/text/test_conll_format.py | 386 +++++++---- tests/text/test_lcs.py | 59 +- tests/text/test_ner_label.py | 352 ++++++---- tests/text/test_preprocess.py | 209 +++--- tests/text/test_utf8.py | 63 +- 23 files changed, 2120 insertions(+), 1128 deletions(-) diff --git a/tests/cases/label_propagation.py b/tests/cases/label_propagation.py index 9c957ad..3d9330e 100644 --- a/tests/cases/label_propagation.py +++ b/tests/cases/label_propagation.py @@ -1,10 +1,10 @@ # Test cases for genalog.text.ner_label.propagate_label_to_ocr() method. -# For READABILITY purpose, ground truth and noisy text are presented as +# For READABILITY purpose, ground truth and noisy text are presented as # a whole string, not in their tokenized format. # Notice the `propagate_label_to_ocr()` method has the contract of -# (list, list, list) -> (list, list, list) -# consuming both ground truth text and noisy text as lists of tokens. +# (list, list, list) -> (list, list, list) +# consuming both ground truth text and noisy text as lists of tokens. # We will use `genalog.text.preprocess.tokenize()` to tokenize these strings from genalog.text import preprocess @@ -106,12 +106,14 @@ desired_ocr_labels.append(["O", "B-FRUIT", "I-FRUIT", "O", "O"]) ner_labels.append(["O", "O", "ENTERTAINMENT", "O"]) gt_txt.append("@ new TV !") ns_txt.append("@ n ow T\\/ |") -desired_ocr_labels.append(["O", "O", "O", "ENTERTAINMENT" ,"O"]) +desired_ocr_labels.append(["O", "O", "O", "ENTERTAINMENT", "O"]) -# Tokenize ground truth and noisy text strings +# Tokenize ground truth and noisy text strings gt_tokens = [preprocess.tokenize(txt) for txt in gt_txt] ns_tokens = [preprocess.tokenize(txt) for txt in ns_txt] -# test function expect params in tuple of +# test function expect params in tuple of # (gt_label, gt_tokens, ocr_tokens, desired_ocr_labels) -LABEL_PROPAGATION_REGRESSION_TEST_CASES = list(zip(ner_labels, gt_tokens, ns_tokens, desired_ocr_labels)) +LABEL_PROPAGATION_REGRESSION_TEST_CASES = list( + zip(ner_labels, gt_tokens, ns_tokens, desired_ocr_labels) +) diff --git a/tests/cases/text_alignment.py b/tests/cases/text_alignment.py index 9fa67de..f5008fb 100644 --- a/tests/cases/text_alignment.py +++ b/tests/cases/text_alignment.py @@ -1,4 +1,3 @@ - # Initializing test cases # For extensibility, all parameters in a case are append to following arrays gt_txt = [] @@ -17,31 +16,35 @@ ns_txt.append("N ewYork kis big.") aligned_gt.append("N@ew York @is big.") aligned_ns.append("N ew@York kis big.") gt_to_noise_maps.append( -[ - # This shows that the first token in gt "New" maps to the - # first ("N") and second ("ewYork") token in the noise - [0,1], - [1], - [2], - [3] -]) + [ + # This shows that the first token in gt "New" maps to the + # first ("N") and second ("ewYork") token in the noise + [0, 1], + [1], + [2], + [3], + ] +) -noise_to_gt_maps.append([ - [0], - # Similarly, the following shows that the second token in noise "ewYork" maps to the - # first ("New") and second ("York") token in gt - [0,1], - [2], - [3] -]) + +noise_to_gt_maps.append( + [ + [0], + # Similarly, the following shows that the second token in noise "ewYork" maps to the + # first ("New") and second ("York") token in gt + [0, 1], + [2], + [3], + ] +) ############################################################################################## -# SPECIAL CASE: noisy text does not contain sufficient whitespaces to account +# SPECIAL CASE: noisy text does not contain sufficient whitespaces to account # for missing tokens # Notice there's only 1 whitespace b/w 'oston' and 'grea' # The ideal situation is that there are 2 whitespaces. Ex: -# ("B oston grea t") +# ("B oston grea t") ns_txt.append("B oston grea t") gt_txt.append("Boston is great") @@ -52,18 +55,9 @@ gt_txt.append("Boston is great") aligned_ns.append("B oston@@@ grea t") aligned_gt.append("B@oston is grea@t") -gt_to_noise_maps.append([ - [0,1], - [1], # 'is' is also mapped to 'oston' - [2,3] -]) +gt_to_noise_maps.append([[0, 1], [1], [2, 3]]) # 'is' is also mapped to 'oston' -noise_to_gt_maps.append([ - [0], - [0,1], # 'oston' is to 'Boston' and 'is' - [2], - [2] -]) +noise_to_gt_maps.append([[0], [0, 1], [2], [2]]) # 'oston' is to 'Boston' and 'is' ############################################################################################ # Empty Cases: @@ -95,18 +89,9 @@ ns_txt.append("B oston bi g") aligned_gt.append("B@oston is bi@g") aligned_ns.append("B oston@@@ bi g") -gt_to_noise_maps.append([ - [0,1], - [1], - [2,3] -]) +gt_to_noise_maps.append([[0, 1], [1], [2, 3]]) -noise_to_gt_maps.append([ - [0], - [0,1], - [2], - [2] -]) +noise_to_gt_maps.append([[0], [0, 1], [2], [2]]) ############################################################################################ gt_txt.append("New York is big.") @@ -115,17 +100,9 @@ ns_txt.append("NewYork big") aligned_gt.append("New York is big.") aligned_ns.append("New@York @@@big@") -gt_to_noise_maps.append([ - [0], - [0], - [1], - [1] -]) +gt_to_noise_maps.append([[0], [0], [1], [1]]) -noise_to_gt_maps.append([ - [0, 1], - [2, 3] -]) +noise_to_gt_maps.append([[0, 1], [2, 3]]) ############################################################################################# gt_txt.append("politicians who lag superfluous on the") @@ -134,23 +111,9 @@ ns_txt.append("politicians who kg superfluous on the") aligned_gt.append("politicians who lag superfluous on the") aligned_ns.append("politicians who @kg superfluous on the") -gt_to_noise_maps.append([ - [0], - [1], - [2], - [3], - [4], - [5] -]) +gt_to_noise_maps.append([[0], [1], [2], [3], [4], [5]]) -noise_to_gt_maps.append([ - [0], - [1], - [2], - [3], - [4], - [5] -]) +noise_to_gt_maps.append([[0], [1], [2], [3], [4], [5]]) ############################################################################################ @@ -160,20 +123,9 @@ ns_txt.append("faithei uifoimtdon the subject") aligned_gt.append("farther @informed on the subject.") aligned_ns.append("faithei ui@foimtd@on the subject@") -gt_to_noise_maps.append([ - [0], - [1], - [1], - [2], - [3] -]) +gt_to_noise_maps.append([[0], [1], [1], [2], [3]]) -noise_to_gt_maps.append([ - [0], - [1,2], - [3], - [4] -]) +noise_to_gt_maps.append([[0], [1, 2], [3], [4]]) ############################################################################################ @@ -183,20 +135,9 @@ ns_txt.append("New Yorkis big .") aligned_gt.append("New York is big .") aligned_ns.append("New York@is big .") -gt_to_noise_maps.append([ - [0], - [1], - [1], - [2], - [3] -]) +gt_to_noise_maps.append([[0], [1], [1], [2], [3]]) -noise_to_gt_maps.append([ - [0], - [1,2], - [3], - [4] -]) +noise_to_gt_maps.append([[0], [1, 2], [3], [4]]) ############################################################################################ @@ -206,23 +147,14 @@ ns_txt.append("New Yo rk is big.") aligned_gt.append("New Yo@rk is big.") aligned_ns.append("New Yo rk is big.") -gt_to_noise_maps.append([ - [0], - [1,2], - [3], - [4] -]) +gt_to_noise_maps.append([[0], [1, 2], [3], [4]]) -noise_to_gt_maps.append([ - [0], - [1], - [1], - [2], - [3] -]) +noise_to_gt_maps.append([[0], [1], [1], [2], [3]]) # Format tests for pytest # Each test expect in the following format # (aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps) -PARSE_ALIGNMENT_REGRESSION_TEST_CASES = zip(aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps) -ALIGNMENT_REGRESSION_TEST_CASES = list(zip(gt_txt, ns_txt, aligned_gt, aligned_ns)) \ No newline at end of file +PARSE_ALIGNMENT_REGRESSION_TEST_CASES = zip( + aligned_gt, aligned_ns, gt_to_noise_maps, noise_to_gt_maps +) +ALIGNMENT_REGRESSION_TEST_CASES = list(zip(gt_txt, ns_txt, aligned_gt, aligned_ns)) diff --git a/tests/degradation/test_degrader.py b/tests/degradation/test_degrader.py index c22ff14..699aeb4 100644 --- a/tests/degradation/test_degrader.py +++ b/tests/degradation/test_degrader.py @@ -6,75 +6,92 @@ import numpy as np import pytest import copy -MOCK_IMAGE_SHAPE = (4,3) +MOCK_IMAGE_SHAPE = (4, 3) MOCK_IMAGE = np.arange(12, dtype=np.uint8).reshape(MOCK_IMAGE_SHAPE) + @pytest.fixture def empty_degrader(): effects = [] return Degrader(effects) -@pytest.fixture(params=[ - [("blur", {"radius": 5})], - [("blur", {"src": ImageState.ORIGINAL_STATE, "radius": 5})], - [("blur", {"src": ImageState.CURRENT_STATE, "radius": 5})], - [ - ("morphology", {"src": ImageState.ORIGINAL_STATE,"operation": "open"}), - ("morphology", {"operation": "close"}), - ("morphology", {"src": ImageState.ORIGINAL_STATE,"operation": "dilate"}), - ("morphology", {"operation": "erode"}), - ], - [ - ("blur", {"radius": 5}), - ("bleed_through", { - "src": ImageState.CURRENT_STATE, - "alpha": 0.7, - "background": ImageState.ORIGINAL_STATE, - }), - ("morphology", { - "operation": "open", - "kernel_shape": (3,3), - "kernel_type": "ones" - }), + +@pytest.fixture( + params=[ + [("blur", {"radius": 5})], + [("blur", {"src": ImageState.ORIGINAL_STATE, "radius": 5})], + [("blur", {"src": ImageState.CURRENT_STATE, "radius": 5})], + [ + ("morphology", {"src": ImageState.ORIGINAL_STATE, "operation": "open"}), + ("morphology", {"operation": "close"}), + ("morphology", {"src": ImageState.ORIGINAL_STATE, "operation": "dilate"}), + ("morphology", {"operation": "erode"}), + ], + [ + ("blur", {"radius": 5}), + ( + "bleed_through", + { + "src": ImageState.CURRENT_STATE, + "alpha": 0.7, + "background": ImageState.ORIGINAL_STATE, + }, + ), + ( + "morphology", + {"operation": "open", "kernel_shape": (3, 3), "kernel_type": "ones"}, + ), + ], ] -]) +) def degrader(request): effects = request.param return Degrader(effects) -def test_degrader_init(empty_degrader): + +def test_empty_degrader_init(empty_degrader): assert empty_degrader.effects_to_apply == [] + def test_degrader_init(degrader): assert degrader.effects_to_apply is not [] for effect_tuple in degrader.effects_to_apply: method_name, method_kwargs = effect_tuple assert DEFAULT_METHOD_PARAM_TO_INCLUDE in method_kwargs param_value = method_kwargs[DEFAULT_METHOD_PARAM_TO_INCLUDE] - assert param_value is ImageState.ORIGINAL_STATE or param_value is ImageState.CURRENT_STATE + assert ( + param_value is ImageState.ORIGINAL_STATE + or param_value is ImageState.CURRENT_STATE + ) -@pytest.mark.parametrize("effects, error_thrown", [ - ([], None), #Empty effect - (None, TypeError), - ([("blur", {"radius": 5})], None), # Validate input - ([("not_a_func", {"radius": 5})], ValueError), # Invalid method name - ([("blur", {"not_a_argument": 5})], ValueError), # Invalid kwargs - ([("blur")], ValueError), # Missing kwargs - ( - [ - ("blur", {"radius": 5}), - ("bleed_through", {"alpha":"0.8"}), - ("morphology", {"operation": "open"}) - ], None - ), # Multiple effects - ( - [ - ("blur", {"radius": 5}), - ("bleed_through", {"not_argument":"0.8"}), - ("morphology", {"missing value"}) - ], ValueError - ), # Multiple effects -]) + +@pytest.mark.parametrize( + "effects, error_thrown", + [ + ([], None), # Empty effect + (None, TypeError), + ([("blur", {"radius": 5})], None), # Validate input + ([("not_a_func", {"radius": 5})], ValueError), # Invalid method name + ([("blur", {"not_a_argument": 5})], ValueError), # Invalid kwargs + ([("blur")], ValueError), # Missing kwargs + ( + [ + ("blur", {"radius": 5}), + ("bleed_through", {"alpha": "0.8"}), + ("morphology", {"operation": "open"}), + ], + None, + ), # Multiple effects + ( + [ + ("blur", {"radius": 5}), + ("bleed_through", {"not_argument": "0.8"}), + ("morphology", {"missing value"}), + ], + ValueError, + ), # Multiple effects + ], +) def test_degrader_validate_effects(effects, error_thrown): if error_thrown: with pytest.raises(error_thrown): @@ -82,23 +99,26 @@ def test_degrader_validate_effects(effects, error_thrown): else: Degrader.validate_effects(effects) + def test_degrader_apply_effects(degrader): method_names = [effect[0] for effect in degrader.effects_to_apply] with patch("genalog.degradation.effect") as mock_effect: - degraded = degrader.apply_effects(MOCK_IMAGE) + degrader.apply_effects(MOCK_IMAGE) for method in method_names: assert mock_effect[method].is_called() # assert degraded.shape == MOCK_IMAGE_SHAPE + def test_degrader_apply_effects_e2e(degrader): degraded = degrader.apply_effects(MOCK_IMAGE) assert degraded.shape == MOCK_IMAGE_SHAPE assert degraded.dtype == np.uint8 + def test_degrader_instructions(degrader): original_instruction = copy.deepcopy(degrader.effects_to_apply) - degraded1 = degrader.apply_effects(MOCK_IMAGE) - degraded2 = degrader.apply_effects(MOCK_IMAGE) + degrader.apply_effects(MOCK_IMAGE) + degrader.apply_effects(MOCK_IMAGE) # Make sure the degradation instructions are not altered assert len(original_instruction) == len(degrader.effects_to_apply) for i in range(len(original_instruction)): @@ -107,5 +127,5 @@ def test_degrader_instructions(degrader): assert org_method_name == method_name assert len(org_method_arg) == len(method_arg) for key in org_method_arg.keys(): - assert type(org_method_arg[key]) == type(method_arg[key]) - assert org_method_arg[key] == method_arg[key] \ No newline at end of file + assert isinstance(org_method_arg[key], type(method_arg[key])) + assert org_method_arg[key] == method_arg[key] diff --git a/tests/degradation/test_effect.py b/tests/degradation/test_effect.py index a218837..b127995 100644 --- a/tests/degradation/test_effect.py +++ b/tests/degradation/test_effect.py @@ -8,23 +8,26 @@ NEW_IMG_SHAPE = (100, 100) MOCK_IMG_SHAPE = (100, 120) MOCK_IMG = np.ones(MOCK_IMG_SHAPE, dtype=np.uint8) + def test_blur(): dst = effect.blur(MOCK_IMG, radius=3) - assert dst.dtype == np.uint8 # preverse dtype - assert dst.shape == MOCK_IMG_SHAPE # preverse image size + assert dst.dtype == np.uint8 # preverse dtype + assert dst.shape == MOCK_IMG_SHAPE # preverse image size + def test_translation(): offset_x = offset_y = 1 # Test that border pixels are not white (<255) assert all([col_pixel < 255 for col_pixel in MOCK_IMG[:, 0]]) - assert all([row_pixel < 255 for row_pixel in MOCK_IMG[0 ,:]]) + assert all([row_pixel < 255 for row_pixel in MOCK_IMG[0, :]]) dst = effect.translation(MOCK_IMG, offset_x, offset_y) # Test that border pixels are white (255) - assert all([col_pixel == 255 for col_pixel in dst[:,0]]) + assert all([col_pixel == 255 for col_pixel in dst[:, 0]]) assert all([row_pixel == 255 for row_pixel in dst[0, :]]) assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE + def test_overlay_weighted(): src = MOCK_IMG.copy() src[0][0] = 10 @@ -34,6 +37,7 @@ def test_overlay_weighted(): assert dst.shape == MOCK_IMG_SHAPE assert dst[0][0] == src[0][0] * alpha + src[0][0] * beta + def test_overlay(): src1 = MOCK_IMG.copy() src2 = MOCK_IMG.copy() @@ -45,6 +49,7 @@ def test_overlay(): assert dst[0][0] == 0 assert dst[0][1] == 1 + @patch("genalog.degradation.effect.translation") def test_bleed_through_default(mock_translation): mock_translation.return_value = MOCK_IMG @@ -53,11 +58,15 @@ def test_bleed_through_default(mock_translation): assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE -@pytest.mark.parametrize("foreground, background, error_thrown", [ - (MOCK_IMG, MOCK_IMG, None), - # Test unmatched shape - (MOCK_IMG, MOCK_IMG[:,:-1], Exception), -]) + +@pytest.mark.parametrize( + "foreground, background, error_thrown", + [ + (MOCK_IMG, MOCK_IMG, None), + # Test unmatched shape + (MOCK_IMG, MOCK_IMG[:, :-1], Exception), + ], +) def test_bleed_through_kwargs(foreground, background, error_thrown): if error_thrown: assert foreground.shape != background.shape @@ -68,97 +77,125 @@ def test_bleed_through_kwargs(foreground, background, error_thrown): assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE + def test_pepper(): dst = effect.pepper(MOCK_IMG, amount=0.1) assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE + def test_salt(): dst = effect.salt(MOCK_IMG, amount=0.1) assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE + def test_salt_then_pepper(): dst = effect.salt_then_pepper(MOCK_IMG, 0.5, 0.001) assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE + def test_pepper_then_salt(): dst = effect.pepper_then_salt(MOCK_IMG, 0.001, 0.5) assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE -@pytest.mark.parametrize("kernel_shape, kernel_type", [ - ((3,3), "NOT_VALID_TYPE"), - (1, "ones"), - ((1,2,3), "ones") -]) + +@pytest.mark.parametrize( + "kernel_shape, kernel_type", + [((3, 3), "NOT_VALID_TYPE"), (1, "ones"), ((1, 2, 3), "ones")], +) def test_create_2D_kernel_error(kernel_shape, kernel_type): with pytest.raises(Exception): effect.create_2D_kernel(kernel_shape, kernel_type) -@pytest.mark.parametrize("kernel_shape, kernel_type, expected_kernel", [ - ((2,2), "ones", np.array([[1,1],[1,1]])), # sq kernel - ((1,2), "ones", np.array([[1,1]])), # horizontal - ((2,1), "ones", np.array([[1],[1]])), # vertical - ((2,2), "upper_triangle", np.array([[1,1],[0,1]])), - ((2,2), "lower_triangle", np.array([[1,0],[1,1]])), - ((2,2), "x", np.array([[1,1],[1,1]])), - ((3,3), "x", np.array([[1,0,1],[0,1,0],[1,0,1]])), - ((2,2), "plus", np.array([[0,1],[1,1]])), - ((3,3), "plus", np.array([[0,1,0],[1,1,1],[0,1,0]])), - ((3,3), "ellipse", np.array([[0,1,0],[1,1,1],[0,1,0]])), - ((5,5), "ellipse", - np.array([ - [0, 0, 1, 0, 0], - [1, 1, 1, 1, 1], - [1, 1, 1, 1, 1], - [1, 1, 1, 1, 1], - [0, 0, 1, 0, 0] - ])), -]) + +@pytest.mark.parametrize( + "kernel_shape, kernel_type, expected_kernel", + [ + ((2, 2), "ones", np.array([[1, 1], [1, 1]])), # sq kernel + ((1, 2), "ones", np.array([[1, 1]])), # horizontal + ((2, 1), "ones", np.array([[1], [1]])), # vertical + ((2, 2), "upper_triangle", np.array([[1, 1], [0, 1]])), + ((2, 2), "lower_triangle", np.array([[1, 0], [1, 1]])), + ((2, 2), "x", np.array([[1, 1], [1, 1]])), + ((3, 3), "x", np.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]])), + ((2, 2), "plus", np.array([[0, 1], [1, 1]])), + ((3, 3), "plus", np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])), + ((3, 3), "ellipse", np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])), + ( + (5, 5), + "ellipse", + np.array( + [ + [0, 0, 1, 0, 0], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [0, 0, 1, 0, 0], + ] + ), + ), + ], +) def test_create_2D_kernel(kernel_shape, kernel_type, expected_kernel): kernel = effect.create_2D_kernel(kernel_shape, kernel_type) assert np.array_equal(kernel, expected_kernel) + def test_morphology_with_error(): INVALID_OPERATION = "NOT_A_OPERATION" with pytest.raises(ValueError): effect.morphology(MOCK_IMG, operation=INVALID_OPERATION) -@pytest.mark.parametrize("operation, kernel_shape, kernel_type", [ - ("open", (3,3), "ones"), - ("close", (3,3), "ones"), - ("dilate", (3,3), "ones"), - ("erode", (3,3), "ones"), -]) + +@pytest.mark.parametrize( + "operation, kernel_shape, kernel_type", + [ + ("open", (3, 3), "ones"), + ("close", (3, 3), "ones"), + ("dilate", (3, 3), "ones"), + ("erode", (3, 3), "ones"), + ], +) def test_morphology(operation, kernel_shape, kernel_type): - dst = effect.morphology(MOCK_IMG, - operation=operation, kernel_shape=kernel_shape, - kernel_type=kernel_type) + dst = effect.morphology( + MOCK_IMG, + operation=operation, + kernel_shape=kernel_shape, + kernel_type=kernel_type, + ) assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE -@pytest.fixture(params=["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"]) + +@pytest.fixture( + params=["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"] +) def kernel(request): - return effect.create_2D_kernel((5,5), request.param) + return effect.create_2D_kernel((5, 5), request.param) + def test_open(kernel): dst = effect.open(MOCK_IMG, kernel) assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE + def test_close(kernel): dst = effect.close(MOCK_IMG, kernel) assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE + def test_erode(kernel): dst = effect.erode(MOCK_IMG, kernel) assert dst.dtype == np.uint8 assert dst.shape == MOCK_IMG_SHAPE + def test_dilate(kernel): dst = effect.dilate(MOCK_IMG, kernel) assert dst.dtype == np.uint8 - assert dst.shape == MOCK_IMG_SHAPE \ No newline at end of file + assert dst.shape == MOCK_IMG_SHAPE diff --git a/tests/e2e/test_anchor_e2e.py b/tests/e2e/test_anchor_e2e.py index 53a369c..1c8d0ee 100644 --- a/tests/e2e/test_anchor_e2e.py +++ b/tests/e2e/test_anchor_e2e.py @@ -5,11 +5,13 @@ import pytest import difflib import warnings -@pytest.mark.parametrize("gt_file, ocr_file", + +@pytest.mark.parametrize( + "gt_file, ocr_file", zip( sorted(glob.glob("tests/text/data/gt_*.txt")), - sorted(glob.glob("tests/text/data/ocr_*.txt")) - ) + sorted(glob.glob("tests/text/data/ocr_*.txt")), + ), ) def test_align_w_anchor_and_align(gt_file, ocr_file): gt_text = open(gt_file, "r").read() @@ -21,17 +23,22 @@ def test_align_w_anchor_and_align(gt_file, ocr_file): aligned_anchor_gt = aligned_anchor_gt.split(".") aligned_gt = aligned_gt.split(".") str_diff = "\n".join(difflib.unified_diff(aligned_gt, aligned_anchor_gt)) - warnings.warn(UserWarning( - "\n"+ f"{str_diff}" + - f"\n\n**** Inconsistent Alignment Results between align() and " + - f"align_w_anchor(). Ignore this if the delta is not significant. ****\n")) + warnings.warn( + UserWarning( + "\n" + + f"{str_diff}" + + "\n\n**** Inconsistent Alignment Results between align() and " + + "align_w_anchor(). Ignore this if the delta is not significant. ****\n" + ) + ) -@pytest.mark.parametrize("gt_file, ocr_file", +@pytest.mark.parametrize( + "gt_file, ocr_file", zip( sorted(glob.glob("tests/text/data/gt_*.txt")), - sorted(glob.glob("tests/text/data/ocr_*.txt")) - ) + sorted(glob.glob("tests/text/data/ocr_*.txt")), + ), ) @pytest.mark.parametrize("max_seg_length", [25, 50, 75, 100, 150]) def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length): @@ -39,7 +46,9 @@ def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length): ocr_text = open(ocr_file, "r").read() gt_tokens = preprocess.tokenize(gt_text) ocr_tokens = preprocess.tokenize(ocr_text) - gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length) + gt_anchors, ocr_anchors = anchor.find_anchor_recur( + gt_tokens, ocr_tokens, max_seg_length=max_seg_length + ) for gt_anchor, ocr_anchor in zip(gt_anchors, ocr_anchors): # Ensure that each anchor word is the same word in both text assert gt_tokens[gt_anchor] == ocr_tokens[ocr_anchor] diff --git a/tests/e2e/test_conll_format_e2e.py b/tests/e2e/test_conll_format_e2e.py index 1826737..4f9c49a 100644 --- a/tests/e2e/test_conll_format_e2e.py +++ b/tests/e2e/test_conll_format_e2e.py @@ -1,47 +1,61 @@ -from genalog.text import conll_format -from unittest import mock - -import argparse import pytest import glob import itertools -@pytest.mark.parametrize("required_args", [ - (["tests/e2e/data/synthetic_dataset", "test_version"]) -]) -@pytest.mark.parametrize("optional_args", [ - (["--train_subset"]), - (["--test_subset"]), - (["--gt_folder", "shared"]), -]) +from genalog.text import conll_format + + +@pytest.mark.parametrize( + "required_args", [(["tests/e2e/data/synthetic_dataset", "test_version"])] +) +@pytest.mark.parametrize( + "optional_args", + [ + (["--train_subset"]), + (["--test_subset"]), + (["--gt_folder", "shared"]), + ], +) def test_conll_format(required_args, optional_args): parser = conll_format.create_parser() arg_list = required_args + optional_args args = parser.parse_args(args=arg_list) conll_format.main(args) + basepath = "tests/e2e/data/conll_formatter/" -@pytest.mark.parametrize("clean_label_filename, ocr_text_filename", + +@pytest.mark.parametrize( + "clean_label_filename, ocr_text_filename", zip( sorted(glob.glob("tests/e2e/data/conll_formatter/clean_labels/*.txt")), - sorted(glob.glob("tests/e2e/data/conll_formatter/ocr_text/*.txt")) - ) + sorted(glob.glob("tests/e2e/data/conll_formatter/ocr_text/*.txt")), + ), ) def test_propagate_labels_sentence_single_file(clean_label_filename, ocr_text_filename): - with open(clean_label_filename, 'r', encoding='utf-8') as clf: + with open(clean_label_filename, "r", encoding="utf-8") as clf: tokens_labels_str = clf.readlines() - clean_tokens = [line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2] - clean_labels = [line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2] + clean_tokens = [ + line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2 + ] + clean_labels = [ + line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2 + ] clean_sentences = conll_format.get_sentences_from_iob_format(tokens_labels_str) # read ocr tokens - with open(ocr_text_filename, 'r', encoding='utf-8') as otf: - ocr_text_str = ' '.join(otf.readlines()) - ocr_tokens = [token.strip() for token in ocr_text_str.split()] # already tokenized in data + with open(ocr_text_filename, "r", encoding="utf-8") as otf: + ocr_text_str = " ".join(otf.readlines()) + ocr_tokens = [ + token.strip() for token in ocr_text_str.split() + ] # already tokenized in data - ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens) + ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences( + clean_tokens, clean_labels, clean_sentences, ocr_tokens + ) ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences)) assert len(ocr_text_sentences) == len(clean_sentences) assert len(ocr_text_sentences) == len(ocr_labels_sentences) - assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens - + assert len(ocr_sentences_flatten) == len( + ocr_tokens + ) # ensure aligned ocr tokens == ocr tokens diff --git a/tests/e2e/test_document_generation.py b/tests/e2e/test_document_generation.py index 0a38b0a..abdb90b 100644 --- a/tests/e2e/test_document_generation.py +++ b/tests/e2e/test_document_generation.py @@ -1,10 +1,12 @@ -from genalog.generation.document import DocumentGenerator -from genalog.generation.content import CompositeContent, ContentType - import pytest import os -CONTENT = CompositeContent(["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH]) +from genalog.generation.document import DocumentGenerator +from genalog.generation.content import CompositeContent, ContentType + +CONTENT = CompositeContent( + ["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH] +) UNSUPPORTED_CONTENT_FORMAT = ["foo bar"] UNSUPPORTED_CONTENT_TYPE = CompositeContent(["foo"], [ContentType.TITLE]) @@ -21,9 +23,10 @@ CUSTOM_STYLE = { "font_family": ["Calibri", "Times"], "font_size": ["10px"], "text_align": ["right"], - "hyphenate": [True, False] + "hyphenate": [True, False], } + def test_default_template_generation(): doc_gen = DocumentGenerator() generator = doc_gen.create_generator(CONTENT, doc_gen.template_list) @@ -31,28 +34,36 @@ def test_default_template_generation(): html_str = doc.render_html() assert "Unsupported Content Type:" not in html_str assert "No content loaded" not in html_str - + + def test_default_template_generation_w_unsupported_content_format(): doc_gen = DocumentGenerator() - generator = doc_gen.create_generator(UNSUPPORTED_CONTENT_FORMAT, doc_gen.template_list) + generator = doc_gen.create_generator( + UNSUPPORTED_CONTENT_FORMAT, doc_gen.template_list + ) for doc in generator: html_str = doc.render_html() assert "No content loaded" in html_str - + + def test_default_template_generation_w_unsupported_content_type(): doc_gen = DocumentGenerator() - generator = doc_gen.create_generator(UNSUPPORTED_CONTENT_TYPE, ["text_block.html.jinja"]) + generator = doc_gen.create_generator( + UNSUPPORTED_CONTENT_TYPE, ["text_block.html.jinja"] + ) for doc in generator: html_str = doc.render_html() assert "Unsupported Content Type: ContentType.TITLE" in html_str + def test_custom_template_generation(): doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH) generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME]) doc = next(generator) result = doc.render_html() assert result == str(CONTENT) - + + def test_undefined_template_generation(): doc_gen = DocumentGenerator() assert UNDEFINED_TEMPLATE_NAME not in doc_gen.template_list @@ -60,6 +71,7 @@ def test_undefined_template_generation(): with pytest.raises(FileNotFoundError): next(generator) + def test_custom_style_template_generation(): doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH) assert len(doc_gen.styles_to_generate) == 1 @@ -70,14 +82,16 @@ def test_custom_style_template_generation(): result = doc.render_html() assert doc.styles["font_family"] == result + def test_render_pdf_and_png(): doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH) generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME]) for doc in generator: pdf_bytes = doc.render_pdf() png_bytes = doc.render_png() - assert pdf_bytes != None - assert png_bytes != None + assert pdf_bytes is not None + assert png_bytes is not None + def test_save_document_as_png(): if not os.path.exists(TEST_OUTPUT_DIR): @@ -86,15 +100,16 @@ def test_save_document_as_png(): generator = doc_gen.create_generator(CONTENT, [CUSTOM_TEMPLATE_NAME]) for doc in generator: doc.render_png(target=FILE_DESTINATION, resolution=100) - # Check if the document is saved in filepath + # Check if the document is saved in filepath assert os.path.exists(FILE_DESTINATION) + def test_save_document_as_separate_png(): if not os.path.exists(TEST_OUTPUT_DIR): os.mkdir(TEST_OUTPUT_DIR) doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH) generator = doc_gen.create_generator(CONTENT, [MULTI_PAGE_TEMPLATE_NAME]) - + document = next(generator) document.render_png(target=FILE_DESTINATION, split_pages=True, resolution=100) # Check if the document is saved as separated .png files @@ -102,6 +117,7 @@ def test_save_document_as_separate_png(): printed_doc_name = FILE_DESTINATION.replace(".png", f"_pg_{page_num}.png") assert os.path.exists(printed_doc_name) + def test_overwriting_style(): new_font = "NewFontFamily" doc_gen = DocumentGenerator(template_path=CUSTOM_TEMPLATE_PATH) @@ -111,4 +127,4 @@ def test_overwriting_style(): assert doc.styles["font_family"] != new_font doc.update_style(font_family=new_font) result = doc.render_html() - assert new_font == result \ No newline at end of file + assert new_font == result diff --git a/tests/e2e/test_generaton_n_degradation.py b/tests/e2e/test_generaton_n_degradation.py index 80090f3..3b16793 100644 --- a/tests/e2e/test_generaton_n_degradation.py +++ b/tests/e2e/test_generaton_n_degradation.py @@ -1,23 +1,25 @@ from genalog.generation.document import DocumentGenerator from genalog.generation.content import CompositeContent, ContentType from genalog.degradation.degrader import Degrader -from genalog.degradation import effect -import numpy as np -import cv2 TEST_OUTPUT_DIR = "test_out/" -SAMPLE_TXT = "Everton 's Duncan Ferguson , who scored twice against Manchester United on Wednesday , was picked on Thursday for the Scottish squad after a 20-month exile ." +SAMPLE_TXT = """Everton 's Duncan Ferguson , who scored twice against Manchester United on Wednesday , + was picked on Thursday for the Scottish squad after a 20-month exile .""" DEFAULT_TEMPLATE = "text_block.html.jinja" DEGRADATION_EFFECTS = [ ("blur", {"radius": 5}), ("bleed_through", {"alpha": 0.8}), - ("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "plus"}), + ( + "morphology", + {"operation": "open", "kernel_shape": (3, 3), "kernel_type": "plus"}, + ), ("morphology", {"operation": "close"}), ("morphology", {"operation": "dilate"}), - ("morphology", {"operation": "erode"}) + ("morphology", {"operation": "erode"}), ] + def test_generation_and_degradation(): # Initiate content content = CompositeContent([SAMPLE_TXT], [ContentType.PARAGRAPH]) @@ -32,6 +34,4 @@ def test_generation_and_degradation(): # get the image in bytes in RGBA channels src = doc.render_array(resolution=100, channel="GRAYSCALE") # run each degradation effect - dst = degrader.apply_effects(src) - - + degrader.apply_effects(src) diff --git a/tests/e2e/test_image_channel.py b/tests/e2e/test_image_channel.py index 58a2177..dfbd1fe 100644 --- a/tests/e2e/test_image_channel.py +++ b/tests/e2e/test_image_channel.py @@ -1,41 +1,44 @@ +import pytest +import cv2 + from genalog.generation.document import DocumentGenerator from genalog.generation.content import CompositeContent, ContentType -import pytest -import numpy as np -import cv2 - TEMPLATE_PATH = "tests/e2e/templates" TEST_OUT_FOLDER = "test_out/" SAMPLE_TXT = "foo" CONTENT = CompositeContent([SAMPLE_TXT], [ContentType.PARAGRAPH]) + @pytest.fixture def doc_generator(): return DocumentGenerator(template_path=TEMPLATE_PATH) + def test_red_channel(doc_generator): generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"]) for doc in generator: - doc.update_style(background_color="red") + doc.update_style(background_color="red") img_array = doc.render_array(resolution=100, channel="BGRA") # css "red" is rgb(255,0,0) or bgra(0,0,255,255) - assert tuple(img_array[0][0]) == (0, 0, 255, 255) + assert tuple(img_array[0][0]) == (0, 0, 255, 255) cv2.imwrite(TEST_OUT_FOLDER + "red.png", img_array) + def test_green_channel(doc_generator): generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"]) for doc in generator: - doc.update_style(background_color="green") + doc.update_style(background_color="green") img_array = doc.render_array(resolution=100, channel="BGRA") # css "green" is rgb(0,128,0) or bgra(0,128,0,255) assert tuple(img_array[0][0]) == (0, 128, 0, 255) cv2.imwrite(TEST_OUT_FOLDER + "green.png", img_array) + def test_blue_channel(doc_generator): generator = doc_generator.create_generator(CONTENT, ["solid_bg.html.jinja"]) for doc in generator: - doc.update_style(background_color="blue") + doc.update_style(background_color="blue") img_array = doc.render_array(resolution=100, channel="BGRA") # css "blue" is rgb(0,0,255) or bgra(255,0,0,255) assert tuple(img_array[0][0]) == (255, 0, 0, 255) diff --git a/tests/e2e/test_ocr_e2e.py b/tests/e2e/test_ocr_e2e.py index 9ab2061..f0929c6 100644 --- a/tests/e2e/test_ocr_e2e.py +++ b/tests/e2e/test_ocr_e2e.py @@ -1,48 +1,58 @@ -from genalog.ocr.rest_client import GrokRestClient from genalog.ocr.blob_client import GrokBlobClient from genalog.ocr.grok import Grok -import requests import pytest -import time import json -import os from dotenv import load_dotenv + load_dotenv("tests/ocr/.env") + class TestBlobClient: - @pytest.mark.parametrize("use_async",[True, False]) + @pytest.mark.parametrize("use_async", [True, False]) def test_upload_images(self, use_async): blob_client = GrokBlobClient.create_from_env_var() subfolder = "tests/ocr/data/img" - file_prefix = subfolder.replace("/", "_") - dst_folder, _ = blob_client.upload_images_to_blob(subfolder, use_async=use_async) + subfolder.replace("/", "_") + dst_folder, _ = blob_client.upload_images_to_blob( + subfolder, use_async=use_async + ) uploaded_items, _ = blob_client.list_blobs(dst_folder) - uploaded_items = sorted(list(uploaded_items), key = lambda x : x.name) + uploaded_items = sorted(list(uploaded_items), key=lambda x: x.name) assert uploaded_items[0].name == f"{dst_folder}/0.png" assert uploaded_items[1].name == f"{dst_folder}/1.png" assert uploaded_items[2].name == f"{dst_folder}/11.png" blob_client.delete_blobs_folder(dst_folder) - assert len(list(blob_client.list_blobs(dst_folder)[0])) == 0, f"folder {dst_folder} was not deleted" + assert ( + len(list(blob_client.list_blobs(dst_folder)[0])) == 0 + ), f"folder {dst_folder} was not deleted" - dst_folder, _ = blob_client.upload_images_to_blob(subfolder, "test_images", use_async=use_async) - assert dst_folder == "test_images" + dst_folder, _ = blob_client.upload_images_to_blob( + subfolder, "test_images", use_async=use_async + ) + assert dst_folder == "test_images" uploaded_items, _ = blob_client.list_blobs(dst_folder) - uploaded_items = sorted(list(uploaded_items), key = lambda x : x.name) + uploaded_items = sorted(list(uploaded_items), key=lambda x: x.name) assert uploaded_items[0].name == f"{dst_folder}/0.png" assert uploaded_items[1].name == f"{dst_folder}/1.png" assert uploaded_items[2].name == f"{dst_folder}/11.png" blob_client.delete_blobs_folder(dst_folder) - assert len(list(blob_client.list_blobs(dst_folder)[0])) == 0, f"folder {dst_folder} was not deleted" + assert ( + len(list(blob_client.list_blobs(dst_folder)[0])) == 0 + ), f"folder {dst_folder} was not deleted" + class TestGROKe2e: - @pytest.mark.parametrize("use_async",[False,True]) + @pytest.mark.parametrize("use_async", [False, True]) def test_grok_e2e(self, tmpdir, use_async): grok = Grok.create_from_env_var() src_folder = "tests/ocr/data/img" - grok.run_grok(src_folder, tmpdir, blob_dest_folder="testimages", use_async=use_async, cleanup=True) - json_folder = "tests/ocr/data/json" - json_hash = "521c38122f783673598856cd81d91c21" - assert json.load(open(f"{tmpdir}/0.json", "r"))[0]["text"] + grok.run_grok( + src_folder, + tmpdir, + blob_dest_folder="testimages", + use_async=use_async, + cleanup=True, + ) + assert json.load(open(f"{tmpdir}/0.json", "r"))[0]["text"] assert json.load(open(f"{tmpdir}/1.json", "r"))[0]["text"] assert json.load(open(f"{tmpdir}/11.json", "r"))[0]["text"] - diff --git a/tests/e2e/test_pipeline.py b/tests/e2e/test_pipeline.py index 079e1db..b9e8270 100644 --- a/tests/e2e/test_pipeline.py +++ b/tests/e2e/test_pipeline.py @@ -5,28 +5,38 @@ import glob EXAMPLE_TEXT_FILE = "tests/text/data/gt_1.txt" + @pytest.fixture def default_analog_generator(): return pipeline.AnalogDocumentGeneration() + @pytest.fixture def custom_analog_generator(): custom_styles = {"font_size": ["5px"]} custom_degradation = [("blur", {"radius": 3})] return pipeline.AnalogDocumentGeneration( - styles=custom_styles, - degradations=custom_degradation, - resolution=300) + styles=custom_styles, degradations=custom_degradation, resolution=300 + ) + def test_default_generate_img(default_analog_generator): example_template = default_analog_generator.list_templates()[0] - img_array = default_analog_generator.generate_img(EXAMPLE_TEXT_FILE, example_template, target_folder=None) + default_analog_generator.generate_img( + EXAMPLE_TEXT_FILE, example_template, target_folder=None + ) + def test_custom_generate_img(custom_analog_generator): example_template = custom_analog_generator.list_templates()[0] - img_array = custom_analog_generator.generate_img(EXAMPLE_TEXT_FILE, example_template, target_folder=None) + custom_analog_generator.generate_img( + EXAMPLE_TEXT_FILE, example_template, target_folder=None + ) + def test_generate_dataset_multiprocess(): INPUT_TEXT_FILENAMES = glob.glob("tests/text/data/gt_*.txt") with pytest.deprecated_call(): - pipeline.generate_dataset_multiprocess(INPUT_TEXT_FILENAMES, "test_out", {}, [], "text_block.html.jinja") \ No newline at end of file + pipeline.generate_dataset_multiprocess( + INPUT_TEXT_FILENAMES, "test_out", {}, [], "text_block.html.jinja" + ) diff --git a/tests/e2e/test_splitter.py b/tests/e2e/test_splitter.py index 4ab42c4..317b803 100644 --- a/tests/e2e/test_splitter.py +++ b/tests/e2e/test_splitter.py @@ -3,6 +3,7 @@ import difflib from genalog.text.splitter import generate_splits, CONLL2003_DOC_SEPERATOR + def _compare_content(file1, file2): txt1 = open(file1, "r").read() txt2 = open(file2, "r").read() @@ -12,15 +13,32 @@ def _compare_content(file1, file2): str_diff = "\n".join(difflib.unified_diff(sentences_txt1, sentences_txt2)) assert False, f"Delta between outputs: \n {str_diff}" + def test_splitter(tmpdir): # tmpdir = "test_out" os.makedirs(f"{tmpdir}/clean_labels") os.makedirs(f"{tmpdir}/clean_text") - - generate_splits("tests/e2e/data/splitter/example_conll2012.txt", tmpdir, - doc_seperator=CONLL2003_DOC_SEPERATOR, sentence_seperator="") - _compare_content("tests/e2e/data/splitter/example_splits/clean_text/0.txt", f"{tmpdir}/clean_text/0.txt") - _compare_content("tests/e2e/data/splitter/example_splits/clean_text/1.txt", f"{tmpdir}/clean_text/1.txt") - _compare_content("tests/e2e/data/splitter/example_splits/clean_labels/0.txt", f"{tmpdir}/clean_labels/0.txt") - _compare_content("tests/e2e/data/splitter/example_splits/clean_labels/1.txt", f"{tmpdir}/clean_labels/1.txt") + generate_splits( + "tests/e2e/data/splitter/example_conll2012.txt", + tmpdir, + doc_seperator=CONLL2003_DOC_SEPERATOR, + sentence_seperator="", + ) + + _compare_content( + "tests/e2e/data/splitter/example_splits/clean_text/0.txt", + f"{tmpdir}/clean_text/0.txt", + ) + _compare_content( + "tests/e2e/data/splitter/example_splits/clean_text/1.txt", + f"{tmpdir}/clean_text/1.txt", + ) + _compare_content( + "tests/e2e/data/splitter/example_splits/clean_labels/0.txt", + f"{tmpdir}/clean_labels/0.txt", + ) + _compare_content( + "tests/e2e/data/splitter/example_splits/clean_labels/1.txt", + f"{tmpdir}/clean_labels/1.txt", + ) diff --git a/tests/generation/test_content.py b/tests/generation/test_content.py index eadf8ae..768128e 100644 --- a/tests/generation/test_content.py +++ b/tests/generation/test_content.py @@ -1,4 +1,5 @@ -from genalog.generation.content import * +from genalog.generation.content import ContentType, Content, CompositeContent +from genalog.generation.content import Paragraph, Title import pytest @@ -6,57 +7,70 @@ CONTENT_LIST = ["foo", "bar"] COMPOSITE_CONTENT_TYPE = [ContentType.TITLE, ContentType.PARAGRAPH] TEXT = "foo bar" + @pytest.fixture def content_base_class(): return Content() + @pytest.fixture def paragraph(): return Paragraph(TEXT) + @pytest.fixture def title(): return Title(TEXT) + @pytest.fixture def section(): return CompositeContent(CONTENT_LIST, COMPOSITE_CONTENT_TYPE) - + + def test_content_set_content_type(content_base_class): with pytest.raises(TypeError): content_base_class.set_content_type("NOT VALID CONTENT TYPE") content_base_class.set_content_type(ContentType.PARAGRAPH) + def test_paragraph_init(paragraph): with pytest.raises(TypeError): Paragraph([]) assert paragraph.content_type == ContentType.PARAGRAPH + def test_paragraph_print(paragraph): assert paragraph.__str__() + def test_paragraph_iterable_indexable(paragraph): for index, character in enumerate(paragraph): assert character == paragraph[index] + def test_title_init(title): with pytest.raises(TypeError): Title([]) assert title.content_type == ContentType.TITLE + def test_title_iterable_indexable(title): for index, character in enumerate(title): assert character == title[index] + def test_composite_content_init(section): with pytest.raises(TypeError): - CompositeContent((),[]) + CompositeContent((), []) assert section.content_type == ContentType.COMPOSITE + def test_composite_content_iterable(section): for index, content in enumerate(section): assert content.content_type == COMPOSITE_CONTENT_TYPE[index] - + + def test_composite_content_print(section): assert "foo" in section.__str__() assert "bar" in section.__str__() diff --git a/tests/generation/test_document.py b/tests/generation/test_document.py index 19499fa..0e6e601 100644 --- a/tests/generation/test_document.py +++ b/tests/generation/test_document.py @@ -21,84 +21,106 @@ DEFAULT_TEMPLATE_NAME = "text_block.html.jinja" DEFAULT_PACKAGE_NAME = "genalog.generation" DEFAULT_TEMPLATE_FOLDER = "templates" + @pytest.fixture def default_document(): mock_jinja_template = MagicMock() mock_jinja_template.render.return_value = MOCK_COMPILED_DOCUMENT return Document(CONTENT, mock_jinja_template) + @pytest.fixture def french_document(): mock_jinja_template = MagicMock() mock_jinja_template.render.return_value = MOCK_COMPILED_DOCUMENT return Document(CONTENT, mock_jinja_template, language=FRENCH) + def test_document_init(default_document): assert default_document.styles == DEFAULT_DOCUMENT_STYLE - assert default_document._document != None - assert default_document.compiled_html != None - + assert default_document._document is not None + assert default_document.compiled_html is not None + + def test_document_init_with_kwargs(french_document): assert french_document.styles["language"] == FRENCH - assert french_document._document != None - assert french_document.compiled_html != None + assert french_document._document is not None + assert french_document.compiled_html is not None + def test_document_render_html(french_document): compiled_document = french_document.render_html() assert compiled_document == MOCK_COMPILED_DOCUMENT - french_document.template.render.assert_called_with(content=CONTENT, **french_document.styles) + french_document.template.render.assert_called_with( + content=CONTENT, **french_document.styles + ) + def test_document_render_pdf(default_document): default_document._document = MagicMock() # run tested function default_document.render_pdf(target=FILE_DESTINATION_PDF, zoom=2) - default_document._document.write_pdf.assert_called_with(target=FILE_DESTINATION_PDF, zoom=2) + default_document._document.write_pdf.assert_called_with( + target=FILE_DESTINATION_PDF, zoom=2 + ) + def test_document_render_png(default_document): default_document._document = MagicMock() # run tested function default_document.render_png(target=FILE_DESTINATION_PNG, resolution=100) - default_document._document.write_png.assert_called_with(target=FILE_DESTINATION_PNG, resolution=100) + default_document._document.write_png.assert_called_with( + target=FILE_DESTINATION_PNG, resolution=100 + ) + def test_document_render_png_split_pages(default_document): default_document._document.copy = MagicMock() # run tested function - default_document.render_png(target=FILE_DESTINATION_PNG, split_pages=True, resolution=100) + default_document.render_png( + target=FILE_DESTINATION_PNG, split_pages=True, resolution=100 + ) result_destination = FILE_DESTINATION_PNG.replace(".png", "_pg_0.png") # assertion document_copy = default_document._document.copy.return_value - document_copy.write_png.assert_called_with(target=result_destination, resolution=100) + document_copy.write_png.assert_called_with( + target=result_destination, resolution=100 + ) + def test_document_render_array_valid_args(default_document): # setup mock mock_surface = MagicMock() - mock_surface.get_format.return_value = 0 # 0 == cairocffi.FORMAT_ARGB32 + mock_surface.get_format.return_value = 0 # 0 == cairocffi.FORMAT_ARGB32 mock_surface.get_data = MagicMock(return_value=IMG_BYTES) # loading a 2x2 image mock_write_image_surface = MagicMock(return_value=(mock_surface, 2, 2)) default_document._document.write_image_surface = mock_write_image_surface channel_types = ["RGBA", "RGB", "GRAYSCALE", "BGRA", "BGR"] - expected_img_shape = [(2,2,4), (2,2,3), (2,2), (2,2,4), (2,2,3)] - + expected_img_shape = [(2, 2, 4), (2, 2, 3), (2, 2), (2, 2, 4), (2, 2, 3)] + for channel_type, expected_img_shape in zip(channel_types, expected_img_shape): img_array = default_document.render_array(resolution=100, channel=channel_type) assert img_array.shape == expected_img_shape + def test_document_render_array_invalid_args(default_document): invalid_channel_types = "INVALID" with pytest.raises(ValueError): default_document.render_array(resolution=100, channel=invalid_channel_types) + def test_document_render_array_invalid_format(default_document): # setup mock mock_surface = MagicMock() - mock_surface.get_format.return_value = 1 # 1 != cairocffi.FORMAT_ARGB32 + mock_surface.get_format.return_value = 1 # 1 != cairocffi.FORMAT_ARGB32 mock_write_image_surface = MagicMock(return_value=(mock_surface, 2, 2)) default_document._document.write_image_surface = mock_write_image_surface - + with pytest.raises(RuntimeError): default_document.render_array(resolution=100) + def test_document_update_style(default_document): new_style = {"language": FRENCH, "new_property": "some value"} # Ensure that a new property is not already defined @@ -111,11 +133,13 @@ def test_document_update_style(default_document): # Ensure that a new property is added assert default_document.styles["new_property"] == new_style["new_property"] + @patch("genalog.generation.document.Environment") @patch("genalog.generation.document.PackageLoader") @patch("genalog.generation.document.FileSystemLoader") -def test_document_generator_init_default_setting(mock_file_system_loader, - mock_package_loader, mock_environment): +def test_document_generator_init_default_setting( + mock_file_system_loader, mock_package_loader, mock_environment +): # setup mock template environment mock_environment_instance = mock_environment.return_value mock_environment_instance.list_templates.return_value = [DEFAULT_TEMPLATE_NAME] @@ -123,15 +147,19 @@ def test_document_generator_init_default_setting(mock_file_system_loader, document_generator = DocumentGenerator() # Ensure the right loader is called mock_file_system_loader.assert_not_called() - mock_package_loader.assert_called_with(DEFAULT_PACKAGE_NAME, DEFAULT_TEMPLATE_FOLDER) + mock_package_loader.assert_called_with( + DEFAULT_PACKAGE_NAME, DEFAULT_TEMPLATE_FOLDER + ) # Ensure that the default template in the package is loaded assert DEFAULT_TEMPLATE_NAME in document_generator.template_list + @patch("genalog.generation.document.Environment") @patch("genalog.generation.document.PackageLoader") @patch("genalog.generation.document.FileSystemLoader") -def test_document_generator_init_custom_template(mock_file_system_loader, - mock_package_loader, mock_environment): +def test_document_generator_init_custom_template( + mock_file_system_loader, mock_package_loader, mock_environment +): # setup mock template environment mock_environment_instance = mock_environment.return_value mock_environment_instance.list_templates.return_value = [CUSTOM_TEMPLATE_NAME] @@ -143,67 +171,79 @@ def test_document_generator_init_custom_template(mock_file_system_loader, # Ensure that the expected template is registered assert CUSTOM_TEMPLATE_NAME in document_generator.template_list + @pytest.fixture def default_document_generator(): with patch("genalog.generation.document.Environment") as MockEnvironment: template_environment_instance = MockEnvironment.return_value - template_environment_instance.list_templates.return_value = [DEFAULT_TEMPLATE_NAME] + template_environment_instance.list_templates.return_value = [ + DEFAULT_TEMPLATE_NAME + ] template_environment_instance.get_template.return_value = MOCK_TEMPLATE doc_gen = DocumentGenerator() - return doc_gen + return doc_gen + def test_document_generator_create_generator(default_document_generator): available_templates = default_document_generator.template_list assert len(available_templates) < 2 - generator = default_document_generator.create_generator(CONTENT, available_templates) - doc = next(generator) + generator = default_document_generator.create_generator( + CONTENT, available_templates + ) + next(generator) with pytest.raises(StopIteration): next(generator) + def test_document_generator_create_generator_(default_document_generator): # setup test case available_templates = default_document_generator.template_list undefined_template = "NOT A VALID TEMPLATE" assert undefined_template not in available_templates - generator = default_document_generator.create_generator(CONTENT, [undefined_template]) + generator = default_document_generator.create_generator( + CONTENT, [undefined_template] + ) with pytest.raises(FileNotFoundError): - doc = next(generator) + next(generator) -@pytest.mark.parametrize("template_name, expected_output", [ - ("base.html.jinja", False), - ("text_block.html.jinja", True), - ("text_block.css.jinja", False), - ("macro/dimension.css.jinja", False) -]) + +@pytest.mark.parametrize( + "template_name, expected_output", + [ + ("base.html.jinja", False), + ("text_block.html.jinja", True), + ("text_block.css.jinja", False), + ("macro/dimension.css.jinja", False), + ], +) def test__keep_templates(template_name, expected_output): output = DocumentGenerator._keep_template(template_name) assert output == expected_output + def test_set_styles_to_generate(default_document_generator): assert len(default_document_generator.styles_to_generate) == 1 default_document_generator.set_styles_to_generate({"foo": ["bar", "bar"]}) assert len(default_document_generator.styles_to_generate) == 2 -@pytest.mark.parametrize("styles, expected_output", [ - ({}, []), # empty case - ({"size": ["10px"], "color": [] }, []), #empty value will result in null combinations - ( - {"size": ["10px"], "color": ["red"] }, - [{"size":"10px", "color":"red"}] - ), - ( - {"size": ["5px", "10px"]}, - [{"size": "5px"}, {"size": "10px"}] + +@pytest.mark.parametrize( + "styles, expected_output", + [ + ({}, []), # empty case + ( + {"size": ["10px"], "color": []}, + [], + ), # empty value will result in null combinations + ({"size": ["10px"], "color": ["red"]}, [{"size": "10px", "color": "red"}]), + ({"size": ["5px", "10px"]}, [{"size": "5px"}, {"size": "10px"}]), + ( + {"size": ["10px", "15px"], "color": ["blue"]}, + [{"size": "10px", "color": "blue"}, {"size": "15px", "color": "blue"}], ), - ( - {"size": ["10px", "15px"], "color": ["blue"] }, - [ - {"size":"10px", "color": "blue"}, - {"size":"15px", "color": "blue"} - ] - ) -]) + ], +) def test_document_generator_expand_style_combinations(styles, expected_output): output = DocumentGenerator.expand_style_combinations(styles) - assert output == expected_output \ No newline at end of file + assert output == expected_output diff --git a/tests/ocr/test_metrics.py b/tests/ocr/test_metrics.py index 118498c..a96405d 100644 --- a/tests/ocr/test_metrics.py +++ b/tests/ocr/test_metrics.py @@ -1,212 +1,614 @@ -from genalog.ocr.metrics import get_align_stats, get_editops_stats, get_metrics, get_stats -from genalog.text.anchor import align_w_anchor +from genalog.ocr.metrics import ( + get_align_stats, + get_editops_stats, + get_stats, +) from genalog.text.alignment import GAP_CHAR, align from genalog.text.ner_label import _find_gap_char_candidates -from pandas._testing import assert_frame_equal import pytest -import genalog.ocr.metrics -import pandas as pd -import numpy as np -import json -import pickle -import os +import genalog.ocr.metrics genalog.ocr.metrics.LOG_LEVEL = 0 -@pytest.mark.parametrize("src_string, target, expected_stats", -[ - ("a worn coat", "a wom coat", - {'edit_insert': 1, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}), - (" ", "a", - {'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}), - ("a", " ", - {'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}), - ("a", "a", - {'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 0, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}), - ("ab", "ac", - {'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}), - ("ac", "ab", - {'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}), - ("New York is big.", "N ewYork kis big.", - {'edit_insert': 0, 'edit_delete': 1, 'edit_replace': 0, 'edit_insert_spacing': 1, 'edit_delete_spacing': 1}), - ("B oston grea t", "Boston is great", - {'edit_insert': 0, 'edit_delete': 2, 'edit_replace': 0, 'edit_insert_spacing': 2, 'edit_delete_spacing': 1}), - ("New York is big.", "N ewyork kis big", - {'edit_insert': 1, 'edit_delete': 1, 'edit_replace': 1, 'edit_insert_spacing': 1, 'edit_delete_spacing': 1}), - ("dog", "d@g", # Test against default gap_char "@" - {'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 1, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}), - ("some@one.com", "some@one.com", - {'edit_insert': 0, 'edit_delete': 0, 'edit_replace': 0, 'edit_insert_spacing': 0, 'edit_delete_spacing': 0}) -]) -def test_editops_stats(src_string, target, expected_stats): - gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target]) - gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop() - alignment = align(target, src_string) - stats, actions = get_editops_stats(alignment, gap_char) - for k in expected_stats: - assert stats[k] == expected_stats[k], (k,stats[k], expected_stats[k]) -@pytest.mark.parametrize("src_string, target, expected_stats, expected_substitutions", -[ - ( - "a worn coat", "a wom coat", - {'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 11, 'total_words': 3, - 'matching_chars': 9, 'matching_words': 2,"matching_alnum_words" : 2, 'word_accuracy': 2/3, 'char_accuracy': 9/11}, - {('rn', 'm'): 1} - ), - ( - "a c", "def", - {'insert': 0, 'delete': 0, 'replace': 1 , 'spacing': 0, 'total_chars': 3, 'total_words': 1, - 'matching_chars': 0, 'matching_words': 0,"matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 0}, - {('a c', 'def'): 1} - ), - ( - "a", "a b", - {'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 2, - 'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 0.5, 'char_accuracy': 1/3}, - {} - ), - ( - "a b", "b", - {'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 1, - 'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1, 'char_accuracy': 1/3}, - {} - ), - ( - "a b", "a", - {'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 1, 'total_chars': 3, 'total_words': 1, - 'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1, 'char_accuracy': 1/3}, - {} - ), - ( - "b ..", "a b ..", - {'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 1, 'total_chars': 6, 'total_words': 3, 'total_alnum_words': 2, - 'matching_chars': 4, 'matching_words': 2, "matching_alnum_words" : 1, 'word_accuracy': 2/3, 'char_accuracy': 4/6}, - {} - ), - ( - "taxi cab", "taxl c b", - {'insert': 0, 'delete': 1, 'replace': 1, 'spacing': 1, 'total_chars': 9, 'total_words': 3, - 'matching_chars': 6, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 6/9}, - {('i','l'):1} - ), - ( - "taxl c b ri de", "taxi cab ride", - {'insert': 1, 'delete': 0, 'replace': 1, 'spacing': 6, 'total_chars': 18, 'total_words': 3, - 'matching_chars': 11, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 11/18}, - {('l','i'):1} - ), - ( - "ab", "ac", - {'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 2, 'total_words': 1, - 'matching_chars': 1, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0.0, 'char_accuracy': 0.5}, - {} - ), - ( - "a", "a", - {'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 1, 'total_words': 1, - 'matching_chars': 1, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1.0, 'char_accuracy': 1.0}, - {} - ), - ( - "New York is big.", "N ewYork kis big.", - {'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 2, 'total_chars': 17, 'total_words': 4, - 'matching_chars': 15, 'matching_words': 1, "matching_alnum_words" : 1, 'word_accuracy': 1/4, 'char_accuracy': 15/17}, - {} - ), - ( - "B oston grea t", "Boston is great", - {'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 3, 'total_chars': 15, 'total_words': 3, - 'matching_chars': 12, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0.0, 'char_accuracy': 0.8}, - {} - ), - ( - "New York is big.", "N ewyork kis big", - {'insert': 1, 'delete': 1, 'replace': 1, 'spacing': 2, 'total_chars': 16, 'total_words': 4, - 'matching_chars': 13, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 13/16}, - {('Y', 'y'): 1} - ), - ( - "dog", "d@g", - {'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3}, - {('o', '@'): 1} - ), - ( - "some@one.com", "some@one.com", - {'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 12, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 12, 'matching_alnum_words': 1, 'matching_words': 1, 'alnum_word_accuracy': 1.0, 'word_accuracy': 1.0, 'char_accuracy': 1.0}, {} +@pytest.mark.parametrize( + "src_string, target, expected_stats", + [ + ( + "a worn coat", + "a wom coat", + { + "edit_insert": 1, + "edit_delete": 0, + "edit_replace": 1, + "edit_insert_spacing": 0, + "edit_delete_spacing": 0, + }, + ), + ( + " ", + "a", + { + "edit_insert": 0, + "edit_delete": 0, + "edit_replace": 1, + "edit_insert_spacing": 0, + "edit_delete_spacing": 0, + }, + ), + ( + "a", + " ", + { + "edit_insert": 0, + "edit_delete": 0, + "edit_replace": 1, + "edit_insert_spacing": 0, + "edit_delete_spacing": 0, + }, + ), + ( + "a", + "a", + { + "edit_insert": 0, + "edit_delete": 0, + "edit_replace": 0, + "edit_insert_spacing": 0, + "edit_delete_spacing": 0, + }, + ), + ( + "ab", + "ac", + { + "edit_insert": 0, + "edit_delete": 0, + "edit_replace": 1, + "edit_insert_spacing": 0, + "edit_delete_spacing": 0, + }, + ), + ( + "ac", + "ab", + { + "edit_insert": 0, + "edit_delete": 0, + "edit_replace": 1, + "edit_insert_spacing": 0, + "edit_delete_spacing": 0, + }, + ), + ( + "New York is big.", + "N ewYork kis big.", + { + "edit_insert": 0, + "edit_delete": 1, + "edit_replace": 0, + "edit_insert_spacing": 1, + "edit_delete_spacing": 1, + }, + ), + ( + "B oston grea t", + "Boston is great", + { + "edit_insert": 0, + "edit_delete": 2, + "edit_replace": 0, + "edit_insert_spacing": 2, + "edit_delete_spacing": 1, + }, + ), + ( + "New York is big.", + "N ewyork kis big", + { + "edit_insert": 1, + "edit_delete": 1, + "edit_replace": 1, + "edit_insert_spacing": 1, + "edit_delete_spacing": 1, + }, + ), + ( + "dog", + "d@g", # Test against default gap_char "@" + { + "edit_insert": 0, + "edit_delete": 0, + "edit_replace": 1, + "edit_insert_spacing": 0, + "edit_delete_spacing": 0, + }, + ), + ( + "some@one.com", + "some@one.com", + { + "edit_insert": 0, + "edit_delete": 0, + "edit_replace": 0, + "edit_insert_spacing": 0, + "edit_delete_spacing": 0, + }, + ), + ], +) +def test_editops_stats(src_string, target, expected_stats): + gap_char_candidates, input_char_set = _find_gap_char_candidates( + [src_string], [target] ) -]) + gap_char = ( + GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop() + ) + alignment = align(target, src_string) + stats, actions = get_editops_stats(alignment, gap_char) + for k in expected_stats: + assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k]) + + +@pytest.mark.parametrize( + "src_string, target, expected_stats, expected_substitutions", + [ + ( + "a worn coat", + "a wom coat", + { + "insert": 0, + "delete": 0, + "replace": 1, + "spacing": 0, + "total_chars": 11, + "total_words": 3, + "matching_chars": 9, + "matching_words": 2, + "matching_alnum_words": 2, + "word_accuracy": 2 / 3, + "char_accuracy": 9 / 11, + }, + {("rn", "m"): 1}, + ), + ( + "a c", + "def", + { + "insert": 0, + "delete": 0, + "replace": 1, + "spacing": 0, + "total_chars": 3, + "total_words": 1, + "matching_chars": 0, + "matching_words": 0, + "matching_alnum_words": 0, + "word_accuracy": 0, + "char_accuracy": 0, + }, + {("a c", "def"): 1}, + ), + ( + "a", + "a b", + { + "insert": 1, + "delete": 0, + "replace": 0, + "spacing": 1, + "total_chars": 3, + "total_words": 2, + "matching_chars": 1, + "matching_words": 1, + "matching_alnum_words": 1, + "word_accuracy": 0.5, + "char_accuracy": 1 / 3, + }, + {}, + ), + ( + "a b", + "b", + { + "insert": 0, + "delete": 1, + "replace": 0, + "spacing": 1, + "total_chars": 3, + "total_words": 1, + "matching_chars": 1, + "matching_words": 1, + "matching_alnum_words": 1, + "word_accuracy": 1, + "char_accuracy": 1 / 3, + }, + {}, + ), + ( + "a b", + "a", + { + "insert": 0, + "delete": 1, + "replace": 0, + "spacing": 1, + "total_chars": 3, + "total_words": 1, + "matching_chars": 1, + "matching_words": 1, + "matching_alnum_words": 1, + "word_accuracy": 1, + "char_accuracy": 1 / 3, + }, + {}, + ), + ( + "b ..", + "a b ..", + { + "insert": 1, + "delete": 0, + "replace": 0, + "spacing": 1, + "total_chars": 6, + "total_words": 3, + "total_alnum_words": 2, + "matching_chars": 4, + "matching_words": 2, + "matching_alnum_words": 1, + "word_accuracy": 2 / 3, + "char_accuracy": 4 / 6, + }, + {}, + ), + ( + "taxi cab", + "taxl c b", + { + "insert": 0, + "delete": 1, + "replace": 1, + "spacing": 1, + "total_chars": 9, + "total_words": 3, + "matching_chars": 6, + "matching_words": 0, + "matching_alnum_words": 0, + "word_accuracy": 0, + "char_accuracy": 6 / 9, + }, + {("i", "l"): 1}, + ), + ( + "taxl c b ri de", + "taxi cab ride", + { + "insert": 1, + "delete": 0, + "replace": 1, + "spacing": 6, + "total_chars": 18, + "total_words": 3, + "matching_chars": 11, + "matching_words": 0, + "matching_alnum_words": 0, + "word_accuracy": 0, + "char_accuracy": 11 / 18, + }, + {("l", "i"): 1}, + ), + ( + "ab", + "ac", + { + "insert": 0, + "delete": 0, + "replace": 1, + "spacing": 0, + "total_chars": 2, + "total_words": 1, + "matching_chars": 1, + "matching_words": 0, + "matching_alnum_words": 0, + "word_accuracy": 0.0, + "char_accuracy": 0.5, + }, + {}, + ), + ( + "a", + "a", + { + "insert": 0, + "delete": 0, + "replace": 0, + "spacing": 0, + "total_chars": 1, + "total_words": 1, + "matching_chars": 1, + "matching_words": 1, + "matching_alnum_words": 1, + "word_accuracy": 1.0, + "char_accuracy": 1.0, + }, + {}, + ), + ( + "New York is big.", + "N ewYork kis big.", + { + "insert": 1, + "delete": 0, + "replace": 0, + "spacing": 2, + "total_chars": 17, + "total_words": 4, + "matching_chars": 15, + "matching_words": 1, + "matching_alnum_words": 1, + "word_accuracy": 1 / 4, + "char_accuracy": 15 / 17, + }, + {}, + ), + ( + "B oston grea t", + "Boston is great", + { + "insert": 1, + "delete": 0, + "replace": 0, + "spacing": 3, + "total_chars": 15, + "total_words": 3, + "matching_chars": 12, + "matching_words": 0, + "matching_alnum_words": 0, + "word_accuracy": 0.0, + "char_accuracy": 0.8, + }, + {}, + ), + ( + "New York is big.", + "N ewyork kis big", + { + "insert": 1, + "delete": 1, + "replace": 1, + "spacing": 2, + "total_chars": 16, + "total_words": 4, + "matching_chars": 13, + "matching_words": 0, + "matching_alnum_words": 0, + "word_accuracy": 0, + "char_accuracy": 13 / 16, + }, + {("Y", "y"): 1}, + ), + ( + "dog", + "d@g", + { + "insert": 0, + "delete": 0, + "replace": 1, + "spacing": 0, + "total_chars": 3, + "total_words": 1, + "total_alnum_words": 1, + "matching_chars": 2, + "matching_alnum_words": 0, + "matching_words": 0, + "alnum_word_accuracy": 0.0, + "word_accuracy": 0.0, + "char_accuracy": 2 / 3, + }, + {("o", "@"): 1}, + ), + ( + "some@one.com", + "some@one.com", + { + "insert": 0, + "delete": 0, + "replace": 0, + "spacing": 0, + "total_chars": 12, + "total_words": 1, + "total_alnum_words": 1, + "matching_chars": 12, + "matching_alnum_words": 1, + "matching_words": 1, + "alnum_word_accuracy": 1.0, + "word_accuracy": 1.0, + "char_accuracy": 1.0, + }, + {}, + ), + ], +) def test_align_stats(src_string, target, expected_stats, expected_substitutions): - gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target]) - gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop() + gap_char_candidates, input_char_set = _find_gap_char_candidates( + [src_string], [target] + ) + gap_char = ( + GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop() + ) alignment = align(src_string, target, gap_char=gap_char) stats, substitution_dict = get_align_stats(alignment, src_string, target, gap_char) for k in expected_stats: assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k]) for k in expected_substitutions: - assert substitution_dict[k] == expected_substitutions[k], (substitution_dict, expected_substitutions) + assert substitution_dict[k] == expected_substitutions[k], ( + substitution_dict, + expected_substitutions, + ) -@pytest.mark.parametrize("src_string, target, expected_stats, expected_substitutions, expected_actions", [ - ( - "ab", "a", - {'insert': 0, 'delete': 1, 'replace': 0, 'spacing': 0, 'total_chars': 2, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 1, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 1/2}, - {}, - {1: 'D'} - ), - ( - "ab", "abb", - {'insert': 1, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3}, - {}, - {2: ('I', 'b')} - ), - ( - "ab", "ac", - {'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 2, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 1, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 1/2}, - {('b', 'c'): 1}, - {1: ('R', 'c')} - ), - ( - "New York is big.", "N ewyork kis big", - {'insert': 1, 'delete': 1, 'replace': 1, 'spacing': 2, 'total_chars': 16, 'total_words': 4, - 'matching_chars': 13, 'matching_words': 0, "matching_alnum_words" : 0, 'word_accuracy': 0, 'char_accuracy': 13/16}, - {('Y', 'y'): 1}, - {1: ('I', ' '), 4: 'D', 5: ('R', 'y'), 10: ('I', 'k'), 17: 'D'} - ), - ( - "dog", "d@g", - {'insert': 0, 'delete': 0, 'replace': 1, 'spacing': 0, 'total_chars': 3, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 2, 'matching_alnum_words': 0, 'matching_words': 0, 'alnum_word_accuracy': 0.0, 'word_accuracy': 0.0, 'char_accuracy': 2/3}, - {('o', '@'): 1}, - {1: ('R', '@')} - ), - ( - "some@one.com", "some@one.com", - {'insert': 0, 'delete': 0, 'replace': 0, 'spacing': 0, 'total_chars': 12, 'total_words': 1, 'total_alnum_words': 1, 'matching_chars': 12, 'matching_alnum_words': 1, 'matching_words': 1, 'alnum_word_accuracy': 1.0, 'word_accuracy': 1.0, 'char_accuracy': 1.0}, - {}, - {} - ) -]) -def test_get_stats(src_string, target, expected_stats, expected_substitutions, expected_actions ): +@pytest.mark.parametrize( + "src_string, target, expected_stats, expected_substitutions, expected_actions", + [ + ( + "ab", + "a", + { + "insert": 0, + "delete": 1, + "replace": 0, + "spacing": 0, + "total_chars": 2, + "total_words": 1, + "total_alnum_words": 1, + "matching_chars": 1, + "matching_alnum_words": 0, + "matching_words": 0, + "alnum_word_accuracy": 0.0, + "word_accuracy": 0.0, + "char_accuracy": 1 / 2, + }, + {}, + {1: "D"}, + ), + ( + "ab", + "abb", + { + "insert": 1, + "delete": 0, + "replace": 0, + "spacing": 0, + "total_chars": 3, + "total_words": 1, + "total_alnum_words": 1, + "matching_chars": 2, + "matching_alnum_words": 0, + "matching_words": 0, + "alnum_word_accuracy": 0.0, + "word_accuracy": 0.0, + "char_accuracy": 2 / 3, + }, + {}, + {2: ("I", "b")}, + ), + ( + "ab", + "ac", + { + "insert": 0, + "delete": 0, + "replace": 1, + "spacing": 0, + "total_chars": 2, + "total_words": 1, + "total_alnum_words": 1, + "matching_chars": 1, + "matching_alnum_words": 0, + "matching_words": 0, + "alnum_word_accuracy": 0.0, + "word_accuracy": 0.0, + "char_accuracy": 1 / 2, + }, + {("b", "c"): 1}, + {1: ("R", "c")}, + ), + ( + "New York is big.", + "N ewyork kis big", + { + "insert": 1, + "delete": 1, + "replace": 1, + "spacing": 2, + "total_chars": 16, + "total_words": 4, + "matching_chars": 13, + "matching_words": 0, + "matching_alnum_words": 0, + "word_accuracy": 0, + "char_accuracy": 13 / 16, + }, + {("Y", "y"): 1}, + {1: ("I", " "), 4: "D", 5: ("R", "y"), 10: ("I", "k"), 17: "D"}, + ), + ( + "dog", + "d@g", + { + "insert": 0, + "delete": 0, + "replace": 1, + "spacing": 0, + "total_chars": 3, + "total_words": 1, + "total_alnum_words": 1, + "matching_chars": 2, + "matching_alnum_words": 0, + "matching_words": 0, + "alnum_word_accuracy": 0.0, + "word_accuracy": 0.0, + "char_accuracy": 2 / 3, + }, + {("o", "@"): 1}, + {1: ("R", "@")}, + ), + ( + "some@one.com", + "some@one.com", + { + "insert": 0, + "delete": 0, + "replace": 0, + "spacing": 0, + "total_chars": 12, + "total_words": 1, + "total_alnum_words": 1, + "matching_chars": 12, + "matching_alnum_words": 1, + "matching_words": 1, + "alnum_word_accuracy": 1.0, + "word_accuracy": 1.0, + "char_accuracy": 1.0, + }, + {}, + {}, + ), + ], +) +def test_get_stats( + src_string, target, expected_stats, expected_substitutions, expected_actions +): stats, substitution_dict, actions = get_stats(target, src_string) for k in expected_stats: assert stats[k] == expected_stats[k], (k, stats[k], expected_stats[k]) for k in expected_substitutions: - assert substitution_dict[k] == expected_substitutions[k], (substitution_dict, expected_substitutions) + assert substitution_dict[k] == expected_substitutions[k], ( + substitution_dict, + expected_substitutions, + ) for k in expected_actions: assert actions[k] == expected_actions[k], (k, actions[k], expected_actions[k]) -@pytest.mark.parametrize("src_string, target, expected_actions", -[ - ("dog and cat", "g and at", - {0: ('I', 'd'), 1: ('I', 'o'), 8:('I', 'c')}), -]) + +@pytest.mark.parametrize( + "src_string, target, expected_actions", + [ + ("dog and cat", "g and at", {0: ("I", "d"), 1: ("I", "o"), 8: ("I", "c")}), + ], +) def test_actions_stats(src_string, target, expected_actions): - gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target]) - gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop() + gap_char_candidates, input_char_set = _find_gap_char_candidates( + [src_string], [target] + ) + gap_char = ( + GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop() + ) alignment = align(target, src_string, gap_char=gap_char) - _ , actions = get_editops_stats(alignment, gap_char) + _, actions = get_editops_stats(alignment, gap_char) print(actions) for k in expected_actions: - assert actions[k] == expected_actions[k], (k,actions[k], expected_actions[k]) + assert actions[k] == expected_actions[k], (k, actions[k], expected_actions[k]) diff --git a/tests/ocr/test_ocr.py b/tests/ocr/test_ocr.py index bc1c6c8..d6385c7 100644 --- a/tests/ocr/test_ocr.py +++ b/tests/ocr/test_ocr.py @@ -1,13 +1,15 @@ -from genalog.ocr.rest_client import GrokRestClient -from genalog.ocr.blob_client import GrokBlobClient -from genalog.ocr.grok import Grok import requests -import pytest -import time import json + +import pytest from dotenv import load_dotenv + +from genalog.ocr.rest_client import GrokRestClient + + load_dotenv("tests/ocr/.env") + @pytest.fixture(autouse=True) def setup_monkeypatch(monkeypatch): def mock_http(*args, **kwargs): @@ -21,7 +23,6 @@ def setup_monkeypatch(monkeypatch): class MockedResponse: - def __init__(self, args, kwargs): self.url = args[0] self.text = "response" @@ -34,27 +35,39 @@ class MockedResponse: if "search.windows.net/indexers/" in self.url: if "status" in self.url: - return { - "lastResult": {"status": "success"}, - "status": "finished" - } + return {"lastResult": {"status": "success"}, "status": "finished"} return {} if "search.windows.net/indexes/" in self.url: if "docs/search" in self.url: return { - "value" : [ - { + "value": [ + { "metadata_storage_name": "521c38122f783673598856cd81d91c21_0.png", - "layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_0.png.json", "r")) + "layoutText": json.load( + open( + "tests/ocr/data/json/521c38122f783673598856cd81d91c21_0.png.json", + "r", + ) + ), }, - { - "metadata_storage_name":"521c38122f783673598856cd81d91c21_1.png", - "layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_1.png.json", "r")) + { + "metadata_storage_name": "521c38122f783673598856cd81d91c21_1.png", + "layoutText": json.load( + open( + "tests/ocr/data/json/521c38122f783673598856cd81d91c21_1.png.json", + "r", + ) + ), }, - { - "metadata_storage_name":"521c38122f783673598856cd81d91c21_11.png", - "layoutText" : json.load(open("tests/ocr/data/json/521c38122f783673598856cd81d91c21_11.png.json", "r")) + { + "metadata_storage_name": "521c38122f783673598856cd81d91c21_11.png", + "layoutText": json.load( + open( + "tests/ocr/data/json/521c38122f783673598856cd81d91c21_11.png.json", + "r", + ) + ), }, ] } @@ -69,7 +82,6 @@ class MockedResponse: class TestGROK: - def test_creating_indexing_pipeline(self): grok_rest_client = GrokRestClient.create_from_env_var() grok_rest_client.create_indexing_pipeline() @@ -78,11 +90,11 @@ class TestGROK: def test_running_indexer(self): grok_rest_client = GrokRestClient.create_from_env_var() grok_rest_client.create_indexing_pipeline() - + indexer_status = grok_rest_client.get_indexer_status() if indexer_status["status"] == "error": raise RuntimeError(f"indexer error: {indexer_status}") - + # if not already running start the indexer if indexer_status["lastResult"]["status"] != "inProgress": grok_rest_client.run_indexer() @@ -90,5 +102,4 @@ class TestGROK: grok_rest_client.run_indexer() indexer_status = grok_rest_client.poll_indexer_till_complete() assert indexer_status["lastResult"]["status"] == "success" - grok_rest_client.delete_indexer_pipeline() - + grok_rest_client.delete_indexer_pipeline() diff --git a/tests/text/test_alignment.py b/tests/text/test_alignment.py index c7a1ea9..65eef01 100644 --- a/tests/text/test_alignment.py +++ b/tests/text/test_alignment.py @@ -1,6 +1,9 @@ from genalog.text import alignment -from genalog.text.alignment import MATCH_REWARD, MISMATCH_PENALTY, GAP_PENALTY, GAP_EXT_PENALTY -from tests.cases.text_alignment import PARSE_ALIGNMENT_REGRESSION_TEST_CASES, ALIGNMENT_REGRESSION_TEST_CASES + +from tests.cases.text_alignment import ( + PARSE_ALIGNMENT_REGRESSION_TEST_CASES, + ALIGNMENT_REGRESSION_TEST_CASES, +) from random import randint from unittest.mock import MagicMock @@ -10,34 +13,49 @@ import warnings RANDOM_INT = randint(1, 100) MOCK_ALIGNMENT_RESULT = [("X", "X", 0, 0, 1)] + # Settup mock for third party library call @pytest.fixture def mock_pairwise2_align(monkeypatch): mock = MagicMock() + def mock_globalcs(*args, **kwargs): mock.globalcs(*args, **kwargs) return MOCK_ALIGNMENT_RESULT + # replace target method reference with the mock method monkeypatch.setattr("Bio.pairwise2.align.globalcs", mock_globalcs) return mock + def test__align_seg(mock_pairwise2_align): # setup method input required_arg = ("A", "B") - optional_arg = (alignment.MATCH_REWARD, alignment.MISMATCH_PENALTY, alignment.GAP_PENALTY, alignment.GAP_EXT_PENALTY) - optional_kwarg = {"gap_char": alignment.GAP_CHAR, "one_alignment_only": alignment.ONE_ALIGNMENT_ONLY} + optional_arg = ( + alignment.MATCH_REWARD, + alignment.MISMATCH_PENALTY, + alignment.GAP_PENALTY, + alignment.GAP_EXT_PENALTY, + ) + optional_kwarg = { + "gap_char": alignment.GAP_CHAR, + "one_alignment_only": alignment.ONE_ALIGNMENT_ONLY, + } # test method result = alignment._align_seg(*required_arg + optional_arg, **optional_kwarg) # assertion mock_pairwise2_align.globalcs.assert_called() assert result == MOCK_ALIGNMENT_RESULT -@pytest.mark.parametrize("alignments, target_num_tokens, raised_exception", -[ - (MOCK_ALIGNMENT_RESULT, 1, None), - (MOCK_ALIGNMENT_RESULT, 2, ValueError), - ([("X", "XY", 0, 0, 1)], 1, ValueError) -]) + +@pytest.mark.parametrize( + "alignments, target_num_tokens, raised_exception", + [ + (MOCK_ALIGNMENT_RESULT, 1, None), + (MOCK_ALIGNMENT_RESULT, 2, ValueError), + ([("X", "XY", 0, 0, 1)], 1, ValueError), + ], +) def test__select_alignment_candidates(alignments, target_num_tokens, raised_exception): if raised_exception: with pytest.raises(raised_exception): @@ -46,24 +64,27 @@ def test__select_alignment_candidates(alignments, target_num_tokens, raised_exce result = alignment._select_alignment_candidates(alignments, target_num_tokens) assert result == MOCK_ALIGNMENT_RESULT[0] -@pytest.mark.parametrize("s, index, desired_output, raised_exception", -[ - # Test exceptions - ("s", 2, None, IndexError), - ("", -1, None, ValueError), # Empty case - # Index at start of string - (" token", 0, 2, None), - ("\t\ntoken", 0, 2, None), - # Index reach end of string - ("token ", 5, 5, None), - ("token", 4, 4, None), - # Index in-between tokens - ("token", 0, 0, None), - ("t1 t2", 2, 7, None), - ("t1 \t \n t2", 3, 7, None), - # Gap char - (" @", 0, 1, None), -]) + +@pytest.mark.parametrize( + "s, index, desired_output, raised_exception", + [ + # Test exceptions + ("s", 2, None, IndexError), + ("", -1, None, ValueError), # Empty case + # Index at start of string + (" token", 0, 2, None), + ("\t\ntoken", 0, 2, None), + # Index reach end of string + ("token ", 5, 5, None), + ("token", 4, 4, None), + # Index in-between tokens + ("token", 0, 0, None), + ("t1 t2", 2, 7, None), + ("t1 \t \n t2", 3, 7, None), + # Gap char + (" @", 0, 1, None), + ], +) def test__find_token_start(s, index, desired_output, raised_exception): if raised_exception: with pytest.raises(raised_exception): @@ -72,25 +93,28 @@ def test__find_token_start(s, index, desired_output, raised_exception): output = alignment._find_token_start(s, index) assert output == desired_output -@pytest.mark.parametrize("s, index, desired_output, raised_exception", -[ - # Test exceptions - ("s", 2, None, IndexError), - ("", -1, None, ValueError), # Empty case - # Index at start of string - (" ", 0, 0, None), - ("\t\ntoken", 0, 0, None), - ("token", 0, 4, None), - ("token\t", 0, 5, None), - ("token\n", 0, 5, None), - # Index reach end of string - ("token ", 5, 5, None), - ("token", 4, 4, None), - # Single Char - (".", 0, 0, None), - # Gap char - ("@@ @", 0, 2, None), -]) + +@pytest.mark.parametrize( + "s, index, desired_output, raised_exception", + [ + # Test exceptions + ("s", 2, None, IndexError), + ("", -1, None, ValueError), # Empty case + # Index at start of string + (" ", 0, 0, None), + ("\t\ntoken", 0, 0, None), + ("token", 0, 4, None), + ("token\t", 0, 5, None), + ("token\n", 0, 5, None), + # Index reach end of string + ("token ", 5, 5, None), + ("token", 4, 4, None), + # Single Char + (".", 0, 0, None), + # Gap char + ("@@ @", 0, 2, None), + ], +) def test__find_token_end(s, index, desired_output, raised_exception): if raised_exception: with pytest.raises(raised_exception): @@ -99,61 +123,81 @@ def test__find_token_end(s, index, desired_output, raised_exception): output = alignment._find_token_end(s, index) assert output == desired_output -@pytest.mark.parametrize("s, start, desired_output", -[ - ("token", 0, (0,4)), - ("token\t", 0, (0,5)), - ("token \n", 0, (0,5)), - (" token ", 0, (1,6)), - # mix with GAP_CHAR - (" @@@@ ", 0, (1,5)), - ("\n\t tok@n@@ \n\t", 0, (3,10)), - # single character string - ("s", 0, (0,0)), - # punctuation - (" !,.: ", 0, (2,6)) -]) + +@pytest.mark.parametrize( + "s, start, desired_output", + [ + ("token", 0, (0, 4)), + ("token\t", 0, (0, 5)), + ("token \n", 0, (0, 5)), + (" token ", 0, (1, 6)), + # mix with GAP_CHAR + (" @@@@ ", 0, (1, 5)), + ("\n\t tok@n@@ \n\t", 0, (3, 10)), + # single character string + ("s", 0, (0, 0)), + # punctuation + (" !,.: ", 0, (2, 6)), + ], +) def test__find_next_token(s, start, desired_output): output = alignment._find_next_token(s, start) assert output == desired_output -@pytest.mark.parametrize("token, desired_output", -[ - # Valid tokens - ("\n\t token.!,:\n\t ", True), - ("token", True), - (" @@@t@@@ ", True), - ("@@token@@", True), - (" @@token@@ ", True), - (f"t1{alignment.GAP_CHAR*RANDOM_INT}t2", True), #i.e. 't1@t2' - # Invalid tokens (i.e. multiples of the GAP_CHAR) - ("", False), - (" ", False), - ("@@", False), - (" @@ ", False), - ("\t\n@", False), - (alignment.GAP_CHAR*1, False), - (alignment.GAP_CHAR*RANDOM_INT, False), - (f"\n\t {alignment.GAP_CHAR*RANDOM_INT} \n\t", False) -]) + +@pytest.mark.parametrize( + "token, desired_output", + [ + # Valid tokens + ("\n\t token.!,:\n\t ", True), + ("token", True), + (" @@@t@@@ ", True), + ("@@token@@", True), + (" @@token@@ ", True), + (f"t1{alignment.GAP_CHAR*RANDOM_INT}t2", True), # i.e. 't1@t2' + # Invalid tokens (i.e. multiples of the GAP_CHAR) + ("", False), + (" ", False), + ("@@", False), + (" @@ ", False), + ("\t\n@", False), + (alignment.GAP_CHAR * 1, False), + (alignment.GAP_CHAR * RANDOM_INT, False), + (f"\n\t {alignment.GAP_CHAR*RANDOM_INT} \n\t", False), + ], +) def test__is_valid_token(token, desired_output): result = alignment._is_valid_token(token) assert result == desired_output -@pytest.mark.parametrize("aligned_gt, aligned_noise," + - "expected_gt_to_noise_map, expected_noise_to_gt_map", - PARSE_ALIGNMENT_REGRESSION_TEST_CASES) -def test_parse_alignment(aligned_gt, aligned_noise, expected_gt_to_noise_map, expected_noise_to_gt_map): - gt_to_noise_map, noise_to_gt_map = alignment.parse_alignment(aligned_gt, aligned_noise) + +@pytest.mark.parametrize( + "aligned_gt, aligned_noise," + "expected_gt_to_noise_map, expected_noise_to_gt_map", + PARSE_ALIGNMENT_REGRESSION_TEST_CASES, +) +def test_parse_alignment( + aligned_gt, aligned_noise, expected_gt_to_noise_map, expected_noise_to_gt_map +): + gt_to_noise_map, noise_to_gt_map = alignment.parse_alignment( + aligned_gt, aligned_noise + ) assert gt_to_noise_map == expected_gt_to_noise_map assert noise_to_gt_map == expected_noise_to_gt_map -@pytest.mark.parametrize("gt_txt, noisy_txt," + - "expected_aligned_gt, expected_aligned_noise", - ALIGNMENT_REGRESSION_TEST_CASES) + +@pytest.mark.parametrize( + "gt_txt, noisy_txt," + "expected_aligned_gt, expected_aligned_noise", + ALIGNMENT_REGRESSION_TEST_CASES, +) def test_align(gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise): aligned_gt, aligned_noise = alignment.align(gt_txt, noisy_txt) if aligned_gt != expected_aligned_gt: - expected_alignment = alignment._format_alignment(expected_aligned_gt, expected_aligned_noise) + expected_alignment = alignment._format_alignment( + expected_aligned_gt, expected_aligned_noise + ) result_alignment = alignment._format_alignment(aligned_gt, aligned_noise) - warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}")) \ No newline at end of file + warnings.warn( + RuntimeWarning( + f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}" + ) + ) diff --git a/tests/text/test_anchor.py b/tests/text/test_anchor.py index 561f6a5..1c63595 100644 --- a/tests/text/test_anchor.py +++ b/tests/text/test_anchor.py @@ -5,39 +5,51 @@ import glob import pytest import warnings -@pytest.mark.parametrize("tokens, case_sensitive, desired_output", [ - ([], True, set()), - ([], False, set()), - (["a", "A"], True, set(["a", "A"])), - (["a", "A"], False, set()), - (["An", "an", "ab"], True, set(["An", "an", "ab"])), - (["An", "an", "ab"], False, set(["ab"])), -]) + +@pytest.mark.parametrize( + "tokens, case_sensitive, desired_output", + [ + ([], True, set()), + ([], False, set()), + (["a", "A"], True, set(["a", "A"])), + (["a", "A"], False, set()), + (["An", "an", "ab"], True, set(["An", "an", "ab"])), + (["An", "an", "ab"], False, set(["ab"])), + ], +) def test_get_unique_words(tokens, case_sensitive, desired_output): output = anchor.get_unique_words(tokens, case_sensitive=case_sensitive) assert desired_output == output -@pytest.mark.parametrize("tokens, desired_output", [ - ([], 0), - ([""], 0), - (["a", "b"], 2), - (["abc.", "def!"], 8) -]) + +@pytest.mark.parametrize( + "tokens, desired_output", + [([], 0), ([""], 0), (["a", "b"], 2), (["abc.", "def!"], 8)], +) def test_segment_len(tokens, desired_output): output = anchor.segment_len(tokens) assert desired_output == output -@pytest.mark.parametrize("unique_words, src_tokens, desired_output, raised_exception", [ - (set(), [], [], None), - (set(), ["a"], [], None), - (set("a"), [], [], ValueError), # unique word not in src_tokens - (set("a"), ["b"], [], ValueError), - (set("a"), ["A"], [], ValueError), # case sensitive - (set("a"), ["an", "na", " a "], [], ValueError), # substring - (set("a"), ["a"], [("a", 0)], None), # valid input - (set("a"), ["c", "b", "a"], [("a", 2)], None), # multiple src_tokens - (set("ab"), ["c", "b", "a"], [("b", 1), ("a", 2)], None), # multiple matches ordered by index -]) + +@pytest.mark.parametrize( + "unique_words, src_tokens, desired_output, raised_exception", + [ + (set(), [], [], None), + (set(), ["a"], [], None), + (set("a"), [], [], ValueError), # unique word not in src_tokens + (set("a"), ["b"], [], ValueError), + (set("a"), ["A"], [], ValueError), # case sensitive + (set("a"), ["an", "na", " a "], [], ValueError), # substring + (set("a"), ["a"], [("a", 0)], None), # valid input + (set("a"), ["c", "b", "a"], [("a", 2)], None), # multiple src_tokens + ( + set("ab"), + ["c", "b", "a"], + [("b", 1), ("a", 2)], + None, + ), # multiple matches ordered by index + ], +) def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception): if raised_exception: with pytest.raises(raised_exception): @@ -46,85 +58,150 @@ def test_get_word_map(unique_words, src_tokens, desired_output, raised_exception output = anchor.get_word_map(unique_words, src_tokens) assert desired_output == output -@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_output", [ - ([], [], ([], [])), # empty - ([""], [""], ([], [])), - (["a"], ["b"], ([], [])), # no common unique words - (["a", "a"], ["a"], ([], [])), # no unique words - (["a"], ["a", "a"], ([], [])), - (["a"], ["a"], ([("a", 0)], [("a", 0)])), # common unique word exist - (["a"], ["b", "a"], ([("a", 0)], [("a", 1)])), - (["a", "b", "c"], ["a", "b", "c"], # common unique words - ([("a", 0), ("b", 1), ("c", 2)], [("a", 0), ("b", 1), ("c", 2)])), - (["a", "b", "c"], ["c", "b", "a"], # common unique words but not in same order - ([("b", 1)], [("b", 1)])), - (["b", "a", "c"], ["c", "b", "a"], # LCS has multiple results - ([("b", 0), ("a", 1)], [("b", 1), ("a", 2)])), - (["c", "a", "b"], ["c", "b", "a"], - ([("c", 0), ("b", 2)], [("c", 0), ("b", 1)])), - (["c", "a", "b"], ["a", "c", "b"], # LCS has multiple results - ([("a", 1), ("b", 2)], [("a", 0), ("b", 2)])), -]) + +@pytest.mark.parametrize( + "gt_tokens, ocr_tokens, desired_output", + [ + ([], [], ([], [])), # empty + ([""], [""], ([], [])), + (["a"], ["b"], ([], [])), # no common unique words + (["a", "a"], ["a"], ([], [])), # no unique words + (["a"], ["a", "a"], ([], [])), + (["a"], ["a"], ([("a", 0)], [("a", 0)])), # common unique word exist + (["a"], ["b", "a"], ([("a", 0)], [("a", 1)])), + ( + ["a", "b", "c"], + ["a", "b", "c"], # common unique words + ([("a", 0), ("b", 1), ("c", 2)], [("a", 0), ("b", 1), ("c", 2)]), + ), + ( + ["a", "b", "c"], + ["c", "b", "a"], # common unique words but not in same order + ([("b", 1)], [("b", 1)]), + ), + ( + ["b", "a", "c"], + ["c", "b", "a"], # LCS has multiple results + ([("b", 0), ("a", 1)], [("b", 1), ("a", 2)]), + ), + ( + ["c", "a", "b"], + ["c", "b", "a"], + ([("c", 0), ("b", 2)], [("c", 0), ("b", 1)]), + ), + ( + ["c", "a", "b"], + ["a", "c", "b"], # LCS has multiple results + ([("a", 1), ("b", 2)], [("a", 0), ("b", 2)]), + ), + ], +) def test_get_anchor_map(gt_tokens, ocr_tokens, desired_output): desired_gt_map, desired_ocr_map = desired_output gt_map, ocr_map = anchor.get_anchor_map(gt_tokens, ocr_tokens) assert desired_gt_map == gt_map assert desired_ocr_map == ocr_map + # max_seg_length does not change the following output @pytest.mark.parametrize("max_seg_length", [0, 1, 2, 3, 5, 4, 6]) -@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_output", [ - ([], [], ([], [])), # empty - ([""], [""], ([], [])), - (["a"], ["b"], ([], [])), # no anchors - (["a", "a"], ["a"], ([], [])), - (["a"], ["a", "a"], ([], [])), - (["a"], ["a"], ([0], [0])), # anchors exist - ("a1 w w w".split(), "a1 w w w".split(), # no anchors in the subsequence [w w w] - ([0], [0])), - ("a1 w w w a2".split(), "a1 w w w a2".split(), - ([0, 4], [0, 4])), - ("a1 w w w2 a2".split(), "a1 w w w3 a2".split(), - ([0, 4], [0, 4])), - ("a1 a2 a3".split(), "a1 a2 a3".split(), # all words are anchors - ([0, 1, 2], [0, 1, 2])), - ("a1 a2 a3".split(), "A1 A2 A3".split(), # anchor words must be in the same casing - ([], [])), - ("a1 w w a2".split(), "a1 w W a2".split(), # unique words are case insensitive - ([0, 3], [0, 3])), - ("a1 w w a2".split(), "A1 w W A2".split(), # unique words are case insensitive, but anchor are case sensitive - ([], [])), -]) -def test_find_anchor_recur_various_seg_len(max_seg_length, gt_tokens, ocr_tokens, desired_output): +@pytest.mark.parametrize( + "gt_tokens, ocr_tokens, desired_output", + [ + ([], [], ([], [])), # empty + ([""], [""], ([], [])), + (["a"], ["b"], ([], [])), # no anchors + (["a", "a"], ["a"], ([], [])), + (["a"], ["a", "a"], ([], [])), + (["a"], ["a"], ([0], [0])), # anchors exist + ( + "a1 w w w".split(), + "a1 w w w".split(), # no anchors in the subsequence [w w w] + ([0], [0]), + ), + ("a1 w w w a2".split(), "a1 w w w a2".split(), ([0, 4], [0, 4])), + ("a1 w w w2 a2".split(), "a1 w w w3 a2".split(), ([0, 4], [0, 4])), + ( + "a1 a2 a3".split(), + "a1 a2 a3".split(), # all words are anchors + ([0, 1, 2], [0, 1, 2]), + ), + ( + "a1 a2 a3".split(), + "A1 A2 A3".split(), # anchor words must be in the same casing + ([], []), + ), + ( + "a1 w w a2".split(), + "a1 w W a2".split(), # unique words are case insensitive + ([0, 3], [0, 3]), + ), + ( + "a1 w w a2".split(), + "A1 w W A2".split(), # unique words are case insensitive, but anchor are case sensitive + ([], []), + ), + ], +) +def test_find_anchor_recur_various_seg_len( + max_seg_length, gt_tokens, ocr_tokens, desired_output +): desired_gt_anchors, desired_ocr_anchors = desired_output - gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length) + gt_anchors, ocr_anchors = anchor.find_anchor_recur( + gt_tokens, ocr_tokens, max_seg_length=max_seg_length + ) assert desired_gt_anchors == gt_anchors assert desired_ocr_anchors == ocr_anchors + # Test the recursion bahavior -@pytest.mark.parametrize("gt_tokens, ocr_tokens, max_seg_length, desired_output", [ - ("a1 w_ w_ a3".split(), "a1 w_ w_ a3".split(), 6, - ([0, 3], [0, 3])), - ("a1 w_ w_ a2 a3 a2".split(), "a1 w_ w_ a2 a3 a2".split(), 4, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3] - ([0, 3, 4], [0, 3, 4])), - ("a1 w_ w_ a2 a3 a2".split(), "a1 w_ w_ a2 a3 a2".split(), 2, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3] - ([0, 2, 3, 4, 5], [0, 2, 3, 4, 5])), - ("a1 w_ w_ a2 w_ w_ a3".split(), "a1 w_ a2 w_ a3".split(), 2, # missing ocr token - ([0, 3, 6], [0, 2, 4])), - ("a1 w_ w_ a2 w_ w_ a3".split(), "a1 w_ a2 W_ A3".split(), 2, # changing cases - ([0, 3], [0, 2])), -]) -def test_find_anchor_recur_fixed_seg_len(gt_tokens, ocr_tokens, max_seg_length, desired_output): +@pytest.mark.parametrize( + "gt_tokens, ocr_tokens, max_seg_length, desired_output", + [ + ("a1 w_ w_ a3".split(), "a1 w_ w_ a3".split(), 6, ([0, 3], [0, 3])), + ( + "a1 w_ w_ a2 a3 a2".split(), + "a1 w_ w_ a2 a3 a2".split(), + 4, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3] + ([0, 3, 4], [0, 3, 4]), + ), + ( + "a1 w_ w_ a2 a3 a2".split(), + "a1 w_ w_ a2 a3 a2".split(), + 2, # a2 is anchor word in subsequence [a1 w_ w_ a2 a3] + ([0, 2, 3, 4, 5], [0, 2, 3, 4, 5]), + ), + ( + "a1 w_ w_ a2 w_ w_ a3".split(), + "a1 w_ a2 w_ a3".split(), + 2, # missing ocr token + ([0, 3, 6], [0, 2, 4]), + ), + ( + "a1 w_ w_ a2 w_ w_ a3".split(), + "a1 w_ a2 W_ A3".split(), + 2, # changing cases + ([0, 3], [0, 2]), + ), + ], +) +def test_find_anchor_recur_fixed_seg_len( + gt_tokens, ocr_tokens, max_seg_length, desired_output +): desired_gt_anchors, desired_ocr_anchors = desired_output - gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length) + gt_anchors, ocr_anchors = anchor.find_anchor_recur( + gt_tokens, ocr_tokens, max_seg_length=max_seg_length + ) assert desired_gt_anchors == gt_anchors assert desired_ocr_anchors == ocr_anchors -@pytest.mark.parametrize("gt_file, ocr_file", + +@pytest.mark.parametrize( + "gt_file, ocr_file", zip( sorted(glob.glob("tests/text/data/gt_1.txt")), - sorted(glob.glob("tests/text/data/ocr_1.txt")) - ) + sorted(glob.glob("tests/text/data/ocr_1.txt")), + ), ) @pytest.mark.parametrize("max_seg_length", [75]) def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length): @@ -132,15 +209,27 @@ def test_find_anchor_recur_e2e(gt_file, ocr_file, max_seg_length): ocr_text = open(ocr_file, "r").read() gt_tokens = preprocess.tokenize(gt_text) ocr_tokens = preprocess.tokenize(ocr_text) - gt_anchors, ocr_anchors = anchor.find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length) + gt_anchors, ocr_anchors = anchor.find_anchor_recur( + gt_tokens, ocr_tokens, max_seg_length=max_seg_length + ) for gt_anchor, ocr_anchor in zip(gt_anchors, ocr_anchors): # Ensure that each anchor word is the same word in both text assert gt_tokens[gt_anchor] == ocr_tokens[ocr_anchor] -@pytest.mark.parametrize("gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", ALIGNMENT_REGRESSION_TEST_CASES) + +@pytest.mark.parametrize( + "gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", + ALIGNMENT_REGRESSION_TEST_CASES, +) def test_align_w_anchor(gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise): aligned_gt, aligned_noise = anchor.align_w_anchor(gt_txt, noisy_txt) if aligned_gt != expected_aligned_gt: - expected_alignment = alignment._format_alignment(expected_aligned_gt, expected_aligned_noise) + expected_alignment = alignment._format_alignment( + expected_aligned_gt, expected_aligned_noise + ) result_alignment = alignment._format_alignment(aligned_gt, aligned_noise) - warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}")) + warnings.warn( + RuntimeWarning( + f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}" + ) + ) diff --git a/tests/text/test_conll_format.py b/tests/text/test_conll_format.py index f496f63..462f5ec 100644 --- a/tests/text/test_conll_format.py +++ b/tests/text/test_conll_format.py @@ -5,153 +5,281 @@ import itertools import pytest import warnings -@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception", [ - (["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], None), - (["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], [], ValueError), # No alignment - (["w1", "w3"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], ValueError), # Unequal tokens - (["w1", "w2"], ["l1", "l2"], [["w1"], ["w3"]], ["w", "w"], ValueError), # Unequal tokens - (["w1", "w3"], ["l1", "l2"], [["w1"]],["w", "w"], ValueError), # Unequal length - (["w1"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], ValueError), # Unequal length -]) -def test_propagate_labels_sentences_error(clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception): + +@pytest.mark.parametrize( + "clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception", + [ + (["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], ["w", "w"], None), + (["w1", "w2"], ["l1", "l2"], [["w1"], ["w2"]], [], ValueError), # No alignment + ( + ["w1", "w3"], + ["l1", "l2"], + [["w1"], ["w2"]], + ["w", "w"], + ValueError, + ), # Unequal tokens + ( + ["w1", "w2"], + ["l1", "l2"], + [["w1"], ["w3"]], + ["w", "w"], + ValueError, + ), # Unequal tokens + ( + ["w1", "w3"], + ["l1", "l2"], + [["w1"]], + ["w", "w"], + ValueError, + ), # Unequal length + ( + ["w1"], + ["l1", "l2"], + [["w1"], ["w2"]], + ["w", "w"], + ValueError, + ), # Unequal length + ], +) +def test_propagate_labels_sentences_error( + clean_tokens, clean_labels, clean_sentences, ocr_tokens, raised_exception +): if raised_exception: with pytest.raises(raised_exception): - conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens) + conll_format.propagate_labels_sentences( + clean_tokens, clean_labels, clean_sentences, ocr_tokens + ) else: - conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens) + conll_format.propagate_labels_sentences( + clean_tokens, clean_labels, clean_sentences, ocr_tokens + ) -@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels", [ - ( - "a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(), - [["a1", "b1"], ["a2", "b2"]], # clean sentences - ["a1", "b1", "a2", "b2"], # ocr token - [["a1", "b1"], ["a2","b2"]], [["l1", "l2"], ["l3", "l4"]] # desired output - ), - ( - "a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(), - [["a1", "b1"], ["a2", "b2"]], # clean sentences - ["a1", "b1"], # Missing sentence 2 - # Ideally we would expect [["a1", "b1"], []] - # But the limitation of text alignment, which yield - # "a1 b1 a2 b2" - # "a1 b1@@@@@@" - # It is difficult to decide the location of "b1" - # when all tokens "b1" "a2" "b2" are aligned to "b1@@@@@@" - # NOTE: this is a improper behavior but the best - # solution to this corner case by preserving the number of OCR tokens. - [["a1"], ["b1"]], [["l1"], ["l2"]] - ), - ( - "a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(), - [["a1", "b1"], ["a2", "b2"]], - ["a", "a2", "b2"], # ocr token (missing b1 token at sentence boundary) - [["a"], ["a2", "b2"]], [["l1"], ["l3", "l4"]] - ), - ( - "a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(), - [["a1", "b1"], ["a2", "b2"]], - ["a1", "b1", "a2"], # ocr token (missing b2 token at sentence boundary) - [["a1", "b1"], ["a2"]], [["l1", "l2"], ["l3"]] - ), - ( - "a1 b1 a2 b2".split(), "l1 l2 l3 l4".split(), - [["a1", "b1"], ["a2", "b2"]], - ["b1", "a2", "b2"], # ocr token (missing a1 token at sentence start) - [["b1", "a2"], ["b2"]], [["l2", "l3"], ["l4"]] - ), - ( - "a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(), - [["a1"], ["b1", "c1", "a2"], ["b2"]], - ["a1", "b1", "a2", "b2"], # ocr token (missing c1 token at middle of sentence) - [["a1"], ["b1", "a2"], ["b2"]], [["l1"], ["l2", "l4"], ["l5"]] - ), + +@pytest.mark.parametrize( + "clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels", + [ ( - "a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(), - [["a1", "b1"], ["c1", "a2", "b2"]], - ["a1", "b1", "b2"], # ocr token (missing c1 a2 tokens) - [["a1"], ["b1", "b2"]], [["l1"], ["l2", "l5"]] - ), - ( - "a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(), - [["a1"], ["b1", "c1", "a2"], ["b2"]], - ["a1", "c1", "a2", "b2"], # ocr token (missing b1 token at sentence start) - [[], ["a1", "c1", "a2"], ["b2"]], [[], ["l1", "l3", "l4"], ["l5"]] - ), - ( - "a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(), - [["a1", "b1", "c1"], ["a2","b2"]], - ["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end) - [["a1"], [ "b1", "b2"]], [["l1"], ["l2", "l5"]] - ), - ( - "a1 b1 c1 a2 b2".split(), "l1 l2 l3 l4 l5".split(), - [["a1", "b1", "c1"], ["a2","b2"]], - ["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end) - [["a1"], [ "b1", "b2"]], [["l1"], ["l2", "l5"]] - ), -]) -def test_propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens, desired_sentences, desired_labels): - ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens) + "a1 b1 a2 b2".split(), + "l1 l2 l3 l4".split(), + [["a1", "b1"], ["a2", "b2"]], # clean sentences + ["a1", "b1", "a2", "b2"], # ocr token + [["a1", "b1"], ["a2", "b2"]], + [["l1", "l2"], ["l3", "l4"]], # desired output + ), + ( + "a1 b1 a2 b2".split(), + "l1 l2 l3 l4".split(), + [["a1", "b1"], ["a2", "b2"]], # clean sentences + ["a1", "b1"], # Missing sentence 2 + # Ideally we would expect [["a1", "b1"], []] + # But the limitation of text alignment, which yield + # "a1 b1 a2 b2" + # "a1 b1@@@@@@" + # It is difficult to decide the location of "b1" + # when all tokens "b1" "a2" "b2" are aligned to "b1@@@@@@" + # NOTE: this is a improper behavior but the best + # solution to this corner case by preserving the number of OCR tokens. + [["a1"], ["b1"]], + [["l1"], ["l2"]], + ), + ( + "a1 b1 a2 b2".split(), + "l1 l2 l3 l4".split(), + [["a1", "b1"], ["a2", "b2"]], + ["a", "a2", "b2"], # ocr token (missing b1 token at sentence boundary) + [["a"], ["a2", "b2"]], + [["l1"], ["l3", "l4"]], + ), + ( + "a1 b1 a2 b2".split(), + "l1 l2 l3 l4".split(), + [["a1", "b1"], ["a2", "b2"]], + ["a1", "b1", "a2"], # ocr token (missing b2 token at sentence boundary) + [["a1", "b1"], ["a2"]], + [["l1", "l2"], ["l3"]], + ), + ( + "a1 b1 a2 b2".split(), + "l1 l2 l3 l4".split(), + [["a1", "b1"], ["a2", "b2"]], + ["b1", "a2", "b2"], # ocr token (missing a1 token at sentence start) + [["b1", "a2"], ["b2"]], + [["l2", "l3"], ["l4"]], + ), + ( + "a1 b1 c1 a2 b2".split(), + "l1 l2 l3 l4 l5".split(), + [["a1"], ["b1", "c1", "a2"], ["b2"]], + [ + "a1", + "b1", + "a2", + "b2", + ], # ocr token (missing c1 token at middle of sentence) + [["a1"], ["b1", "a2"], ["b2"]], + [["l1"], ["l2", "l4"], ["l5"]], + ), + ( + "a1 b1 c1 a2 b2".split(), + "l1 l2 l3 l4 l5".split(), + [["a1", "b1"], ["c1", "a2", "b2"]], + ["a1", "b1", "b2"], # ocr token (missing c1 a2 tokens) + [["a1"], ["b1", "b2"]], + [["l1"], ["l2", "l5"]], + ), + ( + "a1 b1 c1 a2 b2".split(), + "l1 l2 l3 l4 l5".split(), + [["a1"], ["b1", "c1", "a2"], ["b2"]], + ["a1", "c1", "a2", "b2"], # ocr token (missing b1 token at sentence start) + [[], ["a1", "c1", "a2"], ["b2"]], + [[], ["l1", "l3", "l4"], ["l5"]], + ), + ( + "a1 b1 c1 a2 b2".split(), + "l1 l2 l3 l4 l5".split(), + [["a1", "b1", "c1"], ["a2", "b2"]], + ["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end) + [["a1"], ["b1", "b2"]], + [["l1"], ["l2", "l5"]], + ), + ( + "a1 b1 c1 a2 b2".split(), + "l1 l2 l3 l4 l5".split(), + [["a1", "b1", "c1"], ["a2", "b2"]], + ["a1", "b1", "b2"], # ocr token (missing c1 and a2 token at sentence end) + [["a1"], ["b1", "b2"]], + [["l1"], ["l2", "l5"]], + ), + ], +) +def test_propagate_labels_sentences( + clean_tokens, + clean_labels, + clean_sentences, + ocr_tokens, + desired_sentences, + desired_labels, +): + ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences( + clean_tokens, clean_labels, clean_sentences, ocr_tokens + ) ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences)) assert len(ocr_text_sentences) == len(clean_sentences) assert len(ocr_text_sentences) == len(ocr_labels_sentences) - assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens + assert len(ocr_sentences_flatten) == len( + ocr_tokens + ) # ensure aligned ocr tokens == ocr tokens if desired_sentences != ocr_text_sentences: - warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}")) + warnings.warn( + RuntimeWarning( + f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}" + ) + ) if desired_labels != ocr_labels_sentences: - warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}")) + warnings.warn( + RuntimeWarning( + f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}" + ) + ) -@pytest.mark.parametrize("clean_tokens, clean_labels, clean_sentences, ocr_tokens," + -"mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels", [ - ( - "a b c d".split(), "l1 l2 l3 l4".split(), - [["a", "b"], ["c", "d"]], - ["a", "b"], # Sentence is empty - [[0], [1], [], []], - [[0], [1]], - [["a", "b"], []], - [["l1", "l2"], []] - ), - ( - "a b c d".split(), "l1 l2 l3 l4".split(), - [["a", "b",], ["c", "d"]], - ["a", "b", "d"], # Missing sentence start - [[0], [1], [], [2]], - [[0], [1], [3]], - [["a", "b"], ["d"]], - [["l1", "l2"], ["l4"]] - ), - ( - "a b c d".split(), "l1 l2 l3 l4".split(), - [["a", "b",], ["c", "d"]], - ["a", "c", "d"], # Missing sentence end - [[0], [], [1], [2]], - [[0], [2], [3]], - [["a"], ["c", "d"]], - [["l1"], ["l3", "l4"]] - ), -]) -def test_propagate_labels_sentences_text_alignment_corner_cases(clean_tokens, clean_labels, clean_sentences, ocr_tokens, - mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels): + +@pytest.mark.parametrize( + "clean_tokens, clean_labels, clean_sentences, ocr_tokens," + + "mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping, desired_sentences, desired_labels", + [ + ( + "a b c d".split(), + "l1 l2 l3 l4".split(), + [["a", "b"], ["c", "d"]], + ["a", "b"], # Sentence is empty + [[0], [1], [], []], + [[0], [1]], + [["a", "b"], []], + [["l1", "l2"], []], + ), + ( + "a b c d".split(), + "l1 l2 l3 l4".split(), + [ + [ + "a", + "b", + ], + ["c", "d"], + ], + ["a", "b", "d"], # Missing sentence start + [[0], [1], [], [2]], + [[0], [1], [3]], + [["a", "b"], ["d"]], + [["l1", "l2"], ["l4"]], + ), + ( + "a b c d".split(), + "l1 l2 l3 l4".split(), + [ + [ + "a", + "b", + ], + ["c", "d"], + ], + ["a", "c", "d"], # Missing sentence end + [[0], [], [1], [2]], + [[0], [2], [3]], + [["a"], ["c", "d"]], + [["l1"], ["l3", "l4"]], + ), + ], +) +def test_propagate_labels_sentences_text_alignment_corner_cases( + clean_tokens, + clean_labels, + clean_sentences, + ocr_tokens, + mock_gt_to_ocr_mapping, + mock_ocr_to_gt_mapping, + desired_sentences, + desired_labels, +): with patch("genalog.text.alignment.parse_alignment") as mock_alignment: mock_alignment.return_value = (mock_gt_to_ocr_mapping, mock_ocr_to_gt_mapping) - ocr_text_sentences, ocr_labels_sentences = conll_format.propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens) + ( + ocr_text_sentences, + ocr_labels_sentences, + ) = conll_format.propagate_labels_sentences( + clean_tokens, clean_labels, clean_sentences, ocr_tokens + ) ocr_sentences_flatten = list(itertools.chain(*ocr_text_sentences)) assert len(ocr_text_sentences) == len(clean_sentences) assert len(ocr_text_sentences) == len(ocr_labels_sentences) - assert len(ocr_sentences_flatten) == len(ocr_tokens) # ensure aligned ocr tokens == ocr tokens + assert len(ocr_sentences_flatten) == len( + ocr_tokens + ) # ensure aligned ocr tokens == ocr tokens if desired_sentences != ocr_text_sentences: - warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}")) + warnings.warn( + RuntimeWarning( + f"\n\n****Expect propagation returns sentences:****\n{desired_sentences} \n****But got:****\n{ocr_text_sentences}" + ) + ) if desired_labels != ocr_labels_sentences: - warnings.warn(RuntimeWarning(f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}")) + warnings.warn( + RuntimeWarning( + f"\n\n****Expect propagation returns labels:****\n{desired_labels} \n****But got:****\n{ocr_labels_sentences}" + ) + ) -@pytest.mark.parametrize("s, desired_output", [ - ("", []), - ("\n\n", []), - ("a1\tb1\na2\tb2", [["a1", "a2"]]), - ("a1\tb1\n\na2\tb2", [["a1"], ["a2"]]), - ("\n\n\na1\tb1\n\na2\tb2\n\n\n", [["a1"], ["a2"]]), -]) + +@pytest.mark.parametrize( + "s, desired_output", + [ + ("", []), + ("\n\n", []), + ("a1\tb1\na2\tb2", [["a1", "a2"]]), + ("a1\tb1\n\na2\tb2", [["a1"], ["a2"]]), + ("\n\n\na1\tb1\n\na2\tb2\n\n\n", [["a1"], ["a2"]]), + ], +) def test_get_sentences_from_iob_format(s, desired_output): output = conll_format.get_sentences_from_iob_format(s.splitlines(True)) - assert desired_output == output \ No newline at end of file + assert desired_output == output diff --git a/tests/text/test_lcs.py b/tests/text/test_lcs.py index 5518451..1ce8c7b 100644 --- a/tests/text/test_lcs.py +++ b/tests/text/test_lcs.py @@ -2,36 +2,53 @@ from genalog.text.lcs import LCS import pytest -@pytest.fixture(params=[ - ("", ""), # empty - ("abcde", "ace"), # naive case -]) + +@pytest.fixture( + params=[ + ("", ""), # empty + ("abcde", "ace"), # naive case + ] +) def lcs(request): str1, str2 = request.param return LCS(str1, str2) + def test_lcs_init(lcs): assert lcs._lcs_len is not None assert lcs._lcs is not None -@pytest.mark.parametrize("str1, str2, expected_len, expected_lcs", [ - ("", "", 0, ""), # empty - ("abc", "abc", 3, "abc"), - ("abcde", "ace", 3, "ace"), # naive case - ("a", "", 0, ""), # no results - ("abc", "cba", 1, "c"), # multiple cases - ("abcdgh", "aedfhr", 3, "adh"), - ("abc.!\t\nd", "dxab", 2, "ab"), # with punctuations - ("New York @", "New @ York", len("New York"), "New York"), # with space-separated, tokens - ("Is A Big City", "A Big City Is", len("A Big City"), "A Big City"), - ("Is A Big City", "City Big Is A", len(" Big "), " Big "), # reversed order - # mixed order with similar tokens - ("Is A Big City IS", "IS Big A City Is", len("I Big City I"), "I Big City I"), - # casing - ("Is A Big City IS a", "IS a Big City Is A", len("I Big City I "), "I Big City I "), -]) + +@pytest.mark.parametrize( + "str1, str2, expected_len, expected_lcs", + [ + ("", "", 0, ""), # empty + ("abc", "abc", 3, "abc"), + ("abcde", "ace", 3, "ace"), # naive case + ("a", "", 0, ""), # no results + ("abc", "cba", 1, "c"), # multiple cases + ("abcdgh", "aedfhr", 3, "adh"), + ("abc.!\t\nd", "dxab", 2, "ab"), # with punctuations + ( + "New York @", + "New @ York", + len("New York"), + "New York", + ), # with space-separated, tokens + ("Is A Big City", "A Big City Is", len("A Big City"), "A Big City"), + ("Is A Big City", "City Big Is A", len(" Big "), " Big "), # reversed order + # mixed order with similar tokens + ("Is A Big City IS", "IS Big A City Is", len("I Big City I"), "I Big City I"), + # casing + ( + "Is A Big City IS a", + "IS a Big City Is A", + len("I Big City I "), + "I Big City I ", + ), + ], +) def test_lcs_e2e(str1, str2, expected_len, expected_lcs): lcs = LCS(str1, str2) assert expected_lcs == lcs.get_str() assert expected_len == lcs.get_len() - \ No newline at end of file diff --git a/tests/text/test_ner_label.py b/tests/text/test_ner_label.py index d3cd463..255c89b 100644 --- a/tests/text/test_ner_label.py +++ b/tests/text/test_ner_label.py @@ -1,171 +1,295 @@ from genalog.text import ner_label -from genalog.text import alignment from tests.cases.label_propagation import LABEL_PROPAGATION_REGRESSION_TEST_CASES import pytest -import string -@pytest.mark.parametrize("label, desired_output", [ - # Positive Cases - ("B-org", True), (" B-org ", True), #whitespae tolerant - ("\tB-ORG\n", True), - # Negative Cases - ("I-ORG", False), ("O", False), ("other-B-label", False), -]) + +@pytest.mark.parametrize( + "label, desired_output", + [ + # Positive Cases + ("B-org", True), + (" B-org ", True), # whitespae tolerant + ("\tB-ORG\n", True), + # Negative Cases + ("I-ORG", False), + ("O", False), + ("other-B-label", False), + ], +) def test__is_begin_label(label, desired_output): output = ner_label._is_begin_label(label) assert output == desired_output -@pytest.mark.parametrize("label, desired_output", [ - # Positive Cases - ("I-ORG", True), (" \t I-ORG ", True), - # Negative Cases - ("O", False), ("B-LOC", False),("B-ORG", False), -]) + +@pytest.mark.parametrize( + "label, desired_output", + [ + # Positive Cases + ("I-ORG", True), + (" \t I-ORG ", True), + # Negative Cases + ("O", False), + ("B-LOC", False), + ("B-ORG", False), + ], +) def test__is_inside_label(label, desired_output): output = ner_label._is_inside_label(label) assert output == desired_output -@pytest.mark.parametrize("label, desired_output", [ - # Positive Cases - ("I-ORG", True), ("B-ORG", True), - # Negative Cases - ("O", False) -]) + +@pytest.mark.parametrize( + "label, desired_output", + [ + # Positive Cases + ("I-ORG", True), + ("B-ORG", True), + # Negative Cases + ("O", False), + ], +) def test__is_multi_token_label(label, desired_output): output = ner_label._is_multi_token_label(label) assert output == desired_output -@pytest.mark.parametrize("label, desired_output", [ - # Positive Cases - ("I-Place", "B-Place"), (" \t I-place ", "B-place"), - # Negative Cases - ("O", "O"), ("B-LOC", "B-LOC"), (" B-ORG ", " B-ORG ") -]) + +@pytest.mark.parametrize( + "label, desired_output", + [ + # Positive Cases + ("I-Place", "B-Place"), + (" \t I-place ", "B-place"), + # Negative Cases + ("O", "O"), + ("B-LOC", "B-LOC"), + (" B-ORG ", " B-ORG "), + ], +) def test__convert_to_begin_label(label, desired_output): output = ner_label._convert_to_begin_label(label) assert output == desired_output -@pytest.mark.parametrize("label, desired_output", [ - # Positive Cases - ("B-LOC", "I-LOC"), - (" B-ORG ", "I-ORG"), - # Negative Cases - ("", ""), ("O", "O"), ("I-Place", "I-Place"), - (" \t I-place ", " \t I-place ") -]) + +@pytest.mark.parametrize( + "label, desired_output", + [ + # Positive Cases + ("B-LOC", "I-LOC"), + (" B-ORG ", "I-ORG"), + # Negative Cases + ("", ""), + ("O", "O"), + ("I-Place", "I-Place"), + (" \t I-place ", " \t I-place "), + ], +) def test__convert_to_inside_label(label, desired_output): output = ner_label._convert_to_inside_label(label) assert output == desired_output -@pytest.mark.parametrize("begin_label, inside_label, desired_output", [ - # Positive Cases - ("", "I-LOC", True), - ("B-LOC", "I-ORG", True), - ("", "I-ORG", True), - # Negative Cases - ("", "", False), ("O", "O", False), ("", "", False), - ("B-LOC", "O", False), - ("B-LOC", "B-ORG", False), - ("B-LOC", "I-LOC", False), - (" B-ORG ", "I-ORG", False), -]) +@pytest.mark.parametrize( + "begin_label, inside_label, desired_output", + [ + # Positive Cases + ("", "I-LOC", True), + ("B-LOC", "I-ORG", True), + ("", "I-ORG", True), + # Negative Cases + ("", "", False), + ("O", "O", False), + ("", "", False), + ("B-LOC", "O", False), + ("B-LOC", "B-ORG", False), + ("B-LOC", "I-LOC", False), + (" B-ORG ", "I-ORG", False), + ], +) def test__is_missing_begin_label(begin_label, inside_label, desired_output): output = ner_label._is_missing_begin_label(begin_label, inside_label) assert output == desired_output -@pytest.mark.parametrize("gt_tokens, ocr_tokens, desired_input_char_set", [ - (["a","b"], ["c", "d"], set("abcd")), - (["New", "York"], ["is", "big"], set("NewYorkisbig")), - (["word1", "word2"], ["word1", "word2"], set("word12")), -]) + +@pytest.mark.parametrize( + "gt_tokens, ocr_tokens, desired_input_char_set", + [ + (["a", "b"], ["c", "d"], set("abcd")), + (["New", "York"], ["is", "big"], set("NewYorkisbig")), + (["word1", "word2"], ["word1", "word2"], set("word12")), + ], +) def test__find_gap_char_candidates(gt_tokens, ocr_tokens, desired_input_char_set): - gap_char_candidates, input_char_set = ner_label._find_gap_char_candidates(gt_tokens, ocr_tokens) + gap_char_candidates, input_char_set = ner_label._find_gap_char_candidates( + gt_tokens, ocr_tokens + ) assert input_char_set == desired_input_char_set assert ner_label.GAP_CHAR_SET.difference(input_char_set) == gap_char_candidates -@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, raised_exception", -[ - (["o"], ["New York"], ["NewYork"], ValueError), # non-atomic gt_token - (["o"], ["NewYork"], ["New York"], ValueError), # non-atomic ocr_token - (["o"], [" @ New"], ["@ @"], ValueError), # non-atomic tokens with GAP_CHAR - (["o", "o"], ["New"], ["New"], ValueError), # num gt_labels != num gt_tokens - (["o"], ["@"], ["New"], ner_label.GapCharError), # invalid token with gap char only (gt_token) - (["o"], ["New"], ["@"], ner_label.GapCharError), # invalid token with gap char only (ocr_token) - (["o", "o"], ["New", "@"], ["New", "@"], ner_label.GapCharError), # invalid token (both) - (["o"], [" \n\t@@"], ["New"], ner_label.GapCharError), # invalid token with gap char and space chars (gt_token) - (["o"], ["New"], [" \n\t@"], ner_label.GapCharError), # invalid token with gap char and space chars (ocr_token) - (["o"], [""], ["New"], ValueError), # invalid token: empty string (gt_token) - (["o"], ["New"], [""], ValueError), # invalid token: empty string (ocr_token) - (["o"], [" \n\t"], ["New"], ValueError), # invalid token: space characters only (gt_token) - (["o"], ["New"], [" \n\t"], ValueError), # invalid token: space characters only (ocr_token) - (["o"], ["New"], ["New"], None), # positive case - (["o"], ["New@"], ["New"], None), # positive case with gap char - (["o"], ["New"], ["@@New"], None), # positive case with gap char -]) -def test__propagate_label_to_ocr_error(gt_labels, gt_tokens, ocr_tokens, raised_exception): + +@pytest.mark.parametrize( + "gt_labels, gt_tokens, ocr_tokens, raised_exception", + [ + (["o"], ["New York"], ["NewYork"], ValueError), # non-atomic gt_token + (["o"], ["NewYork"], ["New York"], ValueError), # non-atomic ocr_token + (["o"], [" @ New"], ["@ @"], ValueError), # non-atomic tokens with GAP_CHAR + (["o", "o"], ["New"], ["New"], ValueError), # num gt_labels != num gt_tokens + ( + ["o"], + ["@"], + ["New"], + ner_label.GapCharError, + ), # invalid token with gap char only (gt_token) + ( + ["o"], + ["New"], + ["@"], + ner_label.GapCharError, + ), # invalid token with gap char only (ocr_token) + ( + ["o", "o"], + ["New", "@"], + ["New", "@"], + ner_label.GapCharError, + ), # invalid token (both) + ( + ["o"], + [" \n\t@@"], + ["New"], + ner_label.GapCharError, + ), # invalid token with gap char and space chars (gt_token) + ( + ["o"], + ["New"], + [" \n\t@"], + ner_label.GapCharError, + ), # invalid token with gap char and space chars (ocr_token) + (["o"], [""], ["New"], ValueError), # invalid token: empty string (gt_token) + (["o"], ["New"], [""], ValueError), # invalid token: empty string (ocr_token) + ( + ["o"], + [" \n\t"], + ["New"], + ValueError, + ), # invalid token: space characters only (gt_token) + ( + ["o"], + ["New"], + [" \n\t"], + ValueError, + ), # invalid token: space characters only (ocr_token) + (["o"], ["New"], ["New"], None), # positive case + (["o"], ["New@"], ["New"], None), # positive case with gap char + (["o"], ["New"], ["@@New"], None), # positive case with gap char + ], +) +def test__propagate_label_to_ocr_error( + gt_labels, gt_tokens, ocr_tokens, raised_exception +): if raised_exception: with pytest.raises(raised_exception): - ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char='@') + ner_label._propagate_label_to_ocr( + gt_labels, gt_tokens, ocr_tokens, gap_char="@" + ) else: - ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char='@') + ner_label._propagate_label_to_ocr( + gt_labels, gt_tokens, ocr_tokens, gap_char="@" + ) -@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels", - LABEL_PROPAGATION_REGRESSION_TEST_CASES) + +@pytest.mark.parametrize( + "gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels", + LABEL_PROPAGATION_REGRESSION_TEST_CASES, +) def test__propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels): gap_char_candidates, _ = ner_label._find_gap_char_candidates(gt_tokens, ocr_tokens) - # run regression test for each GAP_CHAR candidate to make sure + # run regression test for each GAP_CHAR candidate to make sure # label propagate is function correctly - for gap_char in gap_char_candidates: - ocr_labels, _, _, _ = ner_label._propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char) + for gap_char in gap_char_candidates: + ocr_labels, _, _, _ = ner_label._propagate_label_to_ocr( + gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char + ) assert ocr_labels == desired_ocr_labels -@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, raised_exception", [ - (["o"], ["New"], ["New"], None), # positive case - (["o"], ["New@"], ["New"], None), # positive case with gap char - (["o"], ["New"], ["@@New"], None), # positive case with gap char - (["o"], list(ner_label.GAP_CHAR_SET), [""], ner_label.GapCharError), # input char set == GAP_CHAR_SET - (["o"], [""], list(ner_label.GAP_CHAR_SET), ner_label.GapCharError), # input char set == GAP_CHAR_SET - # all possible gap chars set split between ocr and gt tokens - (["o"], list(ner_label.GAP_CHAR_SET)[:10], list(ner_label.GAP_CHAR_SET)[10:], ner_label.GapCharError), -]) -def test_propagate_label_to_ocr_error(gt_labels, gt_tokens, ocr_tokens, raised_exception): + +@pytest.mark.parametrize( + "gt_labels, gt_tokens, ocr_tokens, raised_exception", + [ + (["o"], ["New"], ["New"], None), # positive case + (["o"], ["New@"], ["New"], None), # positive case with gap char + (["o"], ["New"], ["@@New"], None), # positive case with gap char + ( + ["o"], + list(ner_label.GAP_CHAR_SET), + [""], + ner_label.GapCharError, + ), # input char set == GAP_CHAR_SET + ( + ["o"], + [""], + list(ner_label.GAP_CHAR_SET), + ner_label.GapCharError, + ), # input char set == GAP_CHAR_SET + # all possible gap chars set split between ocr and gt tokens + ( + ["o"], + list(ner_label.GAP_CHAR_SET)[:10], + list(ner_label.GAP_CHAR_SET)[10:], + ner_label.GapCharError, + ), + ], +) +def test_propagate_label_to_ocr_error( + gt_labels, gt_tokens, ocr_tokens, raised_exception +): if raised_exception: with pytest.raises(raised_exception): ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens) else: ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens) -@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels", - LABEL_PROPAGATION_REGRESSION_TEST_CASES) + +@pytest.mark.parametrize( + "gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels", + LABEL_PROPAGATION_REGRESSION_TEST_CASES, +) def test_propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels): - ocr_labels, _, _, _ = ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens) + ocr_labels, _, _, _ = ner_label.propagate_label_to_ocr( + gt_labels, gt_tokens, ocr_tokens + ) assert ocr_labels == desired_ocr_labels -@pytest.mark.parametrize("tokens, labels, label_top, desired_output", -[ - ( - ["New", "York", "is", "big"], - ["B-place", "I-place", "o", "o"], - True, - "B-place I-place o o \n" + - "New York is big \n" - ), - ( - ["New", "York", "is", "big"], - ["B-place", "I-place", "o", "o"], - False, - "New York is big \n" + - "B-place I-place o o \n" - ) -]) + +@pytest.mark.parametrize( + "tokens, labels, label_top, desired_output", + [ + ( + ["New", "York", "is", "big"], + ["B-place", "I-place", "o", "o"], + True, + "B-place I-place o o \n" + "New York is big \n", + ), + ( + ["New", "York", "is", "big"], + ["B-place", "I-place", "o", "o"], + False, + "New York is big \n" + "B-place I-place o o \n", + ), + ], +) def test_format_label(tokens, labels, label_top, desired_output): output = ner_label.format_labels(tokens, labels, label_top=label_top) assert output == desired_output -@pytest.mark.parametrize("gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels", - LABEL_PROPAGATION_REGRESSION_TEST_CASES) + +@pytest.mark.parametrize( + "gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels", + LABEL_PROPAGATION_REGRESSION_TEST_CASES, +) def test_format_gt_ocr_w_labels(gt_labels, gt_tokens, ocr_tokens, desired_ocr_labels): - ocr_labels, aligned_gt, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens) - ner_label.format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, aligned_gt, aligned_ocr) \ No newline at end of file + ocr_labels, aligned_gt, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr( + gt_labels, gt_tokens, ocr_tokens + ) + ner_label.format_label_propagation( + gt_tokens, gt_labels, ocr_tokens, ocr_labels, aligned_gt, aligned_ocr + ) diff --git a/tests/text/test_preprocess.py b/tests/text/test_preprocess.py index 4f144a8..5e313e2 100644 --- a/tests/text/test_preprocess.py +++ b/tests/text/test_preprocess.py @@ -2,123 +2,146 @@ from genalog.text import preprocess from genalog.text.alignment import GAP_CHAR import pytest -@pytest.mark.parametrize("token, replacement, desired_output", -[ - ("", "_", ""), # Do nothing to empty string - (" ", "_", " "), # Do nothing to whitespaces - (" \n\t", "_", " \n\t"), - ("ascii", "_", "ascii"), - ("a s\nc\tii", "_", "a s\nc\tii"), - ("ascii·", "_", "ascii"), # Tokens with non-ASCII values - ("·", "_", "_"), # Tokens with non-ASCII values -]) + +@pytest.mark.parametrize( + "token, replacement, desired_output", + [ + ("", "_", ""), # Do nothing to empty string + (" ", "_", " "), # Do nothing to whitespaces + (" \n\t", "_", " \n\t"), + ("ascii", "_", "ascii"), + ("a s\nc\tii", "_", "a s\nc\tii"), + ("ascii·", "_", "ascii"), # Tokens with non-ASCII values + ("·", "_", "_"), # Tokens with non-ASCII values + ], +) def test_remove_non_ascii(token, replacement, desired_output): - for code in range(128, 1000): # non-ASCII values + for code in range(128, 1000): # non-ASCII values token.replace("·", chr(code)) output = preprocess.remove_non_ascii(token, replacement) assert output == desired_output -@pytest.mark.parametrize("s, desired_output", -[ - ( - " New \t \n", - ["New"] - ), + +@pytest.mark.parametrize( + "s, desired_output", + [ + (" New \t \n", ["New"]), # Mixed in gap char "@" - ( - " @ @", - ["@", "@"] - ), - ( - "New York is big", - ["New", "York", "is", "big"] - ), + (" @ @", ["@", "@"]), + ("New York is big", ["New", "York", "is", "big"]), # Mixed multiple spaces and tabs - ( - " New York \t is \t big", - ["New", "York", "is", "big"] - ), + (" New York \t is \t big", ["New", "York", "is", "big"]), # Mixed in punctuation - ( - "New .York is, big !", - ["New", ".York", "is,", "big", "!"] - ), + ("New .York is, big !", ["New", ".York", "is,", "big", "!"]), # Mixed in gap char "@" - ( - "@N@ew York@@@is,\t big@@@@@", - ["@N@ew", "York@@@is,", "big@@@@@"] - ) -]) + ("@N@ew York@@@is,\t big@@@@@", ["@N@ew", "York@@@is,", "big@@@@@"]), + ], +) def test_tokenize(s, desired_output): output = preprocess.tokenize(s) assert output == desired_output -@pytest.mark.parametrize("tokens, desired_output", -[ - ( - ["New", "York", "is", "big"], - "New York is big", - ), + +@pytest.mark.parametrize( + "tokens, desired_output", + [ + ( + ["New", "York", "is", "big"], + "New York is big", + ), # Mixed in punctuation - ( - ["New", ".York", "is,", "big", "!"], - "New .York is, big !", - ), + ( + ["New", ".York", "is,", "big", "!"], + "New .York is, big !", + ), # Mixed in gap char "@" - ( - ["@N@ew", "York@@@is,", "big@@@@@"], - "@N@ew York@@@is, big@@@@@", - ) -]) + ( + ["@N@ew", "York@@@is,", "big@@@@@"], + "@N@ew York@@@is, big@@@@@", + ), + ], +) def test_join_tokens(tokens, desired_output): output = preprocess.join_tokens(tokens) assert output == desired_output -@pytest.mark.parametrize("c, desired_output", -[ - # Gap char - (GAP_CHAR, False), - # Alphabet char - ('a', False), ('A', False), - # Punctuation - ('.', False), ('!', False), (',', False), ('-', False), - # Token separators - (' ', True), ('\n', True), ('\t', True) -]) + +@pytest.mark.parametrize( + "c, desired_output", + [ + # Gap char + (GAP_CHAR, False), + # Alphabet char + ("a", False), + ("A", False), + # Punctuation + (".", False), + ("!", False), + (",", False), + ("-", False), + # Token separators + (" ", True), + ("\n", True), + ("\t", True), + ], +) def test__is_spacing(c, desired_output): assert desired_output == preprocess._is_spacing(c) -@pytest.mark.parametrize("text, desired_output", [ - ("", ""), - ("w .", "w ."), ("w !", "w !"), ("w ?", "w ?"), - ("w /.", "w /."), ("w /!", "w /!"), ("w /?", "w /?"), - ("w1 , w2 .", "w1 , w2 ."), - ("w1 . w2 .", "w1 . \nw2 ."), ("w1 /. w2 /.", "w1 /. \nw2 /."), - ("w1 ! w2 .", "w1 ! \nw2 ."), ("w1 /! w2 /.", "w1 /! \nw2 /."), - ("w1 ? w2 .", "w1 ? \nw2 ."), ("w1 /? w2 /.", "w1 /? \nw2 /."), - ("U.S. . w2 .", "U.S. . \nw2 ."), - ("w1 ??? w2 .", "w1 ??? w2 ."), # not splitting - ("w1 !!! w2 .", "w1 !!! w2 ."), - ("w1 ... . w2 .", "w1 ... . \nw2 ."), - ("w1 ... /. w2 /.", "w1 ... /. \nw2 /."), - ("w1 /. /. w2 .", "w1 /. /. \nw2 ."), - ("w1 /. /.", "w1 /. \n/."), - ("w1 /. /. ", "w1 /. /. \n"), - ("w1 ? ? ? ? w2 .", "w1 ? ? ? ? \nw2 ."), - ("w1 /? /? /? /? w2 /.", "w1 /? /? /? /? \nw2 /."), - ("w1 ! ! ! ! w2 .", "w1 ! ! ! ! \nw2 ."), - ("w1 /! /! /! /! w2 /.", "w1 /! /! /! /! \nw2 /."), -]) + +@pytest.mark.parametrize( + "text, desired_output", + [ + ("", ""), + ("w .", "w ."), + ("w !", "w !"), + ("w ?", "w ?"), + ("w /.", "w /."), + ("w /!", "w /!"), + ("w /?", "w /?"), + ("w1 , w2 .", "w1 , w2 ."), + ("w1 . w2 .", "w1 . \nw2 ."), + ("w1 /. w2 /.", "w1 /. \nw2 /."), + ("w1 ! w2 .", "w1 ! \nw2 ."), + ("w1 /! w2 /.", "w1 /! \nw2 /."), + ("w1 ? w2 .", "w1 ? \nw2 ."), + ("w1 /? w2 /.", "w1 /? \nw2 /."), + ("U.S. . w2 .", "U.S. . \nw2 ."), + ("w1 ??? w2 .", "w1 ??? w2 ."), # not splitting + ("w1 !!! w2 .", "w1 !!! w2 ."), + ("w1 ... . w2 .", "w1 ... . \nw2 ."), + ("w1 ... /. w2 /.", "w1 ... /. \nw2 /."), + ("w1 /. /. w2 .", "w1 /. /. \nw2 ."), + ("w1 /. /.", "w1 /. \n/."), + ("w1 /. /. ", "w1 /. /. \n"), + ("w1 ? ? ? ? w2 .", "w1 ? ? ? ? \nw2 ."), + ("w1 /? /? /? /? w2 /.", "w1 /? /? /? /? \nw2 /."), + ("w1 ! ! ! ! w2 .", "w1 ! ! ! ! \nw2 ."), + ("w1 /! /! /! /! w2 /.", "w1 /! /! /! /! \nw2 /."), + ], +) def test_split_sentences(text, desired_output): assert desired_output == preprocess.split_sentences(text) -@pytest.mark.parametrize("token, desired_output", [ - ("", False), (" ", False), ("\n", False), ("\t", False), - (" \n \t", False), - ("...", False), - ("???", False), ("!!!", False), - (".", True), ("!", True), ("?", True), - ("/.", True), ("/!", True), ("/?", True), -]) + +@pytest.mark.parametrize( + "token, desired_output", + [ + ("", False), + (" ", False), + ("\n", False), + ("\t", False), + (" \n \t", False), + ("...", False), + ("???", False), + ("!!!", False), + (".", True), + ("!", True), + ("?", True), + ("/.", True), + ("/!", True), + ("/?", True), + ], +) def test_is_sentence_separator(token, desired_output): assert desired_output == preprocess.is_sentence_separator(token) diff --git a/tests/text/test_utf8.py b/tests/text/test_utf8.py index 8bfe412..934fc20 100644 --- a/tests/text/test_utf8.py +++ b/tests/text/test_utf8.py @@ -6,9 +6,10 @@ from genalog.text import alignment from genalog.text.alignment import GAP_CHAR from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES + def random_utf8_char(byte_len=1): if byte_len == 1: - return chr(random.randint(0,0x007F)) + return chr(random.randint(0, 0x007F)) elif byte_len == 2: return chr(random.randint(0x007F, 0x07FF)) elif byte_len == 3: @@ -16,33 +17,61 @@ def random_utf8_char(byte_len=1): elif byte_len == 4: return chr(random.randint(0xFFFF, 0x10FFFF)) else: - raise ValueError(f"Invalid byte length: {byte_len}." + - "utf-8 does not encode characters with more than 4 bytes in length") + raise ValueError( + f"Invalid byte length: {byte_len}." + + "utf-8 does not encode characters with more than 4 bytes in length" + ) -@pytest.mark.parametrize("num_utf_char_to_test", [100]) # Number of char per byte length -@pytest.mark.parametrize("byte_len", [1,2,3,4]) # UTF does not encode with more than 4 bytes -@pytest.mark.parametrize("gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", ALIGNMENT_REGRESSION_TEST_CASES) -def test_align(num_utf_char_to_test, byte_len, gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise): - - invalid_char = set(gt_txt).union(set(GAP_CHAR)) # character to replace to cannot be in this set + +@pytest.mark.parametrize( + "num_utf_char_to_test", [100] +) # Number of char per byte length +@pytest.mark.parametrize( + "byte_len", [1, 2, 3, 4] +) # UTF does not encode with more than 4 bytes +@pytest.mark.parametrize( + "gt_txt, noisy_txt, expected_aligned_gt, expected_aligned_noise", + ALIGNMENT_REGRESSION_TEST_CASES, +) +def test_align( + num_utf_char_to_test, + byte_len, + gt_txt, + noisy_txt, + expected_aligned_gt, + expected_aligned_noise, +): + + invalid_char = set(gt_txt).union( + set(GAP_CHAR) + ) # character to replace to cannot be in this set for _ in range(num_utf_char_to_test): utf_char = random_utf8_char(byte_len) - while utf_char in invalid_char: # find a utf char not in the input string and not GAP_CHAR + while ( + utf_char in invalid_char + ): # find a utf char not in the input string and not GAP_CHAR utf_char = random_utf8_char(byte_len) char_to_replace = random.choice(list(invalid_char)) if gt_txt else "" - gt_txt_sub = gt_txt.replace(char_to_replace, utf_char) - noisy_txt_sub = noisy_txt.replace(char_to_replace, utf_char) + gt_txt.replace(char_to_replace, utf_char) + noisy_txt.replace(char_to_replace, utf_char) expected_aligned_gt_sub = expected_aligned_gt.replace(char_to_replace, utf_char) - expected_aligned_noise_sub = expected_aligned_noise.replace(char_to_replace, utf_char) - + expected_aligned_noise_sub = expected_aligned_noise.replace( + char_to_replace, utf_char + ) + # Run alignment aligned_gt, aligned_noise = alignment.align(gt_txt, noisy_txt) aligned_gt = aligned_gt.replace(char_to_replace, utf_char) aligned_noise = aligned_noise.replace(char_to_replace, utf_char) if aligned_gt != expected_aligned_gt_sub: - expected_alignment = alignment._format_alignment(expected_aligned_gt_sub, expected_aligned_noise_sub) + expected_alignment = alignment._format_alignment( + expected_aligned_gt_sub, expected_aligned_noise_sub + ) result_alignment = alignment._format_alignment(aligned_gt, aligned_noise) - warnings.warn(RuntimeWarning(f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}")) - + warnings.warn( + RuntimeWarning( + f"\n\n****Expect alignment returns:****\n{expected_alignment} \n****But got:****\n{result_alignment}" + ) + ) From 5b87f4eddb7176f901f1deb7727cf0c5ba1c90ad Mon Sep 17 00:00:00 2001 From: "Jianjie Liu (MAIDAP)" Date: Mon, 25 Jan 2021 16:08:22 -0500 Subject: [PATCH 2/5] Fix flake8 issues --- genalog/degradation/degrader.py | 64 ++++--- genalog/degradation/effect.py | 162 ++++++++++-------- genalog/generation/content.py | 19 ++- genalog/generation/document.py | 159 +++++++++-------- genalog/ocr/blob_client.py | 221 +++++++++++++++++------- genalog/ocr/common.py | 2 +- genalog/ocr/grok.py | 48 ++++-- genalog/ocr/metrics.py | 292 ++++++++++++++++++++++---------- genalog/ocr/rest_client.py | 150 +++++++++------- genalog/pipeline.py | 88 +++++++--- genalog/text/__init__.py | 1 - genalog/text/alignment.py | 261 ++++++++++++++++------------ genalog/text/anchor.py | 153 ++++++++++------- genalog/text/conll_format.py | 284 +++++++++++++++++++++---------- genalog/text/lcs.py | 23 ++- genalog/text/ner_label.py | 266 +++++++++++++++++------------ genalog/text/preprocess.py | 32 ++-- genalog/text/splitter.py | 151 +++++++++++------ 18 files changed, 1518 insertions(+), 858 deletions(-) diff --git a/genalog/degradation/degrader.py b/genalog/degradation/degrader.py index ff16b04..1024449 100644 --- a/genalog/degradation/degrader.py +++ b/genalog/degradation/degrader.py @@ -5,21 +5,24 @@ import inspect DEFAULT_METHOD_PARAM_TO_INCLUDE = "src" + class ImageState(Enum): ORIGINAL_STATE = "ORIGINAL_STATE" CURRENT_STATE = "CURRENT_STATE" -class Degrader(): + +class Degrader: """ An object for applying multiple degradation effects onto an image""" + def __init__(self, effects): - """ Initialize a Degrader object + """Initialize a Degrader object Arguments: effects {list} -- a list of 2-element tuple that defines: - + (method_name, method_kwargs) - 1. method_name: the name of the degradation method + 1. method_name: the name of the degradation method (method must be defined in 'genalog.degradation.effect') 2. method_kwargs: the keyword arguments of the corresponding method @@ -28,10 +31,10 @@ class Degrader(): [ ("blur", {"radius": 3}), ("bleed_through", {"alpha": 0.8), - ("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}), + ("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}), ] - The example above will apply degradation effects to the images + The example above will apply degradation effects to the images in the following sequence: blur -> bleed_through -> morphological operation (open) @@ -42,14 +45,14 @@ class Degrader(): @staticmethod def validate_effects(effects): - """ Validate the effects list + """Validate the effects list Arguments: effects {list} -- a list of 2-element tuple that defines: - + (method_name, method_kwargs) - 1. method_name: the name of the degradation method + 1. method_name: the name of the degradation method (method must be defined in 'genalog.degradation.effect') 2. method_kwargs: the keyword arguments of the corresponding method @@ -58,11 +61,11 @@ class Degrader(): [ ("blur", {"radius": "3"}), ("bleed_through", {"alpha":"0.8"}), - ("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}), + ("morphology", {"operation": "open", "kernel_shape": (3,3), "kernel_type": "ones"}), ] Raises: - ValueError: raise this error when + ValueError: raise this error when 1. method_name not defined in "genalog.degradation.effect" 2. method_kwargs is not a valid keyword arguments in the corresponding method @@ -73,37 +76,46 @@ class Degrader(): # Try to find corresponding degradation method in the module method = getattr(effect, method_name) except AttributeError: - raise ValueError(f"Method '{method_name}' is not defined in 'genalog.degradation.effect'") + raise ValueError( + f"Method '{method_name}' is not defined in 'genalog.degradation.effect'" + ) # Get the method signatures method_sign = inspect.signature(method) # Check if method parameters are valid - for param_name in method_kwargs.keys(): # i.e. ["operation", "kernel_shape", ...] - if not param_name in method_sign.parameters: + for ( + param_name + ) in method_kwargs.keys(): # i.e. ["operation", "kernel_shape", ...] + if param_name not in method_sign.parameters: method_args = [param for param in method_sign.parameters] - raise ValueError(f"Invalid parameter name '{param_name}' for method 'genalog.degradation.effect.{method_name}()'. Method parameter names are: {method_args}") + raise ValueError( + f"Invalid parameter name '{param_name}' for method 'genalog.degradation.effect.{method_name}()'. " + + f"Method parameter names are: {method_args}" + ) def _add_default_method_param(self): - """ All methods in "genalog.degradation.effect" module have a required + """All methods in "genalog.degradation.effect" module have a required method parameter named "src". This parameter will be included if not provided by the input keyword argument dictionary. """ for effect_tuple in self.effects_to_apply: method_name, method_kwargs = effect_tuple if DEFAULT_METHOD_PARAM_TO_INCLUDE not in method_kwargs: - method_kwargs[DEFAULT_METHOD_PARAM_TO_INCLUDE] = ImageState.CURRENT_STATE - + method_kwargs[ + DEFAULT_METHOD_PARAM_TO_INCLUDE + ] = ImageState.CURRENT_STATE + def apply_effects(self, src): - """ Apply degradation effects in sequence - + """Apply degradation effects in sequence + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) - + Returns: a copy of the source image {numpy.ndarray} after apply the effects """ self.original_state = src self.current_state = src - # Preserve the original effect instructions + # Preserve the original effect instructions effects_to_apply = copy.deepcopy(self.effects_to_apply) for effect_tuple in effects_to_apply: method_name, method_kwargs = effect_tuple @@ -115,7 +127,7 @@ class Degrader(): return self.current_state def insert_image_state(self, kwargs): - """ Replace the enumeration (ImageState) with the actual image in + """Replace the enumeration (ImageState) with the actual image in the keyword argument dictionary Arguments: @@ -124,12 +136,12 @@ class Degrader(): Ex: {"src": ImageState.ORIGINAL_STATE, "radius": 5} Returns: - return keyword argument dictionary replaced with - reference to the image + return keyword argument dictionary replaced with + reference to the image """ for keyword, argument in kwargs.items(): if argument is ImageState.ORIGINAL_STATE: kwargs[keyword] = self.original_state.copy() if argument is ImageState.CURRENT_STATE: kwargs[keyword] = self.current_state.copy() - return kwargs \ No newline at end of file + return kwargs diff --git a/genalog/degradation/effect.py b/genalog/degradation/effect.py index 22bd5ce..a3c87b1 100644 --- a/genalog/degradation/effect.py +++ b/genalog/degradation/effect.py @@ -2,30 +2,32 @@ import cv2 import numpy as np from math import floor + def blur(src, radius=5): - """ Wrapper function for cv2.GaussianBlur - + """Wrapper function for cv2.GaussianBlur + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) - + Keyword Arguments: - radius {int} -- size of the square kernel, + radius {int} -- size of the square kernel, MUST be an odd integer (default: {5}) - + Returns: a copy of the source image {numpy.ndarray} after apply the effect """ return cv2.GaussianBlur(src, (radius, radius), cv2.BORDER_DEFAULT) + def overlay_weighted(src, background, alpha, beta, gamma=0): - """ overlay two images together, pixels from each image is weighted as follow + """overlay two images together, pixels from each image is weighted as follow dst[i] = alpha*src[i] + beta*background[i] + gamma Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) background {numpy.ndarray} -- background image. Must be in same shape are `src` - alpha {float} -- transparent factor for the foreground + alpha {float} -- transparent factor for the foreground beta {float} -- transparent factor for the background Keyword Arguments: @@ -36,8 +38,9 @@ def overlay_weighted(src, background, alpha, beta, gamma=0): """ return cv2.addWeighted(src, alpha, background, beta, gamma).astype(np.uint8) + def overlay(src, background): - """ Overlay two images together via bitwise-and: + """Overlay two images together via bitwise-and: dst[i] = src[i] & background[i] @@ -50,33 +53,35 @@ def overlay(src, background): """ return cv2.bitwise_and(src, background).astype(np.uint8) + def translation(src, offset_x, offset_y): - """ Shift the image in x, y direction - + """Shift the image in x, y direction + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) - offset_x {int} -- pixels in the x direction. + offset_x {int} -- pixels in the x direction. Positive value shifts right and negative shifts right. offset_y {int} -- pixels in the y direction. Positive value shifts down and negative shifts up. - + Returns: a copy of the source image {numpy.ndarray} after apply the effect """ rows, cols = src.shape - trans_matrix = np.float32([[1,0,offset_x], [0,1,offset_y]]) + trans_matrix = np.float32([[1, 0, offset_x], [0, 1, offset_y]]) # size of the output image should be in the form of (width, height) dst = cv2.warpAffine(src, trans_matrix, (cols, rows), borderValue=255) return dst.astype(np.uint8) + def bleed_through(src, background=None, alpha=0.8, gamma=0, offset_x=0, offset_y=5): - """ Apply bleed through effect, background is flipped horizontally. - + """Apply bleed through effect, background is flipped horizontally. + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) - + Keyword Arguments: - background {numpy.ndarray} -- background image. Must be in same + background {numpy.ndarray} -- background image. Must be in same shape as foreground (default: {None}) alpha {float} -- transparent factor for the foreground (default: {0.8}) gamma {int} -- luminance constant (default: {0}) @@ -84,30 +89,31 @@ def bleed_through(src, background=None, alpha=0.8, gamma=0, offset_x=0, offset_y Positive value shifts right and negative shifts right. offset_y {int} -- background translation offset (default: {5}) Positive value shifts down and negative shifts up. - + Returns: a copy of the source image {numpy.ndarray} after apply the effect. Pixel value ranges [0, 255] """ if background is None: background = src.copy() - background = cv2.flip(background, 1) # flipped horizontally + background = cv2.flip(background, 1) # flipped horizontally background = translation(background, offset_x, offset_y) beta = 1 - alpha return overlay_weighted(src, background, alpha, beta, gamma) + def pepper(src, amount=0.05): - """ Randomly sprinkle dark pixels on src image. + """Randomly sprinkle dark pixels on src image. Wrapper function for skimage.util.noise.random_noise(). See https://scikit-image.org/docs/stable/api/skimage.util.html#random-noise - + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) - + Keyword Arguments: amount {float} -- proportion of pixels in range [0, 1] to apply the effect (default: {0.05}) - + Returns: a copy of the source image {numpy.ndarray} after apply the effect. Pixel value ranges [0, 255] as uint8. @@ -118,18 +124,19 @@ def pepper(src, amount=0.05): dst[noise < amount] = 0 return dst.astype(np.uint8) + def salt(src, amount=0.3): - """ Randomly sprinkle white pixels on src image. + """Randomly sprinkle white pixels on src image. Wrapper function for skimage.util.noise.random_noise(). See https://scikit-image.org/docs/stable/api/skimage.util.html#random-noise - + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) - + Keyword Arguments: amount {float} -- proportion of pixels in range [0, 1] to apply the effect (default: {0.05}) - + Returns: a copy of the source image {numpy.ndarray} after apply the effect. Pixel value ranges [0, 255] @@ -140,18 +147,19 @@ def salt(src, amount=0.3): dst[noise < amount] = 255 return dst.astype(np.uint8) + def salt_then_pepper(src, salt_amount=0.1, pepper_amount=0.05): - """ Randomly add salt then add pepper onto the image. - + """Randomly add salt then add pepper onto the image. + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) - salt_amount {float} -- proportion of pixels in range [0, 1] to + salt_amount {float} -- proportion of pixels in range [0, 1] to apply the salt effect (default: {0.1}) - pepper_amount {float} -- proportion of pixels in range [0, 1] to - apply the pepper effect + pepper_amount {float} -- proportion of pixels in range [0, 1] to + apply the pepper effect (default: {0.05}) - + Returns: a copy of the source image {numpy.ndarray} after apply the effect. Pixel value ranges [0, 255] as uint8. @@ -159,18 +167,19 @@ def salt_then_pepper(src, salt_amount=0.1, pepper_amount=0.05): salted = salt(src, amount=salt_amount) return pepper(salted, amount=pepper_amount) + def pepper_then_salt(src, pepper_amount=0.05, salt_amount=0.1): - """ Randomly add pepper then salt onto the image. - + """Randomly add pepper then salt onto the image. + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) - pepper_amount {float} -- proportion of pixels in range [0, 1] to + pepper_amount {float} -- proportion of pixels in range [0, 1] to apply the pepper effect. - (default: {0.05}) - salt_amount {float} -- proportion of pixels in range [0, 1] to + (default: {0.05}) + salt_amount {float} -- proportion of pixels in range [0, 1] to apply the salt effect. (default: {0.1}) - + Returns: a copy of the source image {numpy.ndarray} after apply the effect. Pixel value ranges [0, 255] as uint8. @@ -178,16 +187,17 @@ def pepper_then_salt(src, pepper_amount=0.05, salt_amount=0.1): peppered = pepper(src, amount=pepper_amount) return salt(peppered, amount=salt_amount) + def create_2D_kernel(kernel_shape, kernel_type="ones"): - """ Create 2D kernel for morphological operations. - + """Create 2D kernel for morphological operations. + Arguments: kernel_shape {tuple} -- shape of the kernel (rows, cols) - + Keyword Arguments: - kernel_type {str} -- type of kernel (default: {"ones"}). + kernel_type {str} -- type of kernel (default: {"ones"}). All supported kernel types are below: - + "ones": kernel is filled with all 1s in shape (rows, cols) [[1,1,1], [1,1,1], @@ -215,9 +225,9 @@ def create_2D_kernel(kernel_shape, kernel_type="ones"): [1, 1, 1, 1, 1], [0, 0, 1, 0, 0]] Raises: - ValueError: if kernel is not a 2-element tuple or + ValueError: if kernel is not a 2-element tuple or kernel_type is not one of the supported values - + Returns: a 2D array {numpy.ndarray} of shape `kernel_shape`. """ @@ -233,37 +243,47 @@ def create_2D_kernel(kernel_shape, kernel_type="ones"): elif kernel_type == "x": diagonal = np.eye(kernel_rows, kernel_cols) kernel = np.add(diagonal, np.fliplr(diagonal)) - kernel[kernel>1] = 1 + kernel[kernel > 1] = 1 elif kernel_type == "plus": kernel = np.zeros(kernel_shape) - center_col = floor(kernel.shape[0]/2) - center_row = floor(kernel.shape[1]/2) + center_col = floor(kernel.shape[0] / 2) + center_row = floor(kernel.shape[1] / 2) kernel[:, center_col] = 1 kernel[center_row, :] = 1 elif kernel_type == "ellipse": kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_shape) else: - valid_kernel_types = {"ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"} - raise ValueError(f"Invalid kernel_type: {kernel_type}. Valid types are {valid_kernel_types}") + valid_kernel_types = { + "ones", + "upper_triangle", + "lower_triangle", + "x", + "plus", + "ellipse", + } + raise ValueError( + f"Invalid kernel_type: {kernel_type}. Valid types are {valid_kernel_types}" + ) return kernel.astype(np.uint8) -def morphology(src, operation="open", kernel_shape=(3,3), kernel_type="ones"): - """ Dynamic calls different morphological operations + +def morphology(src, operation="open", kernel_shape=(3, 3), kernel_type="ones"): + """Dynamic calls different morphological operations ("open", "close", "dilate" and "erode") with the given parameters - + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) - + Keyword Arguments: operation {str} -- name of a morphological operation: ("open", "close", "dilate", "erode") (default: {"open"}) - kernel_shape {tuple} -- shape of the kernel (rows, cols) + kernel_shape {tuple} -- shape of the kernel (rows, cols) (default: {(3,3)}) kernel_type {str} -- type of kernel (default: {"ones"}) Supported kernel_types are: - ["ones", "upper_triangle", "lower_triangle", + ["ones", "upper_triangle", "lower_triangle", "x", "plus", "ellipse"] Returns: @@ -280,7 +300,10 @@ def morphology(src, operation="open", kernel_shape=(3,3), kernel_type="ones"): return erode(src, kernel) else: valid_operations = ["open", "close", "dilate", "erode"] - raise ValueError(f"Invalid morphology operation '{operation}'. Valid morphological operations are {valid_operations}") + raise ValueError( + f"Invalid morphology operation '{operation}'. Valid morphological operations are {valid_operations}" + ) + def open(src, kernel): """ "open" morphological operation. Like morphological "erosion", it removes @@ -289,59 +312,62 @@ def open(src, kernel): For more information see: 1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html 2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/open.htm - + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect - + Returns: a copy of the source image {numpy.ndarray} after apply the effect """ return cv2.morphologyEx(src, cv2.MORPH_OPEN, kernel) + def close(src, kernel): - """ "close" morphological operation. Like morphological "dilation", it grows the + """ "close" morphological operation. Like morphological "dilation", it grows the boundary of the foreground (white pixels), however, it is less destructive than dilation of the original boundary shape. For more information see: 1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html 2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/close.htm - + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect - + Returns: a copy of the source image {numpy.ndarray} after apply the effect """ return cv2.morphologyEx(src, cv2.MORPH_CLOSE, kernel) + def erode(src, kernel): """ "erode" morphological operation. Erodes foreground pixels (white pixels). For more information see: 1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html 2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/erode.htm - + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect - + Returns: a copy of the source image {numpy.ndarray} after apply the effect """ return cv2.erode(src, kernel) + def dilate(src, kernel): - """ "dilate" morphological operation. Grows foreground pixels (white pixels). + """ "dilate" morphological operation. Grows foreground pixels (white pixels). For more information see: 1. https://docs.opencv.org/master/d9/d61/tutorial_py_morphological_ops.html 2. http://homepages.inf.ed.ac.uk/rbf/HIPR2/dilate.htm - + Arguments: src {numpy.ndarray} -- source image of shape (rows, cols) kernel {numpy.ndarray} -- a 2D array for structuring the morphological effect - + Returns: a copy of the source image {numpy.ndarray} after apply the effect """ diff --git a/genalog/generation/content.py b/genalog/generation/content.py index 226af20..b39c369 100644 --- a/genalog/generation/content.py +++ b/genalog/generation/content.py @@ -1,19 +1,23 @@ from enum import Enum, auto + class ContentType(Enum): PARAGRAPH = auto() TITLE = auto() IMAGE = auto() COMPOSITE = auto() -class Content(): + +class Content: def __init__(self): self.iterable = True self._content = None def set_content_type(self, content_type): if type(content_type) != ContentType: - raise TypeError(f"Invalid content type: {content_type}, valid types are {list(ContentType)}") + raise TypeError( + f"Invalid content type: {content_type}, valid types are {list(ContentType)}" + ) self.content_type = content_type def validate_content(self): @@ -21,13 +25,14 @@ class Content(): def __str__(self): return self._content.__str__() - + def __iter__(self): return self._content.__iter__() def __getitem__(self, key): return self._content.__getitem__(key) + class Paragraph(Content): def __init__(self, content): self.set_content_type(ContentType.PARAGRAPH) @@ -38,16 +43,18 @@ class Paragraph(Content): if not isinstance(content, str): raise TypeError(f"Expect a str, but got {type(content)}") + class Title(Content): def __init__(self, content): self.set_content_type(ContentType.TITLE) self.validate_content(content) self._content = content - + def validate_content(self, content): if not isinstance(content, str): raise TypeError(f"Expect a str, but got {type(content)}") + class CompositeContent(Content): def __init__(self, content_list, content_type_list): self.set_content_type(ContentType.COMPOSITE) @@ -68,7 +75,7 @@ class CompositeContent(Content): self._content.append(Paragraph(content)) else: raise NotImplementedError(f"{content_type} is not currently supported") - + def insert_content(self, new_content, index): NotImplementedError @@ -82,5 +89,5 @@ class CompositeContent(Content): """get a string transparent of the nested object types""" transparent_str = "[" for content in self._content: - transparent_str += '"' + content.__str__() + "\", " + transparent_str += '"' + content.__str__() + '", ' return transparent_str + "]" diff --git a/genalog/generation/document.py b/genalog/generation/document.py index 73fe61f..4651911 100644 --- a/genalog/generation/document.py +++ b/genalog/generation/document.py @@ -27,19 +27,21 @@ DEFAULT_STYLE_COMBINATION = { "hyphenate": [False], } + class Document(object): """ A composite object that represents a document """ + def __init__(self, content, template, **styles): """Initialize a Document object with source template and content - + Arguments: content {CompositeContent} -- a iterable object whose elements template {Template} -- a jinja2.Template object - + Optional Argument: - styles [dict] -- a kwargs dictionary (context) whose keys and values are + styles [dict] -- a kwargs dictionary (context) whose keys and values are the template variable and their respective values - + Example: { "font_family": "Calibri", @@ -48,24 +50,24 @@ class Document(object): } Note that this assumes that "font_family", "font_size", "hyphenate" are valid - variables declared in the loaded template. There will be **NO SIDE-EFFECT** + variables declared in the loaded template. There will be **NO SIDE-EFFECT** providing an variable undefined in the template. - + You can also provide these key-value pairs via Python keyword arguments: - + Document(content, template, font_family="Calibri, font_size="10px", hyphenate=True) """ self.content = content self.template = template self.styles = DEFAULT_DOCUMENT_STYLE.copy() # This is a rendered document ready to be painted on a cairo surface - self._document = None # weasyprint.document.Document object + self._document = None # weasyprint.document.Document object self.compiled_html = None # Update the default styles and initialize self._document object self.update_style(**styles) def render_html(self): - """ Wrapper function for Jinjia2.Template.render(). Each template + """Wrapper function for Jinjia2.Template.render(). Each template declare its template variables. This method assigns each variable to its respective value and compiles the template. @@ -77,8 +79,8 @@ class Document(object): return self.template.render(content=self.content, **self.styles) def render_pdf(self, target=None, zoom=1): - """ Wrapper function for WeasyPrint.Document.write_pdf - + """Wrapper function for WeasyPrint.Document.write_pdf + Arguments: target -- a filename, file-like object, or None split_pages {bool} -- true if saving each document page as a separate file. @@ -92,32 +94,37 @@ class Document(object): return self._document.write_pdf(target=target, zoom=zoom) def render_png(self, target=None, split_pages=False, resolution=300): - """ Wrapper function for WeasyPrint.Document.write_png - + """Wrapper function for WeasyPrint.Document.write_png + Arguments: target -- a filename, file-like object, or None split_pages {bool} -- true if save each document page as a separate file. - resolution {int} -- the output resolution in PNG pixels per CSS inch. At 300 dpi (the default), PNG pixels match the CSS px unit. - + resolution {int} -- the output resolution in PNG pixels per CSS inch. At 300 dpi (the default), + PNG pixels match the CSS px unit. + Returns: The image as bytes if target is not provided or None, otherwise None (the PDF is written to target) """ - if target != None and split_pages: + if target is not None and split_pages: # get destination filename and extension filename, ext = os.path.splitext(target) for page_num, page in enumerate(self._document.pages): page_name = filename + f"_pg_{page_num}" + ext - self._document.copy([page]).write_png(target=page_name, resolution=resolution) + self._document.copy([page]).write_png( + target=page_name, resolution=resolution + ) return None - elif target == None: + elif target is None: # return image bytes string if no target is specified - png_bytes, png_width, png_height = self._document.write_png(target=target, resolution=resolution) + png_bytes, png_width, png_height = self._document.write_png( + target=target, resolution=resolution + ) return png_bytes else: return self._document.write_png(target=target, resolution=resolution) def render_array(self, resolution=300, channel="GRAYSCALE"): - """ Render document as a numpy.ndarray. + """Render document as a numpy.ndarray. Keyword Arguments: resolution {int} -- in units dpi (default: {300}) @@ -125,24 +132,29 @@ class Document(object): available values are: "GRAYSCALE", "RGB", "RGBA", "BGRA", "BGR" Note that "RGB" is 3-channel, "RGBA" is 4-channel and "GRAYSCALE" is single channel - + Returns: A numpy.ndarray representation of the document. """ # Method below returns a cairocffi.ImageSurface object # https://cairocffi.readthedocs.io/en/latest/api.html#cairocffi.ImageSurface - surface, width, height = self._document.write_image_surface(resolution=resolution) + surface, width, height = self._document.write_image_surface( + resolution=resolution + ) img_format = surface.get_format() - + # This is BGRA channel in little endian (reverse) if img_format != FORMAT_ARGB32: - raise RuntimeError(f"Expect surface format to be 'cairocffi.FORMAT_ARGB32', but got {img_format}. Please check the underlining implementation of 'weasyprint.document.Document.write_image_surface()'") - + raise RuntimeError( + f"Expect surface format to be 'cairocffi.FORMAT_ARGB32', but got {img_format}." + + "Please check the underlining implementation of 'weasyprint.document.Document.write_image_surface()'" + ) + img_buffer = surface.get_data() # Returns image array in "BGRA" channel - img_array = np.ndarray(shape=(height, width, 4), - dtype=np.uint8, - buffer=img_buffer) + img_array = np.ndarray( + shape=(height, width, 4), dtype=np.uint8, buffer=img_buffer + ) if channel == "GRAYSCALE": return cv2.cvtColor(img_array, cv2.COLOR_BGRA2GRAY) elif channel == "RGBA": @@ -155,13 +167,15 @@ class Document(object): return cv2.cvtColor(img_array, cv2.COLOR_BGRA2BGR) else: valid_channels = ["GRAYSCALE", "RGB", "RGBA", "BGR", "BGRA"] - raise ValueError(f"Invalid channel code {channel}. Valid values are: {valid_channels}.") + raise ValueError( + f"Invalid channel code {channel}. Valid values are: {valid_channels}." + ) def update_style(self, **style): - """ Update template variables that controls the document style and re-compile the document to reflect the style change. - + """Update template variables that controls the document style and re-compile the document to reflect the style change. + Optional Arguments: - style {dict} -- a kwargs dictionary whose keys and values are + style {dict} -- a kwargs dictionary whose keys and values are the template variable and their respective values Example: @@ -175,44 +189,49 @@ class Document(object): self.styles.update(style) # Recompile the html template and the document obj self.compiled_html = self.render_html() - self._document = HTML(string=self.compiled_html).render() # weasyprinter.document.Document object + self._document = HTML( + string=self.compiled_html + ).render() # weasyprinter.document.Document object -class DocumentGenerator(): +class DocumentGenerator: """ Document generator class """ + def __init__(self, template_path=None): - """ Initialize a DocumentGenerator class - + """Initialize a DocumentGenerator class + Keyword Arguments: template_path {str} -- filepath of custom templates (default: {None}) - *** Important *** if not set, will use the default templates from the + *** Important *** if not set, will use the default templates from the package "genalog.generation.templates". """ if template_path: self.template_env = Environment( - loader=FileSystemLoader(template_path), - autoescape=select_autoescape(['html', 'xml']) - ) + loader=FileSystemLoader(template_path), + autoescape=select_autoescape(["html", "xml"]), + ) self.template_list = self.template_env.list_templates() else: # Loading built-in templates from the genalog package self.template_env = Environment( - loader=PackageLoader("genalog.generation", "templates"), - autoescape=select_autoescape(['html', 'xml']) - ) + loader=PackageLoader("genalog.generation", "templates"), + autoescape=select_autoescape(["html", "xml"]), + ) # Remove macros and css templates from rendering - self.template_list = self.template_env.list_templates(filter_func=DocumentGenerator._keep_template) + self.template_list = self.template_env.list_templates( + filter_func=DocumentGenerator._keep_template + ) self.set_styles_to_generate(DEFAULT_STYLE_COMBINATION) @staticmethod def _keep_template(template_name): - """ Auxiliary function for Jinja2.Environment.list_templates(). + """Auxiliary function for Jinja2.Environment.list_templates(). This function filters out non-html templates and base templates - + Arguments: template_name {str} -- target of the template - + Returns: [bool] -- True if keeping the template in the list. False otherwise. """ @@ -220,16 +239,16 @@ class DocumentGenerator(): if any(name in template_name for name in TEMPLATES_TO_REMOVE): return False return True - + def set_styles_to_generate(self, style_combinations): """ Set new styles to generate. - + Arguments: - style_combination {dict} -- a dictionary {str: list} enlisting the combinations - of values to generate per style property + style_combination {dict} -- a dictionary {str: list} enlisting the combinations + of values to generate per style property (default: {None}) - + Example: { "font_family": ["Calibri", "Times"], @@ -248,14 +267,16 @@ class DocumentGenerator(): variables declared in the loaded template. There will be NO side-effect providing an variable UNDEFINED in the template. - If this parameter is not provided, generator will use default document + If this parameter is not provided, generator will use default document styles: DEFAULT_STYLE_COMBINATION """ - self.styles_to_generate = DocumentGenerator.expand_style_combinations(style_combinations) + self.styles_to_generate = DocumentGenerator.expand_style_combinations( + style_combinations + ) def create_generator(self, content, templates_to_render): - """ Create a Document generator - + """Create a Document generator + Arguments: content {list} -- a list [str] of string to populate the template templates_to_render {list} -- a list [str] or templates to render @@ -266,15 +287,17 @@ class DocumentGenerator(): """ for template_name in templates_to_render: if template_name not in self.template_list: - raise FileNotFoundError(f"File '{template_name}' not found. Available templates are {self.template_list}") + raise FileNotFoundError( + f"File '{template_name}' not found. Available templates are {self.template_list}" + ) template = self.template_env.get_template(template_name) for style in self.styles_to_generate: yield Document(content, template, **style) @staticmethod def expand_style_combinations(styles): - """ Expand the list of style values into all possible style combinations - + """Expand the list of style values into all possible style combinations + Example: styles = { @@ -291,12 +314,12 @@ class DocumentGenerator(): {"font_family": "Times", "font_size": "12px", "hyphenate":True } ] - The result dictionaries are intended to be used as a kwargs to initialize a + The result dictionaries are intended to be used as a kwargs to initialize a Document Object: Example: Document(template, content, **{"font_family": "Calibri", "font_size": ...}) - + Arguments: styles {dict} -- a dictionary {str: list} enlisting the combinations of values to generate per style property @@ -308,10 +331,14 @@ class DocumentGenerator(): if not styles: return [] # Python 2.x+ guarantees that the order in keys() and values() is preserved - style_properties = styles.keys() # ex) ["font_family", "font_size", "hyphenate"] - property_values = styles.values() # ex) [["Calibri", "Times"], ["10px", "12px"], [True]] - - # Generate all possible combinations: + style_properties = ( + styles.keys() + ) # ex) ["font_family", "font_size", "hyphenate"] + property_values = ( + styles.values() + ) # ex) [["Calibri", "Times"], ["10px", "12px"], [True]] + + # Generate all possible combinations: # [("Calibri", "10px", True), ("Calibri", "12px", True), ...] property_value_combinations = itertools.product(*property_values) @@ -328,4 +355,4 @@ class DocumentGenerator(): style_dict[style_property] = property_value style_combinations.append(style_dict) - return style_combinations \ No newline at end of file + return style_combinations diff --git a/genalog/ocr/blob_client.py b/genalog/ocr/blob_client.py index 326bf76..547b3d4 100644 --- a/genalog/ocr/blob_client.py +++ b/genalog/ocr/blob_client.py @@ -2,13 +2,11 @@ see: https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python """ import os -import time import hashlib import json import asyncio import random -from multiprocessing import Pool -from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient +from azure.storage.blob import BlobServiceClient from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from tqdm import tqdm @@ -24,13 +22,21 @@ FILE_SEMAPHORE = asyncio.Semaphore(500) MAX_RETRIES = 5 + class GrokBlobClient: """This class is a client that is used to upload and delete files from Azure Blob storage https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python """ - def __init__(self, datasource_container_name, blob_account_name, blob_key, projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME): + + def __init__( + self, + datasource_container_name, + blob_account_name, + blob_key, + projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME, + ): """Creates the blob storage client given the key and storage account name - + Args: datasource_container_name (str): container name. This container does not need to be existing projections_container_name (str): projections container to store ocr projections. @@ -42,38 +48,54 @@ class GrokBlobClient: self.PROJECTIONS_CONTAINER_NAME = projections_container_name self.BLOB_NAME = blob_account_name self.BLOB_KEY = blob_key - self.BLOB_CONNECTION_STRING = f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" \ + self.BLOB_CONNECTION_STRING = ( + f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net" + ) @staticmethod def create_from_env_var(): """Created the blob client using values in the environment variables - + Returns: GrokBlobClient: the new blob client """ DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"] BLOB_NAME = os.environ["BLOB_NAME"] BLOB_KEY = os.environ["BLOB_KEY"] - PROJECTIONS_CONTAINER_NAME = os.environ.get("PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME) - client = GrokBlobClient(DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY, projections_container_name=PROJECTIONS_CONTAINER_NAME) + PROJECTIONS_CONTAINER_NAME = os.environ.get( + "PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME + ) + client = GrokBlobClient( + DATASOURCE_CONTAINER_NAME, + BLOB_NAME, + BLOB_KEY, + projections_container_name=PROJECTIONS_CONTAINER_NAME, + ) return client - def upload_images_to_blob(self, src_folder_path, dest_folder_name=None,check_existing_cache=True, use_async=True): + def upload_images_to_blob( + self, + src_folder_path, + dest_folder_name=None, + check_existing_cache=True, + use_async=True, + ): """Uploads images from the src_folder_path to blob storage at the destination folder. The destination folder is created if it doesn't exist. If a destination folder is not given a folder is created named by the md5 hash of the files. - + Args: src_folder_path (src): path to local folder that has images dest_folder_name (str, optional): destination folder name. Defaults to None. - + Returns: str: the destination folder name """ self._create_container() blob_service_client = BlobServiceClient.from_connection_string( - self.BLOB_CONNECTION_STRING) + self.BLOB_CONNECTION_STRING + ) if dest_folder_name is None: dest_folder_name = self.get_folder_hash(src_folder_path) @@ -94,30 +116,50 @@ class GrokBlobClient: return (upload_file_path, blob_name) if check_existing_cache: - existing_blobs,_ = self.list_blobs(dest_folder_name or "") - existing_blobs = list(map(lambda blob: blob["name"] , existing_blobs)) - file_blob_names = filter(lambda file_blob_names: not file_blob_names[1] in existing_blobs, zip(files_to_upload, blob_names)) - job_args = [get_job_args(file_path, blob_name) for file_path, blob_name in file_blob_names ] + existing_blobs, _ = self.list_blobs(dest_folder_name or "") + existing_blobs = list(map(lambda blob: blob["name"], existing_blobs)) + file_blob_names = filter( + lambda file_blob_names: not file_blob_names[1] in existing_blobs, + zip(files_to_upload, blob_names), + ) + job_args = [ + get_job_args(file_path, blob_name) + for file_path, blob_name in file_blob_names + ] else: - job_args = [get_job_args(file_path, blob_name) for file_path, blob_name in zip(files_to_upload, blob_names)] + job_args = [ + get_job_args(file_path, blob_name) + for file_path, blob_name in zip(files_to_upload, blob_names) + ] print("uploading ", len(job_args), "files") if not use_async: blob_service_client = BlobServiceClient.from_connection_string( - self.BLOB_CONNECTION_STRING) - blob_container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME) + self.BLOB_CONNECTION_STRING + ) + blob_container_client = blob_service_client.get_container_client( + self.DATASOURCE_CONTAINER_NAME + ) jobs = [(blob_container_client,) + x for x in job_args] - for _ in tqdm(map(_upload_worker_sync, jobs), total=len(jobs) ): + for _ in tqdm(map(_upload_worker_sync, jobs), total=len(jobs)): pass else: async_blob_service_client = asyncBlobServiceClient.from_connection_string( - self.BLOB_CONNECTION_STRING) + self.BLOB_CONNECTION_STRING + ) async def async_upload(): async with async_blob_service_client: - async_blob_container_client = async_blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME) + async_blob_container_client = ( + async_blob_service_client.get_container_client( + self.DATASOURCE_CONTAINER_NAME + ) + ) jobs = [(async_blob_container_client,) + x for x in job_args] - for f in tqdm(asyncio.as_completed(map(_upload_worker_async,jobs)), total=len(jobs)): + for f in tqdm( + asyncio.as_completed(map(_upload_worker_async, jobs)), + total=len(jobs), + ): await f loop = asyncio.get_event_loop() @@ -150,7 +192,7 @@ class GrokBlobClient: def delete_blobs_folder(self, folder_name): """Deletes all blobs in a folder - + Args: folder_name (str): folder to delete """ @@ -158,53 +200,83 @@ class GrokBlobClient: blobs_list, blob_service_client = self.list_blobs(folder_name) for blob in blobs_list: blob_client = blob_service_client.get_blob_client( - container=self.DATASOURCE_CONTAINER_NAME, blob=blob) + container=self.DATASOURCE_CONTAINER_NAME, blob=blob + ) blob_client.delete_blob() def list_blobs(self, folder_name): blob_service_client = BlobServiceClient.from_connection_string( - self.BLOB_CONNECTION_STRING) - container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME) - return container_client.list_blobs(name_starts_with=folder_name), blob_service_client - + self.BLOB_CONNECTION_STRING + ) + container_client = blob_service_client.get_container_client( + self.DATASOURCE_CONTAINER_NAME + ) + return ( + container_client.list_blobs(name_starts_with=folder_name), + blob_service_client, + ) def _create_container(self): - """Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist. - """ + """Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist.""" # Create the BlobServiceClient object which will be used to create a container client blob_service_client = BlobServiceClient.from_connection_string( - self.BLOB_CONNECTION_STRING) + self.BLOB_CONNECTION_STRING + ) try: - blob_service_client.create_container( - self.DATASOURCE_CONTAINER_NAME) + blob_service_client.create_container(self.DATASOURCE_CONTAINER_NAME) except ResourceExistsError: print("container already exists:", self.DATASOURCE_CONTAINER_NAME) # create the container for storing ocr projections try: - print("creating projections storage container:", self.PROJECTIONS_CONTAINER_NAME) - blob_service_client.create_container( - self.PROJECTIONS_CONTAINER_NAME) + print( + "creating projections storage container:", + self.PROJECTIONS_CONTAINER_NAME, + ) + blob_service_client.create_container(self.PROJECTIONS_CONTAINER_NAME) except ResourceExistsError: print("container already exists:", self.PROJECTIONS_CONTAINER_NAME) def get_ocr_json(self, remote_path, output_folder, use_async=True): blob_service_client = BlobServiceClient.from_connection_string( - self.BLOB_CONNECTION_STRING) - container_client = blob_service_client.get_container_client(self.DATASOURCE_CONTAINER_NAME) + self.BLOB_CONNECTION_STRING + ) + container_client = blob_service_client.get_container_client( + self.DATASOURCE_CONTAINER_NAME + ) blobs_list = list(container_client.list_blobs(name_starts_with=remote_path)) container_uri = f"https://{self.BLOB_NAME}.blob.core.windows.net/{self.DATASOURCE_CONTAINER_NAME}" if use_async: async_blob_service_client = asyncBlobServiceClient.from_connection_string( - self.BLOB_CONNECTION_STRING) + self.BLOB_CONNECTION_STRING + ) + async def async_download(): async with async_blob_service_client: - async_projection_container_client = async_blob_service_client.get_container_client(self.PROJECTIONS_CONTAINER_NAME) - jobs = list(map(lambda blob :(blob, async_projection_container_client, container_uri, output_folder), blobs_list )) - for f in tqdm(asyncio.as_completed(map(_download_worker_async,jobs)), total=len(jobs)): + async_projection_container_client = ( + async_blob_service_client.get_container_client( + self.PROJECTIONS_CONTAINER_NAME + ) + ) + jobs = list( + map( + lambda blob: ( + blob, + async_projection_container_client, + container_uri, + output_folder, + ), + blobs_list, + ) + ) + for f in tqdm( + asyncio.as_completed(map(_download_worker_async, jobs)), + total=len(jobs), + ): await f + loop = asyncio.get_event_loop() if loop.is_running(): result = loop.create_task(async_download()) @@ -212,11 +284,24 @@ class GrokBlobClient: result = loop.run_until_complete(async_download()) return result else: - projection_container_client = blob_service_client.get_container_client(self.PROJECTIONS_CONTAINER_NAME) - jobs = list(map(lambda blob : (blob, projection_container_client, container_uri, output_folder), blobs_list)) + projection_container_client = blob_service_client.get_container_client( + self.PROJECTIONS_CONTAINER_NAME + ) + jobs = list( + map( + lambda blob: ( + blob, + projection_container_client, + container_uri, + output_folder, + ), + blobs_list, + ) + ) print("downloading", len(jobs), "files") - for _ in tqdm(map( _download_worker_sync, jobs), total=len(jobs)): - pass + for _ in tqdm(map(_download_worker_sync, jobs), total=len(jobs)): + pass + def _get_projection_path(container_uri, blob): blob_uri = f"{container_uri}/{blob.name}" @@ -226,13 +311,16 @@ def _get_projection_path(container_uri, blob): # hopefully this doesn't change soon otherwise we will have to do linear search over all docs to find # the projections we want projection_path = base64.b64encode(blob_uri.encode()).decode() - projection_path = projection_path.replace("=","") + str(projection_path.count("=")) + projection_path = projection_path.replace("=", "") + str(projection_path.count("=")) return projection_path + def _download_worker_sync(args): blob, projection_container_client, container_uri, output_folder = args - projection_path = _get_projection_path(container_uri,blob) - blob_client = projection_container_client.get_blob_client(blob=f"{projection_path}/document.json") + projection_path = _get_projection_path(container_uri, blob) + blob_client = projection_container_client.get_blob_client( + blob=f"{projection_path}/document.json" + ) doc = json.loads(blob_client.download_blob().readall()) file_name = os.path.basename(blob.name) base_name, ext = os.path.splitext(file_name) @@ -242,10 +330,13 @@ def _download_worker_sync(args): json.dump(text, open(output_file, "w", encoding="utf-8"), ensure_ascii=False) return output_file + async def _download_worker_async(args): blob, async_projection_container_client, container_uri, output_folder = args - projection_path = _get_projection_path(container_uri,blob) - async_blob_client = async_projection_container_client.get_blob_client( blob=f"{projection_path}/document.json") + projection_path = _get_projection_path(container_uri, blob) + async_blob_client = async_projection_container_client.get_blob_client( + blob=f"{projection_path}/document.json" + ) file_name = os.path.basename(blob.name) base_name, ext = os.path.splitext(file_name) for retry in range(MAX_RETRIES): @@ -260,14 +351,15 @@ async def _download_worker_async(args): json.dump(text, open(output_file, "w")) return output_file except ResourceNotFoundError: - print(f"blob doesn't exist in OCR projection. try rerunning OCR", blob.name) + print(f"Blob '{blob.name}'' doesn't exist in OCR projection. try rerunning OCR") return except Exception as e: print("error getting blob OCR projection", blob.name, e) - + # sleep for a bit then retry asyncio.sleep(2 * random.random()) + async def _upload_worker_async(args): async_blob_container_client, upload_file_path, blob_name = args async with FILE_SEMAPHORE: @@ -276,24 +368,33 @@ async def _upload_worker_async(args): for retry in range(MAX_RETRIES): async with REQUEST_SEMAPHORE: try: - await async_blob_container_client.upload_blob(name=blob_name, max_concurrency=8, data=data) + await async_blob_container_client.upload_blob( + name=blob_name, max_concurrency=8, data=data + ) return blob_name except ResourceExistsError: print("blob already exists:", blob_name) - return + return except Exception as e: - print(f"blob upload error. retry count: {retry}/{MAX_RETRIES} :", blob_name, e) + print( + f"blob upload error. retry count: {retry}/{MAX_RETRIES} :", + blob_name, + e, + ) # sleep for a bit then retry asyncio.sleep(2 * random.random()) return blob_name + def _upload_worker_sync(args): blob_container_client, upload_file_path, blob_name = args with open(upload_file_path, "rb") as data: try: - blob_container_client.upload_blob(name=blob_name, max_concurrency=8, data=data) + blob_container_client.upload_blob( + name=blob_name, max_concurrency=8, data=data + ) except ResourceExistsError: print("blob already exists:", blob_name) except Exception as e: print("blob upload error:", blob_name, e) - return blob_name \ No newline at end of file + return blob_name diff --git a/genalog/ocr/common.py b/genalog/ocr/common.py index db1a1bf..dab589e 100644 --- a/genalog/ocr/common.py +++ b/genalog/ocr/common.py @@ -1 +1 @@ -DEFAULT_PROJECTIONS_CONTAINER_NAME = "ocrprojections" \ No newline at end of file +DEFAULT_PROJECTIONS_CONTAINER_NAME = "ocrprojections" diff --git a/genalog/ocr/grok.py b/genalog/ocr/grok.py index 97dab50..b51f0cb 100644 --- a/genalog/ocr/grok.py +++ b/genalog/ocr/grok.py @@ -4,11 +4,10 @@ import time class Grok: - @staticmethod def create_from_env_var(): """Initializes Grok based on keys in the environment variables. - + Returns: Grok: the Grok client """ @@ -16,31 +15,41 @@ class Grok: grok_blob_client = GrokBlobClient.create_from_env_var() return Grok(grok_rest_client, grok_blob_client) - def __init__(self, grok_rest_client: GrokRestClient, grok_blob_client: GrokBlobClient): + def __init__( + self, grok_rest_client: GrokRestClient, grok_blob_client: GrokBlobClient + ): self.grok_rest_client = grok_rest_client self.grok_blob_client = grok_blob_client - def run_grok(self, src_folder_path, dest_folder_path, blob_dest_folder=None,cleanup=False,use_async=True): - """Uploads images in the source folder to blob, sets up an indexing pipeline to run + def run_grok( + self, + src_folder_path, + dest_folder_path, + blob_dest_folder=None, + cleanup=False, + use_async=True, + ): + """Uploads images in the source folder to blob, sets up an indexing pipeline to run GROK OCR on this blob storage as a source, then dowloads the OCR output json to the destination - folder. There resulting json files are of the same name as the original images except prefixed + folder. There resulting json files are of the same name as the original images except prefixed with the name of their folder on the blob storages and suffixed with the .json extension. - + Args: src_folder_path (str): Path to folder holding the images. This folder must only contain png or jpg files dest_folder_path (str): Path to folder where OCR json files will be placed blob_dest_folder (str, optional): Folder tag to use on the blob storage. If set to None, a hash is generated based on the names of files in the src folder. Defaults to None. - cleanup (bool, optional): If set to True, the indexing pipeline is deleted, and the files uploaded to the blob are + cleanup (bool, optional): If set to True, the indexing pipeline is deleted, and the files uploaded to the blob are deleted from blob after running. Defaults to True. use_multiprocessing (boo, optional): If set to True, this will use multiprocessing to increase blob transfers speed. - + Returns: indexer_status json, blob folder name """ print("uploading images to blob") - blob_folder_name,_ = self.grok_blob_client.upload_images_to_blob( - src_folder_path, dest_folder_name=blob_dest_folder, use_async=use_async) + blob_folder_name, _ = self.grok_blob_client.upload_images_to_blob( + src_folder_path, dest_folder_name=blob_dest_folder, use_async=use_async + ) print(f"images upload under folder {blob_folder_name}") try: print("creating and running indexer") @@ -50,10 +59,13 @@ class Grok: indexer_status = self.grok_rest_client.get_indexer_status() if indexer_status["status"] == "error": raise RuntimeError(f"indexer error: {indexer_status}") - + # if not already running start the indexer print("indexer_status", indexer_status) - if indexer_status["lastResult"] == None or indexer_status["lastResult"]["status"] != "inProgress": + if ( + indexer_status["lastResult"] is None + or indexer_status["lastResult"]["status"] != "inProgress" + ): self.grok_rest_client.run_indexer() time.sleep(1) @@ -62,9 +74,13 @@ class Grok: if indexer_status["lastResult"]["status"] == "success": time.sleep(30) print("fetching ocr json results.") - self.grok_blob_client.get_ocr_json(blob_folder_name, dest_folder_path, use_async=use_async) + self.grok_blob_client.get_ocr_json( + blob_folder_name, dest_folder_path, use_async=use_async + ) print(f"indexer status {indexer_status}") - print(f"finished running indexer. json files saved to {dest_folder_path}") + print( + f"finished running indexer. json files saved to {dest_folder_path}" + ) else: print("GROK failed", indexer_status["status"]) raise RuntimeError("GROK failed", indexer_status["status"]) @@ -77,7 +93,7 @@ class Grok: def cleanup(self, folder_name): """Deletes the indexing pipeline (index, indexer, datasource, skillset) from the search service. Deletes uploaded files from the blob - + Args: folder_name (str): blob folder name tag to remove """ diff --git a/genalog/ocr/metrics.py b/genalog/ocr/metrics.py index a8afec4..5466f6d 100644 --- a/genalog/ocr/metrics.py +++ b/genalog/ocr/metrics.py @@ -2,19 +2,19 @@ Utility functions to support getting OCR metrics OCR Metrics -1. word/character accuracy like in this paper https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6065412. +1. word/character accuracy like in this paper https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6065412. Accuracy = Correct Words/Total Words (in target strings) -2. Count of edit distance ops: - insert, delete, substitutions; like in the paper "Deep Statistical Analysis of OCR Errors for Effective Post-OCR Processing". This is based on Levenshtein edit distance. +2. Count of edit distance ops: + insert, delete, substitutions; like in the paper "Deep Statistical Analysis of OCR Errors for Effective Post-OCR Processing". + This is based on Levenshtein edit distance. -3. By looking at the gaps in alignment we also generate substitution dicts: -e.g: if we have text "a worn coat" and ocr is "a wom coat" , "rn" -> "m" will be captured as a substitution +3. By looking at the gaps in alignment we also generate substitution dicts: +e.g: if we have text "a worn coat" and ocr is "a wom coat" , "rn" -> "m" will be captured as a substitution since the rest of the segments align.The assumption here is that we do not expect to have very long gaps in alignment, hence collecting and counting these substitutions will be managable. """ -import string import os import re import json @@ -30,26 +30,29 @@ from genalog.text.anchor import align_w_anchor LOG_LEVEL = 0 WORKERS_PER_CPU = 2 + def _log(*args, **kwargs): if LOG_LEVEL: print(args) + def _trim_whitespace(src_string): return re.sub(r"\s+", " ", src_string.strip()) + def _update_align_stats(src, target, align_stats, substitution_dict, gap_char): """Given two string that differ and have no alignment at all, update the alignment dict and fill in substitution if replacements are found. update alignment stats with counts of the edit operation to transform the source string to the targes - + Args: src (str): source string - target (str): target string at the + target (str): target string at the align_stats (dict): key-value dictionary that stores the counts of inserts, deletes, spacing and replacements substitution_dict (dict): store the counts of mapping from one substring to another of - the replacement edit operation. e.g if 'rm' in source needs to map to 'm' in the target 2 + the replacement edit operation. e.g if 'rm' in source needs to map to 'm' in the target 2 times this will be { ('rm','m'): 2} gap_char (str): gap character used in alignment """ @@ -70,30 +73,40 @@ def _update_align_stats(src, target, align_stats, substitution_dict, gap_char): else: align_stats["replace"] += 1 _log("replacing", source_substr, target_substr) - substitution_dict[source_substr, target_substr] = substitution_dict.get( - (source_substr, target_substr), 0) + 1 + substitution_dict[source_substr, target_substr] = ( + substitution_dict.get((source_substr, target_substr), 0) + 1 + ) _log("spacing count", spacing_count) align_stats["spacing"] += spacing_count -def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matching_chars_count, \ - matching_words_count, matching_alnum_words_count): + +def _update_word_stats( + aligned_src, + aligned_target, + gap_char, + start, + end, + matching_chars_count, + matching_words_count, + matching_alnum_words_count, +): """Given two string segments that align. update the counts of matching words and characters - + Args: aligned_src (str): full source string aligned_target (str): full target string - gap_char (str): gap character used in alignment + gap_char (str): gap character used in alignment start (int): start position of alignment end (int): end position of alignment matching_chars_count (int): current count of matching characters matching_words_count (int): current count of matching words matching_alnum_words_count (int): current count of alphanumeric matching words - + Returns: tuple(int,int,int): the updated matching_chars_count, matching_words_count, matching_alnum_words_count """ aligned_part = aligned_src[start:end] - matching_chars_count += (end-start) + matching_chars_count += end - start # aligned_part = seq.strip() _log("aligned", aligned_part, start, end) if len(aligned_src) != len(aligned_target): @@ -112,10 +125,18 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi # to be compared with the full string to see if they have space before or after if i == 0: - if start != 0 and (aligned_target[start] != " " or aligned_src[start] != " " ): + if start != 0 and ( + aligned_target[start] != " " or aligned_src[start] != " " + ): # if this was the start of the string in the target or source - if not(aligned_src[:start].replace(gap_char,"").replace(" ","") == "" and aligned_target[start-1] == " ") and \ - not(aligned_target[:start].replace(gap_char,"").replace(" ","") == "" and aligned_src[start-1] == " "): + if not ( + aligned_src[:start].replace(gap_char, "").replace(" ", "") == "" + and aligned_target[start - 1] == " " + ) and not ( + aligned_target[:start].replace(gap_char, "").replace(" ", "") + == "" + and aligned_src[start - 1] == " " + ): # beginning word not matching completely _log("removing first match word from count", word, aligned_part) matching_words_count -= 1 @@ -123,34 +144,43 @@ def _update_word_stats(aligned_src, aligned_target, gap_char, start, end, matchi matching_alnum_words_count -= 1 continue - if i == len(words)-1: - if end != len(aligned_target) and (aligned_target[end] != " " or aligned_src[end] != " " ): + if i == len(words) - 1: + if end != len(aligned_target) and ( + aligned_target[end] != " " or aligned_src[end] != " " + ): # this was not the end of the string in the src and not end of string in target - if not(aligned_src[end:].replace(gap_char,"").replace(" ","") == "" and aligned_target[end] == " ") and \ - not(aligned_target[end:].replace(gap_char,"").replace(" ","") == ""and aligned_src[end] == " "): + if not ( + aligned_src[end:].replace(gap_char, "").replace(" ", "") == "" + and aligned_target[end] == " " + ) and not ( + aligned_target[end:].replace(gap_char, "").replace(" ", "") + == "" + and aligned_src[end] == " " + ): # last word not matching completely - _log("removing last match word from count",word, aligned_part) + _log("removing last match word from count", word, aligned_part) matching_words_count -= 1 if re.search(r"\w", word): matching_alnum_words_count -= 1 _log("matched count", matching_words_count) _log("matched alnum count", matching_alnum_words_count) - return matching_chars_count, matching_words_count, matching_alnum_words_count - + return matching_chars_count, matching_words_count, matching_alnum_words_count + + def _get_align_stats(alignment, src_string, target, gap_char): - """Given an alignment, this function get the align stats and substitution mapping to + """Given an alignment, this function get the align stats and substitution mapping to transform the source string to the target string - + Args: alignment (tuple(str, str)): the result of calling align on the two strings src_source (str): the source string target (str) : the target string gap_char (str) : the gap character used in alignment - + Raises: ValueError: if any of the aligned string are empty - + Returns: tuple(dict, dict): align stats dict, substitution mappings dict """ @@ -167,17 +197,24 @@ def _get_align_stats(alignment, src_string, target, gap_char): word_count = len(words) # alphanumeric words are defined here as words with at least one alphanumeric character - alnum_words_count = len(list(filter(lambda x: re.search(r"\w",x) , words ))) - + alnum_words_count = len(list(filter(lambda x: re.search(r"\w", x), words))) + char_count = max(len(target), len(src_string)) matching_chars_count = 0 matching_words_count = 0 matching_alnum_words_count = 0 - align_stats = {"insert": 0, "delete": 0, "replace": 0, "spacing": 0, - "total_chars": char_count, "total_words": word_count, "total_alnum_words": alnum_words_count} + align_stats = { + "insert": 0, + "delete": 0, + "replace": 0, + "spacing": 0, + "total_chars": char_count, + "total_words": word_count, + "total_alnum_words": alnum_words_count, + } start = 0 - + _log("######### Alignment ############") _log(aligned_src) _log(aligned_target) @@ -190,57 +227,105 @@ def _get_align_stats(alignment, src_string, target, gap_char): # since this substring aligns, simple count the number of matching words and chars in and update # the word stats end = i - _log("sequences", aligned_src[start:end], aligned_target[start:end], start, end) + _log( + "sequences", + aligned_src[start:end], + aligned_target[start:end], + start, + end, + ) assert aligned_src[start:end] == aligned_target[start:end] - matching_chars_count, matching_words_count, matching_alnum_words_count = _update_word_stats(aligned_src, - aligned_target, gap_char, start, end, matching_chars_count,matching_words_count, matching_alnum_words_count) + ( + matching_chars_count, + matching_words_count, + matching_alnum_words_count, + ) = _update_word_stats( + aligned_src, + aligned_target, + gap_char, + start, + end, + matching_chars_count, + matching_words_count, + matching_alnum_words_count, + ) start = end + 1 if gap_start is None: gap_start = end else: gap_end = i - if not gap_start is None: + if gap_start is not None: # since characters now match gap_start:i contains a substring of the characters that didnt align before # handle this gap alignment by calling _update_align_stats - _log("gap", aligned_src[gap_start:gap_end], aligned_target[gap_start:gap_end], gap_start, gap_end) - _update_align_stats(aligned_src[gap_start:gap_end], aligned_target[gap_start:gap_end], align_stats, substitution_dict, gap_char) + _log( + "gap", + aligned_src[gap_start:gap_end], + aligned_target[gap_start:gap_end], + gap_start, + gap_end, + ) + _update_align_stats( + aligned_src[gap_start:gap_end], + aligned_target[gap_start:gap_end], + align_stats, + substitution_dict, + gap_char, + ) gap_start = None # Now compare any left overs string segments from the for loop if gap_start is not None: # handle last alignment gap _log("last gap", aligned_src[gap_start:], aligned_target[gap_start:]) - _update_align_stats(aligned_src[gap_start:], aligned_target[gap_start:], align_stats, substitution_dict, gap_char) + _update_align_stats( + aligned_src[gap_start:], + aligned_target[gap_start:], + align_stats, + substitution_dict, + gap_char, + ) else: # handle last aligned substring seq = aligned_src[start:] aligned_part = seq.strip() end = len(aligned_src) _log("last aligned", aligned_part) - matching_chars_count, matching_words_count, matching_alnum_words_count = _update_word_stats(aligned_src, - aligned_target, gap_char, start, end, matching_chars_count,matching_words_count, matching_alnum_words_count) + ( + matching_chars_count, + matching_words_count, + matching_alnum_words_count, + ) = _update_word_stats( + aligned_src, + aligned_target, + gap_char, + start, + end, + matching_chars_count, + matching_words_count, + matching_alnum_words_count, + ) align_stats["matching_chars"] = matching_chars_count align_stats["matching_alnum_words"] = matching_alnum_words_count align_stats["matching_words"] = matching_words_count - align_stats["alnum_word_accuracy"] = matching_alnum_words_count/alnum_words_count - align_stats["word_accuracy"] = matching_words_count/word_count - align_stats["char_accuracy"] = matching_chars_count/char_count + align_stats["alnum_word_accuracy"] = matching_alnum_words_count / alnum_words_count + align_stats["word_accuracy"] = matching_words_count / word_count + align_stats["char_accuracy"] = matching_chars_count / char_count return align_stats, substitution_dict def get_editops_stats(alignment, gap_char): - """Get stats for character level edit operations that need to be done to + """Get stats for character level edit operations that need to be done to transform the source string to the target string. Inputs must not be empty and must be the result of calling the runing the align function. - + Args: alignment (tuple(str, str)): the results from the string alignment biopy function gap_char (str): gap character used in alignment - + Raises: ValueError: If any of the string in the alignment are empty - + Returns: [type]: [description] """ @@ -248,42 +333,49 @@ def get_editops_stats(alignment, gap_char): aligned_src, aligned_target = alignment if aligned_src == "" or aligned_target == "": raise ValueError("one of the input strings is empty") - stats = {"edit_insert": 0, "edit_delete": 0, "edit_replace": 0, - "edit_insert_spacing": 0, "edit_delete_spacing": 0} + stats = { + "edit_insert": 0, + "edit_delete": 0, + "edit_replace": 0, + "edit_insert_spacing": 0, + "edit_delete_spacing": 0, + } actions = {} for i, (char_1, char_2) in enumerate(zip(aligned_src, aligned_target)): - if LOG_LEVEL > 1: _log(char_1, char_2) + if LOG_LEVEL > 1: + _log(char_1, char_2) if char_1 == gap_char: # insert if char_2 == " ": stats["edit_insert_spacing"] += 1 else: stats["edit_insert"] += 1 - actions[i] = ("I",char_2) + actions[i] = ("I", char_2) elif char_2 == gap_char: # delete if char_1 == " ": stats["edit_delete_spacing"] += 1 else: stats["edit_delete"] += 1 - actions[i] = ("D") + actions[i] = "D" elif char_2 != char_1: stats["edit_replace"] += 1 - actions[i] = ("R",char_2) + actions[i] = ("R", char_2) return stats, actions + def get_align_stats(alignment, src_string, target, gap_char): - """Get alignment stats - + """Get alignment stats + Args: alignment (tuple(str,str)): the result of calling the align function src_string (str): the original source string target (str): the original target string gap_char (str): the gap character used in alignment - + Raises: ValueError: if any of the strings are empty - + Returns: tuple(dict, dict): dict of the align starts and dict of the substitution mappings """ @@ -293,52 +385,61 @@ def get_align_stats(alignment, src_string, target, gap_char): _log("alignment results") _log(alignment) align_stats, substitution_dict = _get_align_stats( - alignment, src_string, target, gap_char) + alignment, src_string, target, gap_char + ) return align_stats, substitution_dict + def get_stats(target, src_string): - """Get align stats, edit stats, and substitution mappings for transforming the + """Get align stats, edit stats, and substitution mappings for transforming the source string to the target string. Edit stats refers to character level edit operation required to transform the source to target. Align stats referers to substring level operation required to transform the source to target. Align stats have keys insert,replace,delete and the special key spacing which counts spacing differences between the two strings. Edit stats have the keys edit_insert, edit_replace, edit_delete which count the character level edits. - + Args: src_string (str): the source string target (str): the target string - + Returns: tuple(str, str): One dict containing the edit and align stats, another dict containing the substitutions """ - gap_char_candidates, input_char_set = _find_gap_char_candidates([src_string], [target]) - gap_char = GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop() + gap_char_candidates, input_char_set = _find_gap_char_candidates( + [src_string], [target] + ) + gap_char = ( + GAP_CHAR if GAP_CHAR in gap_char_candidates else gap_char_candidates.pop() + ) alignment = align_w_anchor(src_string, target, gap_char=gap_char) - align_stats, substitution_dict = get_align_stats(alignment,src_string, target, gap_char) + align_stats, substitution_dict = get_align_stats( + alignment, src_string, target, gap_char + ) edit_stats, actions = get_editops_stats(alignment, gap_char) _log("alignment", align_stats) return {**edit_stats, **align_stats}, substitution_dict, actions -def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocessing=True): - """Given a path to the folder containing the source text and a folder containing - the output OCR json, this generates the metrics for all files in the source folder. +def get_metrics( + src_text_path, ocr_json_path, folder_hash=None, use_multiprocessing=True +): + """Given a path to the folder containing the source text and a folder containing + the output OCR json, this generates the metrics for all files in the source folder. This assumes that the files json folder are of the same name the text files except they are prefixed by the parameter folder_hash followed by underscore and suffixed by .png.json. - + Args: src_text_path (str): path to source txt files ocr_json_path (str): path to OCR json files folder_hash (str): prefix for OCR json files use_multiprocessing (bool): use multiprocessing - + Returns: tuple(pandas.DataFrame, dict): A pandas dataframe of the metrics with each file in a row, a dict containing the substitions mappings for each file. the key to the dict is the filename and the values are dicts of the substition mappings for that file. """ - rows = [] substitutions = {} actions_map = {} @@ -347,16 +448,25 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess cpu_count = multiprocessing.cpu_count() n_workers = WORKERS_PER_CPU * cpu_count - job_args = list(map(lambda f: (f, src_text_path, ocr_json_path, folder_hash) , os.listdir(src_text_path))) + job_args = list( + map( + lambda f: (f, src_text_path, ocr_json_path, folder_hash), + os.listdir(src_text_path), + ) + ) if use_multiprocessing: with Pool(n_workers) as pool: - for f, stats, actions, subs in tqdm(pool.imap_unordered(_worker, job_args), total=len(job_args)): + for f, stats, actions, subs in tqdm( + pool.imap_unordered(_worker, job_args), total=len(job_args) + ): substitutions[f] = subs actions_map[f] = actions - rows.append(stats) + rows.append(stats) else: - for f, stats, actions, subs in tqdm(map(_worker, job_args), total=len(job_args)): + for f, stats, actions, subs in tqdm( + map(_worker, job_args), total=len(job_args) + ): substitutions[f] = subs actions_map[f] = actions rows.append(stats) @@ -364,16 +474,17 @@ def get_metrics(src_text_path, ocr_json_path, folder_hash=None, use_multiprocess df = pd.DataFrame(rows) return df, substitutions, actions_map + def get_file_metrics(f, src_text_path, ocr_json_path, folder_hash): src_filename = os.path.join(src_text_path, f) if folder_hash: ocr_filename = os.path.join( - ocr_json_path, f"{folder_hash}_{f.split('txt')[0] + 'json'}") + ocr_json_path, f"{folder_hash}_{f.split('txt')[0] + 'json'}" + ) else: - ocr_filename = os.path.join( - ocr_json_path, f"{f.split('txt')[0] + 'json'}") + ocr_filename = os.path.join(ocr_json_path, f"{f.split('txt')[0] + 'json'}") try: - src_string = open(src_filename, "r", errors='ignore', encoding="utf8").read() + src_string = open(src_filename, "r", errors="ignore", encoding="utf8").read() except FileNotFoundError: print(f"File not found: {src_filename}, skipping this file.") return f, {}, {}, {} @@ -395,34 +506,41 @@ def get_file_metrics(f, src_text_path, ocr_json_path, folder_hash): stats["filename"] = f return f, stats, actions, subs + def _worker(args): (f, src_text_path, ocr_json_path, folder_hash) = args return get_file_metrics(f, src_text_path, ocr_json_path, folder_hash) + def _get_sorted_text(ocr_json): if "lines" in ocr_json[0]: lines = ocr_json[0]["lines"] - sorted_lines = sorted(lines, key=lambda line : line["boundingBox"][0]["y"]) + sorted_lines = sorted(lines, key=lambda line: line["boundingBox"][0]["y"]) return " ".join([line["text"] for line in sorted_lines]) else: return ocr_json[0]["text"] + def substitution_dict_to_json(substitution_dict): """Converts substitution dict to list of tuples of (source_substring, target_substring, count) - + Args: substitution_dict ([type]): [description] """ - to_tuple = lambda x: [(k +(x[k],)) for k in x] + to_tuple = lambda x: [(k + (x[k],)) for k in x] # noqa: E731 out = {} for filename in substitution_dict: out[filename] = to_tuple(substitution_dict[filename]) return out + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("src", help="path to folder with text files.") - parser.add_argument("ocr", help="folder with ocr json. the filename must match the text filename prefixed by ocr_prefix.") + parser.add_argument( + "ocr", + help="folder with ocr json. the filename must match the text filename prefixed by ocr_prefix.", + ) parser.add_argument("--ocr_prefix", help="the prefix of the ocr files") parser.add_argument("--output", help="output names of metrics files") diff --git a/genalog/ocr/rest_client.py b/genalog/ocr/rest_client.py index 21d8e8c..dbcdce8 100644 --- a/genalog/ocr/rest_client.py +++ b/genalog/ocr/rest_client.py @@ -5,29 +5,39 @@ import requests import os import pkgutil import json -from dotenv import load_dotenv import time import sys from itertools import cycle from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME -API_VERSION = '?api-version=2019-05-06-Preview' +API_VERSION = "?api-version=2019-05-06-Preview" # 15 min schedule -SCHEDULE_INTERVAL= "PT15M" +SCHEDULE_INTERVAL = "PT15M" class GrokRestClient: """This is a REST client. It is a wrapper around the REST api for the Azure Search Service see: https://docs.microsoft.com/en-us/rest/api/searchservice/ - This class can be used to create an indexing pipeline and can be used to run and monitor + This class can be used to create an indexing pipeline and can be used to run and monitor ongoing indexers. The indexing pipeline can allow you to run batch OCR enrichment of documents. """ - def __init__(self, cognitive_service_key, search_service_key, search_service_name, skillset_name, - index_name, indexer_name, datasource_name, datasource_container_name, blob_account_name, blob_key, - projections_container_name = DEFAULT_PROJECTIONS_CONTAINER_NAME): + def __init__( + self, + cognitive_service_key, + search_service_key, + search_service_name, + skillset_name, + index_name, + indexer_name, + datasource_name, + datasource_container_name, + blob_account_name, + blob_key, + projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME, + ): """Creates the REST client Args: @@ -36,7 +46,7 @@ class GrokRestClient: search_service_name (str): name of the search service account skillset_name (str): name of the skillset index_name (str): name of the index - indexer_name (str): the name of indexer + indexer_name (str): the name of indexer datasource_name (str): the name to give the the attached blob storage source datasource_container_name (str): the container in the blob storage that host the files blob_account_name (str): blob storage account name that will host the documents to push though the pipeline @@ -70,8 +80,10 @@ class GrokRestClient: self.API_VERSION = API_VERSION - self.BLOB_CONNECTION_STRING = f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" \ + self.BLOB_CONNECTION_STRING = ( + f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net" + ) @staticmethod def create_from_env_var(): @@ -85,51 +97,68 @@ class GrokRestClient: DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"] BLOB_NAME = os.environ["BLOB_NAME"] BLOB_KEY = os.environ["BLOB_KEY"] - PROJECTIONS_CONTAINER_NAME = os.environ.get("PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME) + PROJECTIONS_CONTAINER_NAME = os.environ.get( + "PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME + ) - client = GrokRestClient(COGNITIVE_SERVICE_KEY, SEARCH_SERVICE_KEY, SEARCH_SERVICE_NAME, SKILLSET_NAME, INDEX_NAME, - INDEXER_NAME, DATASOURCE_NAME, DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY,projections_container_name=PROJECTIONS_CONTAINER_NAME) + client = GrokRestClient( + COGNITIVE_SERVICE_KEY, + SEARCH_SERVICE_KEY, + SEARCH_SERVICE_NAME, + SKILLSET_NAME, + INDEX_NAME, + INDEXER_NAME, + DATASOURCE_NAME, + DATASOURCE_CONTAINER_NAME, + BLOB_NAME, + BLOB_KEY, + projections_container_name=PROJECTIONS_CONTAINER_NAME, + ) return client def create_skillset(self): - """Adds a skillset that performs OCR on images - """ + """Adds a skillset that performs OCR on images""" headers = { - 'Content-Type': 'application/json', - 'api-key': self.SEARCH_SERVICE_KEY, + "Content-Type": "application/json", + "api-key": self.SEARCH_SERVICE_KEY, } - skillset_json = json.loads(pkgutil.get_data( - __name__, "templates/skillset.json")) + skillset_json = json.loads( + pkgutil.get_data(__name__, "templates/skillset.json") + ) skillset_json["name"] = self.SKILLSET_NAME skillset_json["cognitiveServices"]["key"] = self.COGNITIVE_SERVICE_KEY - knowledge_store_json = json.loads(pkgutil.get_data( - __name__, "templates/knowledge_store.json")) + knowledge_store_json = json.loads( + pkgutil.get_data(__name__, "templates/knowledge_store.json") + ) knowledge_store_json["storageConnectionString"] = self.BLOB_CONNECTION_STRING - knowledge_store_json["projections"][0]["objects"][0]["storageContainer"] = self.PROJECTIONS_CONTAINER_NAME + knowledge_store_json["projections"][0]["objects"][0][ + "storageContainer" + ] = self.PROJECTIONS_CONTAINER_NAME skillset_json["knowledgeStore"] = knowledge_store_json print(skillset_json) endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}" - r = requests.put(endpoint + self.API_VERSION, - json.dumps(skillset_json), headers=headers) + r = requests.put( + endpoint + self.API_VERSION, json.dumps(skillset_json), headers=headers + ) print("skillset response", r.text) r.raise_for_status() print("added skillset", self.SKILLSET_NAME, r) def create_datasource(self): - """Attaches the blob data store to the search service as a source for image documents - """ + """Attaches the blob data store to the search service as a source for image documents""" headers = { - 'Content-Type': 'application/json', - 'api-key': self.SEARCH_SERVICE_KEY, + "Content-Type": "application/json", + "api-key": self.SEARCH_SERVICE_KEY, } - datasource_json = json.loads(pkgutil.get_data( - __name__, "templates/datasource.json")) + datasource_json = json.loads( + pkgutil.get_data(__name__, "templates/datasource.json") + ) datasource_json["name"] = self.DATASOURCE_NAME datasource_json["credentials"]["connectionString"] = self.BLOB_CONNECTION_STRING datasource_json["type"] = "azureblob" @@ -137,27 +166,27 @@ class GrokRestClient: endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/datasources/{self.DATASOURCE_NAME}" - r = requests.put(endpoint + self.API_VERSION, - json.dumps(datasource_json), headers=headers) + r = requests.put( + endpoint + self.API_VERSION, json.dumps(datasource_json), headers=headers + ) print("datasource response", r.text) r.raise_for_status() print("added datasource", self.DATASOURCE_NAME, r) def create_index(self): - """Create an index with the layoutText column to store OCR output from the enrichment - """ + """Create an index with the layoutText column to store OCR output from the enrichment""" headers = { - 'Content-Type': 'application/json', - 'api-key': self.SEARCH_SERVICE_KEY, + "Content-Type": "application/json", + "api-key": self.SEARCH_SERVICE_KEY, } - index_json = json.loads(pkgutil.get_data( - __name__, "templates/index.json")) + index_json = json.loads(pkgutil.get_data(__name__, "templates/index.json")) index_json["name"] = self.INDEX_NAME endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexes/{self.INDEX_NAME}" - r = requests.put(endpoint + self.API_VERSION, - json.dumps(index_json), headers=headers) + r = requests.put( + endpoint + self.API_VERSION, json.dumps(index_json), headers=headers + ) print("index response", r.text) r.raise_for_status() print("created index", self.INDEX_NAME, r) @@ -167,24 +196,26 @@ class GrokRestClient: The enriched results are pushed to the index. """ headers = { - 'Content-Type': 'application/json', - 'api-key': self.SEARCH_SERVICE_KEY, + "Content-Type": "application/json", + "api-key": self.SEARCH_SERVICE_KEY, } - indexer_json = json.loads(pkgutil.get_data( - __name__, "templates/indexer.json")) + indexer_json = json.loads(pkgutil.get_data(__name__, "templates/indexer.json")) indexer_json["name"] = self.INDEXER_NAME indexer_json["skillsetName"] = self.SKILLSET_NAME indexer_json["targetIndexName"] = self.INDEX_NAME indexer_json["dataSourceName"] = self.DATASOURCE_NAME indexer_json["schedule"] = {"interval": SCHEDULE_INTERVAL} - indexer_json["parameters"]["configuration"]["excludedFileNameExtensions"] = extension_to_exclude + indexer_json["parameters"]["configuration"][ + "excludedFileNameExtensions" + ] = extension_to_exclude endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}" - r = requests.put(endpoint + self.API_VERSION, - json.dumps(indexer_json), headers=headers) + r = requests.put( + endpoint + self.API_VERSION, json.dumps(indexer_json), headers=headers + ) print("indexer response", r.text) r.raise_for_status() print("created indexer", self.INDEXER_NAME, r) @@ -203,14 +234,14 @@ class GrokRestClient: created """ headers = { - 'Content-Type': 'application/json', - 'api-key': self.SEARCH_SERVICE_KEY, + "Content-Type": "application/json", + "api-key": self.SEARCH_SERVICE_KEY, } endpoints = [ f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}", f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexes/{self.INDEX_NAME}", f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/datasources/{self.DATASOURCE_NAME}", - f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}" + f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/skillsets/{self.SKILLSET_NAME}", ] for endpoint in endpoints: @@ -220,8 +251,8 @@ class GrokRestClient: def run_indexer(self): headers = { - 'Content-Type': 'application/json', - 'api-key': self.SEARCH_SERVICE_KEY, + "Content-Type": "application/json", + "api-key": self.SEARCH_SERVICE_KEY, } endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}/run" @@ -235,23 +266,26 @@ class GrokRestClient: i = 0 while True: # attempt a call every 100 steps - if i % 100 == 0: + if i % 100 == 0: request_json = self.get_indexer_status() if request_json["status"] == "error": raise RuntimeError("Indexer failed") - if request_json["lastResult"] and not request_json["lastResult"]["status"] == "inProgress": + if ( + request_json["lastResult"] + and not request_json["lastResult"]["status"] == "inProgress" + ): print(request_json["lastResult"]["status"], self.INDEXER_NAME) return request_json - + sys.stdout.write(next(progress)) sys.stdout.flush() time.sleep(0.05) - i = (1+i) % 1000 # to avoid overflow + i = (1 + i) % 1000 # to avoid overflow def get_indexer_status(self): headers = { - 'Content-Type': 'application/json', - 'api-key': self.SEARCH_SERVICE_KEY, + "Content-Type": "application/json", + "api-key": self.SEARCH_SERVICE_KEY, } endpoint = f"https://{self.SEARCH_SERVICE_NAME}.search.windows.net/indexers/{self.INDEXER_NAME}/status" response = requests.get(endpoint + self.API_VERSION, headers=headers) @@ -259,5 +293,5 @@ class GrokRestClient: return response.json() def _checkArg(self, name, value): - if not(value): + if not (value): raise ValueError(f"argument {name} is not set") diff --git a/genalog/pipeline.py b/genalog/pipeline.py index 2ce2296..8a8534b 100644 --- a/genalog/pipeline.py +++ b/genalog/pipeline.py @@ -10,14 +10,22 @@ import timeit import cv2 import os + class ImageStateEncoder(JSONEncoder): def default(self, obj): if isinstance(obj, ImageState): return obj.value return JSONEncoder.default(self, obj) + class AnalogDocumentGeneration(object): - def __init__(self, template_path=None, styles=DEFAULT_STYLE_COMBINATION, degradations=[], resolution=300): + def __init__( + self, + template_path=None, + styles=DEFAULT_STYLE_COMBINATION, + degradations=[], + resolution=300, + ): self.doc_generator = DocumentGenerator(template_path=template_path) self.doc_generator.set_styles_to_generate(styles) self.degrader = Degrader(degradations) @@ -25,7 +33,7 @@ class AnalogDocumentGeneration(object): self.resolution = resolution def list_templates(self): - """ List available templates to generate documents from + """List available templates to generate documents from Returns: list -- a list of template names @@ -33,7 +41,7 @@ class AnalogDocumentGeneration(object): return self.doc_generator.template_list def generate_img(self, full_text_path, template, target_folder=None): - """ Generate synthetic images given the filepath of a text document + """Generate synthetic images given the filepath of a text document Arguments: full_text_path {str} -- full filepath of a text document (i.e /dataset/doc.txt) @@ -44,18 +52,18 @@ class AnalogDocumentGeneration(object): target_folder {str} -- folder path in which the generated images are stored (default: {None}) resolution {int} -- resolution in dpi (default: {300}) - """ - with open(full_text_path, "r",encoding="utf8") as f: # read file + """ + with open(full_text_path, "r", encoding="utf8") as f: # read file text = f.read() content = CompositeContent([text], [ContentType.PARAGRAPH]) - + generator = self.doc_generator.create_generator(content, [template]) # Generate the image doc = next(generator) src = doc.render_array(resolution=self.resolution, channel="GRAYSCALE") # Degrade the image dst = self.degrader.apply_effects(src) - + if not target_folder: # return the analog document as numpy.ndarray return dst @@ -67,45 +75,71 @@ class AnalogDocumentGeneration(object): cv2.imwrite(img_dst_path, dst) return + def _divide_batches(a, batch_size): - for i in range(0, len(a), batch_size): - yield a[i:i + batch_size] + for i in range(0, len(a), batch_size): + yield a[i: i + batch_size] + def _setup_folder(output_folder): os.makedirs(os.path.join(output_folder, "img"), exist_ok=True) - + + def batch_img_generate(args): input_files, output_folder, styles, degradations, template, resolution = args - generator = AnalogDocumentGeneration(styles=styles, degradations=degradations, resolution=resolution) + generator = AnalogDocumentGeneration( + styles=styles, degradations=degradations, resolution=resolution + ) for file in input_files: generator.generate_img(file, template, target_folder=output_folder) -def _set_batch_generate_args(file_batches, output_folder, styles, degradations, template, resolution): - return list(map( - lambda batch: - (batch, output_folder, styles, degradations, template, resolution), - file_batches - )) + +def _set_batch_generate_args( + file_batches, output_folder, styles, degradations, template, resolution +): + return list( + map( + lambda batch: ( + batch, + output_folder, + styles, + degradations, + template, + resolution, + ), + file_batches, + ) + ) + def generate_dataset_multiprocess( - input_text_files, output_folder, styles, degradations, template, - resolution=300, batch_size=25 - ): + input_text_files, + output_folder, + styles, + degradations, + template, + resolution=300, + batch_size=25, +): _setup_folder(output_folder) - print(f"Storing generated images in {output_folder}") + print(f"Storing generated images in {output_folder}") - batches = list(_divide_batches(input_text_files, batch_size)) - print(f"Splitting {len(input_text_files)} documents into {len(batches)} batches with size {batch_size}") + batches = list(_divide_batches(input_text_files, batch_size)) + print( + f"Splitting {len(input_text_files)} documents into {len(batches)} batches with size {batch_size}" + ) - batch_img_generate_args = _set_batch_generate_args(batches, output_folder, styles, degradations, template, resolution) + batch_img_generate_args = _set_batch_generate_args( + batches, output_folder, styles, degradations, template, resolution + ) # Default to the number of processors on the machine start_time = timeit.default_timer() with concurrent.futures.ProcessPoolExecutor() as executor: batch_iterator = executor.map(batch_img_generate, batch_img_generate_args) - for _ in tqdm(batch_iterator, total=len(batch_img_generate_args)): # wrapping tqdm for progress report + for _ in tqdm( + batch_iterator, total=len(batch_img_generate_args) + ): # wrapping tqdm for progress report pass elapsed = timeit.default_timer() - start_time print(f"Time to generate {len(input_text_files)} documents: {elapsed:.3f} sec") - - diff --git a/genalog/text/__init__.py b/genalog/text/__init__.py index 8b13789..e69de29 100644 --- a/genalog/text/__init__.py +++ b/genalog/text/__init__.py @@ -1 +0,0 @@ - diff --git a/genalog/text/alignment.py b/genalog/text/alignment.py index a357939..a12d253 100644 --- a/genalog/text/alignment.py +++ b/genalog/text/alignment.py @@ -7,36 +7,45 @@ MATCH_REWARD = 1 GAP_PENALTY = -0.5 GAP_EXT_PENALTY = -0.5 MISMATCH_PENALTY = -0.5 -GAP_CHAR = '@' +GAP_CHAR = "@" ONE_ALIGNMENT_ONLY = False -SPACE_MISMATCH_PENALTY = .1 +SPACE_MISMATCH_PENALTY = 0.1 + def _join_char_list(alignment_tuple): """ Post-process alignment results for unicode support """ gt_char_list, noise_char_list, score, start, end = alignment_tuple return "".join(gt_char_list), "".join(noise_char_list), score, start, end -def _align_seg(gt, noise, - match_reward=MATCH_REWARD, mismatch_pen=MISMATCH_PENALTY, - gap_pen=GAP_PENALTY, gap_ext_pen=GAP_EXT_PENALTY, space_mismatch_penalty=SPACE_MISMATCH_PENALTY, - gap_char=GAP_CHAR, one_alignment_only=ONE_ALIGNMENT_ONLY): - """ Wrapper function for Bio.pairwise2.align.globalms(), which + +def _align_seg( + gt, + noise, + match_reward=MATCH_REWARD, + mismatch_pen=MISMATCH_PENALTY, + gap_pen=GAP_PENALTY, + gap_ext_pen=GAP_EXT_PENALTY, + space_mismatch_penalty=SPACE_MISMATCH_PENALTY, + gap_char=GAP_CHAR, + one_alignment_only=ONE_ALIGNMENT_ONLY, +): + """Wrapper function for Bio.pairwise2.align.globalms(), which calls the sequence alignment algorithm (Needleman-Wunsch) Arguments: gt {str} -- a ground truth string noise {str} -- a string with ocr noise - + Keyword Arguments: match_reward {int} -- reward for matching characters (default: {MATCH_REWARD}) mismatch_pen {int} -- penalty for mistmatching characters (default: {MISMATCH_PENALTY}) gap_pen {int} -- penalty for creating a gap (default: {GAP_PENALTY}) gap_ext_pen {int} -- penalty for extending a gap (default: {GAP_EXT_PENALTY}) - + Returns: list -- a list of alignment tuples. Each alignment tuple - is one possible alignment candidate. - + is one possible alignment candidate. + A tuple (str, str, int, int, int) contains the following information: (aligned_gt, aligned_noise, alignment_score, alignment_start, alignment_end) @@ -47,22 +56,32 @@ def _align_seg(gt, noise, ... ] """ - def match_reward_fn (x,y) : + + def match_reward_fn(x, y): if x == y: return match_reward - elif x == " " or y == " ": + elif x == " " or y == " ": # mismatch of a character with a space get a stronger penalty - return mismatch_pen - space_mismatch_penalty + return mismatch_pen - space_mismatch_penalty else: return mismatch_pen - # NOTE: Work-around to enable support full Unicode character set - passing string as a list of characters - alignments = pairwise2.align.globalcs(list(gt), list(noise), match_reward_fn, - gap_pen, gap_ext_pen, gap_char=[gap_char], one_alignment_only=ONE_ALIGNMENT_ONLY) + + # NOTE: Work-around to enable support full Unicode character set - passing string as a list of characters + alignments = pairwise2.align.globalcs( + list(gt), + list(noise), + match_reward_fn, + gap_pen, + gap_ext_pen, + gap_char=[gap_char], + one_alignment_only=ONE_ALIGNMENT_ONLY, + ) # Alignment result is a list of char instead of string because of the work-around return list(map(_join_char_list, alignments)) + def _select_alignment_candidates(alignments, target_num_gt_tokens): - """ Return an alignment that contains the desired number + """Return an alignment that contains the desired number of ground truth tokens from a list of possible alignments Case Analysis: @@ -70,20 +89,20 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens): be guaranteed by the nature of text alignment. Invariant 2: we should not expect alignment introducing additional ground truth tokens. - However, in some cases, the alignment algorithm can - introduce a group of GAP_CHARs as a separate token at the + However, in some cases, the alignment algorithm can + introduce a group of GAP_CHARs as a separate token at the end of string, especially if there are lingering whitespaces. E.g: gt: "Boston is big " (num_tokens = 3) - noise: "B oston bi g" + noise: "B oston bi g" aligned_gt: "B@oston is big @" (num_tokens = 4) aligned_noise: "B oston @@@bi@ g" - Remember, the example above is just one out of the many possible alignment + Remember, the example above is just one out of the many possible alignment candidates, and we need to search for the one with the target number of gt_tokens E.g: gt: "Boston is big " (num_tokens = 3) - noise: "B oston bi g" + noise: "B oston bi g" aligned_gt: "B@oston is bi@g " (num_tokens = 3) aligned_noise: "B oston @@@bi g@" @@ -93,12 +112,12 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens): alignments {list} -- a list of alignment tuples as follows: [(str1, str2, alignment_score, alignment_start, alignment_end), (str1, str2, ...), ...] target_num_gt_tokens {int} -- the number of token in the aligned ground truth string should have - + Raises: - ValueError: raises this error if + ValueError: raises this error if 1. all the alignment candidates does NOT have the target number of tokens OR - 2. the aligned strings (str1 and str2) in the selected candidate are NOT EQUAL in length - + 2. the aligned strings (str1 and str2) in the selected candidate are NOT EQUAL in length + Returns: an alignment tuple (str, str, int, int, int) with following information: (str1, str2, alignment_score, alignment_start, alignment_end) @@ -111,29 +130,34 @@ def _select_alignment_candidates(alignments, target_num_gt_tokens): if num_aligned_gt_tokens == target_num_gt_tokens: # Invariant 1 if len(aligned_gt) != len(aligned_noise): - raise ValueError(f"Aligned strings are not equal in length: \naligned_gt: '{aligned_gt}'\naligned_noise '{aligned_noise}'\n") + raise ValueError( + f"Aligned strings are not equal in length: \naligned_gt: '{aligned_gt}'\naligned_noise '{aligned_noise}'\n" + ) # Returns the FIRST candidate that satisfies the invariant return alignment - - raise ValueError(f"No alignment candidates with {target_num_gt_tokens} tokens. Total candidates: {len(alignments)}") + + raise ValueError( + f"No alignment candidates with {target_num_gt_tokens} tokens. Total candidates: {len(alignments)}" + ) + def align(gt, noise, gap_char=GAP_CHAR): """Align two text segments via sequence alignment algorithm - NOTE: this algorithm is O(N^2) and is NOT efficient for longer text. + NOTE: this algorithm is O(N^2) and is NOT efficient for longer text. Please refer to `genalog.text.anchor` for faster alignment on longer strings. - + Arguments: gt {str} -- ground true text (should not contain GAP_CHAR) noise {str} -- str with ocr noise (should not contain GAP_CHAR) - + Keyword Arguments: gap_char {char} -- gap char used in alignment algorithm (default: {GAP_CHAR}) - + Returns: a tuple (str, str) of aligned ground truth and noise: (aligned_gt, aligned_noise) - + Invariants: The returned aligned strings will satisfy the following invariants: 1. len(aligned_gt) == len(aligned_noise) @@ -143,34 +167,39 @@ def align(gt, noise, gap_char=GAP_CHAR): aligned_gt: "N@ew @@York @is big@@" (num_tokens = 4) """ - if not gt and not noise: # Both inputs are empty string - return '', '' - elif not gt: # Either is empty - return gap_char*len(noise), noise + if not gt and not noise: # Both inputs are empty string + return "", "" + elif not gt: # Either is empty + return gap_char * len(noise), noise elif not noise: - return gt, gap_char*len(gt) + return gt, gap_char * len(gt) else: num_gt_tokens = len(tokenize(gt)) alignments = _align_seg(gt, noise, gap_char=gap_char) try: - aligned_gt, aligned_noise, _, _, _ = _select_alignment_candidates(alignments, num_gt_tokens) + aligned_gt, aligned_noise, _, _, _ = _select_alignment_candidates( + alignments, num_gt_tokens + ) except ValueError as e: - raise ValueError(f"Error with input strings '{gt}' and '{noise}': \n{str(e)}") + raise ValueError( + f"Error with input strings '{gt}' and '{noise}': \n{str(e)}" + ) return aligned_gt, aligned_noise + def _format_alignment(align1, align2): """Wrapper function for Bio.pairwise2.format_alignment() - + Arguments: align1 {str} -- alignment str align2 {str} -- second str for alignment Returns: - a string with formatted alignment. + a string with formatted alignment. '|' is for matching character '.' is for substition '-' indicates gap - + For example: " New York is big. @@ -178,44 +207,50 @@ def _format_alignment(align1, align2): New Yerk@is big. " """ - formatted_str = pairwise2.format_alignment(align1, align2, 0, 0, len(align1), full_sequences=True) + formatted_str = pairwise2.format_alignment( + align1, align2, 0, 0, len(align1), full_sequences=True + ) # Remove the "Score=0" from the str formatted_str_no_score = formatted_str.replace("\n Score=0", "") return formatted_str_no_score + def _find_token_start(s, index): - """Find the position of the start of token - + """Find the position of the start of token + Arguments: s {str} -- string to search in index {int} -- index to begin search from - + Returns: - position {int} of the first non-whitespace character - + Raises: ValueError: if input s is an empty string IndexError: if is out-of-bound index """ max_index = len(s) - 1 - if len(s) == 0: raise ValueError("Cannot search in an empty string") - if index > max_index: raise IndexError(f"Out-of-bound index: {index} in string: {s}") + if len(s) == 0: + raise ValueError("Cannot search in an empty string") + if index > max_index: + raise IndexError(f"Out-of-bound index: {index} in string: {s}") while index < max_index and _is_spacing(s[index]): index += 1 return index + def _find_token_end(s, index): """Find the position of the end of a token - + *** Important *** This method ALWAYS return index within the bound of the string. - So, for single character string (eg. "c"), it will return 0. - + So, for single character string (eg. "c"), it will return 0. + Arguments: s {str} -- string to search in index {int} -- index to begin search from - + Returns: - position {int} of the first non-whitespace character @@ -224,40 +259,44 @@ def _find_token_end(s, index): IndexError: if is out-of-bound index """ max_index = len(s) - 1 - if len(s) == 0: raise ValueError("Cannot search in an empty string") - if index > max_index: raise IndexError(f"Out-of-bound index: {index} in string: {s}") + if len(s) == 0: + raise ValueError("Cannot search in an empty string") + if index > max_index: + raise IndexError(f"Out-of-bound index: {index} in string: {s}") while index < max_index and not _is_spacing(s[index]): index += 1 return index + def _find_next_token(s, start): - """ Return the start and end index of a token in a string - + """Return the start and end index of a token in a string + *** Important *** This method ALWAYS return indices within the bound of the string. So, for single character string (eg. "c"), it will return (0,0) Arguments: s {str} -- the string to search token in - start {int} -- the starting index to start search in - + start {int} -- the starting index to start search in + Returns: a tuple of (int, int) responding to the start and end indices of - a token in the given s. + a token in the given s. """ token_start = _find_token_start(s, start) token_end = _find_token_end(s, token_start) return token_start, token_end + def _is_valid_token(token, gap_char=GAP_CHAR): - """ Returns true if token is valid (i.e. compose of non-gap characters) - Invalid tokens are + """Returns true if token is valid (i.e. compose of non-gap characters) + Invalid tokens are 1. multiple occurrences of the GAP_CHAR (e.g. '@@@') 2. empty string ("") 3. string with spaces (" ") - **Important: this method expects one token and not multiple space-separated tokens + **Important: this method expects one token and not multiple space-separated tokens Arguments: token {str} -- input string token @@ -269,12 +308,15 @@ def _is_valid_token(token, gap_char=GAP_CHAR): bool-- True if is a valid token, false otherwise """ # Matches multiples of 'gap_char' that are padded with whitespace characters on either end - INVALID_TOKEN_REGEX = rf'^\s*{re.escape(gap_char)}*\s*$' # Escape special regex chars + INVALID_TOKEN_REGEX = ( + rf"^\s*{re.escape(gap_char)}*\s*$" # Escape special regex chars + ) return not re.match(INVALID_TOKEN_REGEX, token) + def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR): """Parse alignment to pair ground truth tokens with noise tokens - + Case 1: Case 2: Case 3: Case 4: Case 5: one-to-many many-to-one many-to-many missing tokens one-to-one (Case 1&2 Composite) @@ -295,10 +337,10 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR): Returns: a tuple (list, list) of two 2D int arrays as follows: - + (gt_to_noise_mapping, noise_to_gt_mapping) - - where each array defines the mapping between aligned gt tokens + + where each array defines the mapping between aligned gt tokens to noise tokens and vice versa. For example: @@ -316,44 +358,44 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR): # tk_start_gt=12 tk_index_gt = 4 total_tokens = 4 # | tk_end_gt=15 tk_index_noise = 3 total_tokens = 3 # | | - # "New York is big " gt_token:big gt_to_noise_mapping: [[0][0][][2]] + # "New York is big " gt_token:big gt_to_noise_mapping: [[0][0][][2]] # "New@york @@ big " noise_token:big noise_to_gt_mapping: [[0][][3]] # | | # | tk_end_noise=15 INVALID TOKENS: @* # tk_start_noise=12 # 1. Initialization: - #1. IMPORTANT: add whitespace padding (' ') to both end of aligned_gt and aligned_noise to avoid overflow - #2. find the first gt_token and the first noise_token - #3. tk_index_gt = tk_index_noise = 0 + # 1. IMPORTANT: add whitespace padding (' ') to both end of aligned_gt and aligned_noise to avoid overflow + # 2. find the first gt_token and the first noise_token + # 3. tk_index_gt = tk_index_noise = 0 # 2. While tk_index_gt < total_tk_gt and tk_index_noise < total_tk_noise: - #1. if tk_end_gt == tk_end_noise (1-1 case) - #1. check if the two tokens are valid - #1. if so, register tokens in mapping - #2. find next gt_token token and next noise_token - #3. tk_index_gt ++, tk_index_noise ++ - #3. if tk_end_gt < tk_end_noise (many-1 case) - #1. while tk_end_gt < tk_end_noise - #1. check if gt_token and noise_token are BOTH valid - #1. if so register tokens in mapping - #2. find next gt_token - #3. tk_index_gt ++ - #4. if tk_end_gt > tk_end_noise (1-many case) - #1. while tk_end_gt > tk_end_noise - #1. check if gt_token and noise_token are BOTH valid - #1. if so register tokens in mapping - #2. find next noise token - #3. tk_index_noise ++ + # 1. if tk_end_gt == tk_end_noise (1-1 case) + # 1. check if the two tokens are valid + # 1. if so, register tokens in mapping + # 2. find next gt_token token and next noise_token + # 3. tk_index_gt ++, tk_index_noise ++ + # 3. if tk_end_gt < tk_end_noise (many-1 case) + # 1. while tk_end_gt < tk_end_noise + # 1. check if gt_token and noise_token are BOTH valid + # 1. if so register tokens in mapping + # 2. find next gt_token + # 3. tk_index_gt ++ + # 4. if tk_end_gt > tk_end_noise (1-many case) + # 1. while tk_end_gt > tk_end_noise + # 1. check if gt_token and noise_token are BOTH valid + # 1. if so register tokens in mapping + # 2. find next noise token + # 3. tk_index_noise ++ # sanity check if len(aligned_gt) != len(aligned_noise): - raise ValueError("Aligned strings are not equal in length") - + raise ValueError("Aligned strings are not equal in length") + total_gt_tokens = len(tokenize(aligned_gt)) total_noise_tokens = len(tokenize(aligned_noise)) # Initialization - aligned_gt += ' ' # add whitespace padding to prevent ptr overflow - aligned_noise += ' ' # add whitespace padding to prevent ptr overflow + aligned_gt += " " # add whitespace padding to prevent ptr overflow + aligned_noise += " " # add whitespace padding to prevent ptr overflow tk_index_gt = tk_index_noise = 0 tk_start_gt, tk_end_gt = _find_next_token(aligned_gt, 0) tk_start_noise, tk_end_noise = _find_next_token(aligned_noise, 0) @@ -364,8 +406,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR): # If both tokens are aligned (one-to-one case) if tk_end_gt == tk_end_noise: # if both gt_token and noise_token are valid (missing token case) - if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \ - and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char): + if _is_valid_token( + aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char + ) and _is_valid_token( + aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char + ): # register the index of these tokens in the gt_to_noise_mapping index_row = gt_to_noise_mapping[tk_index_gt] index_row.append(tk_index_noise) @@ -381,8 +426,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR): elif tk_end_gt < tk_end_noise: while tk_end_gt < tk_end_noise: # if both gt_token and noise_token are valid (missing token case) - if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \ - and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char): + if _is_valid_token( + aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char + ) and _is_valid_token( + aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char + ): # register the index of these tokens in the gt_to_noise_mapping index_row = gt_to_noise_mapping[tk_index_gt] index_row.append(tk_index_noise) @@ -397,8 +445,11 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR): else: while tk_end_gt > tk_end_noise: # if both gt_token and noise_token are valid (missing token case) - if _is_valid_token(aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char) \ - and _is_valid_token(aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char): + if _is_valid_token( + aligned_gt[tk_start_gt:tk_end_gt], gap_char=gap_char + ) and _is_valid_token( + aligned_noise[tk_start_noise:tk_end_noise], gap_char=gap_char + ): # register the index of these token in the gt_to_noise mapping index_row = gt_to_noise_mapping[tk_index_gt] index_row.append(tk_index_noise) @@ -406,8 +457,10 @@ def parse_alignment(aligned_gt, aligned_noise, gap_char=GAP_CHAR): index_row = noise_to_gt_mapping[tk_index_noise] index_row.append(tk_index_gt) # Find the next gt_token - tk_start_noise, tk_end_noise = _find_next_token(aligned_noise, tk_end_noise) + tk_start_noise, tk_end_noise = _find_next_token( + aligned_noise, tk_end_noise + ) # Increment index tk_index_noise += 1 - return gt_to_noise_mapping, noise_to_gt_mapping \ No newline at end of file + return gt_to_noise_mapping, noise_to_gt_mapping diff --git a/genalog/text/anchor.py b/genalog/text/anchor.py index 23a8de0..9e6e0c6 100644 --- a/genalog/text/anchor.py +++ b/genalog/text/anchor.py @@ -1,15 +1,15 @@ """ - Baseline alignment algorithm is slow on long documents. - The idea is to break down the longer text into smaller fragments - for quicker alignment on individual pieces. We refer "anchor words" + Baseline alignment algorithm is slow on long documents. + The idea is to break down the longer text into smaller fragments + for quicker alignment on individual pieces. We refer "anchor words" as these points of breakage. The bulk of this algorithm is to identify these "anchor words". - This is an re-implementation of the algorithm in this paper + This is an re-implementation of the algorithm in this paper "A Fast Alignment Scheme for Automatic OCR Evaluation of Books" (https://ieeexplore.ieee.org/document/6065412) - + We rely on `genalog.text.alignment` to align the subsequences. """ import itertools @@ -18,19 +18,20 @@ from genalog.text import preprocess, alignment from genalog.text.lcs import LCS from genalog.text.alignment import GAP_CHAR -# The recursively portion of the algorithm will run on -# segments longer than this value to find anchor points in +# The recursively portion of the algorithm will run on +# segments longer than this value to find anchor points in # the longer segment (to break it up further). -MAX_ALIGN_SEGMENT_LENGTH = 100 # in characters length +MAX_ALIGN_SEGMENT_LENGTH = 100 # in characters length + def get_unique_words(tokens, case_sensitive=False): - """ Get a set of unique words from a Counter dictionary of word occurrences - + """Get a set of unique words from a Counter dictionary of word occurrences + Arguments: d {dict} -- a Counter dictionary of word occurrences - + Keyword Arguments: - case_sensitive {bool} -- whether unique words are case sensitive + case_sensitive {bool} -- whether unique words are case sensitive (default: {False}) Returns: @@ -38,14 +39,15 @@ def get_unique_words(tokens, case_sensitive=False): """ if case_sensitive: word_count = Counter(tokens) - return {word for word, count in word_count.items() if count < 2 } + return {word for word, count in word_count.items() if count < 2} else: tokens_lowercase = [tk.lower() for tk in tokens] word_count = Counter(tokens_lowercase) - return {tk for tk in tokens if word_count[tk.lower()] < 2 } + return {tk for tk in tokens if word_count[tk.lower()] < 2} + def segment_len(tokens): - """ Get length of the segment + """Get length of the segment Arguments: segment {list} -- a list of tokens @@ -54,8 +56,9 @@ def segment_len(tokens): """ return sum(map(len, tokens)) + def get_word_map(unique_words, src_tokens): - """ Arrange the set of unique words by the order they original appear in the text + """Arrange the set of unique words by the order they original appear in the text Arguments: unique_words {set} -- a set of unique words @@ -70,18 +73,19 @@ def get_word_map(unique_words, src_tokens): # Find the indices of the unique words in the source text unique_word_indices = map(src_tokens.index, unique_words) word_map = list(zip(unique_words, unique_word_indices)) - word_map.sort(key = lambda x: x[1]) # Re-arrange order by the index + word_map.sort(key=lambda x: x[1]) # Re-arrange order by the index return word_map + def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2): - """ Find the location of anchor words in both the gt and ocr text. - Anchor words are location where we can split both the source gt + """Find the location of anchor words in both the gt and ocr text. + Anchor words are location where we can split both the source gt and ocr text into smaller text fragment for faster alignment. Arguments: gt_tokens {list} -- a list of ground truth tokens ocr_tokens {list} -- a list of tokens from OCR'ed document - + Keyword Arguments: min_anchor_len {int} -- minimum len of the anchor word (default: {2}) @@ -91,9 +95,9 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2): (anchor_map_gt, anchor_map_ocr) 1. `anchor_map_gt` is a `word_map` that locates all the anchor words in the gt tokens 2. `anchor_map_gt` is a `word_map` that locates all the anchor words in the ocr tokens - + For example: - Input: + Input: gt_tokens: ["b", "a", "c"] ocr_tokens: ["c", "b", "a"] Ourput: @@ -113,15 +117,15 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2): unique_word_map_ocr = get_word_map(unique_words_common, ocr_tokens) # Unzip to get the ordered unique_words ordered_unique_words_gt, _ = zip(*unique_word_map_gt) - ordered_unique_words_ocr, _ = zip(*unique_word_map_ocr) + ordered_unique_words_ocr, _ = zip(*unique_word_map_ocr) # Join words into a space-separated string for finding LCS - unique_words_gt_str = preprocess.join_tokens(ordered_unique_words_gt) + unique_words_gt_str = preprocess.join_tokens(ordered_unique_words_gt) unique_words_ocr_str = preprocess.join_tokens(ordered_unique_words_ocr) # 3. Find the LCS between the two ordered list of unique words lcs = LCS(unique_words_gt_str, unique_words_ocr_str) lcs_str = lcs.get_str() - + # 4. Break up the LCS string into tokens lcs_words = set(preprocess.tokenize(lcs_str)) @@ -129,19 +133,30 @@ def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2): anchor_words = lcs_words.intersection(unique_words_common) # 6. Filter the unique words to keep the anchor words ONLY - anchor_map_gt = list(filter( - # This is a list of (unique_word, unique_word_index) - lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_gt - )) - anchor_map_ocr = list(filter( - lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_ocr - )) + anchor_map_gt = list( + filter( + # This is a list of (unique_word, unique_word_index) + lambda word_coordinate: word_coordinate[0] in anchor_words, + unique_word_map_gt, + ) + ) + anchor_map_ocr = list( + filter( + lambda word_coordinate: word_coordinate[0] in anchor_words, + unique_word_map_ocr, + ) + ) return anchor_map_gt, anchor_map_ocr -def find_anchor_recur(gt_tokens, ocr_tokens, - start_pos_gt=0, start_pos_ocr=0, - max_seg_length=MAX_ALIGN_SEGMENT_LENGTH): - """ Recursively find anchor positions in the gt and ocr text + +def find_anchor_recur( + gt_tokens, + ocr_tokens, + start_pos_gt=0, + start_pos_ocr=0, + max_seg_length=MAX_ALIGN_SEGMENT_LENGTH, +): + """Recursively find anchor positions in the gt and ocr text Arguments: gt_tokens {list} -- a list of ground truth tokens @@ -150,12 +165,12 @@ def find_anchor_recur(gt_tokens, ocr_tokens, Keyword Arguments: start_pos {int} -- a constant to add to all the resulting indices (default: {0}) - max_seg_length {int} -- trigger recursion if any text segment is larger than this + max_seg_length {int} -- trigger recursion if any text segment is larger than this (default: {MAX_ALIGN_SEGMENT_LENGTH}) Raises: ValueError: when there different number of anchor points in gt and ocr. - + Returns: tuple -- two lists of token indices: (output_gt_anchors, output_ocr_anchors) @@ -165,7 +180,7 @@ def find_anchor_recur(gt_tokens, ocr_tokens, # 1. Try to find anchor words anchor_word_map_gt, anchor_word_map_ocr = get_anchor_map(gt_tokens, ocr_tokens) - # 2. Check invariant + # 2. Check invariant if len(anchor_word_map_gt) != len(anchor_word_map_ocr): raise ValueError("Unequal number of anchor points across gt and ocr string") # Return empty if no anchor word found @@ -182,17 +197,26 @@ def find_anchor_recur(gt_tokens, ocr_tokens, seg_start_gt = list(itertools.chain([0], anchor_indices_gt)) seg_start_ocr = list(itertools.chain([0], anchor_indices_ocr)) start_n_end_gt = zip(seg_start_gt, itertools.chain(anchor_indices_gt, [None])) - start_n_end_ocr = zip(seg_start_ocr, itertools.chain(anchor_indices_ocr, [None])) + start_n_end_ocr = zip(seg_start_ocr, itertools.chain(anchor_indices_ocr, [None])) gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt] ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr] - # 4. Loop through each segment - for gt_seg, ocr_seg, gt_start, ocr_start in zip(gt_segments, ocr_segments, seg_start_gt, seg_start_ocr): - if segment_len(gt_seg) > max_seg_length or segment_len(ocr_seg) > max_seg_length: + # 4. Loop through each segment + for gt_seg, ocr_seg, gt_start, ocr_start in zip( + gt_segments, ocr_segments, seg_start_gt, seg_start_ocr + ): + if ( + segment_len(gt_seg) > max_seg_length + or segment_len(ocr_seg) > max_seg_length + ): # recur on the segment in between the two anchors. # We assume the first token in the segment is an anchor word - gt_anchors, ocr_anchors = find_anchor_recur(gt_seg[1:], ocr_seg[1:], - start_pos_gt=gt_start + 1, start_pos_ocr=ocr_start + 1, - max_seg_length=max_seg_length) + gt_anchors, ocr_anchors = find_anchor_recur( + gt_seg[1:], + ocr_seg[1:], + start_pos_gt=gt_start + 1, + start_pos_ocr=ocr_start + 1, + max_seg_length=max_seg_length, + ) # shift the token indices # (these are indices of a subsequence and does not reflect true position in the source sequence) gt_anchors = set(map(lambda x: x + start_pos_gt, gt_anchors)) @@ -200,12 +224,13 @@ def find_anchor_recur(gt_tokens, ocr_tokens, # merge recursion results output_gt_anchors = output_gt_anchors.union(gt_anchors) output_ocr_anchors = output_ocr_anchors.union(ocr_anchors) - + return sorted(output_gt_anchors), sorted(output_ocr_anchors) + def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_LENGTH): """A faster alignment scheme of two text segments. This method first - breaks the strings into smaller segments with anchor words. + breaks the strings into smaller segments with anchor words. Then these smaller segments are aligned. NOTE: this function shares the same contract as `genalog.text.alignment.align()` @@ -222,7 +247,7 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_ "The planet Mar, " "The plamet Maris, " - + "I scarcely need " "I scacely neee " @@ -230,7 +255,7 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_ "remind te reader," And run sequence alignment on each pair. - + Arguments: gt {str} -- ground truth text noise {str} -- text with ocr noise @@ -248,11 +273,17 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_ ocr_tokens = preprocess.tokenize(ocr) # 1. Find anchor positions - gt_anchors, ocr_anchors = find_anchor_recur(gt_tokens, ocr_tokens, max_seg_length=max_seg_length) + gt_anchors, ocr_anchors = find_anchor_recur( + gt_tokens, ocr_tokens, max_seg_length=max_seg_length + ) # 2. Split into segments - start_n_end_gt = zip(itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None])) - start_n_end_ocr = zip(itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None])) + start_n_end_gt = zip( + itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None]) + ) + start_n_end_ocr = zip( + itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None]) + ) gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt] ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr] @@ -263,13 +294,15 @@ def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_ gt_segment = preprocess.join_tokens(gt_segment) noisy_segment = preprocess.join_tokens(noisy_segment) # Run alignment algorithm - aligned_seg_gt, aligned_seg_ocr = alignment.align(gt_segment, noisy_segment, gap_char=gap_char) - if aligned_seg_gt and aligned_seg_ocr: # if not empty string "" + aligned_seg_gt, aligned_seg_ocr = alignment.align( + gt_segment, noisy_segment, gap_char=gap_char + ) + if aligned_seg_gt and aligned_seg_ocr: # if not empty string "" aligned_segments_gt.append(aligned_seg_gt) aligned_segments_ocr.append(aligned_seg_ocr) - - # Stitch all segments together - aligned_gt = ' '.join(aligned_segments_gt) - aligned_noise = ' '.join(aligned_segments_ocr) - return aligned_gt, aligned_noise \ No newline at end of file + # Stitch all segments together + aligned_gt = " ".join(aligned_segments_gt) + aligned_noise = " ".join(aligned_segments_ocr) + + return aligned_gt, aligned_noise diff --git a/genalog/text/conll_format.py b/genalog/text/conll_format.py index 58478d7..005e6e3 100644 --- a/genalog/text/conll_format.py +++ b/genalog/text/conll_format.py @@ -4,7 +4,7 @@ usage: conll_format.py [-h] [--train_subset] [--test_subset] [--gt_folder GT_FOLDER] - base_folder degraded_folder + base_folder degraded_folder positional argument: base_folder base directory containing the collection of dataset @@ -18,20 +18,20 @@ optional arguments: optional arguments: -h, --help show this help message and exit -example usage +example usage (to run for specified degradation of the dataset on both train and test) - python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all' + python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all' (to run for specified degradation of the dataset and ground truth) - python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all' + python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all' --gt_folder='shared' (to run for specified degradation of the dataset on only test subset) - python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all' + python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all' --test_subset (to run for specified degradation of the dataset on only train subset) - python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all' + python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all' --train_subset """ import itertools @@ -39,20 +39,20 @@ import difflib import argparse import json import os -import sys import timeit import concurrent.futures from tqdm import tqdm -from genalog.text import ner_label, ner_label, alignment +from genalog.text import ner_label, alignment EMPTY_SENTENCE_SENTINEL = "<<<>>>" EMPTY_SENTENCE_SENTINEL_NER_LABEL = "O" + def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens): """ propagate_labels_sentences propagates clean labels for clean tokens to ocr tokens and splits ocr tokens into sentences - + Parameters ---------- clean_tokens : list @@ -63,7 +63,7 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_ list of sentences (each sentence is a list of tokens) ocr_tokens : list list of tokens in ocr text - + Returns ------- list, list @@ -73,9 +73,18 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_ # Ensure equal number of tokens in both clean_tokens and clean_sentences merged_sentences = list(itertools.chain(*clean_sentences)) if merged_sentences != clean_tokens: - delta = "\n".join(difflib.unified_diff(merged_sentences, clean_tokens, fromfile='merged_clean_sentences', tofile="clean_tokens")) - raise ValueError(f"Inconsistent tokens. " + - f"Delta between clean_text and clean_labels:\n{delta}") + delta = "\n".join( + difflib.unified_diff( + merged_sentences, + clean_tokens, + fromfile="merged_clean_sentences", + tofile="clean_tokens", + ) + ) + raise ValueError( + "Inconsistent tokens. " + + f"Delta between clean_text and clean_labels:\n{delta}" + ) # Ensure that there's OCR result if len(ocr_tokens) == 0: raise ValueError("Empty OCR tokens.") @@ -83,14 +92,18 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_ raise ValueError("Empty clean tokens.") # 1. Propagate labels + alig - ocr_labels, aligned_clean, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr(clean_labels, clean_tokens, ocr_tokens) + ocr_labels, aligned_clean, aligned_ocr, gap_char = ner_label.propagate_label_to_ocr( + clean_labels, clean_tokens, ocr_tokens + ) # 2. Parse alignment to get mapping - gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(aligned_clean, aligned_ocr, gap_char=gap_char) + gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment( + aligned_clean, aligned_ocr, gap_char=gap_char + ) # 3. Find sentence breaks in clean text sentences gt_to_ocr_mapping_is_empty = [len(mapping) == 0 for mapping in gt_to_ocr_mapping] - gt_to_ocr_mapping_is_empty_reverse = gt_to_ocr_mapping_is_empty[::-1] + sentence_index = [] sentence_token_counts = 0 for sentence in clean_sentences: @@ -108,12 +121,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_ ocr_start = 0 # if gt token at sentence break is not mapped to any ocr token elif len(gt_to_ocr_mapping[gt_start]) < 1: - try: # finding next gt token that is mapped to an ocr token + try: # finding next gt token that is mapped to an ocr token new_gt_start = gt_to_ocr_mapping_is_empty.index(False, gt_start) ocr_start = gt_to_ocr_mapping[new_gt_start][0] # If no valid token mapping in the remaining gt tokens except ValueError: - ocr_start = len(ocr_tokens) # use the last ocr token + ocr_start = len(ocr_tokens) # use the last ocr token else: ocr_start = gt_to_ocr_mapping[gt_start][0] @@ -121,12 +134,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_ if gt_end >= len(gt_to_ocr_mapping): ocr_end = len(ocr_tokens) elif len(gt_to_ocr_mapping[gt_end]) < 1: - try: # finding next gt token that is mapped to an ocr token + try: # finding next gt token that is mapped to an ocr token new_gt_end = gt_to_ocr_mapping_is_empty.index(False, gt_end) ocr_end = gt_to_ocr_mapping[new_gt_end][0] # If no valid token mapping in the remaining gt tokens except ValueError: - ocr_end = len(ocr_tokens) # use the last ocr token + ocr_end = len(ocr_tokens) # use the last ocr token else: ocr_end = gt_to_ocr_mapping[gt_end][0] ocr_sentence = ocr_tokens[ocr_start:ocr_end] @@ -135,11 +148,12 @@ def propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_ ocr_labels_sentences.append(ocr_sentence_labels) return ocr_text_sentences, ocr_labels_sentences + def get_sentences_from_iob_format(iob_format_str): sentences = [] sentence = [] for line in iob_format_str: - if line.strip() == '': # if line is empty (sentence separator) + if line.strip() == "": # if line is empty (sentence separator) sentences.append(sentence) sentence = [] else: @@ -149,47 +163,81 @@ def get_sentences_from_iob_format(iob_format_str): # filter any empty sentences return list(filter(lambda sentence: len(sentence) > 0, sentences)) -def propagate_labels_sentence_single_file(arg): - clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext, input_filename = arg - clean_labels_file = os.path.join(clean_labels_dir, input_filename).replace(clean_label_ext, ".txt") +def propagate_labels_sentence_single_file(arg): + ( + clean_labels_dir, + output_text_dir, + output_labels_dir, + clean_label_ext, + input_filename, + ) = arg + + clean_labels_file = os.path.join(clean_labels_dir, input_filename).replace( + clean_label_ext, ".txt" + ) ocr_text_file = os.path.join(output_text_dir, input_filename) ocr_labels_file = os.path.join(output_labels_dir, input_filename) if not os.path.exists(clean_labels_file): - print(f"Warning: missing clean label file '{clean_labels_file}'. Please check file corruption. Skipping this file index...") + print( + f"Warning: missing clean label file '{clean_labels_file}'. Please check file corruption. Skipping this file index..." + ) return elif not os.path.exists(ocr_text_file): - print(f"Warning: missing ocr text file '{ocr_text_file}'. Please check file corruption. Skipping this file index...") + print( + f"Warning: missing ocr text file '{ocr_text_file}'. Please check file corruption. Skipping this file index..." + ) return else: - with open(clean_labels_file, 'r', encoding='utf-8') as clf: + with open(clean_labels_file, "r", encoding="utf-8") as clf: tokens_labels_str = clf.readlines() - clean_tokens = [line.split()[0].strip() for line in tokens_labels_str if len(line.split()) == 2] - clean_labels = [line.split()[1].strip() for line in tokens_labels_str if len(line.split()) == 2] + clean_tokens = [ + line.split()[0].strip() + for line in tokens_labels_str + if len(line.split()) == 2 + ] + clean_labels = [ + line.split()[1].strip() + for line in tokens_labels_str + if len(line.split()) == 2 + ] clean_sentences = get_sentences_from_iob_format(tokens_labels_str) # read ocr tokens - with open(ocr_text_file, 'r', encoding='utf-8') as otf: - ocr_text_str = ' '.join(otf.readlines()) - ocr_tokens = [token.strip() for token in ocr_text_str.split()] # already tokenized in data + with open(ocr_text_file, "r", encoding="utf-8") as otf: + ocr_text_str = " ".join(otf.readlines()) + ocr_tokens = [ + token.strip() for token in ocr_text_str.split() + ] # already tokenized in data try: - ocr_tokens_sentences, ocr_labels_sentences = propagate_labels_sentences(clean_tokens, clean_labels, clean_sentences, ocr_tokens) + ocr_tokens_sentences, ocr_labels_sentences = propagate_labels_sentences( + clean_tokens, clean_labels, clean_sentences, ocr_tokens + ) except Exception as e: - print(f"\nWarning: error processing '{input_filename}': {str(e)}.\nSkipping this file...") + print( + f"\nWarning: error processing '{input_filename}': {str(e)}.\nSkipping this file..." + ) return - # Write result to file - with open(ocr_labels_file, 'w', encoding="utf-8") as olf: - for ocr_tokens, ocr_labels in zip(ocr_tokens_sentences, ocr_labels_sentences): - if len(ocr_tokens) == 0: # if empty OCR sentences - olf.write(f'{EMPTY_SENTENCE_SENTINEL}\t{EMPTY_SENTENCE_SENTINEL_NER_LABEL}\n') - else: - for token, label in zip(ocr_tokens, ocr_labels): - olf.write(f"{token}\t{label}\n") - olf.write('\n') + # Write result to file + with open(ocr_labels_file, "w", encoding="utf-8") as olf: + for ocr_tokens, ocr_labels in zip( + ocr_tokens_sentences, ocr_labels_sentences + ): + if len(ocr_tokens) == 0: # if empty OCR sentences + olf.write( + f"{EMPTY_SENTENCE_SENTINEL}\t{EMPTY_SENTENCE_SENTINEL_NER_LABEL}\n" + ) + else: + for token, label in zip(ocr_tokens, ocr_labels): + olf.write(f"{token}\t{label}\n") + olf.write("\n") -def propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext): + +def propagate_labels_sentences_multiprocess( + clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext +): """ propagate_labels_sentences_all_files propagates labels and sentences for all files in dataset - + Parameters ---------- clean_labels_dir : str @@ -204,20 +252,28 @@ def propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, o file extension of the clean_labels """ clean_label_files = os.listdir(clean_labels_dir) - args = list(map( - lambda clean_label_filename: - (clean_labels_dir, output_text_dir, output_labels_dir, clean_label_ext, clean_label_filename), - clean_label_files - )) + args = list( + map( + lambda clean_label_filename: ( + clean_labels_dir, + output_text_dir, + output_labels_dir, + clean_label_ext, + clean_label_filename, + ), + clean_label_files, + ) + ) with concurrent.futures.ProcessPoolExecutor() as executor: iterator = executor.map(propagate_labels_sentence_single_file, args) - for _ in tqdm(iterator, total=len(args)): # wrapping tqdm for progress report + for _ in tqdm(iterator, total=len(args)): # wrapping tqdm for progress report pass + def extract_ocr_text(input_file, output_file): """ extract_ocr_text from GROK json - + Parameters ---------- input_file : str @@ -227,16 +283,17 @@ def extract_ocr_text(input_file, output_file): """ out_dir = os.path.dirname(output_file) in_file_name = os.path.basename(input_file) - file_pre = in_file_name.split('_')[-1].split('.')[0] - output_file_name = '{}.txt'.format(file_pre) + file_pre = in_file_name.split("_")[-1].split(".")[0] + output_file_name = "{}.txt".format(file_pre) output_file = os.path.join(out_dir, output_file_name) - with open(input_file, 'r', encoding='utf-8') as fin: + with open(input_file, "r", encoding="utf-8") as fin: json_data = json.load(fin) json_dict = json_data[0] - text = json_dict['text'] - with open(output_file, 'wb') as fout: + text = json_dict["text"] + with open(output_file, "wb") as fout: fout.write(text.encode("utf-8")) + def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext): """ check_n_sentences prints file name if number of sentences is different in clean and OCR files @@ -253,49 +310,54 @@ def check_n_sentences(clean_labels_dir, output_labels_dir, clean_label_ext): text_files = os.listdir(output_labels_dir) skip_files = [] for text_filename in tqdm(text_files): - clean_labels_file = os.path.join(clean_labels_dir, text_filename).replace(".txt", clean_label_ext) + clean_labels_file = os.path.join(clean_labels_dir, text_filename).replace( + ".txt", clean_label_ext + ) ocr_labels_file = os.path.join(output_labels_dir, text_filename) remove_first_line(clean_labels_file, clean_labels_file) remove_first_line(ocr_labels_file, ocr_labels_file) remove_last_line(clean_labels_file, clean_labels_file) remove_last_line(ocr_labels_file, ocr_labels_file) - with open(clean_labels_file, 'r', encoding='utf-8') as lf: + with open(clean_labels_file, "r", encoding="utf-8") as lf: clean_tokens_labels = lf.readlines() - with open(ocr_labels_file, 'r', encoding='utf-8') as of: + with open(ocr_labels_file, "r", encoding="utf-8") as of: ocr_tokens_labels = of.readlines() error = False n_clean_sentences = 0 nl = False for line in clean_tokens_labels: - if line == '\n': + if line == "\n": if nl is True: error = True else: - nl=True + nl = True n_clean_sentences += 1 else: nl = False n_ocr_sentences = 0 nl = False for line in ocr_tokens_labels: - if line == '\n': + if line == "\n": if nl is True: error = True else: - nl=True + nl = True n_ocr_sentences += 1 else: nl = False if error or n_ocr_sentences != n_clean_sentences: - print(f"Warning: Inconsistent numbers of sentences in '{text_filename}''." + - f"clean_sentences to ocr_sentences: {n_clean_sentences}:{n_ocr_sentences}") + print( + f"Warning: Inconsistent numbers of sentences in '{text_filename}''." + + f"clean_sentences to ocr_sentences: {n_clean_sentences}:{n_ocr_sentences}" + ) skip_files.append(text_filename) return skip_files + def remove_first_line(input_file, output_file): """ remove_first_line from files (some clean CoNLL files have an empty first line) - + Parameters ---------- input_file : str @@ -303,17 +365,18 @@ def remove_first_line(input_file, output_file): output_file : str output file path """ - with open(input_file, 'r', encoding='utf-8') as in_f: + with open(input_file, "r", encoding="utf-8") as in_f: lines = in_f.readlines() - if len(lines) > 1 and lines[0].strip() == '': + if len(lines) > 1 and lines[0].strip() == "": # the clean CoNLL formatted files had a newline as the first line - with open(output_file, 'w', encoding='utf-8') as out_f: + with open(output_file, "w", encoding="utf-8") as out_f: out_f.writelines(lines[1:]) + def remove_last_line(input_file, output_file): """ remove_last_line from files (some clean CoNLL files have an empty last line) - + Parameters ---------- input_file : str @@ -322,16 +385,17 @@ def remove_last_line(input_file, output_file): output file path """ - with open(input_file, 'r', encoding='utf-8') as in_f: + with open(input_file, "r", encoding="utf-8") as in_f: lines = in_f.readlines() - if len(lines) > 1 and lines[-1].strip() == '': - with open(output_file, 'w', encoding='utf-8') as out_f: + if len(lines) > 1 and lines[-1].strip() == "": + with open(output_file, "w", encoding="utf-8") as out_f: out_f.writelines(lines[:-1]) + def for_all_files(input_dir, output_dir, func): """ for_all_files will apply function to every file in a director - + Parameters ---------- input_dir : str @@ -347,25 +411,34 @@ def for_all_files(input_dir, output_dir, func): output_file = os.path.join(output_dir, text_filename) func(input_file, output_file) + def main(args): if not args.train_subset and not args.test_subset: - subsets = ['train', 'test'] + subsets = ["train", "test"] else: subsets = [] if args.train_subset: - subsets.append('train') + subsets.append("train") if args.test_subset: - subsets.append('test') + subsets.append("test") for subset in subsets: print("Processing {} subset...".format(subset)) - clean_labels_dir = os.path.join(args.base_folder, args.gt_folder, subset,'clean_labels') - ocr_json_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr') + clean_labels_dir = os.path.join( + args.base_folder, args.gt_folder, subset, "clean_labels" + ) + ocr_json_dir = os.path.join( + args.base_folder, args.degraded_folder, subset, "ocr" + ) - output_text_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr_text') - output_labels_dir = os.path.join(args.base_folder, args.degraded_folder, subset, 'ocr_labels') + output_text_dir = os.path.join( + args.base_folder, args.degraded_folder, subset, "ocr_text" + ) + output_labels_dir = os.path.join( + args.base_folder, args.degraded_folder, subset, "ocr_labels" + ) # remove first empty line of labels file, if exists for_all_files(clean_labels_dir, clean_labels_dir, remove_first_line) @@ -380,21 +453,50 @@ def main(args): os.mkdir(output_labels_dir) # make ocr labels files by propagating clean labels to ocr_text and creating files in ocr_labels - propagate_labels_sentences_multiprocess(clean_labels_dir, output_text_dir, output_labels_dir, args.clean_label_ext) + propagate_labels_sentences_multiprocess( + clean_labels_dir, output_text_dir, output_labels_dir, args.clean_label_ext + ) print("Validating number of sentences in gt and ocr labels") - check_n_sentences(clean_labels_dir, output_labels_dir, args.clean_label_ext) # check number of sentences and make sure same; print anomaly files + check_n_sentences( + clean_labels_dir, output_labels_dir, args.clean_label_ext + ) # check number of sentences and make sure same; print anomaly files + def create_parser(): parser = argparse.ArgumentParser() - parser.add_argument("base_folder", help="base directory containing the collection of dataset") - parser.add_argument("degraded_folder", help="directory containing train and test subset for degradation") - parser.add_argument("--gt_folder", type=str, default="shared", help="directory containing the ground truth") - parser.add_argument("--clean_label_ext", type=str, default=".txt", help="file extension of the clean_labels files") - parser.add_argument('--train_subset', help="include if only train folder should be processed", action='store_true') - parser.add_argument('--test_subset', help="include if only test folder should be processed", action='store_true') + parser.add_argument( + "base_folder", help="base directory containing the collection of dataset" + ) + parser.add_argument( + "degraded_folder", + help="directory containing train and test subset for degradation", + ) + parser.add_argument( + "--gt_folder", + type=str, + default="shared", + help="directory containing the ground truth", + ) + parser.add_argument( + "--clean_label_ext", + type=str, + default=".txt", + help="file extension of the clean_labels files", + ) + parser.add_argument( + "--train_subset", + help="include if only train folder should be processed", + action="store_true", + ) + parser.add_argument( + "--test_subset", + help="include if only test folder should be processed", + action="store_true", + ) return parser -if __name__ == '__main__': + +if __name__ == "__main__": start = timeit.default_timer() parser = create_parser() args = parser.parse_args() diff --git a/genalog/text/lcs.py b/genalog/text/lcs.py index a0ce179..abe7d79 100644 --- a/genalog/text/lcs.py +++ b/genalog/text/lcs.py @@ -1,14 +1,13 @@ -from genalog.text import preprocess - -class LCS(): +class LCS: """ Compute the Longest Common Subsequence (LCS) of two given string.""" + def __init__(self, str_m, str_n): self.str_m_len = len(str_m) self.str_n_len = len(str_n) dp_table = self._construct_dp_table(str_m, str_n) self._lcs_len = dp_table[self.str_m_len][self.str_n_len] self._lcs = self._find_lcs_str(str_m, str_n, dp_table) - + def _construct_dp_table(self, str_m, str_n): m = self.str_m_len n = self.str_n_len @@ -16,14 +15,14 @@ class LCS(): # Initialize DP table dp = [[0 for j in range(n + 1)] for i in range(m + 1)] - for i in range(1, m+1): - for j in range(1, n+1): + for i in range(1, m + 1): + for j in range(1, n + 1): # Case 1: if char1 == char2 - if str_m[i-1] == str_n[j-1]: - dp[i][j] = 1 + dp[i-1][j-1] + if str_m[i - 1] == str_n[j - 1]: + dp[i][j] = 1 + dp[i - 1][j - 1] # Case 2: take the max of the values in the top and left cell else: - dp[i][j] = max(dp[i-1][j], dp[i][j-1]) + dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) return dp def _find_lcs_str(self, str_m, str_n, dp_table): @@ -32,13 +31,13 @@ class LCS(): lcs = "" while m > 0 and n > 0: # same char - if str_m[m-1] == str_n[n-1]: + if str_m[m - 1] == str_n[n - 1]: # prepend the character - lcs = str_m[m - 1] + lcs + lcs = str_m[m - 1] + lcs m -= 1 n -= 1 # top cell > left cell - elif dp_table[m-1][n] > dp_table[m][n-1]: + elif dp_table[m - 1][n] > dp_table[m][n - 1]: m -= 1 else: n -= 1 diff --git a/genalog/text/ner_label.py b/genalog/text/ner_label.py index 3b3f571..67817db 100644 --- a/genalog/text/ner_label.py +++ b/genalog/text/ner_label.py @@ -10,71 +10,81 @@ import itertools # For example, given a label 'B-PLACE' # Group 1 (denoted by \1): Label Indicator (B-) # Group 2 (denoted by \2): Label Name (PLACE) -MULTI_TOKEN_BEGIN_LABEL_REGEX = r'^\s*(B-)([a-z|A-Z]+)\s*$' -MULTI_TOKEN_INSIDE_LABEL_REGEX = r'^\s*(I-)([a-z|A-Z]+)\s*$' -MULTI_TOKEN_LABEL_REGEX = r'^\s*([B|I]-)([a-z|A-Z]+)\s*' +MULTI_TOKEN_BEGIN_LABEL_REGEX = r"^\s*(B-)([a-z|A-Z]+)\s*$" +MULTI_TOKEN_INSIDE_LABEL_REGEX = r"^\s*(I-)([a-z|A-Z]+)\s*$" +MULTI_TOKEN_LABEL_REGEX = r"^\s*([B|I]-)([a-z|A-Z]+)\s*" # To avoid confusion in the Python interpreter, -# gap char should not be any of the following special characters -SPECIAL_CHAR = set(" \t\n'\x0b''\x0c''\r'") # Notice space characters (' ', '\t', '\n') are in this set. +# gap char should not be any of the following special characters +SPECIAL_CHAR = set( + " \t\n'\x0b''\x0c''\r'" +) # Notice space characters (' ', '\t', '\n') are in this set. GAP_CHAR_SET = set(string.printable).difference(SPECIAL_CHAR) # GAP_CHAR_SET = '!"#$%&()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~' + class GapCharError(Exception): pass + def _is_begin_label(label): """ Return true if the NER label is a begin label (eg. B-PLACE) """ - return re.match(MULTI_TOKEN_BEGIN_LABEL_REGEX, label) != None + return re.match(MULTI_TOKEN_BEGIN_LABEL_REGEX, label) is not None + def _is_inside_label(label): """ Return true if the NER label is an inside label (eg. I-PLACE) """ - return re.match(MULTI_TOKEN_INSIDE_LABEL_REGEX, label) != None + return re.match(MULTI_TOKEN_INSIDE_LABEL_REGEX, label) is not None + def _is_multi_token_label(label): """ Return true if the NER label is a multi token label (eg. B-PLACE, I-PLACE) """ - return re.match(MULTI_TOKEN_LABEL_REGEX, label) != None + return re.match(MULTI_TOKEN_LABEL_REGEX, label) is not None + def _clean_multi_token_label(label): """ Rid the multi-token-labels of whitespaces""" - return re.sub(MULTI_TOKEN_LABEL_REGEX, r'\1\2', label) + return re.sub(MULTI_TOKEN_LABEL_REGEX, r"\1\2", label) + def _convert_to_begin_label(label): - """ Convert an inside label, or I-label, (ex. I-PLACE) to a begin label, or B-Label, (ex. B-PLACE) - + """Convert an inside label, or I-label, (ex. I-PLACE) to a begin label, or B-Label, (ex. B-PLACE) + Arguments: label {str} -- an NER label - + Returns: an NER label. This method DOES NOT alter the label unless it is an inside label """ if _is_inside_label(label): # Replace the Label Indicator to 'B-'(\1) and keep the Label Name (\2) - return re.sub(MULTI_TOKEN_INSIDE_LABEL_REGEX, r'B-\2', label) + return re.sub(MULTI_TOKEN_INSIDE_LABEL_REGEX, r"B-\2", label) return label + def _convert_to_inside_label(label): - """ Convert a begin label, or B-label, (ex. B-PLACE) to an inside label, or I-Label, (ex. B-PLACE) - + """Convert a begin label, or B-label, (ex. B-PLACE) to an inside label, or I-Label, (ex. B-PLACE) + Arguments: label {str} -- an NER label - + Returns: an NER label. This method DOES NOT alter the label unless it is a begin label """ if _is_begin_label(label): # Replace the Label Indicator to 'I-'(\1) and keep the Label Name (\2) - return re.sub(MULTI_TOKEN_BEGIN_LABEL_REGEX, r'I-\2', label) + return re.sub(MULTI_TOKEN_BEGIN_LABEL_REGEX, r"I-\2", label) return label + def _is_missing_begin_label(begin_label, inside_label): - """ Validate a inside label given an begin label - + """Validate a inside label given an begin label + Arguments: - begin_label {str} -- a begin NER label used to + begin_label {str} -- a begin NER label used to check if the given label is part of a multi-token label inside_label {str} -- an inside label to check for its validity - + Returns: True if the inside label paired with the begin_label. False otherwise. Also False if input is not an inside label @@ -87,20 +97,21 @@ def _is_missing_begin_label(begin_label, inside_label): inside_label = _clean_multi_token_label(inside_label) begin_label = _clean_multi_token_label(begin_label) # convert inside label to a begin label for string comparison - # True if the two labels have different names + # True if the two labels have different names # (e.g. B-LOC followed by I-ORG, and I-ORG is missing a begin label) return _convert_to_begin_label(inside_label) != begin_label else: return True + def correct_ner_labels(labels): - """ Correct the given list of labels for the following case: + """Correct the given list of labels for the following case: 1. Missing B-Label (i.e. I-PLACE I-PLACE -> B-PLACE I-PLACE) - + Arguments: labels {list} -- list of NER labels - + Returns: a list of NER labels """ @@ -109,7 +120,7 @@ def correct_ner_labels(labels): if _is_multi_token_label(label): if _is_begin_label(label): cur_begin_label = label - # else is an inside label, so we check if it's missing a begin label + # else is an inside label, so we check if it's missing a begin label else: if _is_missing_begin_label(cur_begin_label, label): labels[i] = _convert_to_begin_label(label) @@ -118,10 +129,11 @@ def correct_ner_labels(labels): else: cur_begin_label = "" return labels - + + def _select_from_multiple_ner_labels(label_indices): - """ Private method to select a NER label from a list of candidate - + """Private method to select a NER label from a list of candidate + Note: this method is used to tackle the issue when multiple gt tokens are aligned to ONE ocr_token @@ -129,7 +141,7 @@ def _select_from_multiple_ner_labels(label_indices): gt_labels: B-p I-p O O | | | | - gt: New York is big + gt: New York is big | \\ / | ocr: New Yorkis big | | | @@ -140,15 +152,16 @@ def _select_from_multiple_ner_labels(label_indices): Arguments: label_indices {list} -- a list of token indices - + Returns: a specific index """ # TODO: may need a more sophisticated way to select from multiple NER labels return label_indices[0] + def _find_gap_char_candidates(gt_tokens, ocr_tokens): - """ Find a set of suitable GAP_CHARs based not in the set of input characters + """Find a set of suitable GAP_CHARs based not in the set of input characters Arguments: gt_tokens {list} -- a list of tokens @@ -159,15 +172,18 @@ def _find_gap_char_candidates(gt_tokens, ocr_tokens): 1. the set of suitable GAP_CHARs 2. the set of input characters """ - input_char_set = set(''.join(itertools.chain(gt_tokens, ocr_tokens))) # The set of input characters + input_char_set = set( + "".join(itertools.chain(gt_tokens, ocr_tokens)) + ) # The set of input characters gap_char_set = GAP_CHAR_SET # The set of possible GAP_CHARs # Find a set of gap_char that is NOT in the set of input characters gap_char_candidates = gap_char_set.difference(input_char_set) return gap_char_candidates, input_char_set + def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True): - """ Propagate NER label for ground truth tokens to to ocr tokens. - + """Propagate NER label for ground truth tokens to to ocr tokens. + NOTE that `gt_tokens` and `ocr_tokens` MUST NOT contain invalid tokens. Invalid tokens are: 1. non-atomic tokens, or space-separated string ("New York") @@ -175,7 +191,7 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True): 4. string with spaces (" ") Arguments: - gt_labels {list} -- a list of NER label for ground truth token + gt_labels {list} -- a list of NER label for ground truth token gt_tokens {list} -- a list of ground truth string tokens ocr_tokens {list} -- a list of OCR'ed text tokens @@ -185,8 +201,8 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True): (default: {True}) Raises: - GapCharError: - when the set of input character is EQUAL + GapCharError: + when the set of input character is EQUAL to set of all possible gap characters (GAP_CHAR_SET) Returns: @@ -199,21 +215,30 @@ def propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, use_anchor=True): `gap_char` is the char used to alignment for inserting gaps """ # Find a set of suitable GAP_CHAR based not in the set of input characters - gap_char_candidates, input_char_set = _find_gap_char_candidates(gt_tokens, ocr_tokens) + gap_char_candidates, input_char_set = _find_gap_char_candidates( + gt_tokens, ocr_tokens + ) if len(gap_char_candidates) == 0: - raise GapCharError("Exhausted all possible GAP_CHAR candidates for alignment." + - " Consider reducing cardinality of the input character set.\n" + - f"The set of possible GAP_CHAR candidates is: '{''.join(sorted(GAP_CHAR_SET))}'\n" + - f"The set of input character is: '{''.join(sorted(input_char_set))}'") + raise GapCharError( + "Exhausted all possible GAP_CHAR candidates for alignment." + + " Consider reducing cardinality of the input character set.\n" + + f"The set of possible GAP_CHAR candidates is: '{''.join(sorted(GAP_CHAR_SET))}'\n" + + f"The set of input character is: '{''.join(sorted(input_char_set))}'" + ) else: if alignment.GAP_CHAR in gap_char_candidates: - gap_char = alignment.GAP_CHAR # prefer to use default GAP_CHAR + gap_char = alignment.GAP_CHAR # prefer to use default GAP_CHAR else: gap_char = gap_char_candidates.pop() - return _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char, use_anchor=use_anchor) + return _propagate_label_to_ocr( + gt_labels, gt_tokens, ocr_tokens, gap_char=gap_char, use_anchor=use_anchor + ) -def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment.GAP_CHAR, use_anchor=True): - """ Propagate NER label for ground truth tokens to to ocr tokens. Low level implementation + +def _propagate_label_to_ocr( + gt_labels, gt_tokens, ocr_tokens, gap_char=alignment.GAP_CHAR, use_anchor=True +): + """Propagate NER label for ground truth tokens to to ocr tokens. Low level implementation NOTE: that `gt_tokens` and `ocr_tokens` MUST NOT contain invalid tokens. Invalid tokens are: @@ -221,10 +246,10 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment 2. multiple occurrences of the GAP_CHAR ('@@@') 3. empty string ("") 4. string with spaces (" ") - + Case Analysis: ******************************** MULTI-TOKEN-LABELS ******************************** - + Case 1: Case 2: Case 3: Case 4: Case 5: one-to-many many-to-one many-to-many missing tokens missing tokens (Case 1&2 comb) (I-label) (B-label) @@ -233,24 +258,24 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment gt_token New York New York New York New York New York City / \\ / \\ \\/ /\\ / | | | ocr_token N ew Yo rk NewYork N ew@York New York City - | | | | | | | | | | + | | | | | | | | | | ocr label B-p I-p I-p I-p B-p B-p I-p B-p B-p I-p ******************************** SINGLE-TOKEN-LABELS ******************************** - - Case 1: Case 2: Case 3: Case 4: - one-to-many many-to-one many-to-many missing tokens - (Case 1&2 comb) - gt label O V O O V W O O - | | | | | | | | - gt_token something is big this is huge is big - / \\ \\ \\/ /\\ /\\/ | - ocr_token so me thing isbig th isi shuge is - | | | | | | | | - ocr label o o o V O O V O - + + Case 1: Case 2: Case 3: Case 4: + one-to-many many-to-one many-to-many missing tokens + (Case 1&2 comb) + gt label O V O O V W O O + | | | | | | | | + gt_token something is big this is huge is big + / \\ \\ \\/ /\\ /\\/ | + ocr_token so me thing isbig th isi shuge is + | | | | | | | | + ocr label o o o V O O V O + Arguments: - gt_labels {list} -- a list of NER label for ground truth token + gt_labels {list} -- a list of NER label for ground truth token gt_tokens {list} -- a list of ground truth string tokens ocr_tokens {list} -- a list of OCR'ed text tokens @@ -259,13 +284,13 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment use_anchor {bool} -- use faster alignment method with anchors if set to True (default: {True}) Raises: - ValueError: when + ValueError: when 1. there is unequal number of gt_tokens and gt_labels 2. there is a non-atomic token in gt_tokens or ocr_tokens 3. there is an empty string in gt_tokens or ocr_tokens 4. there is a token full of space characters only in gt_tokens or ocr_tokens 5. gt_to_ocr_mapping has more tokens than gt_tokens - GapCharError: when + GapCharError: when 1. there is a token consisted of GAP_CHAR only @@ -278,24 +303,24 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment `aligned_ocr` is the ocr text aligned with ground true `gap_char` is the char used to alignment for inserting gaps - For example, + For example, given input: gt_labels: ["B-place", "I-place", "o", "o"] gt_tokens: ["New", "York", "is", "big"] ocr_tokens: ["N", "ewYork", "big"] - + output: ( ["B-place", "I-place", "o"], "N@ew York is big", - "N ew@York@@@ big" + "N ew@York@@@ big" ) """ # Pseudo-algorithm: # ocr_to_gt_mapping = [ - # gt_labels: B-P I-P I-P O O B-P I-P [1, 2], ('YorkCity' maps to 'York' and 'City') + # gt_labels: B-P I-P I-P O O B-P I-P [1, 2], ('YorkCity' maps to 'York' and 'City') # | | | | | | | [3], ('i' maps to 'is') # gt_txt: "New York City is in New York" [3, 4], ('sin' maps to 'is' and 'in') # \/ /\ | /\ [5], ('N' maps to 'New') @@ -312,13 +337,13 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment # # gt_to_ocr_mapping = [ - # gt_labels: B-P I-P I-P O O B-P I-P [], ('New' does not map to any ocr token) + # gt_labels: B-P I-P I-P O O B-P I-P [], ('New' does not map to any ocr token) # | | | | | | | [0], ('York' maps to 'YorkCity') # gt_txt: "New York City is in New York" [0], ('City' maps to 'YorkCity') # \/ /\ | /\ [1, 2], ('is' maps to 'i' and 'sin') # ocr_txt: "YorkCity i sin N ew" [2], ('in' maps to 'sin) - # | | | | | [3,4], ('New' maps to 'N' and 'ew') - # I-P O O B-P B-P [] ('York' does not map to any ocr token) + # | | | | | [3,4], ('New' maps to 'N' and 'ew') + # I-P O O B-P B-P [] ('York' does not map to any ocr token) # ] # STEP 2, clean up corner cases from multi-token-labels @@ -332,41 +357,53 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment # YorkCity # We can address MULTI-TOKEN-LABELS Case 1 with following pseudo-algorithm: - # 1. For each gt_token in gt_to_ocr_mapping: - # 1. If the gt_token is mapped to 2 or more ocr_tokens AND the gt_token has a B-label - # 1. For all the ocr_tokens this gt_token mapped to - # 1. Keep the B-label for the 1st ocr_token - # 2. For the rest of the ocr_token, convert the B-label to an I-label - + # 1. For each gt_token in gt_to_ocr_mapping: + # 1. If the gt_token is mapped to 2 or more ocr_tokens AND the gt_token has a B-label + # 1. For all the ocr_tokens this gt_token mapped to + # 1. Keep the B-label for the 1st ocr_token + # 2. For the rest of the ocr_token, convert the B-label to an I-label + # We can address the MULTI-TOKEN-LABELS Case 5 with the '_correct_ner_labels()' method - + # Sanity check: if len(gt_tokens) != len(gt_labels): - raise ValueError(f"Unequal number of gt_tokens ({len(gt_tokens)})" + - f"to that of gt_labels ({len(gt_labels)})") - - for tk in (gt_tokens + ocr_tokens): + raise ValueError( + f"Unequal number of gt_tokens ({len(gt_tokens)})" + + f"to that of gt_labels ({len(gt_labels)})" + ) + + for tk in gt_tokens + ocr_tokens: if len(preprocess.tokenize(tk)) > 1: raise ValueError(f"Invalid token '{tk}'. Tokens must be atomic.") if not alignment._is_valid_token(tk, gap_char=gap_char): - if re.search(rf'{re.escape(gap_char)}+', tk): # Escape special regex chars - raise GapCharError(f"Invalid token '{tk}'. Tokens cannot be a chain repetition of the GAP_CHAR '{gap_char}'") + if re.search(rf"{re.escape(gap_char)}+", tk): # Escape special regex chars + raise GapCharError( + f"Invalid token '{tk}'. Tokens cannot be a chain repetition of the GAP_CHAR '{gap_char}'" + ) else: - raise ValueError(f"Invalid token '{tk}'. Tokens cannot be an empty string or a mix of space characters (spaces, tabs, newlines)") + raise ValueError( + f"Invalid token '{tk}'. Tokens cannot be an empty string or a mix of space characters (spaces, tabs, newlines)" + ) # Stitch tokens together into one string for alignment gt_txt = preprocess.join_tokens(gt_tokens) ocr_txt = preprocess.join_tokens(ocr_tokens) # Align the ground truth and ocr text first if use_anchor: - aligned_gt, aligned_ocr = anchor.align_w_anchor(gt_txt, ocr_txt, gap_char=gap_char) + aligned_gt, aligned_ocr = anchor.align_w_anchor( + gt_txt, ocr_txt, gap_char=gap_char + ) else: aligned_gt, aligned_ocr = alignment.align(gt_txt, ocr_txt, gap_char=gap_char) - gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment(aligned_gt, aligned_ocr, gap_char=gap_char) + gt_to_ocr_mapping, ocr_to_gt_mapping = alignment.parse_alignment( + aligned_gt, aligned_ocr, gap_char=gap_char + ) # Check invariant if len(gt_to_ocr_mapping) != len(gt_tokens): - raise ValueError(f"Alignment modified number of gt_tokens. aligned_gt_tokens to gt_tokens: " + - f"{len(gt_to_ocr_mapping)}:{len(gt_tokens)}. \nCheck alignment.parse_alignment().") + raise ValueError( + "Alignment modified number of gt_tokens. aligned_gt_tokens to gt_tokens: " + + f"{len(gt_to_ocr_mapping)}:{len(gt_tokens)}. \nCheck alignment.parse_alignment()." + ) ocr_labels = [] # STEP 1: naively propagate NER label based on text-alignment @@ -374,7 +411,9 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment # if is not mapping to missing a token (Case 4) if ocr_to_gt_token_relationship: # Find the corresponding gt_token it is aligned to - ner_label_index = _select_from_multiple_ner_labels(ocr_to_gt_token_relationship) + ner_label_index = _select_from_multiple_ner_labels( + ocr_to_gt_token_relationship + ) # Get the NER label for that particular gt_token ocr_labels.append(gt_labels[ner_label_index]) @@ -397,7 +436,7 @@ def _propagate_label_to_ocr(gt_labels, gt_tokens, ocr_tokens, gap_char=alignment def format_labels(tokens, labels, label_top=True): """Format tokens and their NER label for display - + Arguments: tokens {list} -- a list of word tokens labels {list} -- a list of NER labels @@ -405,7 +444,7 @@ def format_labels(tokens, labels, label_top=True): Keyword Arguments: label_top {bool} -- True if label is place on top of the token (default: {True}) - + Returns: a str with NER label align to the token it is labeling @@ -429,20 +468,28 @@ def format_labels(tokens, labels, label_top=True): len_diff = abs(len(label) - len(token)) # Add padding spaces for whichever is shorter if len(label) > len(token): - formatted_labels += label + ' ' - formatted_tokens += token + ' '*len_diff + ' ' + formatted_labels += label + " " + formatted_tokens += token + " " * len_diff + " " else: - formatted_labels += label + ' '*len_diff + ' ' - formatted_tokens += token + ' ' + formatted_labels += label + " " * len_diff + " " + formatted_tokens += token + " " if label_top: - return formatted_labels + '\n' + formatted_tokens + '\n' + return formatted_labels + "\n" + formatted_tokens + "\n" else: - return formatted_tokens + '\n' + formatted_labels + '\n' + return formatted_tokens + "\n" + formatted_labels + "\n" + + +def format_label_propagation( + gt_tokens, + gt_labels, + ocr_tokens, + ocr_labels, + aligned_gt, + aligned_ocr, + show_alignment=True, +): + """Format label propagation for display -def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \ - aligned_gt, aligned_ocr, show_alignment=True): - """ Format label propagation for display - Arguments: gt_tokens {list} -- list of ground truth tokens gt_labels {list} -- list of NER labels for ground truth tokens @@ -450,15 +497,15 @@ def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \ ocr_labels {list} -- list of NER labels for the OCR'ed tokens aligned_gt {str} -- ground truth string aligned with the OCR'ed text aligned_ocr {str} -- OCR'ed text aligned with ground truth - + Keyword Arguments: show_alignment {bool} -- if true, show alignment result (default: {True}) - + Returns: a string formatted for display as follows: if show_alignment=TRUE - " + " B-PLACE I-PLACE V O [gt_labels] New York is big [gt_txt] New York is big [aligned_gt] @@ -468,19 +515,18 @@ def format_label_propagation(gt_tokens, gt_labels, ocr_tokens, ocr_labels, \ B-PLACE V O [ocr_labels] " else - " + " B-PLACE I-PLACE V O [gt_labels] New York is big [gt_txt] New is big [ocr_txt] B-PLACE V O [ocr_labels] " """ - + gt_label_str = format_labels(gt_tokens, gt_labels) - label_str = format_labels(ocr_tokens, ocr_labels, label_top=False) + label_str = format_labels(ocr_tokens, ocr_labels, label_top=False) if show_alignment: alignment_str = alignment._format_alignment(aligned_gt, aligned_ocr) return gt_label_str + alignment_str + label_str else: return gt_label_str + label_str - diff --git a/genalog/text/preprocess.py b/genalog/text/preprocess.py index 326284b..1b60d51 100644 --- a/genalog/text/preprocess.py +++ b/genalog/text/preprocess.py @@ -1,10 +1,11 @@ -import re +import re -END_OF_TOKEN = {' ', '\t', '\n'} +END_OF_TOKEN = {" ", "\t", "\n"} NON_ASCII_REPLACEMENT = "_" + def remove_non_ascii(token, replacement=NON_ASCII_REPLACEMENT): - """ Remove non ascii characters in a token + """Remove non ascii characters in a token Arguments: token {str} -- a word token @@ -15,27 +16,29 @@ def remove_non_ascii(token, replacement=NON_ASCII_REPLACEMENT): str -- a word token with non-ASCII characters removed """ # Remove non-ASCII characters in the token - ascii_token = str(token.encode('utf-8').decode('ascii', 'ignore')) - # If token becomes an empty string as a result + ascii_token = str(token.encode("utf-8").decode("ascii", "ignore")) + # If token becomes an empty string as a result if len(ascii_token) == 0 and len(token) != 0: - ascii_token = replacement # replace with a default character + ascii_token = replacement # replace with a default character return ascii_token + def tokenize(s): - """ Tokenize string - + """Tokenize string + Arguments: s {str} -- aligned string - + Returns: a list of tokens """ # split alignment tokens by spaces, tabs and newline (and excluding them in the tokens) return s.split() + def join_tokens(tokens): - """ Join a list of tokens into a string - + """Join a list of tokens into a string + Arguments: tokens {list} -- a list of tokens @@ -44,14 +47,17 @@ def join_tokens(tokens): """ return " ".join(tokens) + def _is_spacing(c): """ Determine if the character is ignorable """ return True if c in END_OF_TOKEN else False + def split_sentences(text, delimiter="\n"): """ Split a text into sentences with a delimiter""" - return re.sub(r'(( /?[.!?])+ )', rf'\1{delimiter}', text) + return re.sub(r"(( /?[.!?])+ )", rf"\1{delimiter}", text) + def is_sentence_separator(token): """ Returns true if the token is a sentence splitter """ - return re.match(r'^/?[.!?]$', token) != None \ No newline at end of file + return re.match(r"^/?[.!?]$", token) is not None diff --git a/genalog/text/splitter.py b/genalog/text/splitter.py index b17a518..9732c14 100644 --- a/genalog/text/splitter.py +++ b/genalog/text/splitter.py @@ -1,5 +1,5 @@ -"""This is a utility tool to split CoNLL formated files. It has the capability to pack sentences into generated -pages more tightly. +"""This is a utility tool to split CoNLL formated files. +It has the capability to pack sentences into generated pages more tightly. usage: splitter.py [-h] [--doc_sep DOC_SEP] [--line_sep LINE_SEP] [--force_doc_sep] @@ -22,7 +22,6 @@ example usage: python -m genalog.text.splitter CoNLL-2012_train.txt conll2012_train """ -import re import os import multiprocessing import argparse @@ -40,14 +39,15 @@ CONLL2012_DOC_SEPERATOR = "" CONLL2003_DOC_SEPERATOR = "-DOCSTART-" SEPERATOR = "" -STARTING_SPLIT_GUESS = 100 # starting estimate of point where to split text -MAX_SIZE = 100 # max number of sentences to pack on a doc page +STARTING_SPLIT_GUESS = 100 # starting estimate of point where to split text +MAX_SIZE = 100 # max number of sentences to pack on a doc page -SPLIT_ITERS = 2 # number of iterations to run to find a good split +SPLIT_ITERS = 2 # number of iterations to run to find a good split WORKERS_PER_CPU = 2 default_generator = DocumentGenerator() + def unwrap(size, accumulator): words = [] labels = [] @@ -55,11 +55,14 @@ def unwrap(size, accumulator): sentence = accumulator[i] for word, tok in sentence: words.append(word) - labels.append((word,tok)) + labels.append((word, tok)) return words, labels -def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name='text_block.html.jinja'): - """Run a few iterations of binary search to find the best split point + +def find_split_position( + accumulator, start_pos, iters=SPLIT_ITERS, template_name="text_block.html.jinja" +): + """Run a few iterations of binary search to find the best split point from the start to pack in sentences into a page without overflow. Args: @@ -71,38 +74,48 @@ def find_split_position(accumulator,start_pos,iters=SPLIT_ITERS, template_name=' """ global STARTING_SPLIT_GUESS # use binary search to find page split point - start, end = start_pos, min(len(accumulator), MAX_SIZE+start_pos) + start, end = start_pos, min(len(accumulator), MAX_SIZE + start_pos) best = None count = 0 while start <= end: - if count==0 and (STARTING_SPLIT_GUESS+start_pos > start and STARTING_SPLIT_GUESS + start_pos < end): + if count == 0 and ( + STARTING_SPLIT_GUESS + start_pos > start + and STARTING_SPLIT_GUESS + start_pos < end + ): split_point = STARTING_SPLIT_GUESS else: - split_point = (start + end)//2 - doc_buf = (start_pos,split_point) + split_point = (start + end) // 2 + doc_buf = (start_pos, split_point) content_words, labels = unwrap(doc_buf, accumulator) content_types = [ContentType.PARAGRAPH] text = " ".join(content_words) content = CompositeContent([text], content_types) - doc_gen = default_generator.create_generator(content, [template_name]) + doc_gen = default_generator.create_generator(content, [template_name]) doc = next(doc_gen) - + if len(doc._document.pages) > 1: - end = split_point-1 + end = split_point - 1 else: - start = split_point+1 - best = split_point, doc, labels,text + start = split_point + 1 + best = split_point, doc, labels, text if count >= iters: break count += 1 STARTING_SPLIT_GUESS = split_point return best - -def generate_splits(input_file, output_folder, sentence_seperator="", doc_seperator=None, pool=None, - force_doc_sep=False, ext="txt"): + +def generate_splits( + input_file, + output_folder, + sentence_seperator="", + doc_seperator=None, + pool=None, + force_doc_sep=False, + ext="txt", +): """Processes the file line by line and add sentences to the buffer for processing. Args: @@ -120,19 +133,21 @@ def generate_splits(input_file, output_folder, sentence_seperator="", doc_sepera with open(input_file) as f: for line in f: if line.strip() == sentence_seperator or line.strip() == doc_seperator: - - if len(sentence) > 0: + + if len(sentence) > 0: accumulator.append(sentence) sentence = [] if line.strip() == doc_seperator and force_doc_sep: # progress to processing buffer immediately if force_doc_sep pass - elif len(accumulator) < BUFFER_SIZE: + elif len(accumulator) < BUFFER_SIZE: continue start_pos = 0 while start_pos < len(accumulator): - start_pos = next_doc(accumulator,doc_id, start_pos, output_folder,pool) + start_pos = next_doc( + accumulator, doc_id, start_pos, output_folder, pool + ) doc_id += 1 progress_bar.update(1) accumulator = [] @@ -141,69 +156,93 @@ def generate_splits(input_file, output_folder, sentence_seperator="", doc_sepera word, tok = line.split("\t") if word.strip() == "": continue - sentence.append((word,tok)) + sentence.append((word, tok)) # process any left over lines start_pos = 0 - if len(sentence) > 0 : accumulator.append(sentence) + if len(sentence) > 0: + accumulator.append(sentence) while start_pos < len(accumulator): - start_pos = next_doc(accumulator,doc_id, start_pos, output_folder,pool) + start_pos = next_doc(accumulator, doc_id, start_pos, output_folder, pool) doc_id += 1 progress_bar.update(1) -def next_doc(accumulator,doc_id, start_pos, output_folder,pool,ext="txt"): - split_pos,doc,labels,text = find_split_position(accumulator, start_pos) + +def next_doc(accumulator, doc_id, start_pos, output_folder, pool, ext="txt"): + split_pos, doc, labels, text = find_split_position(accumulator, start_pos) handle_doc(doc, labels, doc_id, text, output_folder, pool, ext) return split_pos - + + def write_doc(doc, doc_id, labels, text, output_folder, ext="txt", write_png=False): if write_png: f = f"{output_folder}/img/img_{doc_id}.png" doc.render_png(target=f) - text += " " # adding a space at EOF + text += " " # adding a space at EOF text = preprocess.split_sentences(text) - with open(f"{output_folder}/clean_labels/{doc_id}.{ext}", "w") as l: + with open(f"{output_folder}/clean_labels/{doc_id}.{ext}", "w") as fp: for idx, (token, label) in enumerate(labels): - l.write(token + "\t" + label) + fp.write(token + "\t" + label) next_token, _ = labels[(idx + 1) % len(labels)] - if preprocess.is_sentence_separator(token) and not \ - preprocess.is_sentence_separator(next_token): - l.write("\n") - if idx == len(labels): # Reach the end of the document - l.write("\n") - + if preprocess.is_sentence_separator( + token + ) and not preprocess.is_sentence_separator(next_token): + fp.write("\n") + if idx == len(labels): # Reach the end of the document + fp.write("\n") + with open(f"{output_folder}/clean_text/{doc_id}.txt", "w") as text_file: text_file.write(text) return f"wrote: doc id: {doc_id}" + def _error_callback(err): raise RuntimeError(err) + def handle_doc(doc, labels, doc_id, text, output_folder, pool, ext="txt"): if pool: - pool.apply_async(write_doc, args=(doc, doc_id, labels, text, output_folder,ext), error_callback=_error_callback) + pool.apply_async( + write_doc, + args=(doc, doc_id, labels, text, output_folder, ext), + error_callback=_error_callback, + ) else: write_doc(doc, doc_id, labels, text, output_folder) + def setup_folder(output_folder): - os.makedirs(os.path.join(output_folder,"clean_text"), exist_ok=True) - os.makedirs(os.path.join(output_folder,"clean_labels"), exist_ok=True) + os.makedirs(os.path.join(output_folder, "clean_text"), exist_ok=True) + os.makedirs(os.path.join(output_folder, "clean_labels"), exist_ok=True) + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("input_file", default="CoNLL-2012_train.txt", help="path to input CoNLL formated file.") - parser.add_argument("output_folder", default="conll2012_train", help="folder to write results to.") + parser.add_argument( + "input_file", + default="CoNLL-2012_train.txt", + help="path to input CoNLL formated file.", + ) + parser.add_argument( + "output_folder", default="conll2012_train", help="folder to write results to." + ) parser.add_argument("--doc_sep", help="CoNLL doc seperator") parser.add_argument("--ext", help="file extension", default="txt") - parser.add_argument("--line_sep", default=CONLL2012_DOC_SEPERATOR, help="CoNLL line seperator") - parser.add_argument("--force_doc_sep", default=False, action="store_true", - help="If set, documents are forced to be split by the doc seperator (recommended to turn this off)") + parser.add_argument( + "--line_sep", default=CONLL2012_DOC_SEPERATOR, help="CoNLL line seperator" + ) + parser.add_argument( + "--force_doc_sep", + default=False, + action="store_true", + help="If set, documents are forced to be split by the doc seperator (recommended to turn this off)", + ) args = parser.parse_args() - unescape = lambda s: s.encode('utf-8').decode('unicode_escape') if s else None + unescape = lambda s: s.encode("utf-8").decode("unicode_escape") if s else None # noqa: E731 input_file = args.input_file output_folder = args.output_folder @@ -211,10 +250,18 @@ if __name__ == "__main__": # allow special characters in seperators line_sep = unescape(args.line_sep) or "" - doc_sep = unescape(args.doc_sep) + doc_sep = unescape(args.doc_sep) n_workers = WORKERS_PER_CPU * multiprocessing.cpu_count() with ThreadPool(processes=n_workers) as pool: - generate_splits(input_file, output_folder, line_sep, doc_seperator=doc_sep, pool=pool, force_doc_sep=False, ext=args.ext) + generate_splits( + input_file, + output_folder, + line_sep, + doc_seperator=doc_sep, + pool=pool, + force_doc_sep=False, + ext=args.ext, + ) pool.close() - pool.join() \ No newline at end of file + pool.join() From a9620f51528c9715b33408a3ed9c902475c76eda Mon Sep 17 00:00:00 2001 From: "Jianjie Liu (MAIDAP)" Date: Mon, 25 Jan 2021 16:13:16 -0500 Subject: [PATCH 3/5] Update build pipeline --- azure-pipeline.yml | 37 ++++++++++--------------------------- pytest.ini | 2 -- requirements-dev.txt | 3 +++ setup.py | 2 +- tox.ini | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 46 insertions(+), 30 deletions(-) delete mode 100644 pytest.ini create mode 100644 requirements-dev.txt create mode 100644 tox.ini diff --git a/azure-pipeline.yml b/azure-pipeline.yml index 969320f..c0a5b4e 100644 --- a/azure-pipeline.yml +++ b/azure-pipeline.yml @@ -35,31 +35,20 @@ steps: displayName: 'Use Python $(python.version)' - bash: | - python -m venv .venv - displayName: 'Create virtual environment' - -- bash: | - if [[ '$(Agent.OS)' == Windows* ]] - then - source .venv/Scripts/activate - else - source .venv/bin/activate - fi - pip install --upgrade pip - pip install setuptools wheel - pip install -r requirements.txt - pip install pytest==5.3.5 pytest-cov==2.8.1 + python -m pip install --upgrade pip + python -m pip install setuptools wheel + python -m pip install -r requirements.txt + python -m pip install -r requirements-dev.txt workingDirectory: $(Build.SourcesDirectory) displayName: 'Install dependencies' - bash: | - if [[ '$(Agent.OS)' == Windows* ]] - then - source .venv/Scripts/activate - else - source .venv/bin/activate - fi - python -m pytest tests --cov=genalog --doctest-modules --junitxml=junit/test-results.xml --cov-report=xml --cov-report=html + python -m flake8 + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Run Linter (flake8)' + +- bash: | + python -m pytest tests env: BLOB_KEY : $(BLOB_KEY) SEARCH_SERVICE_KEY: $(SEARCH_SERVICE_KEY) @@ -86,12 +75,6 @@ steps: displayName: 'Publish test coverage' - bash: | - if [[ '$(Agent.OS)' == Windows* ]] - then - source .venv/Scripts/activate - else - source .venv/bin/activate - fi python setup.py bdist_wheel --build-number $(Build.BuildNumber) --dist-dir dist workingDirectory: $(Build.SourcesDirectory) displayName: 'Building wheel package' diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index c857d08..0000000 --- a/pytest.ini +++ /dev/null @@ -1,2 +0,0 @@ -[pytest] -junit_family=xunit1 \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..f604537 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +pytest +pytest-cov +flake8 diff --git a/setup.py b/setup.py index 91293ec..3ea6f51 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ with open("README.md", "r", encoding="utf8") as fh: setuptools.setup( name="genalog", - install_requires=requirements, + install_requires=requirements, version=BUILD_VERSION, author="Team Enki", author_email="ta_nerds@microsoft.com", diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..f00b8fc --- /dev/null +++ b/tox.ini @@ -0,0 +1,32 @@ +; [tox] +; envlist = flake8, py36 # add other python versions if necessary + +; [testenv] +; # Reading additional dependencies to run the test +; # https://tox.readthedocs.io/en/latest/example/basic.html#depending-on-requirements-txt-or-defining-constraints +; deps = -rdev-requirements.txt +; commands = +; pytest + +; [testenv:flake8] +; deps = flake8 +; skip_install = True +; commands = flake8 . + +; # Configurations for running pytest +[pytest] +junit_family=xunit2 +testpaths = + tests +addopts = + -rsx --cov=genalog --cov-report=html --cov-report=term-missing --cov-report=xml --junitxml=junit/test-results.xml + +[flake8] +max-line-length = 140 +exclude = + build, dist + .env*,.venv* # local virtual environments + .tox + +; [mypy] +; ignore_missing_imports = True \ No newline at end of file From f8c1e63b37ad2c6d0285ec31c920c4724d63d198 Mon Sep 17 00:00:00 2001 From: "Jianjie Liu (MAIDAP)" Date: Mon, 25 Jan 2021 16:49:14 -0500 Subject: [PATCH 4/5] Add flake8 plugin to check import orders --- requirements-dev.txt | 1 + tox.ini | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index f604537..01c7355 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ pytest pytest-cov flake8 +flake8-import-order \ No newline at end of file diff --git a/tox.ini b/tox.ini index f00b8fc..5cc5272 100644 --- a/tox.ini +++ b/tox.ini @@ -22,6 +22,10 @@ addopts = -rsx --cov=genalog --cov-report=html --cov-report=term-missing --cov-report=xml --junitxml=junit/test-results.xml [flake8] +# Configs for flake8-import-order, see https://pypi.org/project/flake8-import-order/ for more info. +import-order-style=edited +application-import-names=genalog, tests +# Native flake8 configs max-line-length = 140 exclude = build, dist From 4c8445b8742f5eee3be23ef23e7ffcf58384af86 Mon Sep 17 00:00:00 2001 From: "Jianjie Liu (MAIDAP)" Date: Mon, 25 Jan 2021 16:49:24 -0500 Subject: [PATCH 5/5] Fix import orders --- genalog/degradation/degrader.py | 5 +++-- genalog/degradation/effect.py | 3 ++- genalog/generation/content.py | 2 +- genalog/generation/document.py | 14 ++++++-------- genalog/ocr/blob_client.py | 12 +++++++----- genalog/ocr/grok.py | 5 +++-- genalog/ocr/metrics.py | 12 +++++++----- genalog/ocr/rest_client.py | 8 +++++--- genalog/pipeline.py | 18 ++++++++++-------- genalog/text/alignment.py | 6 ++++-- genalog/text/anchor.py | 5 +++-- genalog/text/conll_format.py | 9 +++++---- genalog/text/ner_label.py | 7 ++++--- genalog/text/splitter.py | 14 ++++++++------ setup.py | 3 ++- tests/degradation/test_degrader.py | 7 ++++--- tests/degradation/test_effect.py | 3 ++- tests/e2e/test_anchor_e2e.py | 9 +++++---- tests/e2e/test_conll_format_e2e.py | 3 ++- tests/e2e/test_document_generation.py | 5 +++-- tests/e2e/test_generaton_n_degradation.py | 4 ++-- tests/e2e/test_image_channel.py | 4 ++-- tests/e2e/test_ocr_e2e.py | 8 +++++--- tests/e2e/test_pipeline.py | 5 +++-- tests/e2e/test_splitter.py | 4 ++-- tests/generation/test_content.py | 6 +++--- tests/generation/test_document.py | 8 +++++--- tests/ocr/test_metrics.py | 13 +++++-------- tests/ocr/test_ocr.py | 2 +- tests/text/test_alignment.py | 14 ++++++-------- tests/text/test_anchor.py | 9 +++++---- tests/text/test_conll_format.py | 7 ++++--- tests/text/test_lcs.py | 4 ++-- tests/text/test_ner_label.py | 4 ++-- tests/text/test_preprocess.py | 3 ++- tests/text/test_utf8.py | 3 ++- 36 files changed, 137 insertions(+), 111 deletions(-) diff --git a/genalog/degradation/degrader.py b/genalog/degradation/degrader.py index 1024449..62c68da 100644 --- a/genalog/degradation/degrader.py +++ b/genalog/degradation/degrader.py @@ -1,7 +1,8 @@ -from genalog.degradation import effect -from enum import Enum import copy import inspect +from enum import Enum + +from genalog.degradation import effect DEFAULT_METHOD_PARAM_TO_INCLUDE = "src" diff --git a/genalog/degradation/effect.py b/genalog/degradation/effect.py index a3c87b1..24ef110 100644 --- a/genalog/degradation/effect.py +++ b/genalog/degradation/effect.py @@ -1,6 +1,7 @@ +from math import floor + import cv2 import numpy as np -from math import floor def blur(src, radius=5): diff --git a/genalog/generation/content.py b/genalog/generation/content.py index b39c369..71b186a 100644 --- a/genalog/generation/content.py +++ b/genalog/generation/content.py @@ -1,4 +1,4 @@ -from enum import Enum, auto +from enum import auto, Enum class ContentType(Enum): diff --git a/genalog/generation/document.py b/genalog/generation/document.py index 4651911..99aa23c 100644 --- a/genalog/generation/document.py +++ b/genalog/generation/document.py @@ -1,14 +1,12 @@ -from jinja2 import PackageLoader, FileSystemLoader -from jinja2 import Environment, select_autoescape -from weasyprint import HTML -from cairocffi import FORMAT_ARGB32 -import numpy as np - - import itertools import os -import cv2 +import cv2 +import numpy as np +from cairocffi import FORMAT_ARGB32 +from jinja2 import Environment, select_autoescape +from jinja2 import FileSystemLoader, PackageLoader +from weasyprint import HTML DEFAULT_DOCUMENT_STYLE = { "language": "en_US", diff --git a/genalog/ocr/blob_client.py b/genalog/ocr/blob_client.py index 547b3d4..43b856a 100644 --- a/genalog/ocr/blob_client.py +++ b/genalog/ocr/blob_client.py @@ -1,18 +1,20 @@ """Uses the python sdk to make operation on Azure Blob storage. see: https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python """ -import os +import asyncio +import base64 import hashlib import json -import asyncio +import os import random + +import aiofiles +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from azure.storage.blob import BlobServiceClient from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient -from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from tqdm import tqdm + from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME -import base64 -import aiofiles # maximum number of simultaneous requests REQUEST_SEMAPHORE = asyncio.Semaphore(50) diff --git a/genalog/ocr/grok.py b/genalog/ocr/grok.py index b51f0cb..84e6ada 100644 --- a/genalog/ocr/grok.py +++ b/genalog/ocr/grok.py @@ -1,7 +1,8 @@ -from .rest_client import GrokRestClient -from .blob_client import GrokBlobClient import time +from .blob_client import GrokBlobClient +from .rest_client import GrokRestClient + class Grok: @staticmethod diff --git a/genalog/ocr/metrics.py b/genalog/ocr/metrics.py index 5466f6d..50aa498 100644 --- a/genalog/ocr/metrics.py +++ b/genalog/ocr/metrics.py @@ -15,17 +15,19 @@ since the rest of the segments align.The assumption here is that we do not expec hence collecting and counting these substitutions will be managable. """ +import argparse +import json +import multiprocessing import os import re -import json -import argparse -import multiprocessing +from multiprocessing import Pool + import pandas as pd from tqdm import tqdm -from multiprocessing import Pool + from genalog.text.alignment import GAP_CHAR -from genalog.text.ner_label import _find_gap_char_candidates from genalog.text.anchor import align_w_anchor +from genalog.text.ner_label import _find_gap_char_candidates LOG_LEVEL = 0 WORKERS_PER_CPU = 2 diff --git a/genalog/ocr/rest_client.py b/genalog/ocr/rest_client.py index dbcdce8..c00c0d8 100644 --- a/genalog/ocr/rest_client.py +++ b/genalog/ocr/rest_client.py @@ -1,13 +1,15 @@ """Uses the REST api to perform operations on the search service. see: https://docs.microsoft.com/en-us/rest/api/searchservice/ """ -import requests +import json import os import pkgutil -import json -import time import sys +import time from itertools import cycle + +import requests + from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME API_VERSION = "?api-version=2019-05-06-Preview" diff --git a/genalog/pipeline.py b/genalog/pipeline.py index 8a8534b..5872514 100644 --- a/genalog/pipeline.py +++ b/genalog/pipeline.py @@ -1,14 +1,16 @@ -from genalog.generation.document import DocumentGenerator -from genalog.generation.document import DEFAULT_STYLE_COMBINATION -from genalog.generation.content import CompositeContent, ContentType -from genalog.degradation.degrader import Degrader, ImageState -from json import JSONEncoder -from tqdm import tqdm import concurrent.futures -import timeit -import cv2 import os +import timeit +from json import JSONEncoder + +import cv2 +from tqdm import tqdm + +from genalog.degradation.degrader import Degrader, ImageState +from genalog.generation.content import CompositeContent, ContentType +from genalog.generation.document import DEFAULT_STYLE_COMBINATION +from genalog.generation.document import DocumentGenerator class ImageStateEncoder(JSONEncoder): diff --git a/genalog/text/alignment.py b/genalog/text/alignment.py index a12d253..728a07a 100644 --- a/genalog/text/alignment.py +++ b/genalog/text/alignment.py @@ -1,7 +1,9 @@ -from genalog.text.preprocess import _is_spacing, tokenize -from Bio import pairwise2 import re +from Bio import pairwise2 + +from genalog.text.preprocess import _is_spacing, tokenize + # Configuration params for global sequence alignment algorithm (Needleman-Wunsch) MATCH_REWARD = 1 GAP_PENALTY = -0.5 diff --git a/genalog/text/anchor.py b/genalog/text/anchor.py index 9e6e0c6..b885b67 100644 --- a/genalog/text/anchor.py +++ b/genalog/text/anchor.py @@ -14,9 +14,10 @@ """ import itertools from collections import Counter -from genalog.text import preprocess, alignment -from genalog.text.lcs import LCS + +from genalog.text import alignment, preprocess from genalog.text.alignment import GAP_CHAR +from genalog.text.lcs import LCS # The recursively portion of the algorithm will run on # segments longer than this value to find anchor points in diff --git a/genalog/text/conll_format.py b/genalog/text/conll_format.py index 005e6e3..05b82d0 100644 --- a/genalog/text/conll_format.py +++ b/genalog/text/conll_format.py @@ -34,16 +34,17 @@ example usage python -m genalog.text.conll_format '/data/enki/datasets/synthetic_dataset/' 'hyphens_all' --train_subset """ -import itertools -import difflib import argparse +import concurrent.futures +import difflib +import itertools import json import os import timeit -import concurrent.futures from tqdm import tqdm -from genalog.text import ner_label, alignment + +from genalog.text import alignment, ner_label EMPTY_SENTENCE_SENTINEL = "<<<>>>" EMPTY_SENTENCE_SENTINEL_NER_LABEL = "O" diff --git a/genalog/text/ner_label.py b/genalog/text/ner_label.py index 67817db..b2e16a3 100644 --- a/genalog/text/ner_label.py +++ b/genalog/text/ner_label.py @@ -1,8 +1,9 @@ -from genalog.text import alignment, anchor -from genalog.text import preprocess +import itertools import re import string -import itertools + +from genalog.text import alignment, anchor +from genalog.text import preprocess # Both regex below has the following behavior: # 1. whitespace-tolerant at both ends of the string diff --git a/genalog/text/splitter.py b/genalog/text/splitter.py index 9732c14..9df2fac 100644 --- a/genalog/text/splitter.py +++ b/genalog/text/splitter.py @@ -22,15 +22,17 @@ example usage: python -m genalog.text.splitter CoNLL-2012_train.txt conll2012_train """ -import os -import multiprocessing import argparse -from tqdm import tqdm -from genalog.text import preprocess -from genalog.generation.document import DocumentGenerator -from genalog.generation.content import CompositeContent, ContentType +import multiprocessing +import os from multiprocessing.pool import ThreadPool +from tqdm import tqdm + +from genalog.generation.content import CompositeContent, ContentType +from genalog.generation.document import DocumentGenerator +from genalog.text import preprocess + # default buffer. Preferebly set this to something large # It holds the lines read from the CoNLL file BUFFER_SIZE = 50000 diff --git a/setup.py b/setup.py index 3ea6f51..09bf4f5 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ -import setuptools import os +import setuptools + with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'VERSION.txt')) as version_file: BUILD_VERSION = version_file.read().strip() diff --git a/tests/degradation/test_degrader.py b/tests/degradation/test_degrader.py index 699aeb4..863c89a 100644 --- a/tests/degradation/test_degrader.py +++ b/tests/degradation/test_degrader.py @@ -1,10 +1,11 @@ -from genalog.degradation.degrader import Degrader, ImageState -from genalog.degradation.degrader import DEFAULT_METHOD_PARAM_TO_INCLUDE +import copy from unittest.mock import patch import numpy as np import pytest -import copy + +from genalog.degradation.degrader import DEFAULT_METHOD_PARAM_TO_INCLUDE +from genalog.degradation.degrader import Degrader, ImageState MOCK_IMAGE_SHAPE = (4, 3) MOCK_IMAGE = np.arange(12, dtype=np.uint8).reshape(MOCK_IMAGE_SHAPE) diff --git a/tests/degradation/test_effect.py b/tests/degradation/test_effect.py index b127995..2f406f9 100644 --- a/tests/degradation/test_effect.py +++ b/tests/degradation/test_effect.py @@ -1,9 +1,10 @@ -from genalog.degradation import effect from unittest.mock import patch import numpy as np import pytest +from genalog.degradation import effect + NEW_IMG_SHAPE = (100, 100) MOCK_IMG_SHAPE = (100, 120) MOCK_IMG = np.ones(MOCK_IMG_SHAPE, dtype=np.uint8) diff --git a/tests/e2e/test_anchor_e2e.py b/tests/e2e/test_anchor_e2e.py index 1c8d0ee..119a915 100644 --- a/tests/e2e/test_anchor_e2e.py +++ b/tests/e2e/test_anchor_e2e.py @@ -1,10 +1,11 @@ -from genalog.text import alignment, anchor, preprocess - -import glob -import pytest import difflib +import glob import warnings +import pytest + +from genalog.text import alignment, anchor, preprocess + @pytest.mark.parametrize( "gt_file, ocr_file", diff --git a/tests/e2e/test_conll_format_e2e.py b/tests/e2e/test_conll_format_e2e.py index 4f9c49a..b7bbf6a 100644 --- a/tests/e2e/test_conll_format_e2e.py +++ b/tests/e2e/test_conll_format_e2e.py @@ -1,7 +1,8 @@ -import pytest import glob import itertools +import pytest + from genalog.text import conll_format diff --git a/tests/e2e/test_document_generation.py b/tests/e2e/test_document_generation.py index abdb90b..fd65dd2 100644 --- a/tests/e2e/test_document_generation.py +++ b/tests/e2e/test_document_generation.py @@ -1,8 +1,9 @@ -import pytest import os -from genalog.generation.document import DocumentGenerator +import pytest + from genalog.generation.content import CompositeContent, ContentType +from genalog.generation.document import DocumentGenerator CONTENT = CompositeContent( ["foo", "bar"], [ContentType.PARAGRAPH, ContentType.PARAGRAPH] diff --git a/tests/e2e/test_generaton_n_degradation.py b/tests/e2e/test_generaton_n_degradation.py index 3b16793..00986ed 100644 --- a/tests/e2e/test_generaton_n_degradation.py +++ b/tests/e2e/test_generaton_n_degradation.py @@ -1,6 +1,6 @@ -from genalog.generation.document import DocumentGenerator -from genalog.generation.content import CompositeContent, ContentType from genalog.degradation.degrader import Degrader +from genalog.generation.content import CompositeContent, ContentType +from genalog.generation.document import DocumentGenerator TEST_OUTPUT_DIR = "test_out/" diff --git a/tests/e2e/test_image_channel.py b/tests/e2e/test_image_channel.py index dfbd1fe..461a116 100644 --- a/tests/e2e/test_image_channel.py +++ b/tests/e2e/test_image_channel.py @@ -1,8 +1,8 @@ -import pytest import cv2 +import pytest -from genalog.generation.document import DocumentGenerator from genalog.generation.content import CompositeContent, ContentType +from genalog.generation.document import DocumentGenerator TEMPLATE_PATH = "tests/e2e/templates" TEST_OUT_FOLDER = "test_out/" diff --git a/tests/e2e/test_ocr_e2e.py b/tests/e2e/test_ocr_e2e.py index f0929c6..0f526e4 100644 --- a/tests/e2e/test_ocr_e2e.py +++ b/tests/e2e/test_ocr_e2e.py @@ -1,8 +1,10 @@ +import json + +import pytest +from dotenv import load_dotenv + from genalog.ocr.blob_client import GrokBlobClient from genalog.ocr.grok import Grok -import pytest -import json -from dotenv import load_dotenv load_dotenv("tests/ocr/.env") diff --git a/tests/e2e/test_pipeline.py b/tests/e2e/test_pipeline.py index b9e8270..648797b 100644 --- a/tests/e2e/test_pipeline.py +++ b/tests/e2e/test_pipeline.py @@ -1,7 +1,8 @@ -from genalog import pipeline +import glob import pytest -import glob + +from genalog import pipeline EXAMPLE_TEXT_FILE = "tests/text/data/gt_1.txt" diff --git a/tests/e2e/test_splitter.py b/tests/e2e/test_splitter.py index 317b803..6a5bc40 100644 --- a/tests/e2e/test_splitter.py +++ b/tests/e2e/test_splitter.py @@ -1,7 +1,7 @@ -import os import difflib +import os -from genalog.text.splitter import generate_splits, CONLL2003_DOC_SEPERATOR +from genalog.text.splitter import CONLL2003_DOC_SEPERATOR, generate_splits def _compare_content(file1, file2): diff --git a/tests/generation/test_content.py b/tests/generation/test_content.py index 768128e..775dc0a 100644 --- a/tests/generation/test_content.py +++ b/tests/generation/test_content.py @@ -1,8 +1,8 @@ -from genalog.generation.content import ContentType, Content, CompositeContent -from genalog.generation.content import Paragraph, Title - import pytest +from genalog.generation.content import CompositeContent, Content, ContentType +from genalog.generation.content import Paragraph, Title + CONTENT_LIST = ["foo", "bar"] COMPOSITE_CONTENT_TYPE = [ContentType.TITLE, ContentType.PARAGRAPH] TEXT = "foo bar" diff --git a/tests/generation/test_document.py b/tests/generation/test_document.py index 0e6e601..687657a 100644 --- a/tests/generation/test_document.py +++ b/tests/generation/test_document.py @@ -1,8 +1,10 @@ -from genalog.generation.document import Document, DocumentGenerator -from genalog.generation.document import DEFAULT_DOCUMENT_STYLE +from unittest.mock import MagicMock, patch import pytest -from unittest.mock import MagicMock, patch + +from genalog.generation.document import DEFAULT_DOCUMENT_STYLE +from genalog.generation.document import Document, DocumentGenerator + FRENCH = "fr" CONTENT = ["some text"] diff --git a/tests/ocr/test_metrics.py b/tests/ocr/test_metrics.py index a96405d..50c9265 100644 --- a/tests/ocr/test_metrics.py +++ b/tests/ocr/test_metrics.py @@ -1,13 +1,10 @@ -from genalog.ocr.metrics import ( - get_align_stats, - get_editops_stats, - get_stats, -) -from genalog.text.alignment import GAP_CHAR, align +import pytest + +import genalog.ocr.metrics +from genalog.ocr.metrics import get_align_stats, get_editops_stats, get_stats +from genalog.text.alignment import align, GAP_CHAR from genalog.text.ner_label import _find_gap_char_candidates -import pytest -import genalog.ocr.metrics genalog.ocr.metrics.LOG_LEVEL = 0 diff --git a/tests/ocr/test_ocr.py b/tests/ocr/test_ocr.py index d6385c7..98ccfd1 100644 --- a/tests/ocr/test_ocr.py +++ b/tests/ocr/test_ocr.py @@ -1,7 +1,7 @@ -import requests import json import pytest +import requests from dotenv import load_dotenv from genalog.ocr.rest_client import GrokRestClient diff --git a/tests/text/test_alignment.py b/tests/text/test_alignment.py index 65eef01..88ed70d 100644 --- a/tests/text/test_alignment.py +++ b/tests/text/test_alignment.py @@ -1,14 +1,12 @@ -from genalog.text import alignment - -from tests.cases.text_alignment import ( - PARSE_ALIGNMENT_REGRESSION_TEST_CASES, - ALIGNMENT_REGRESSION_TEST_CASES, -) - +import warnings from random import randint from unittest.mock import MagicMock + import pytest -import warnings + +from genalog.text import alignment +from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES +from tests.cases.text_alignment import PARSE_ALIGNMENT_REGRESSION_TEST_CASES RANDOM_INT = randint(1, 100) MOCK_ALIGNMENT_RESULT = [("X", "X", 0, 0, 1)] diff --git a/tests/text/test_anchor.py b/tests/text/test_anchor.py index 1c63595..0f8a789 100644 --- a/tests/text/test_anchor.py +++ b/tests/text/test_anchor.py @@ -1,10 +1,11 @@ +import glob +import warnings + +import pytest + from genalog.text import alignment, anchor, preprocess from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES -import glob -import pytest -import warnings - @pytest.mark.parametrize( "tokens, case_sensitive, desired_output", diff --git a/tests/text/test_conll_format.py b/tests/text/test_conll_format.py index 462f5ec..563b4e4 100644 --- a/tests/text/test_conll_format.py +++ b/tests/text/test_conll_format.py @@ -1,9 +1,10 @@ -from genalog.text import conll_format +import itertools +import warnings from unittest.mock import patch -import itertools import pytest -import warnings + +from genalog.text import conll_format @pytest.mark.parametrize( diff --git a/tests/text/test_lcs.py b/tests/text/test_lcs.py index 1ce8c7b..74334eb 100644 --- a/tests/text/test_lcs.py +++ b/tests/text/test_lcs.py @@ -1,7 +1,7 @@ -from genalog.text.lcs import LCS - import pytest +from genalog.text.lcs import LCS + @pytest.fixture( params=[ diff --git a/tests/text/test_ner_label.py b/tests/text/test_ner_label.py index 255c89b..c4c1891 100644 --- a/tests/text/test_ner_label.py +++ b/tests/text/test_ner_label.py @@ -1,8 +1,8 @@ +import pytest + from genalog.text import ner_label from tests.cases.label_propagation import LABEL_PROPAGATION_REGRESSION_TEST_CASES -import pytest - @pytest.mark.parametrize( "label, desired_output", diff --git a/tests/text/test_preprocess.py b/tests/text/test_preprocess.py index 5e313e2..d0231cf 100644 --- a/tests/text/test_preprocess.py +++ b/tests/text/test_preprocess.py @@ -1,6 +1,7 @@ +import pytest + from genalog.text import preprocess from genalog.text.alignment import GAP_CHAR -import pytest @pytest.mark.parametrize( diff --git a/tests/text/test_utf8.py b/tests/text/test_utf8.py index 934fc20..0cb6b16 100644 --- a/tests/text/test_utf8.py +++ b/tests/text/test_utf8.py @@ -1,7 +1,8 @@ import random -import pytest import warnings +import pytest + from genalog.text import alignment from genalog.text.alignment import GAP_CHAR from tests.cases.text_alignment import ALIGNMENT_REGRESSION_TEST_CASES