diff --git a/backwardcompatibilityml/helpers/training.py b/backwardcompatibilityml/helpers/training.py index fb6df1d..e9c6804 100644 --- a/backwardcompatibilityml/helpers/training.py +++ b/backwardcompatibilityml/helpers/training.py @@ -376,6 +376,26 @@ def get_all_error_instance_indices(h1, h2, batch_ids, batched_evaluation_data, b instance_ground_truths)) +def get_incompatible_instances_by_class(all_errors, batch_ids, batched_evaluation_target, class_incompatible_instance_ids): + """ + Finds instances where h2 is incompatible with h1 and inserts + {class : incompatible_data_id} mappings into the class_incompatible_instance_ids dictionary. + + Args: + all_errors: A list of tuples of error indices, h1 and h2 predictions, and ground truth for each instance + batch_ids: The instance ids of the data rows in the batched data. + batched_evaluation_target: A single batch of the corresponding output targets. + class_incompatible_instance_ids: The dictionary to fill with incompatible instances and their ids + """ + for (error_instance_id, error_instance_metadata, h1_prediction, h2_prediction, ground_truth) in all_errors: + if (h1_prediction == ground_truth and h2_prediction != ground_truth): + batch_index = batch_ids.index(error_instance_id) + incompatible_class = batched_evaluation_target[batch_index].item() + if (incompatible_class not in class_incompatible_instance_ids): + class_incompatible_instance_ids[incompatible_class] = [] + class_incompatible_instance_ids[incompatible_class].append(error_instance_id) + + def get_model_error_overlap(h1, h2, batch_ids, batched_evaluation_data, batched_evaluation_target, device="cpu"): """ @@ -530,33 +550,27 @@ def evaluate_model_performance_and_compatibility_on_dataset(h1, h2, dataset, per h1_dataset_error_instance_ids = [] h2_dataset_error_instance_ids = [] h1_and_h2_dataset_error_instance_ids = [] - h2_dataset_error_instance_ids_by_class = {} + h1h2_dataset_incompatible_instance_ids_by_class = {} classes = set() all_error_instances = [] for batch_ids, data, target in dataset: classes = classes.union(target.tolist()) h1_error_count_batch, h2_error_count_batch, h1_and_h2_error_count_batch =\ get_model_error_overlap(h1, h2, batch_ids, data, target, device=device) - h2_error_instance_ids_by_class =\ - get_error_instance_ids_by_class(h2, batch_ids, data, target, device=device) all_errors = get_all_error_instance_indices( h1, h2, batch_ids, data, target, get_instance_metadata=get_instance_metadata, device=device) + get_incompatible_instances_by_class(all_errors, batch_ids, target, h1h2_dataset_incompatible_instance_ids_by_class) all_error_instances += all_errors h1_dataset_error_instance_ids += h1_error_count_batch h2_dataset_error_instance_ids += h2_error_count_batch h1_and_h2_dataset_error_instance_ids += h1_and_h2_error_count_batch - for class_label, error_instance_ids in h2_error_instance_ids_by_class.items(): - if class_label in h2_dataset_error_instance_ids_by_class: - h2_dataset_error_instance_ids_by_class[class_label] += error_instance_ids - else: - h2_dataset_error_instance_ids_by_class[class_label] = error_instance_ids - h2_ds_error_instance_ids_by_class = [] - for class_label, error_instance_ids in h2_dataset_error_instance_ids_by_class.items(): - h2_ds_error_instance_ids_by_class.append({ + h1h2_ds_incompatible_instance_ids_by_class = [] + for class_label, incompatible_instance_ids in h1h2_dataset_incompatible_instance_ids_by_class.items(): + h1h2_ds_incompatible_instance_ids_by_class.append({ "class": class_label, - "errorInstanceIds": error_instance_ids + "incompatibleInstanceIds": incompatible_instance_ids }) h2_performance = performance_metric(h2, dataset, device) @@ -580,7 +594,7 @@ def evaluate_model_performance_and_compatibility_on_dataset(h1, h2, dataset, per h1_and_h2_dataset_error_instance_ids ], - "h2_error_instance_ids_by_class": h2_ds_error_instance_ids_by_class, + "h2_incompatible_instance_ids_by_class": h1h2_ds_incompatible_instance_ids_by_class, "sorted_classes": sorted(list(classes)), "h2_performance": h2_performance, "btc": btc, diff --git a/widget/IncompatiblePointDistribution.tsx b/widget/IncompatiblePointDistribution.tsx index 2fe911b..3e0738c 100644 --- a/widget/IncompatiblePointDistribution.tsx +++ b/widget/IncompatiblePointDistribution.tsx @@ -75,28 +75,34 @@ class IncompatiblePointDistribution extends Component (dataDict["class"] == instanceClass)).pop(); - totalErrors += dataRow["errorInstanceIds"].length; + if (dataRow) { + totalIncompatible += dataRow["incompatibleInstanceIds"]?.length ?? 0; + } } // We add the following so that we do not get a divide by zero - // error later on if there are no errors. - if (totalErrors == 0) { - totalErrors = 1; + // error later on if there are no incompatible points. + if (totalIncompatible == 0) { + totalIncompatible = 1; } var dataRows = []; for (var i=startI; i < endI; i++) { var instanceClass = this.props.selectedDataPoint.sorted_classes[i]; - var dataRow = this.props.selectedDataPoint.h2_error_instance_ids_by_class.filter( + var dataRow = this.props.selectedDataPoint.h2_incompatible_instance_ids_by_class.filter( dataDict => (dataDict["class"] == instanceClass)).pop(); - dataRows.push(dataRow); + if (dataRow) { + dataRows.push(dataRow); + } else { + dataRows.push({class: instanceClass, incompatibleInstanceIds: []}) + } } var xScale = d3.scaleBand().range([0, w]).padding(0.4), @@ -136,11 +142,11 @@ class IncompatiblePointDistribution extends Component