add average parameter to MeanAveragePrecision to specify micro or macro calculation (#2412)
This commit is contained in:
Родитель
4c01e0cbb1
Коммит
6a5d308567
|
@ -1191,6 +1191,7 @@ class RAIVisionInsights(RAIBaseInsights):
|
|||
continue
|
||||
|
||||
metric_OD = MeanAveragePrecision(
|
||||
average=aggregate_method.lower(),
|
||||
class_metrics=True,
|
||||
iou_thresholds=normalized_iou_threshold).to(device)
|
||||
true_y_cohort = [true_y[cohort_index] for cohort_index
|
||||
|
|
|
@ -332,11 +332,12 @@ def run_rai_insights(model, test_data, target_column,
|
|||
ignore_index)
|
||||
if task_type == ModelTask.OBJECT_DETECTION:
|
||||
selection_indexes = [[0]]
|
||||
aggregate_method = 'Macro'
|
||||
class_name = classes[0]
|
||||
iou_threshold = 70
|
||||
object_detection_cache = {}
|
||||
metrics = rai_insights.compute_object_detection_metrics(
|
||||
selection_indexes, aggregate_method, class_name, iou_threshold,
|
||||
object_detection_cache)
|
||||
assert len(metrics) == 2
|
||||
aggregate_methods = ['macro', 'micro']
|
||||
for aggregate_method in aggregate_methods:
|
||||
metrics = rai_insights.compute_object_detection_metrics(
|
||||
selection_indexes, aggregate_method, class_name, iou_threshold,
|
||||
object_detection_cache)
|
||||
assert len(metrics) == 2
|
||||
|
|
Загрузка…
Ссылка в новой задаче