Clickable legend in Venn diagram and removing the warnings in the example notebooks (#78)

* Replace usage of size_average keyword argument with reduction keyword argument set to "sum" where appropriate

* Selectable legend in Venn diagram

* Remove commented code

* Make the clickable area larger
This commit is contained in:
ilmarinen 2020-11-20 14:00:38 -08:00 коммит произвёл GitHub
Родитель 7a677c0256
Коммит c8ebe6adb7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 351 добавлений и 289 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -199,7 +199,7 @@
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" _, _, output = network(data)\n",
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
" pred = output.data.max(1, keepdim=True)[1]\n",
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
" test_loss /= len(train_loader_a)*batch_size_train\n",
@ -214,14 +214,6 @@
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/xavier/mnt/datapartition/work/virtualenvs/venv/lib/python3.6/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
" warnings.warn(warning.format(ret))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
@ -540,7 +532,7 @@
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" _, _, output = h2(data)\n",
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
" pred = output.data.max(1, keepdim=True)[1]\n",
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
" test_loss /= len(train_loader_b)*batch_size_train\n",

Просмотреть файл

@ -199,7 +199,7 @@
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" _, _, output = network(data)\n",
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
" pred = output.data.max(1, keepdim=True)[1]\n",
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
" test_loss /= len(train_loader_a)*batch_size_train\n",
@ -214,14 +214,6 @@
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/xavier/mnt/datapartition/work/virtualenvs/venv/lib/python3.6/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
" warnings.warn(warning.format(ret))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
@ -542,7 +534,7 @@
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" _, _, output = h2(data)\n",
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
" pred = output.data.max(1, keepdim=True)[1]\n",
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
" test_loss /= len(train_loader_b)*batch_size_train\n",

Просмотреть файл

@ -199,7 +199,7 @@
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" _, _, output = network(data)\n",
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
" pred = output.data.max(1, keepdim=True)[1]\n",
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
" test_loss /= len(train_loader_a)*batch_size_train\n",
@ -214,14 +214,6 @@
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/xavier/mnt/datapartition/work/virtualenvs/venv/lib/python3.6/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
" warnings.warn(warning.format(ret))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
@ -542,7 +534,7 @@
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" _, _, output = h2(data)\n",
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
" pred = output.data.max(1, keepdim=True)[1]\n",
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
" test_loss /= len(train_loader_b)*batch_size_train\n",

Просмотреть файл

@ -202,87 +202,195 @@ class IntersectionBetweenModelErrors extends Component<IntersectionBetweenModelE
var red = "rgba(206, 160, 205, 0.8)";
var yellow = "rgba(241, 241, 127, 0.8)";
// Draw the legend of the Vnn diagram
var legendEntries = [
{"label": "Progress", "color": green},
{"label": "Regress", "color": red},
{"label": "Common", "color": yellow},
];
var vennLegend = svg.append("g").attr("id", "vennlegend");
vennLegend.append("rect")
.attr("id", "progress")
.attr("width", "10px")
.attr("height", "10px")
.attr("x", "10px")
.attr("y", "5px");
vennLegend.append("rect")
.attr("id", "regress")
.attr("width", "10px")
.attr("height", "10px")
.attr("x", "90px")
.attr("y", "5px");
vennLegend.append("rect")
.attr("id", "commonerror")
.attr("width", "10px")
.attr("height", "10px")
.attr("x", "170px")
.attr("y", "5px");
vennLegend.append("text")
.attr("x", "25px")
.attr("y", "12px")
.attr("font-size", "10px")
.attr("text-anchor", "left")
.style("alignment-baseline", "middle");
vennLegend.append("text")
.attr("x", "105px")
.attr("y", "12px")
.attr("font-size", "10px")
.attr("text-anchor", "left")
.style("alignment-baseline", "middle");
vennLegend.append("text")
.attr("x", "185px")
.attr("y", "12px")
.attr("font-size", "10px")
.attr("text-anchor", "left")
.style("alignment-baseline", "middle");
vennLegend.selectAll("rect")
.data(legendEntries)
.attr("fill", function(d) {
return d["color"];
});
vennLegend.selectAll("text")
.data(legendEntries)
.text(function(d) {
return d["label"];
});
function getRegionFill(regionName) {
if ((regionName == "intersection") && !(_this.state.regionSelected == "intersection")) {
return "rgba(241, 241, 127, 0.8)";
} else if ((regionName == "intersection") && (_this.state.regionSelected == "intersection")) {
return "rgba(141, 141, 27, 0.8)";
} else if((regionName == "progress") && !(_this.state.regionSelected == "progress")) {
// return "rgba(206, 160, 205, 0.8)";
return "rgba(175, 227, 141, 0.8)";
} else if ((regionName == "progress") && (_this.state.regionSelected == "progress")) {
// return "rgba(106, 60, 105, 0.8)";
return "rgba(75, 127, 41, 0.8)";
} else if ((regionName == "regress") && !(_this.state.regionSelected == "regress")) {
// return "rgba(175, 227, 141, 0.8)";
return "rgba(206, 160, 205, 0.8)";
} else if ((regionName == "regress") && (_this.state.regionSelected == "regress")) {
// return "rgba(75, 127, 41, 0.8)";
return "rgba(106, 60, 105, 0.8)";
}
}
// Draw the legend of the Vnn diagram
var legendEntries = [
{"label": "Progress", "name": "progress", "color": green},
{"label": "Regress", "name": "regress", "color": red},
{"label": "Common", "name": "intersection", "color": yellow},
];
var vennLegend = svg.append("g").attr("id", "vennlegend");
vennLegend.append("text")
.attr("x", "25px")
.attr("y", "17px")
.attr("font-size", "10px")
.attr("text-anchor", "left")
.style("alignment-baseline", "middle")
.text(function() {
return "Progress";
});
vennLegend.append("text")
.attr("x", "105px")
.attr("y", "17px")
.attr("font-size", "10px")
.attr("text-anchor", "left")
.style("alignment-baseline", "middle")
.text(function() {
return "Regress";
});
vennLegend.append("text")
.attr("x", "180px")
.attr("y", "17px")
.attr("font-size", "10px")
.attr("text-anchor", "left")
.style("alignment-baseline", "middle")
.text(function() {
return "Intersection";
});
vennLegend.append("rect")
.attr("id", "progress")
.attr("width", "10px")
.attr("height", "10px")
.attr("x", "10px")
.attr("y", "10px")
.attr("stroke", "black")
.attr("stroke-width", "0px")
.attr("fill", function() {
return getRegionFill("progress");
});
vennLegend.append("rect")
.attr("width", "65px")
.attr("height", "20px")
.attr("x", "5px")
.attr("y", "5px")
.attr("fill", "rgba(255, 255, 255, 0.0)")
.attr("stroke", "black")
.attr("stroke-width", function(d) {
if (_this.state.regionSelected == "progress") {
return "1px";
} else {
return "0px";
}
})
.on("click", function() {
_this.setState({
regionSelected: "progress"
});
})
.on("mousemove", function() {
d3.select(this).attr("stroke-width", "2px");
})
.on("mouseout", function() {
d3.select(this)
.attr("stroke-width", function() {
if (_this.state.regionSelected == "progress") {
return "1px";
} else {
return "0px";
}
});
});
vennLegend.append("rect")
.attr("id", "regress")
.attr("width", "10px")
.attr("height", "10px")
.attr("x", "90px")
.attr("y", "10px")
.attr("stroke", "black")
.attr("stroke-width", "0px")
.attr("fill", function() {
return getRegionFill("regress");
});
vennLegend.append("rect")
.attr("width", "65px")
.attr("height", "20px")
.attr("x", "85px")
.attr("y", "5px")
.attr("fill", "rgba(255, 255, 255, 0.0)")
.attr("stroke", "black")
.attr("stroke-width", function(d) {
if (_this.state.regionSelected == "regress") {
return "1px";
} else {
return "0px";
}
})
.on("click", function() {
_this.setState({
regionSelected: "regress"
});
})
.on("mouseover", function() {
d3.select(this).attr("stroke-width", "2px");
})
.on("mouseout", function() {
d3.select(this)
.attr("stroke-width", function() {
if (_this.state.regionSelected == "regress") {
return "1px";
} else {
return "0px";
}
});
});
vennLegend.append("rect")
.attr("id", "commonerror")
.attr("width", "10px")
.attr("height", "10px")
.attr("x", "165px")
.attr("y", "10px")
.attr("stroke", "black")
.attr("stroke-width", "0px")
.attr("fill", function() {
return getRegionFill("intersection");
});
vennLegend.append("rect")
.attr("width", "75px")
.attr("height", "20px")
.attr("x", "160px")
.attr("y", "5px")
.attr("fill", "rgba(255, 255, 255, 0.0)")
.attr("stroke", "black")
.attr("stroke-width", function(d) {
if (_this.state.regionSelected == "intersection") {
return "1px";
} else {
return "0px";
}
})
.on("click", function() {
_this.setState({
regionSelected: "intersection"
});
})
.on("mouseover", function() {
d3.select(this).attr("stroke-width", "2px");
})
.on("mouseout", function() {
d3.select(this)
.attr("stroke-width", function() {
if (_this.state.regionSelected == "intersection") {
return "1px";
} else {
return "0px";
}
});
});
if (totalErrors > 0) {
aProportion = a / this.state.selectedDataPoint.dataset_size;
bProportion = b / this.state.selectedDataPoint.dataset_size;