diff --git a/tests/unit/similarity/test_similarity_data.py b/tests/unit/similarity/test_similarity_data.py index bbe2ccc..d340ef2 100644 --- a/tests/unit/similarity/test_similarity_data.py +++ b/tests/unit/similarity/test_similarity_data.py @@ -1,14 +1,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import numpy as np + from utils_cv.similarity.data import comparative_set_builder def test_comparative_set_builder(testing_databunch): - resulting_set = comparative_set_builder(testing_databunch) - # first_key, first_value = next(iter(resulting_set.items())) - - assert isinstance(resulting_set, list) - assert len(resulting_set) == len(testing_databunch.y) - # assert isinstance(first_key, str) - # assert isinstance(first_value, list) + comparative_sets = comparative_set_builder(testing_databunch, num_sets = 20, num_negatives=50) + assert isinstance(comparative_sets, list) + assert len(comparative_sets) == 20 + for cs in comparative_sets: + assert len(cs.neg_im_paths) == 50 + neg_and_pos_label_identical = np.where(np.array(cs.neg_labels) == cs.pos_label)[0] + assert len(neg_and_pos_label_identical)==0, "Negative contains at least one image with same label as the positive"