Show incompatible points in histogram (#51)
* Change h2 errors by class to h1h2 incompatible points by class Fix some variable names - feature is working end to end Fix app crash if there are no incompatible points Refactor finding incompatible instances to its own method simplify get_incompatible_instances_by_class and add docstring * Merged code in working state * remove debug print * fix bar click to use d.incompatibleInstanceIds instead of d.errorInstanceIds Co-authored-by: Nicholas King <v-nicki@microsoft.com>
This commit is contained in:
Родитель
fb20d2d7ce
Коммит
e8062ebabf
|
@ -376,6 +376,26 @@ def get_all_error_instance_indices(h1, h2, batch_ids, batched_evaluation_data, b
|
|||
instance_ground_truths))
|
||||
|
||||
|
||||
def get_incompatible_instances_by_class(all_errors, batch_ids, batched_evaluation_target, class_incompatible_instance_ids):
|
||||
"""
|
||||
Finds instances where h2 is incompatible with h1 and inserts
|
||||
{class : incompatible_data_id} mappings into the class_incompatible_instance_ids dictionary.
|
||||
|
||||
Args:
|
||||
all_errors: A list of tuples of error indices, h1 and h2 predictions, and ground truth for each instance
|
||||
batch_ids: The instance ids of the data rows in the batched data.
|
||||
batched_evaluation_target: A single batch of the corresponding output targets.
|
||||
class_incompatible_instance_ids: The dictionary to fill with incompatible instances and their ids
|
||||
"""
|
||||
for (error_instance_id, error_instance_metadata, h1_prediction, h2_prediction, ground_truth) in all_errors:
|
||||
if (h1_prediction == ground_truth and h2_prediction != ground_truth):
|
||||
batch_index = batch_ids.index(error_instance_id)
|
||||
incompatible_class = batched_evaluation_target[batch_index].item()
|
||||
if (incompatible_class not in class_incompatible_instance_ids):
|
||||
class_incompatible_instance_ids[incompatible_class] = []
|
||||
class_incompatible_instance_ids[incompatible_class].append(error_instance_id)
|
||||
|
||||
|
||||
def get_model_error_overlap(h1, h2, batch_ids, batched_evaluation_data, batched_evaluation_target,
|
||||
device="cpu"):
|
||||
"""
|
||||
|
@ -530,33 +550,27 @@ def evaluate_model_performance_and_compatibility_on_dataset(h1, h2, dataset, per
|
|||
h1_dataset_error_instance_ids = []
|
||||
h2_dataset_error_instance_ids = []
|
||||
h1_and_h2_dataset_error_instance_ids = []
|
||||
h2_dataset_error_instance_ids_by_class = {}
|
||||
h1h2_dataset_incompatible_instance_ids_by_class = {}
|
||||
classes = set()
|
||||
all_error_instances = []
|
||||
for batch_ids, data, target in dataset:
|
||||
classes = classes.union(target.tolist())
|
||||
h1_error_count_batch, h2_error_count_batch, h1_and_h2_error_count_batch =\
|
||||
get_model_error_overlap(h1, h2, batch_ids, data, target, device=device)
|
||||
h2_error_instance_ids_by_class =\
|
||||
get_error_instance_ids_by_class(h2, batch_ids, data, target, device=device)
|
||||
all_errors = get_all_error_instance_indices(
|
||||
h1, h2, batch_ids, data, target,
|
||||
get_instance_metadata=get_instance_metadata, device=device)
|
||||
get_incompatible_instances_by_class(all_errors, batch_ids, target, h1h2_dataset_incompatible_instance_ids_by_class)
|
||||
all_error_instances += all_errors
|
||||
h1_dataset_error_instance_ids += h1_error_count_batch
|
||||
h2_dataset_error_instance_ids += h2_error_count_batch
|
||||
h1_and_h2_dataset_error_instance_ids += h1_and_h2_error_count_batch
|
||||
for class_label, error_instance_ids in h2_error_instance_ids_by_class.items():
|
||||
if class_label in h2_dataset_error_instance_ids_by_class:
|
||||
h2_dataset_error_instance_ids_by_class[class_label] += error_instance_ids
|
||||
else:
|
||||
h2_dataset_error_instance_ids_by_class[class_label] = error_instance_ids
|
||||
|
||||
h2_ds_error_instance_ids_by_class = []
|
||||
for class_label, error_instance_ids in h2_dataset_error_instance_ids_by_class.items():
|
||||
h2_ds_error_instance_ids_by_class.append({
|
||||
h1h2_ds_incompatible_instance_ids_by_class = []
|
||||
for class_label, incompatible_instance_ids in h1h2_dataset_incompatible_instance_ids_by_class.items():
|
||||
h1h2_ds_incompatible_instance_ids_by_class.append({
|
||||
"class": class_label,
|
||||
"errorInstanceIds": error_instance_ids
|
||||
"incompatibleInstanceIds": incompatible_instance_ids
|
||||
})
|
||||
|
||||
h2_performance = performance_metric(h2, dataset, device)
|
||||
|
@ -580,7 +594,7 @@ def evaluate_model_performance_and_compatibility_on_dataset(h1, h2, dataset, per
|
|||
h1_and_h2_dataset_error_instance_ids
|
||||
],
|
||||
|
||||
"h2_error_instance_ids_by_class": h2_ds_error_instance_ids_by_class,
|
||||
"h2_incompatible_instance_ids_by_class": h1h2_ds_incompatible_instance_ids_by_class,
|
||||
"sorted_classes": sorted(list(classes)),
|
||||
"h2_performance": h2_performance,
|
||||
"btc": btc,
|
||||
|
|
|
@ -75,28 +75,34 @@ class IncompatiblePointDistribution extends Component<IncompatiblePointDistribut
|
|||
|
||||
if (this.props.selectedDataPoint != null) {
|
||||
// Sort the data into the dataRows based on the ordering of the sorted classes
|
||||
var totalErrors = 0;
|
||||
var totalIncompatible = 0;
|
||||
var startI = this.state.page * this.props.pageSize;
|
||||
var endI = Math.min(startI + this.props.pageSize, this.props.selectedDataPoint.sorted_classes.length);
|
||||
for (var i = startI; i < endI; i++) {
|
||||
var instanceClass = this.props.selectedDataPoint.sorted_classes[i];
|
||||
var dataRow = this.props.selectedDataPoint.h2_error_instance_ids_by_class.filter(
|
||||
var dataRow = this.props.selectedDataPoint.h2_incompatible_instance_ids_by_class.filter(
|
||||
dataDict => (dataDict["class"] == instanceClass)).pop();
|
||||
totalErrors += dataRow["errorInstanceIds"].length;
|
||||
if (dataRow) {
|
||||
totalIncompatible += dataRow["incompatibleInstanceIds"]?.length ?? 0;
|
||||
}
|
||||
}
|
||||
|
||||
// We add the following so that we do not get a divide by zero
|
||||
// error later on if there are no errors.
|
||||
if (totalErrors == 0) {
|
||||
totalErrors = 1;
|
||||
// error later on if there are no incompatible points.
|
||||
if (totalIncompatible == 0) {
|
||||
totalIncompatible = 1;
|
||||
}
|
||||
|
||||
var dataRows = [];
|
||||
for (var i=startI; i < endI; i++) {
|
||||
var instanceClass = this.props.selectedDataPoint.sorted_classes[i];
|
||||
var dataRow = this.props.selectedDataPoint.h2_error_instance_ids_by_class.filter(
|
||||
var dataRow = this.props.selectedDataPoint.h2_incompatible_instance_ids_by_class.filter(
|
||||
dataDict => (dataDict["class"] == instanceClass)).pop();
|
||||
dataRows.push(dataRow);
|
||||
if (dataRow) {
|
||||
dataRows.push(dataRow);
|
||||
} else {
|
||||
dataRows.push({class: instanceClass, incompatibleInstanceIds: []})
|
||||
}
|
||||
}
|
||||
|
||||
var xScale = d3.scaleBand().range([0, w]).padding(0.4),
|
||||
|
@ -136,11 +142,11 @@ class IncompatiblePointDistribution extends Component<IncompatiblePointDistribut
|
|||
.enter().append("rect")
|
||||
.attr("class", "bar")
|
||||
.attr("x", function(d) { return xScale(d.class); })
|
||||
.attr("y", function(d) { return yScale(d.errorInstanceIds.length/totalErrors * 100); })
|
||||
.attr("y", function(d) { return yScale(d.incompatibleInstanceIds.length/totalIncompatible * 100); })
|
||||
.attr("width", xScale.bandwidth())
|
||||
.attr("height", function(d) { return h - yScale(d.errorInstanceIds.length/totalErrors * 100); })
|
||||
.attr("height", function(d) { return h - yScale(d.incompatibleInstanceIds.length/totalIncompatible * 100); })
|
||||
.on("click", function(d) {
|
||||
_this.props.filterByInstanceIds(d.errorInstanceIds);
|
||||
_this.props.filterByInstanceIds(d.incompatibleInstanceIds);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче