This commit is contained in:
chongyangbai 2024-08-20 21:48:26 +00:00
Родитель 1ef304fb93
Коммит fce17ce68a
6 изменённых файлов: 18 добавлений и 17 удалений

Просмотреть файл

@ -364,7 +364,7 @@ class VisualObjectGroundingTestCases:
]
}
]
class KeyValuePairTestCases:
schema_dicts = [
@ -446,7 +446,7 @@ class KeyValuePairTestCases:
"id": 1,
"image_ids": [1],
"text": {
"query": "Complete the order"
"query": "Complete the order"
},
"fields": {
"action": {

Просмотреть файл

@ -15,16 +15,16 @@ class TestKeyValuePair(BaseCocoAdaptor):
@pytest.mark.parametrize("coco_dict, schema", zip(coco_database[TASK], schema_database[TASK]))
def test_create_data_manifest(self, coco_dict, schema):
super().test_create_data_manifest(coco_dict, schema)
@pytest.mark.parametrize("coco_dict, schema", zip(coco_database[TASK], schema_database[TASK]))
def test_create_data_manifest_with_additional_info(self, coco_dict, schema):
super().test_create_data_manifest_with_additional_info(coco_dict, schema)
def prepare_schema_and_coco_dict(self):
schema = copy.deepcopy(schema_database[TestKeyValuePair.TASK][1])
coco_dict = copy.deepcopy(coco_database[TestKeyValuePair.TASK][1])
return schema, coco_dict
def test_create_data_manifest_example(self):
schema, coco_dict = self.prepare_schema_and_coco_dict()
adaptor = CocoManifestAdaptorFactory.create(TestKeyValuePair.TASK, schema=schema)

Просмотреть файл

@ -12,7 +12,7 @@ class MultiClassClassificationCocoManifestAdaptor(CocoManifestWithCategoriesAdap
def process_label(self, image: ImageDataManifest, annotation, coco_manifest, label_id_to_pos):
if len(image.labels) != 0:
raise ValueError(f"image with id {annotation['image_id']} will possess unexpected number of annotations {len(image.labels) + 1} for {DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS} dataset.")
raise ValueError(f"image with id {annotation['image_id']} will possess unexpected number of annotations {len(image.labels) + 1} for {DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS} dataset")
label = ImageClassificationLabelManifest(label_id_to_pos[annotation['category_id']], additional_info=self._get_additional_info(annotation, {'id', 'image_id', 'category_id'}))
image.labels.append(label)

Просмотреть файл

@ -14,7 +14,8 @@ class KeyValuePairCocoManifestAdaptor(CocoManifestWithMultiImageLabelAdaptor):
def _construct_label_manifest(self, img_ids, ann, coco_manifest):
label_data = self.process_label(ann, coco_manifest)
return KeyValuePairLabelManifest(ann['id'], img_ids, label_data, self._get_additional_info(ann, {'id', KeyValuePairLabelManifest.IMAGES_INPUT_KEY, KeyValuePairLabelManifest.LABEL_KEY, KeyValuePairLabelManifest.TEXT_INPUT_KEY}))
return KeyValuePairLabelManifest(ann['id'], img_ids, label_data, self._get_additional_info(ann, {'id', KeyValuePairLabelManifest.IMAGES_INPUT_KEY, KeyValuePairLabelManifest.LABEL_KEY,
KeyValuePairLabelManifest.TEXT_INPUT_KEY}))
def _construct_manifest(self, images_by_id, coco_manifest, data_type, additional_info):
images, annotations = self.get_images_and_annotations(images_by_id, coco_manifest)

Просмотреть файл

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional
from enum import Enum
from ..common import MultiImageLabelManifest, DatasetManifestWithMultiImageLabel, DatasetTypes
@ -40,11 +40,11 @@ class KeyValuePairFieldSchema:
KeyValuePairValueTypes.BOOLEAN: bool,
KeyValuePairValueTypes.STRING: str
}
def __init__(self, type: str,
description: str = None,
examples: List[str] = None,
classes: Dict[str|int|float, KeyValuePairClassSchema] = None,
classes: Dict[str | int | float, KeyValuePairClassSchema] = None,
items: 'KeyValuePairFieldSchema' = None,
properties: Dict[str, 'KeyValuePairFieldSchema'] = None,
includeGrounding: bool = False) -> None:
@ -71,7 +71,7 @@ class KeyValuePairFieldSchema:
self.properties = {k: KeyValuePairFieldSchema(**v) for k, v in properties.items()} if properties else None
self.includeGrounding = includeGrounding
self._check()
def __eq__(self, other) -> bool:
if not isinstance(other, KeyValuePairFieldSchema):
return False
@ -82,7 +82,7 @@ class KeyValuePairFieldSchema:
and self.items == other.items
and self.properties == other.properties
and self.includeGrounding == other.includeGrounding)
def _check(self):
if self.type not in self.TYPE_NAME_TO_PYTHON_TYPE:
raise ValueError(f'Invalid type: {self.type}')
@ -110,7 +110,7 @@ class KeyValuePairSchema:
class KeyValuePairLabelManifest(MultiImageLabelManifest):
"""
Label manifest for key-value pair annotations. The "fields" field follows KeyValuePairSchema.
For example, the label data can be:
For example, the label data can be:
{
"fields": {
"key1": {"value": "v1", "groundings": [[10,10,5,5]]},
@ -134,7 +134,7 @@ class KeyValuePairLabelManifest(MultiImageLabelManifest):
LABEL_GROUNDINGS_KEY = 'groundings'
TEXT_INPUT_KEY = 'text'
IMAGES_INPUT_KEY = 'image_ids'
@property
def fields(self) -> dict:
return self.label_data[self.LABEL_KEY]
@ -145,7 +145,7 @@ class KeyValuePairLabelManifest(MultiImageLabelManifest):
def _read_label_data(self):
raise NotImplementedError('Read label data is not supported!')
def _check_label(self, label_data):
if not isinstance(label_data, dict) or self.LABEL_KEY not in label_data:
raise ValueError(f'{self.LABEL_KEY} not found in label_data dictionary: {label_data}')
@ -193,7 +193,7 @@ class KeyValuePairDatasetManifest(DatasetManifestWithMultiImageLabel):
self.schema = KeyValuePairSchema(schema['name'], schema['fieldSchema'], schema.get('description'))
super().__init__(images, annotations, DatasetTypes.KEY_VALUE_PAIR, additional_info)
self._check_annotations()
def _check_annotations(self):
for ann in self.annotations:
if not isinstance(ann, KeyValuePairLabelManifest):

Просмотреть файл

@ -1,6 +1,6 @@
from ..common import DatasetTypes, MultiImageCocoDictGenerator, \
MultiImageDatasetSingleTaskMerge, CocoDictGeneratorFactory, ManifestMergeStrategyFactory
from .manifest import KeyValuePairLabelManifest, KeyValuePairDatasetManifest
_DATA_TYPE = DatasetTypes.KEY_VALUE_PAIR