Clickable legend in Venn diagram and removing the warnings in the example notebooks (#78)
* Replace usage of size_average keyword argument with reduction keyword argument set to "sum" where appropriate * Selectable legend in Venn diagram * Remove commented code * Make the clickable area larger
This commit is contained in:
Родитель
7a677c0256
Коммит
c8ebe6adb7
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -199,7 +199,7 @@
|
|||
" with torch.no_grad():\n",
|
||||
" for data, target in test_loader:\n",
|
||||
" _, _, output = network(data)\n",
|
||||
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
|
||||
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
|
||||
" pred = output.data.max(1, keepdim=True)[1]\n",
|
||||
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
|
||||
" test_loss /= len(train_loader_a)*batch_size_train\n",
|
||||
|
@ -214,14 +214,6 @@
|
|||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/xavier/mnt/datapartition/work/virtualenvs/venv/lib/python3.6/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
|
||||
" warnings.warn(warning.format(ret))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
|
@ -540,7 +532,7 @@
|
|||
" with torch.no_grad():\n",
|
||||
" for data, target in test_loader:\n",
|
||||
" _, _, output = h2(data)\n",
|
||||
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
|
||||
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
|
||||
" pred = output.data.max(1, keepdim=True)[1]\n",
|
||||
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
|
||||
" test_loss /= len(train_loader_b)*batch_size_train\n",
|
||||
|
|
|
@ -199,7 +199,7 @@
|
|||
" with torch.no_grad():\n",
|
||||
" for data, target in test_loader:\n",
|
||||
" _, _, output = network(data)\n",
|
||||
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
|
||||
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
|
||||
" pred = output.data.max(1, keepdim=True)[1]\n",
|
||||
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
|
||||
" test_loss /= len(train_loader_a)*batch_size_train\n",
|
||||
|
@ -214,14 +214,6 @@
|
|||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/xavier/mnt/datapartition/work/virtualenvs/venv/lib/python3.6/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
|
||||
" warnings.warn(warning.format(ret))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
|
@ -542,7 +534,7 @@
|
|||
" with torch.no_grad():\n",
|
||||
" for data, target in test_loader:\n",
|
||||
" _, _, output = h2(data)\n",
|
||||
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
|
||||
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
|
||||
" pred = output.data.max(1, keepdim=True)[1]\n",
|
||||
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
|
||||
" test_loss /= len(train_loader_b)*batch_size_train\n",
|
||||
|
|
|
@ -199,7 +199,7 @@
|
|||
" with torch.no_grad():\n",
|
||||
" for data, target in test_loader:\n",
|
||||
" _, _, output = network(data)\n",
|
||||
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
|
||||
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
|
||||
" pred = output.data.max(1, keepdim=True)[1]\n",
|
||||
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
|
||||
" test_loss /= len(train_loader_a)*batch_size_train\n",
|
||||
|
@ -214,14 +214,6 @@
|
|||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/xavier/mnt/datapartition/work/virtualenvs/venv/lib/python3.6/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
|
||||
" warnings.warn(warning.format(ret))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
|
@ -542,7 +534,7 @@
|
|||
" with torch.no_grad():\n",
|
||||
" for data, target in test_loader:\n",
|
||||
" _, _, output = h2(data)\n",
|
||||
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
|
||||
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
|
||||
" pred = output.data.max(1, keepdim=True)[1]\n",
|
||||
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
|
||||
" test_loss /= len(train_loader_b)*batch_size_train\n",
|
||||
|
|
|
@ -202,87 +202,195 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
|
|||
var red = "rgba(206, 160, 205, 0.8)";
|
||||
var yellow = "rgba(241, 241, 127, 0.8)";
|
||||
|
||||
// Draw the legend of the Vnn diagram
|
||||
var legendEntries = [
|
||||
{"label": "Progress", "color": green},
|
||||
{"label": "Regress", "color": red},
|
||||
{"label": "Common", "color": yellow},
|
||||
];
|
||||
var vennLegend = svg.append("g").attr("id", "vennlegend");
|
||||
vennLegend.append("rect")
|
||||
.attr("id", "progress")
|
||||
.attr("width", "10px")
|
||||
.attr("height", "10px")
|
||||
.attr("x", "10px")
|
||||
.attr("y", "5px");
|
||||
|
||||
vennLegend.append("rect")
|
||||
.attr("id", "regress")
|
||||
.attr("width", "10px")
|
||||
.attr("height", "10px")
|
||||
.attr("x", "90px")
|
||||
.attr("y", "5px");
|
||||
|
||||
vennLegend.append("rect")
|
||||
.attr("id", "commonerror")
|
||||
.attr("width", "10px")
|
||||
.attr("height", "10px")
|
||||
.attr("x", "170px")
|
||||
.attr("y", "5px");
|
||||
|
||||
vennLegend.append("text")
|
||||
.attr("x", "25px")
|
||||
.attr("y", "12px")
|
||||
.attr("font-size", "10px")
|
||||
.attr("text-anchor", "left")
|
||||
.style("alignment-baseline", "middle");
|
||||
|
||||
vennLegend.append("text")
|
||||
.attr("x", "105px")
|
||||
.attr("y", "12px")
|
||||
.attr("font-size", "10px")
|
||||
.attr("text-anchor", "left")
|
||||
.style("alignment-baseline", "middle");
|
||||
|
||||
vennLegend.append("text")
|
||||
.attr("x", "185px")
|
||||
.attr("y", "12px")
|
||||
.attr("font-size", "10px")
|
||||
.attr("text-anchor", "left")
|
||||
.style("alignment-baseline", "middle");
|
||||
|
||||
vennLegend.selectAll("rect")
|
||||
.data(legendEntries)
|
||||
.attr("fill", function(d) {
|
||||
return d["color"];
|
||||
});
|
||||
|
||||
vennLegend.selectAll("text")
|
||||
.data(legendEntries)
|
||||
.text(function(d) {
|
||||
return d["label"];
|
||||
});
|
||||
|
||||
function getRegionFill(regionName) {
|
||||
if ((regionName == "intersection") && !(_this.state.regionSelected == "intersection")) {
|
||||
return "rgba(241, 241, 127, 0.8)";
|
||||
} else if ((regionName == "intersection") && (_this.state.regionSelected == "intersection")) {
|
||||
return "rgba(141, 141, 27, 0.8)";
|
||||
} else if((regionName == "progress") && !(_this.state.regionSelected == "progress")) {
|
||||
// return "rgba(206, 160, 205, 0.8)";
|
||||
return "rgba(175, 227, 141, 0.8)";
|
||||
} else if ((regionName == "progress") && (_this.state.regionSelected == "progress")) {
|
||||
// return "rgba(106, 60, 105, 0.8)";
|
||||
return "rgba(75, 127, 41, 0.8)";
|
||||
} else if ((regionName == "regress") && !(_this.state.regionSelected == "regress")) {
|
||||
// return "rgba(175, 227, 141, 0.8)";
|
||||
return "rgba(206, 160, 205, 0.8)";
|
||||
} else if ((regionName == "regress") && (_this.state.regionSelected == "regress")) {
|
||||
// return "rgba(75, 127, 41, 0.8)";
|
||||
return "rgba(106, 60, 105, 0.8)";
|
||||
}
|
||||
}
|
||||
|
||||
// Draw the legend of the Vnn diagram
|
||||
var legendEntries = [
|
||||
{"label": "Progress", "name": "progress", "color": green},
|
||||
{"label": "Regress", "name": "regress", "color": red},
|
||||
{"label": "Common", "name": "intersection", "color": yellow},
|
||||
];
|
||||
var vennLegend = svg.append("g").attr("id", "vennlegend");
|
||||
|
||||
vennLegend.append("text")
|
||||
.attr("x", "25px")
|
||||
.attr("y", "17px")
|
||||
.attr("font-size", "10px")
|
||||
.attr("text-anchor", "left")
|
||||
.style("alignment-baseline", "middle")
|
||||
.text(function() {
|
||||
return "Progress";
|
||||
});
|
||||
|
||||
vennLegend.append("text")
|
||||
.attr("x", "105px")
|
||||
.attr("y", "17px")
|
||||
.attr("font-size", "10px")
|
||||
.attr("text-anchor", "left")
|
||||
.style("alignment-baseline", "middle")
|
||||
.text(function() {
|
||||
return "Regress";
|
||||
});
|
||||
|
||||
vennLegend.append("text")
|
||||
.attr("x", "180px")
|
||||
.attr("y", "17px")
|
||||
.attr("font-size", "10px")
|
||||
.attr("text-anchor", "left")
|
||||
.style("alignment-baseline", "middle")
|
||||
.text(function() {
|
||||
return "Intersection";
|
||||
});
|
||||
|
||||
vennLegend.append("rect")
|
||||
.attr("id", "progress")
|
||||
.attr("width", "10px")
|
||||
.attr("height", "10px")
|
||||
.attr("x", "10px")
|
||||
.attr("y", "10px")
|
||||
.attr("stroke", "black")
|
||||
.attr("stroke-width", "0px")
|
||||
.attr("fill", function() {
|
||||
return getRegionFill("progress");
|
||||
});
|
||||
|
||||
vennLegend.append("rect")
|
||||
.attr("width", "65px")
|
||||
.attr("height", "20px")
|
||||
.attr("x", "5px")
|
||||
.attr("y", "5px")
|
||||
.attr("fill", "rgba(255, 255, 255, 0.0)")
|
||||
.attr("stroke", "black")
|
||||
.attr("stroke-width", function(d) {
|
||||
if (_this.state.regionSelected == "progress") {
|
||||
return "1px";
|
||||
} else {
|
||||
return "0px";
|
||||
}
|
||||
})
|
||||
.on("click", function() {
|
||||
_this.setState({
|
||||
regionSelected: "progress"
|
||||
});
|
||||
})
|
||||
.on("mousemove", function() {
|
||||
d3.select(this).attr("stroke-width", "2px");
|
||||
})
|
||||
.on("mouseout", function() {
|
||||
d3.select(this)
|
||||
.attr("stroke-width", function() {
|
||||
if (_this.state.regionSelected == "progress") {
|
||||
return "1px";
|
||||
} else {
|
||||
return "0px";
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
vennLegend.append("rect")
|
||||
.attr("id", "regress")
|
||||
.attr("width", "10px")
|
||||
.attr("height", "10px")
|
||||
.attr("x", "90px")
|
||||
.attr("y", "10px")
|
||||
.attr("stroke", "black")
|
||||
.attr("stroke-width", "0px")
|
||||
.attr("fill", function() {
|
||||
return getRegionFill("regress");
|
||||
});
|
||||
|
||||
vennLegend.append("rect")
|
||||
.attr("width", "65px")
|
||||
.attr("height", "20px")
|
||||
.attr("x", "85px")
|
||||
.attr("y", "5px")
|
||||
.attr("fill", "rgba(255, 255, 255, 0.0)")
|
||||
.attr("stroke", "black")
|
||||
.attr("stroke-width", function(d) {
|
||||
if (_this.state.regionSelected == "regress") {
|
||||
return "1px";
|
||||
} else {
|
||||
return "0px";
|
||||
}
|
||||
})
|
||||
.on("click", function() {
|
||||
_this.setState({
|
||||
regionSelected: "regress"
|
||||
});
|
||||
})
|
||||
.on("mouseover", function() {
|
||||
d3.select(this).attr("stroke-width", "2px");
|
||||
})
|
||||
.on("mouseout", function() {
|
||||
d3.select(this)
|
||||
.attr("stroke-width", function() {
|
||||
if (_this.state.regionSelected == "regress") {
|
||||
return "1px";
|
||||
} else {
|
||||
return "0px";
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
vennLegend.append("rect")
|
||||
.attr("id", "commonerror")
|
||||
.attr("width", "10px")
|
||||
.attr("height", "10px")
|
||||
.attr("x", "165px")
|
||||
.attr("y", "10px")
|
||||
.attr("stroke", "black")
|
||||
.attr("stroke-width", "0px")
|
||||
.attr("fill", function() {
|
||||
return getRegionFill("intersection");
|
||||
});
|
||||
|
||||
vennLegend.append("rect")
|
||||
.attr("width", "75px")
|
||||
.attr("height", "20px")
|
||||
.attr("x", "160px")
|
||||
.attr("y", "5px")
|
||||
.attr("fill", "rgba(255, 255, 255, 0.0)")
|
||||
.attr("stroke", "black")
|
||||
.attr("stroke-width", function(d) {
|
||||
if (_this.state.regionSelected == "intersection") {
|
||||
return "1px";
|
||||
} else {
|
||||
return "0px";
|
||||
}
|
||||
})
|
||||
.on("click", function() {
|
||||
_this.setState({
|
||||
regionSelected: "intersection"
|
||||
});
|
||||
})
|
||||
.on("mouseover", function() {
|
||||
d3.select(this).attr("stroke-width", "2px");
|
||||
})
|
||||
.on("mouseout", function() {
|
||||
d3.select(this)
|
||||
.attr("stroke-width", function() {
|
||||
if (_this.state.regionSelected == "intersection") {
|
||||
return "1px";
|
||||
} else {
|
||||
return "0px";
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if (totalErrors > 0) {
|
||||
aProportion = a / this.state.selectedDataPoint.dataset_size;
|
||||
bProportion = b / this.state.selectedDataPoint.dataset_size;
|
||||
|
|
Загрузка…
Ссылка в новой задаче