Add statefulness to selectivity of the regions of the Venn diagram, and calculate the percentage of errors out of the total number of instances (#61)
* Add statefulness to selectivity of the regions of the Venn diagram, and calculate the percentage of errors out of the total number of instances * Better differentiate hover vs select styles
This commit is contained in:
Родитель
e8062ebabf
Коммит
26908f40d5
|
@ -553,7 +553,9 @@ def evaluate_model_performance_and_compatibility_on_dataset(h1, h2, dataset, per
|
|||
h1h2_dataset_incompatible_instance_ids_by_class = {}
|
||||
classes = set()
|
||||
all_error_instances = []
|
||||
dataset_size = 0
|
||||
for batch_ids, data, target in dataset:
|
||||
dataset_size += len(batch_ids)
|
||||
classes = classes.union(target.tolist())
|
||||
h1_error_count_batch, h2_error_count_batch, h1_and_h2_error_count_batch =\
|
||||
get_model_error_overlap(h1, h2, batch_ids, data, target, device=device)
|
||||
|
@ -599,7 +601,8 @@ def evaluate_model_performance_and_compatibility_on_dataset(h1, h2, dataset, per
|
|||
"h2_performance": h2_performance,
|
||||
"btc": btc,
|
||||
"bec": bec,
|
||||
"error_instances": all_error_instances_results
|
||||
"error_instances": all_error_instances_results,
|
||||
"dataset_size": dataset_size
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,8 @@ import { bisect } from "./optimization.tsx";
|
|||
|
||||
|
||||
type IntersectionBetweenModelErrorsState = {
|
||||
selectedDataPoint: any
|
||||
selectedDataPoint: any,
|
||||
regionSelected: any
|
||||
}
|
||||
|
||||
type IntersectionBetweenModelErrorsProps = {
|
||||
|
@ -21,7 +22,8 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
|
|||
super(props);
|
||||
|
||||
this.state = {
|
||||
selectedDataPoint: this.props.selectedDataPoint
|
||||
selectedDataPoint: this.props.selectedDataPoint,
|
||||
regionSelected: null
|
||||
};
|
||||
|
||||
this.node = React.createRef<HTMLDivElement>();
|
||||
|
@ -36,7 +38,8 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
|
|||
|
||||
componentWillReceiveProps(nextProps) {
|
||||
this.setState({
|
||||
selectedDataPoint: nextProps.selectedDataPoint
|
||||
selectedDataPoint: nextProps.selectedDataPoint,
|
||||
regionSelected: null
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -90,100 +93,38 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
|
|||
var b = errorPartition[1].length;
|
||||
var ab = errorPartition[2].length;
|
||||
|
||||
// Error instance ids of regress instances
|
||||
var regress = errorPartition[0].filter(instanceId => (errorPartition[2].indexOf(instanceId) == -1));
|
||||
// Error instance ids of progress instances
|
||||
var progress = errorPartition[1].filter(instanceId => (errorPartition[2].indexOf(instanceId) == -1));
|
||||
var regressSize = regress.length;
|
||||
var regressProportion = regressSize / this.state.selectedDataPoint.dataset_size;
|
||||
var progressSize = progress.length;
|
||||
var progressProportion = progressSize / this.state.selectedDataPoint.dataset_size;
|
||||
|
||||
var totalErrors = a + b - ab
|
||||
var aProportion = 0.0
|
||||
var bProportion = 0.0
|
||||
var abProportion = 0.0
|
||||
|
||||
if (totalErrors > 0) {
|
||||
aProportion = a / totalErrors;
|
||||
bProportion = b / totalErrors;
|
||||
abProportion = ab / totalErrors;
|
||||
}
|
||||
|
||||
var data = [
|
||||
{"name": "intersectionRaRb", "area": ab},
|
||||
{"name": "Ra", "area": a},
|
||||
{"name": "Rb", "area": b},
|
||||
{"name": "intersectionRaRb", "area": ab}
|
||||
]
|
||||
|
||||
let Ra;
|
||||
let Rb;
|
||||
let Aab;
|
||||
let x = 1;
|
||||
if (a >= b) {
|
||||
x = (50 * 50 * 3.14) / a;
|
||||
Ra = 50;
|
||||
Rb = Math.sqrt(b * x / 3.14);
|
||||
Aab = ab * x;
|
||||
} else {
|
||||
x = (50 * 50 * 3.14) / b;
|
||||
Rb = 50;
|
||||
Ra = Math.sqrt(a * x / 3.14);
|
||||
Aab = ab * x;
|
||||
}
|
||||
|
||||
function areaIntersection(r1, r2, dist) {
|
||||
var r = Math.min(r1, r2);
|
||||
var R = Math.max(r1, r2);
|
||||
if (dist == 0) {
|
||||
return (3.14 * Math.pow(r, 2));
|
||||
} else if (dist >= (R + r)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
var sectorAreas = Math.pow(r, 2) * Math.acos((Math.pow(dist, 2) +
|
||||
Math.pow(r, 2) - Math.pow(R, 2)) / (2 * dist * r)) +
|
||||
Math.pow(R, 2) * Math.acos((Math.pow(dist, 2) + Math.pow(R, 2) - Math.pow(r, 2)) / (2 * dist * R));
|
||||
|
||||
var triangleAreas = 1/2 * Math.sqrt((-dist + r + R) * (dist + r - R) * (dist - r + R) * (dist + r + R));
|
||||
var intersectionArea = sectorAreas - triangleAreas;
|
||||
|
||||
return intersectionArea;
|
||||
}
|
||||
|
||||
var r = Math.min(Ra, Rb);
|
||||
var R = Math.max(Ra, Rb);
|
||||
|
||||
function aIntersection(dist) {
|
||||
return (areaIntersection(Ra, Rb, dist) - Aab);
|
||||
}
|
||||
|
||||
let d = bisect(aIntersection, (r + R - 0.00001), (R - r + 0.00001));
|
||||
|
||||
var circleRad = 50;
|
||||
var xCenter = w/2 - d/2;
|
||||
var yCenter = h/2
|
||||
var xCenter2 = xCenter + d;
|
||||
{"name": "Rb", "area": b}
|
||||
];
|
||||
|
||||
// Colors
|
||||
var green = "rgba(175, 227, 141, 0.8)";
|
||||
var red = "rgba(206, 160, 205, 0.8)";
|
||||
var yellow = "rgba(241, 241, 127, 0.8)";
|
||||
|
||||
var areas = svg.append("g").attr("id", "areas");
|
||||
// Draw the legend of the Vnn diagram
|
||||
var legendEntries = [
|
||||
{"label": "Progress", "color": green},
|
||||
{"label": "Regress", "color": red},
|
||||
{"label": "Common", "color": yellow},
|
||||
];
|
||||
|
||||
function bringToTop(name) {
|
||||
areas.selectAll("g")
|
||||
.sort(function(a, b) {
|
||||
if (a.name == "intersectionRaRb") {
|
||||
return 1;
|
||||
} else if (b.name == "intersectionRaRb") {
|
||||
return -1;
|
||||
} else if (a.name == name) {
|
||||
return 1;
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
var vennLegend = svg.append("g").attr("id", "vennlegend");
|
||||
|
||||
vennLegend.append("rect")
|
||||
.attr("id", "progress")
|
||||
.attr("width", "10px")
|
||||
|
@ -238,52 +179,164 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
|
|||
return d["label"];
|
||||
});
|
||||
|
||||
areas.append("circle")
|
||||
.attr("r", Ra)
|
||||
.attr('transform',
|
||||
"translate(" +
|
||||
xCenter + "," +
|
||||
yCenter + ")")
|
||||
.attr("fill", "rgba(175, 227, 141, 0.8)")
|
||||
if (totalErrors > 0) {
|
||||
aProportion = a / this.state.selectedDataPoint.dataset_size;
|
||||
bProportion = b / this.state.selectedDataPoint.dataset_size;
|
||||
abProportion = ab / this.state.selectedDataPoint.dataset_size;
|
||||
|
||||
// We use the following reasoning:
|
||||
// aProportion = Area(circleA) / totalArea
|
||||
// bProportion = Area(circleB) / totalArea
|
||||
// We want the largest circle to have a radius of 50.
|
||||
// So we select the larget of the two circular regions and set
|
||||
// its area to be Pi * 50 * 50.
|
||||
// Thus if aProportion is the larger proportion, we have
|
||||
// that Ra must be 50 and aProportion = (Pi * 50 * 50) / totalArea
|
||||
// Thus totalArea = (Pi * 50 * 50) / aProportion.
|
||||
// Similarly if bProportion is the larger proportion:
|
||||
// We have that totalArea = (Pi * 50 * 50) / bProportion.
|
||||
let Ra;
|
||||
let Rb;
|
||||
let Aab;
|
||||
let totalArea = 1;
|
||||
if (a >= b) {
|
||||
totalArea = (50 * 50 * 3.14) / aProportion;
|
||||
Ra = 50;
|
||||
Rb = Math.sqrt(bProportion * totalArea / 3.14);
|
||||
Aab = abProportion * totalArea;
|
||||
} else {
|
||||
totalArea = (50 * 50 * 3.14) / bProportion;
|
||||
Rb = 50;
|
||||
Ra = Math.sqrt(aProportion * totalArea / 3.14);
|
||||
Aab = abProportion * totalArea;
|
||||
}
|
||||
|
||||
// This function calcuates the area of overlap
|
||||
// of two circles of radii r1 and r2 whose
|
||||
// centers are separated by a distance d.
|
||||
function areaOverlap(r1, r2, dist) {
|
||||
var r = Math.min(r1, r2);
|
||||
var R = Math.max(r1, r2);
|
||||
if (dist == 0) {
|
||||
return (3.14 * Math.pow(r, 2));
|
||||
} else if (dist >= (R + r)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
var sectorAreas = Math.pow(r, 2) * Math.acos((Math.pow(dist, 2) +
|
||||
Math.pow(r, 2) - Math.pow(R, 2)) / (2 * dist * r)) +
|
||||
Math.pow(R, 2) * Math.acos((Math.pow(dist, 2) + Math.pow(R, 2) - Math.pow(r, 2)) / (2 * dist * R));
|
||||
|
||||
var triangleAreas = 1/2 * Math.sqrt((-dist + r + R) * (dist + r - R) * (dist - r + R) * (dist + r + R));
|
||||
var overlapArea = sectorAreas - triangleAreas;
|
||||
|
||||
return overlapArea;
|
||||
}
|
||||
|
||||
// In order to decide the distance d that the circles
|
||||
// Need to be from each other, we use the bisection method
|
||||
// to search for the point at which the difference between
|
||||
// The area of overlap between the circles and the actual
|
||||
// are of the intersection is minimal.
|
||||
function aIntersection(dist) {
|
||||
return (areaOverlap(Ra, Rb, dist) - Aab);
|
||||
}
|
||||
|
||||
var r = Math.min(Ra, Rb);
|
||||
var R = Math.max(Ra, Rb);
|
||||
|
||||
// We perform the bisection search between the two extreme values
|
||||
let d = bisect(aIntersection, (r + R - 0.00001), (R - r + 0.00001));
|
||||
|
||||
var circleRad = 50;
|
||||
var xCenter = w/2 - d/2;
|
||||
var yCenter = h/2
|
||||
var xCenter2 = xCenter + d;
|
||||
|
||||
var areas = svg.append("g").attr("id", "areas");
|
||||
|
||||
// Draw the path that demarcates the boundary of the Intersection region
|
||||
var path = areas.append("path");
|
||||
var intersectionPath = d3.path();
|
||||
intersectionPath.arc(xCenter, yCenter, Ra, -Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)), Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)));
|
||||
intersectionPath.arc(xCenter2, yCenter, Rb, Math.PI - Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)), Math.PI + Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)));
|
||||
intersectionPath.closePath();
|
||||
|
||||
// Draw the path that demarcates the boundary of the Regress region
|
||||
var rPath = areas.append("path");
|
||||
var regressPath = d3.path();
|
||||
regressPath.arc(xCenter2, yCenter, Rb,
|
||||
Math.PI + Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)),
|
||||
Math.PI - Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)), true);
|
||||
regressPath.arc(xCenter, yCenter, Ra,
|
||||
Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)),
|
||||
-Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)), false);
|
||||
regressPath.closePath();
|
||||
|
||||
// Draw the path that demarcates the boundary of the Progress region
|
||||
var pPath = areas.append("path");
|
||||
var progressPath = d3.path();
|
||||
progressPath.arc(xCenter2, yCenter, Rb,
|
||||
Math.PI - Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)),
|
||||
Math.PI + Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)), true);
|
||||
progressPath.arc(xCenter, yCenter, Ra,
|
||||
-Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)),
|
||||
Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)), false);
|
||||
progressPath.closePath();
|
||||
|
||||
function getRegionFill(regionName) {
|
||||
if ((regionName == "intersection") && !(_this.state.regionSelected == "intersection")) {
|
||||
return "rgba(241, 241, 127, 0.8)";
|
||||
} else if ((regionName == "intersection") && (_this.state.regionSelected == "intersection")) {
|
||||
return "rgba(141, 141, 27, 0.8)";
|
||||
} else if((regionName == "progress") && !(_this.state.regionSelected == "progress")) {
|
||||
return "rgba(206, 160, 205, 0.8)";
|
||||
} else if ((regionName == "progress") && (_this.state.regionSelected == "progress")) {
|
||||
return "rgba(106, 60, 105, 0.8)";
|
||||
} else if ((regionName == "regress") && !(_this.state.regionSelected == "regress")) {
|
||||
return "rgba(175, 227, 141, 0.8)";
|
||||
} else if ((regionName == "regress") && (_this.state.regionSelected == "regress")) {
|
||||
return "rgba(75, 127, 41, 0.8)";
|
||||
}
|
||||
}
|
||||
|
||||
// Draw and style the Intersection region
|
||||
path.attr("d", intersectionPath)
|
||||
.attr("stroke", "black")
|
||||
.attr("stroke-width", "1px")
|
||||
.on("mouseover", function() {
|
||||
tooltip.text(`${a} (${(aProportion * 100).toFixed(0)}%)`)
|
||||
.style("opacity", 0.8);
|
||||
.attr("fill", getRegionFill("intersection"))
|
||||
.on("mouseover", function() {
|
||||
tooltip.text(`${ab} (${(abProportion * 100).toFixed(3)}%)`)
|
||||
.style("opacity", 0.8);
|
||||
|
||||
bringToTop("Ra");
|
||||
d3.select(this).attr("stroke-width", "3px");
|
||||
})
|
||||
.on("mousemove", function() {
|
||||
var vennDiagramPlot = document.getElementById("venndiagramplot");
|
||||
var coords = d3.mouse(vennDiagramPlot);
|
||||
tooltip.style("left", `${coords[0] - (margin.left + margin.right)/2}px`)
|
||||
.style("top", `${coords[1] - (margin.top + margin.bottom)/2}px`);
|
||||
})
|
||||
.on("mouseout", function() {
|
||||
tooltip.style("opacity", 0);
|
||||
d3.select(this).attr("stroke-width", "1px");
|
||||
})
|
||||
.on("click", function() {
|
||||
_this.props.filterByInstanceIds(errorPartition[2]);
|
||||
_this.setState({
|
||||
regionSelected: "intersection"
|
||||
});
|
||||
});
|
||||
|
||||
d3.select(this).attr("stroke-width", "2px");
|
||||
})
|
||||
.on("mousemove", function() {
|
||||
var vennDiagramPlot = document.getElementById("venndiagramplot");
|
||||
var coords = d3.mouse(vennDiagramPlot);
|
||||
tooltip.style("left", `${coords[0] - (margin.left + margin.right)/2}px`)
|
||||
.style("top", `${coords[1] - (margin.top + margin.bottom)/2}px`);
|
||||
})
|
||||
.on("mouseout", function() {
|
||||
tooltip.style("opacity", 0);
|
||||
d3.select(this).attr("stroke-width", "1px");
|
||||
})
|
||||
.on("click", function() {
|
||||
_this.props.filterByInstanceIds(errorPartition[0]);
|
||||
});
|
||||
|
||||
areas.append("circle")
|
||||
.attr("r", Rb)
|
||||
.attr('transform',
|
||||
"translate(" +
|
||||
xCenter2 + "," +
|
||||
yCenter + ")")
|
||||
.attr("fill", "rgba(206, 160, 205, 0.8)")
|
||||
// Draw and style the Regress region
|
||||
rPath.attr("d", regressPath)
|
||||
.attr("stroke", "black")
|
||||
.attr("stroke-width", "1px")
|
||||
.attr("fill", getRegionFill("regress"))
|
||||
.on("mouseover", function() {
|
||||
tooltip.text(`${b} (${(bProportion * 100).toFixed(0)}%)`)
|
||||
tooltip.text(`${regressSize} (${(regressProportion * 100).toFixed(3)}%)`)
|
||||
.style("opacity", 0.8);
|
||||
|
||||
bringToTop("Rb");
|
||||
|
||||
d3.select(this).attr("stroke-width", "2px");
|
||||
})
|
||||
.on("mousemove", function() {
|
||||
|
@ -297,27 +350,22 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
|
|||
d3.select(this).attr("stroke-width", "1px");
|
||||
})
|
||||
.on("click", function() {
|
||||
_this.props.filterByInstanceIds(errorPartition[0]);
|
||||
_this.props.filterByInstanceIds(regress);
|
||||
_this.setState({
|
||||
regionSelected: "regress"
|
||||
});
|
||||
});
|
||||
|
||||
var path = areas.append("path");
|
||||
var myPath = d3.path();
|
||||
myPath.arc(xCenter, yCenter, Ra, -Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)), Math.acos((Math.pow(d, 2) + Math.pow(Ra, 2) - Math.pow(Rb, 2))/(2 * d *Ra)));
|
||||
myPath.arc(xCenter2, yCenter, Rb, Math.PI - Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)), Math.PI + Math.acos((Math.pow(d, 2) + Math.pow(Rb, 2) - Math.pow(Ra, 2))/(2 * d *Rb)));
|
||||
myPath.closePath();
|
||||
|
||||
path.attr("d", myPath)
|
||||
.attr("stroke", "black")
|
||||
.attr("stroke-width", "1px")
|
||||
.attr("fill", "rgba(241, 241, 127, 0.8)")
|
||||
// Draw and style the Progress region
|
||||
pPath.attr("d", progressPath)
|
||||
.attr("stroke", "black")
|
||||
.attr("stroke-width", "1px")
|
||||
.attr("fill", getRegionFill("progress"))
|
||||
.on("mouseover", function() {
|
||||
tooltip.text(`${ab} (${(abProportion * 100).toFixed(0)}%)`)
|
||||
tooltip.text(`${progressSize} (${(progressProportion * 100).toFixed(3)}%)`)
|
||||
.style("opacity", 0.8);
|
||||
|
||||
bringToTop("intersectionRaRb");
|
||||
|
||||
d3.select(this).attr("stroke-width", "2px");
|
||||
d3.select(this).attr("stroke", "black");
|
||||
})
|
||||
.on("mousemove", function() {
|
||||
var vennDiagramPlot = document.getElementById("venndiagramplot");
|
||||
|
@ -328,14 +376,17 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
|
|||
.on("mouseout", function() {
|
||||
tooltip.style("opacity", 0);
|
||||
d3.select(this).attr("stroke-width", "1px");
|
||||
d3.select(this).attr("stroke", "black");
|
||||
})
|
||||
.on("click", function() {
|
||||
_this.props.filterByInstanceIds(errorPartition[2]);
|
||||
_this.props.filterByInstanceIds(progress);
|
||||
_this.setState({
|
||||
regionSelected: "progress"
|
||||
});
|
||||
});
|
||||
|
||||
areas.selectAll("g")
|
||||
.data(data);
|
||||
areas.selectAll("g")
|
||||
.data(data);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче