Show incompatible points in histogram (#51)

* Change h2 errors by class to h1h2 incompatible points by class

Fix some variable names - feature is working end to end

Fix app crash if there are no incompatible points

Refactor finding incompatible instances to its own method

simplify get_incompatible_instances_by_class and add docstring

* Merged code in working state

* remove debug print

* fix bar click to use d.incompatibleInstanceIds instead of d.errorInstanceIds

Co-authored-by: Nicholas King <v-nicki@microsoft.com>
This commit is contained in:
Nicholas King 2020-11-10 11:15:19 -08:00 коммит произвёл Xavier Fernandes
Родитель fb20d2d7ce
Коммит e8062ebabf
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 1B011D38C073A7F2
2 изменённых файлов: 44 добавлений и 24 удалений

Просмотреть файл

@ -376,6 +376,26 @@ def get_all_error_instance_indices(h1, h2, batch_ids, batched_evaluation_data, b
instance_ground_truths))
def get_incompatible_instances_by_class(all_errors, batch_ids, batched_evaluation_target, class_incompatible_instance_ids):
"""
Finds instances where h2 is incompatible with h1 and inserts
{class : incompatible_data_id} mappings into the class_incompatible_instance_ids dictionary.
Args:
all_errors: A list of tuples of error indices, h1 and h2 predictions, and ground truth for each instance
batch_ids: The instance ids of the data rows in the batched data.
batched_evaluation_target: A single batch of the corresponding output targets.
class_incompatible_instance_ids: The dictionary to fill with incompatible instances and their ids
"""
for (error_instance_id, error_instance_metadata, h1_prediction, h2_prediction, ground_truth) in all_errors:
if (h1_prediction == ground_truth and h2_prediction != ground_truth):
batch_index = batch_ids.index(error_instance_id)
incompatible_class = batched_evaluation_target[batch_index].item()
if (incompatible_class not in class_incompatible_instance_ids):
class_incompatible_instance_ids[incompatible_class] = []
class_incompatible_instance_ids[incompatible_class].append(error_instance_id)
def get_model_error_overlap(h1, h2, batch_ids, batched_evaluation_data, batched_evaluation_target,
device="cpu"):
"""
@ -530,33 +550,27 @@ def evaluate_model_performance_and_compatibility_on_dataset(h1, h2, dataset, per
h1_dataset_error_instance_ids = []
h2_dataset_error_instance_ids = []
h1_and_h2_dataset_error_instance_ids = []
h2_dataset_error_instance_ids_by_class = {}
h1h2_dataset_incompatible_instance_ids_by_class = {}
classes = set()
all_error_instances = []
for batch_ids, data, target in dataset:
classes = classes.union(target.tolist())
h1_error_count_batch, h2_error_count_batch, h1_and_h2_error_count_batch =\
get_model_error_overlap(h1, h2, batch_ids, data, target, device=device)
h2_error_instance_ids_by_class =\
get_error_instance_ids_by_class(h2, batch_ids, data, target, device=device)
all_errors = get_all_error_instance_indices(
h1, h2, batch_ids, data, target,
get_instance_metadata=get_instance_metadata, device=device)
get_incompatible_instances_by_class(all_errors, batch_ids, target, h1h2_dataset_incompatible_instance_ids_by_class)
all_error_instances += all_errors
h1_dataset_error_instance_ids += h1_error_count_batch
h2_dataset_error_instance_ids += h2_error_count_batch
h1_and_h2_dataset_error_instance_ids += h1_and_h2_error_count_batch
for class_label, error_instance_ids in h2_error_instance_ids_by_class.items():
if class_label in h2_dataset_error_instance_ids_by_class:
h2_dataset_error_instance_ids_by_class[class_label] += error_instance_ids
else:
h2_dataset_error_instance_ids_by_class[class_label] = error_instance_ids
h2_ds_error_instance_ids_by_class = []
for class_label, error_instance_ids in h2_dataset_error_instance_ids_by_class.items():
h2_ds_error_instance_ids_by_class.append({
h1h2_ds_incompatible_instance_ids_by_class = []
for class_label, incompatible_instance_ids in h1h2_dataset_incompatible_instance_ids_by_class.items():
h1h2_ds_incompatible_instance_ids_by_class.append({
"class": class_label,
"errorInstanceIds": error_instance_ids
"incompatibleInstanceIds": incompatible_instance_ids
})
h2_performance = performance_metric(h2, dataset, device)
@ -580,7 +594,7 @@ def evaluate_model_performance_and_compatibility_on_dataset(h1, h2, dataset, per
h1_and_h2_dataset_error_instance_ids
],
"h2_error_instance_ids_by_class": h2_ds_error_instance_ids_by_class,
"h2_incompatible_instance_ids_by_class": h1h2_ds_incompatible_instance_ids_by_class,
"sorted_classes": sorted(list(classes)),
"h2_performance": h2_performance,
"btc": btc,

Просмотреть файл

@ -75,28 +75,34 @@ class IncompatiblePointDistribution extends Component<IncompatiblePointDistribut
if (this.props.selectedDataPoint != null) {
// Sort the data into the dataRows based on the ordering of the sorted classes
var totalErrors = 0;
var totalIncompatible = 0;
var startI = this.state.page * this.props.pageSize;
var endI = Math.min(startI + this.props.pageSize, this.props.selectedDataPoint.sorted_classes.length);
for (var i = startI; i < endI; i++) {
var instanceClass = this.props.selectedDataPoint.sorted_classes[i];
var dataRow = this.props.selectedDataPoint.h2_error_instance_ids_by_class.filter(
var dataRow = this.props.selectedDataPoint.h2_incompatible_instance_ids_by_class.filter(
dataDict => (dataDict["class"] == instanceClass)).pop();
totalErrors += dataRow["errorInstanceIds"].length;
if (dataRow) {
totalIncompatible += dataRow["incompatibleInstanceIds"]?.length ?? 0;
}
}
// We add the following so that we do not get a divide by zero
// error later on if there are no errors.
if (totalErrors == 0) {
totalErrors = 1;
// error later on if there are no incompatible points.
if (totalIncompatible == 0) {
totalIncompatible = 1;
}
var dataRows = [];
for (var i=startI; i < endI; i++) {
var instanceClass = this.props.selectedDataPoint.sorted_classes[i];
var dataRow = this.props.selectedDataPoint.h2_error_instance_ids_by_class.filter(
var dataRow = this.props.selectedDataPoint.h2_incompatible_instance_ids_by_class.filter(
dataDict => (dataDict["class"] == instanceClass)).pop();
dataRows.push(dataRow);
if (dataRow) {
dataRows.push(dataRow);
} else {
dataRows.push({class: instanceClass, incompatibleInstanceIds: []})
}
}
var xScale = d3.scaleBand().range([0, w]).padding(0.4),
@ -136,11 +142,11 @@ class IncompatiblePointDistribution extends Component<IncompatiblePointDistribut
.enter().append("rect")
.attr("class", "bar")
.attr("x", function(d) { return xScale(d.class); })
.attr("y", function(d) { return yScale(d.errorInstanceIds.length/totalErrors * 100); })
.attr("y", function(d) { return yScale(d.incompatibleInstanceIds.length/totalIncompatible * 100); })
.attr("width", xScale.bandwidth())
.attr("height", function(d) { return h - yScale(d.errorInstanceIds.length/totalErrors * 100); })
.attr("height", function(d) { return h - yScale(d.incompatibleInstanceIds.length/totalIncompatible * 100); })
.on("click", function(d) {
_this.props.filterByInstanceIds(d.errorInstanceIds);
_this.props.filterByInstanceIds(d.incompatibleInstanceIds);
});
}
}