Add statefulness to selectivity of the regions of the Venn diagram, and calculate the percentage of errors out of the total number of instances (#61)

* Add statefulness to selectivity of the regions of the Venn diagram, and calculate the percentage of errors out of the total number of instances

* Better differentiate hover vs select styles
This commit is contained in:
ilmarinen 2020-11-10 12:09:51 -08:00 коммит произвёл Xavier Fernandes
Родитель e8062ebabf
Коммит 26908f40d5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 1B011D38C073A7F2
2 изменённых файлов: 190 добавлений и 136 удалений

Просмотреть файл

@ -553,7 +553,9 @@ def evaluate_model_performance_and_compatibility_on_dataset(h1, h2, dataset, per
h1h2_dataset_incompatible_instance_ids_by_class = {}
classes = set()
all_error_instances = []
dataset_size = 0
for batch_ids, data, target in dataset:
dataset_size += len(batch_ids)
classes = classes.union(target.tolist())
h1_error_count_batch, h2_error_count_batch, h1_and_h2_error_count_batch =\
get_model_error_overlap(h1, h2, batch_ids, data, target, device=device)
@ -599,7 +601,8 @@ def evaluate_model_performance_and_compatibility_on_dataset(h1, h2, dataset, per
"h2_performance": h2_performance,
"btc": btc,
"bec": bec,
"error_instances": all_error_instances_results
"error_instances": all_error_instances_results,
"dataset_size": dataset_size
}

Просмотреть файл

@ -8,7 +8,8 @@ import { bisect } from "./optimization.tsx";
type IntersectionBetweenModelErrorsState = {
selectedDataPoint: any
selectedDataPoint: any,
regionSelected: any
}
type IntersectionBetweenModelErrorsProps = {
@ -21,7 +22,8 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
super(props);
this.state = {
selectedDataPoint: this.props.selectedDataPoint
selectedDataPoint: this.props.selectedDataPoint,
regionSelected: null
};
this.node = React.createRef<HTMLDivElement>();
@ -36,7 +38,8 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
componentWillReceiveProps(nextProps) {
this.setState({
selectedDataPoint: nextProps.selectedDataPoint
selectedDataPoint: nextProps.selectedDataPoint,
regionSelected: null
});
}
@ -90,100 +93,38 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
var b = errorPartition[1].length;
var ab = errorPartition[2].length;
// Error instance ids of regress instances
var regress = errorPartition[0].filter(instanceId => (errorPartition[2].indexOf(instanceId) == -1));
// Error instance ids of progress instances
var progress = errorPartition[1].filter(instanceId => (errorPartition[2].indexOf(instanceId) == -1));
var regressSize = regress.length;
var regressProportion = regressSize / this.state.selectedDataPoint.dataset_size;
var progressSize = progress.length;
var progressProportion = progressSize / this.state.selectedDataPoint.dataset_size;
var totalErrors = a + b - ab
var aProportion = 0.0
var bProportion = 0.0
var abProportion = 0.0
if (totalErrors > 0) {
aProportion = a / totalErrors;
bProportion = b / totalErrors;
abProportion = ab / totalErrors;
}
var data = [
{"name": "intersectionRaRb", "area": ab},
{"name": "Ra", "area": a},
{"name": "Rb", "area": b},
{"name": "intersectionRaRb", "area": ab}
]
let Ra;
let Rb;
let Aab;
let x = 1;
if (a >= b) {
x = (50 * 50 * 3.14) / a;
Ra = 50;
Rb = Math.sqrt(b * x / 3.14);
Aab = ab * x;
} else {
x = (50 * 50 * 3.14) / b;
Rb = 50;
Ra = Math.sqrt(a * x / 3.14);
Aab = ab * x;
}
function areaIntersection(r1, r2, dist) {
var r = Math.min(r1, r2);
var R = Math.max(r1, r2);
if (dist == 0) {
return (3.14 * Math.pow(r, 2));
} else if (dist >= (R + r)) {
return 0;
}
var sectorAreas = Math.pow(r, 2) * Math.acos((Math.pow(dist, 2) +
Math.pow(r, 2) - Math.pow(R, 2)) / (2 * dist * r)) +
Math.pow(R, 2) * Math.acos((Math.pow(dist, 2) + Math.pow(R, 2) - Math.pow(r, 2)) / (2 * dist * R));
var triangleAreas = 1/2 * Math.sqrt((-dist + r + R) * (dist + r - R) * (dist - r + R) * (dist + r + R));
var intersectionArea = sectorAreas - triangleAreas;
return intersectionArea;
}
var r = Math.min(Ra, Rb);
var R = Math.max(Ra, Rb);
function aIntersection(dist) {
return (areaIntersection(Ra, Rb, dist) - Aab);
}
let d = bisect(aIntersection, (r + R - 0.00001), (R - r + 0.00001));
var circleRad = 50;
var xCenter = w/2 - d/2;
var yCenter = h/2
var xCenter2 = xCenter + d;
{"name": "Rb", "area": b}
];
// Colors
var green = "rgba(175, 227, 141, 0.8)";
var red = "rgba(206, 160, 205, 0.8)";
var yellow = "rgba(241, 241, 127, 0.8)";
var areas = svg.append("g").attr("id", "areas");
// Draw the legend of the Vnn diagram
var legendEntries = [
{"label": "Progress", "color": green},
{"label": "Regress", "color": red},
{"label": "Common", "color": yellow},
];
function bringToTop(name) {
areas.selectAll("g")
.sort(function(a, b) {
if (a.name == "intersectionRaRb") {
return 1;
} else if (b.name == "intersectionRaRb") {
return -1;
} else if (a.name == name) {
return 1;
} else {
return -1;
}
});
}
var vennLegend = svg.append("g").attr("id", "vennlegend");
vennLegend.append("rect")
.attr("id", "progress")
.attr("width", "10px")
@ -238,52 +179,164 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
return d["label"];
});
areas.append("circle")
.attr("r", Ra)
.attr('transform',
"translate(" +
xCenter + "," +
yCenter + ")")
.attr("fill", "rgba(175, 227, 141, 0.8)")
if (totalErrors > 0) {
aProportion = a / this.state.selectedDataPoint.dataset_size;
bProportion = b / this.state.selectedDataPoint.dataset_size;
abProportion = ab / this.state.selectedDataPoint.dataset_size;
// We use the following reasoning:
// aProportion = Area(circleA) / totalArea
// bProportion = Area(circleB) / totalArea
// We want the largest circle to have a radius of 50.
// So we select the larget of the two circular regions and set
// its area to be Pi * 50 * 50.
// Thus if aProportion is the larger proportion, we have
// that Ra must be 50 and aProportion = (Pi * 50 * 50) / totalArea
// Thus totalArea = (Pi * 50 * 50) / aProportion.
// Similarly if bProportion is the larger proportion:
// We have that totalArea = (Pi * 50 * 50) / bProportion.
let Ra;
let Rb;
let Aab;
let totalArea = 1;
if (a >= b) {
totalArea = (50 * 50 * 3.14) / aProportion;
Ra = 50;
Rb = Math.sqrt(bProportion * totalArea / 3.14);
Aab = abProportion * totalArea;
} else {
totalArea = (50 * 50 * 3.14) / bProportion;
Rb = 50;
Ra = Math.sqrt(aProportion * totalArea / 3.14);
Aab = abProportion * totalArea;
}
// This function calcuates the area of overlap
// of two circles of radii r1 and r2 whose
// centers are separated by a distance d.
function areaOverlap(r1, r2, dist) {
var r = Math.min(r1, r2);
var R = Math.max(r1, r2);
if (dist == 0) {
return (3.14 * Math.pow(r, 2));
} else if (dist >= (R + r)) {
return 0;
}
var sectorAreas = Math.pow(r, 2) * Math.acos((Math.pow(dist, 2) +
Math.pow(r, 2) - Math.pow(R, 2)) / (2 * dist * r)) +
Math.pow(R, 2) * Math.acos((Math.pow(dist, 2) + Math.pow(R, 2) - Math.pow(r, 2)) / (2 * dist * R));
var triangleAreas = 1/2 * Math.sqrt((-dist + r + R) * (dist + r - R) * (dist - r + R) * (dist + r + R));
var overlapArea = sectorAreas - triangleAreas;
return overlapArea;
}
// In order to decide the distance d that the circles
// Need to be from each other, we use the bisection method
// to search for the point at which the difference between
// The area of overlap between the circles and the actual
// are of the intersection is minimal.
function aIntersection(dist) {
return (areaOverlap(Ra, Rb, dist) - Aab);
}
var r = Math.min(Ra, Rb);
var R = Math.max(Ra, Rb);
// We perform the bisection search between the two extreme values
let d = bisect(aIntersection, (r + R - 0.00001), (R - r + 0.00001));
var circleRad = 50;
var xCenter = w/2 - d/2;
var yCenter = h/2
var xCenter2 = xCenter + d;
var areas = svg.append("g").attr("id", "areas");
// Draw the path that demarcates the boundary of the Intersection region
var path = areas.append("path");
var intersectionPath = d3.path();
intersectionPath.arc(xCenter, yCenter, Ra, -Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)), Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)));
intersectionPath.arc(xCenter2, yCenter, Rb, Math.PI - Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)), Math.PI + Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)));
intersectionPath.closePath();
// Draw the path that demarcates the boundary of the Regress region
var rPath = areas.append("path");
var regressPath = d3.path();
regressPath.arc(xCenter2, yCenter, Rb,
Math.PI + Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)),
Math.PI - Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)), true);
regressPath.arc(xCenter, yCenter, Ra,
Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)),
-Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)), false);
regressPath.closePath();
// Draw the path that demarcates the boundary of the Progress region
var pPath = areas.append("path");
var progressPath = d3.path();
progressPath.arc(xCenter2, yCenter, Rb,
Math.PI - Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)),
Math.PI + Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)), true);
progressPath.arc(xCenter, yCenter, Ra,
-Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)),
Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)), false);
progressPath.closePath();
function getRegionFill(regionName) {
if ((regionName == "intersection") && !(_this.state.regionSelected == "intersection")) {
return "rgba(241, 241, 127, 0.8)";
} else if ((regionName == "intersection") && (_this.state.regionSelected == "intersection")) {
return "rgba(141, 141, 27, 0.8)";
} else if((regionName == "progress") && !(_this.state.regionSelected == "progress")) {
return "rgba(206, 160, 205, 0.8)";
} else if ((regionName == "progress") && (_this.state.regionSelected == "progress")) {
return "rgba(106, 60, 105, 0.8)";
} else if ((regionName == "regress") && !(_this.state.regionSelected == "regress")) {
return "rgba(175, 227, 141, 0.8)";
} else if ((regionName == "regress") && (_this.state.regionSelected == "regress")) {
return "rgba(75, 127, 41, 0.8)";
}
}
// Draw and style the Intersection region
path.attr("d", intersectionPath)
.attr("stroke", "black")
.attr("stroke-width", "1px")
.on("mouseover", function() {
tooltip.text(`${a} (${(aProportion * 100).toFixed(0)}%)`)
.style("opacity", 0.8);
.attr("fill", getRegionFill("intersection"))
.on("mouseover", function() {
tooltip.text(`${ab} (${(abProportion * 100).toFixed(3)}%)`)
.style("opacity", 0.8);
bringToTop("Ra");
d3.select(this).attr("stroke-width", "3px");
})
.on("mousemove", function() {
var vennDiagramPlot = document.getElementById("venndiagramplot");
var coords = d3.mouse(vennDiagramPlot);
tooltip.style("left", `${coords[0] - (margin.left + margin.right)/2}px`)
.style("top", `${coords[1] - (margin.top + margin.bottom)/2}px`);
})
.on("mouseout", function() {
tooltip.style("opacity", 0);
d3.select(this).attr("stroke-width", "1px");
})
.on("click", function() {
_this.props.filterByInstanceIds(errorPartition[2]);
_this.setState({
regionSelected: "intersection"
});
});
d3.select(this).attr("stroke-width", "2px");
})
.on("mousemove", function() {
var vennDiagramPlot = document.getElementById("venndiagramplot");
var coords = d3.mouse(vennDiagramPlot);
tooltip.style("left", `${coords[0] - (margin.left + margin.right)/2}px`)
.style("top", `${coords[1] - (margin.top + margin.bottom)/2}px`);
})
.on("mouseout", function() {
tooltip.style("opacity", 0);
d3.select(this).attr("stroke-width", "1px");
})
.on("click", function() {
_this.props.filterByInstanceIds(errorPartition[0]);
});
areas.append("circle")
.attr("r", Rb)
.attr('transform',
"translate(" +
xCenter2 + "," +
yCenter + ")")
.attr("fill", "rgba(206, 160, 205, 0.8)")
// Draw and style the Regress region
rPath.attr("d", regressPath)
.attr("stroke", "black")
.attr("stroke-width", "1px")
.attr("fill", getRegionFill("regress"))
.on("mouseover", function() {
tooltip.text(`${b} (${(bProportion * 100).toFixed(0)}%)`)
tooltip.text(`${regressSize} (${(regressProportion * 100).toFixed(3)}%)`)
.style("opacity", 0.8);
bringToTop("Rb");
d3.select(this).attr("stroke-width", "2px");
})
.on("mousemove", function() {
@ -297,27 +350,22 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
d3.select(this).attr("stroke-width", "1px");
})
.on("click", function() {
_this.props.filterByInstanceIds(errorPartition[0]);
_this.props.filterByInstanceIds(regress);
_this.setState({
regionSelected: "regress"
});
});
var path = areas.append("path");
var myPath = d3.path();
myPath.arc(xCenter, yCenter, Ra, -Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)), Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)));
myPath.arc(xCenter2, yCenter, Rb, Math.PI - Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)), Math.PI + Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)));
myPath.closePath();
path.attr("d", myPath)
.attr("stroke", "black")
.attr("stroke-width", "1px")
.attr("fill", "rgba(241, 241, 127, 0.8)")
// Draw and style the Progress region
pPath.attr("d", progressPath)
.attr("stroke", "black")
.attr("stroke-width", "1px")
.attr("fill", getRegionFill("progress"))
.on("mouseover", function() {
tooltip.text(`${ab} (${(abProportion * 100).toFixed(0)}%)`)
tooltip.text(`${progressSize} (${(progressProportion * 100).toFixed(3)}%)`)
.style("opacity", 0.8);
bringToTop("intersectionRaRb");
d3.select(this).attr("stroke-width", "2px");
d3.select(this).attr("stroke", "black");
})
.on("mousemove", function() {
var vennDiagramPlot = document.getElementById("venndiagramplot");
@ -328,14 +376,17 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
.on("mouseout", function() {
tooltip.style("opacity", 0);
d3.select(this).attr("stroke-width", "1px");
d3.select(this).attr("stroke", "black");
})
.on("click", function() {
_this.props.filterByInstanceIds(errorPartition[2]);
_this.props.filterByInstanceIds(progress);
_this.setState({
regionSelected: "progress"
});
});
areas.selectAll("g")
.data(data);
areas.selectAll("g")
.data(data);
}
}