зеркало из https://github.com/microsoft/archai.git
docs(ds): example notebooks in nbsphinx docs
This commit is contained in:
Родитель
64075df028
Коммит
a6b8d81627
|
@ -0,0 +1,14 @@
|
|||
Discrete Search Tutorials
|
||||
==============================
|
||||
|
||||
This page is under development
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 2
|
||||
|
||||
Search Spaces <discrete_search/search_space.ipynb>
|
||||
Objectives <discrete_search/objectives.ipynb>
|
||||
Search Algorithms <discrete_search/search_algos.ipynb>
|
||||
Config Search <discrete_search/config_search.ipynb>
|
||||
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,731 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d6fa8b24",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Objective functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "1b103993",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from overrides import overrides\n",
|
||||
"import torch\n",
|
||||
"from typing import Tuple, List, Optional\n",
|
||||
"from archai.discrete_search import ArchaiModel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f81488a6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We will use SegmentationDag search space for this example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "e7f475e2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from archai.discrete_search.search_spaces.segmentation_dag.search_space import SegmentationDagSearchSpace"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "338ecd2c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ss = SegmentationDagSearchSpace(nb_classes=1, img_size=(64, 64), max_layers=5, seed=11)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "8390ddcb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"image/svg+xml": [
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
||||
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||||
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||||
"<!-- Generated by graphviz version 2.43.0 (0)\n",
|
||||
" -->\n",
|
||||
"<!-- Title: architecture Pages: 1 -->\n",
|
||||
"<svg width=\"878pt\" height=\"260pt\"\n",
|
||||
" viewBox=\"0.00 0.00 877.63 260.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
||||
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n",
|
||||
"<title>architecture</title>\n",
|
||||
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 873.63,-256 873.63,4 -4,4\"/>\n",
|
||||
"<!-- input -->\n",
|
||||
"<g id=\"node1\" class=\"node\">\n",
|
||||
"<title>input</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"305.69\" cy=\"-162\" rx=\"79.09\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"305.69\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv3x3_e2</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_0 -->\n",
|
||||
"<g id=\"node2\" class=\"node\">\n",
|
||||
"<title>layer_0</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"413.69\" cy=\"-90\" rx=\"79.09\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"413.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv5x5_e2</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- input->layer_0 -->\n",
|
||||
"<g id=\"edge1\" class=\"edge\">\n",
|
||||
"<title>input->layer_0</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M331.29,-144.94C345.71,-135.32 363.91,-123.19 379.53,-112.77\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"381.53,-115.65 387.91,-107.19 377.65,-109.82 381.53,-115.65\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_2 -->\n",
|
||||
"<g id=\"node4\" class=\"node\">\n",
|
||||
"<title>layer_2</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"629.69\" cy=\"-90\" rx=\"79.09\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"629.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv5x5_e2</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- input->layer_2 -->\n",
|
||||
"<g id=\"edge5\" class=\"edge\">\n",
|
||||
"<title>input->layer_2</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M362.5,-149.38C418.43,-136.95 503.83,-117.97 562.96,-104.83\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"563.95,-108.2 572.95,-102.61 562.43,-101.36 563.95,-108.2\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_1 -->\n",
|
||||
"<g id=\"node3\" class=\"node\">\n",
|
||||
"<title>layer_1</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"521.69\" cy=\"-18\" rx=\"49.29\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"521.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">conv5x5</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_0->layer_1 -->\n",
|
||||
"<g id=\"edge2\" class=\"edge\">\n",
|
||||
"<title>layer_0->layer_1</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M439.29,-72.94C454.23,-62.97 473.24,-50.3 489.23,-39.64\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"491.38,-42.41 497.76,-33.95 487.5,-36.59 491.38,-42.41\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_0->layer_2 -->\n",
|
||||
"<g id=\"edge4\" class=\"edge\">\n",
|
||||
"<title>layer_0->layer_2</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M493.2,-90C508.5,-90 524.61,-90 540.2,-90\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"540.35,-93.5 550.35,-90 540.35,-86.5 540.35,-93.5\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_1->layer_2 -->\n",
|
||||
"<g id=\"edge3\" class=\"edge\">\n",
|
||||
"<title>layer_1->layer_2</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M545.39,-33.8C559.96,-43.51 578.87,-56.12 595.09,-66.93\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"593.52,-70.09 603.78,-72.72 597.4,-64.27 593.52,-70.09\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- output -->\n",
|
||||
"<g id=\"node5\" class=\"node\">\n",
|
||||
"<title>output</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"737.69\" cy=\"-18\" rx=\"41.69\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"737.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">output</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_2->output -->\n",
|
||||
"<g id=\"edge6\" class=\"edge\">\n",
|
||||
"<title>layer_2->output</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M655.29,-72.94C670.55,-62.76 690.04,-49.76 706.24,-38.97\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"708.47,-41.69 714.85,-33.23 704.59,-35.86 708.47,-41.69\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- upsample -->\n",
|
||||
"<g id=\"node9\" class=\"node\">\n",
|
||||
"<title>upsample</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"737.69\" cy=\"-234\" rx=\"131.88\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"737.69\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">Upsample + 2 x Conv 3x3</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- output->upsample -->\n",
|
||||
"<g id=\"edge7\" class=\"edge\">\n",
|
||||
"<title>output->upsample</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M737.69,-36.04C737.69,-73.61 737.69,-160.45 737.69,-205.59\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"734.19,-205.85 737.69,-215.85 741.19,-205.85 734.19,-205.85\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- scale-1 -->\n",
|
||||
"<g id=\"node6\" class=\"node\">\n",
|
||||
"<title>scale-1</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-162\" rx=\"83.69\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=2, ch=40</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- scale-2 -->\n",
|
||||
"<g id=\"node7\" class=\"node\">\n",
|
||||
"<title>scale-2</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-90\" rx=\"83.69\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=4, ch=72</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- scale-4 -->\n",
|
||||
"<g id=\"node8\" class=\"node\">\n",
|
||||
"<title>scale-4</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-18\" rx=\"89.88\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=8, ch=104</text>\n",
|
||||
"</g>\n",
|
||||
"</g>\n",
|
||||
"</svg>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<graphviz.graphs.Digraph at 0x7f577fda8f10>"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"m = ss.random_sample()\n",
|
||||
"m.arch.view()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "de3d9ea3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`SegmentationDagSearchSpace` is a subclass of `EvolutionarySearchSpace`, so `mutate` and `crossover` methods are already implemented"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "dab02f97",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"image/svg+xml": [
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
||||
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||||
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||||
"<!-- Generated by graphviz version 2.43.0 (0)\n",
|
||||
" -->\n",
|
||||
"<!-- Title: architecture Pages: 1 -->\n",
|
||||
"<svg width=\"878pt\" height=\"260pt\"\n",
|
||||
" viewBox=\"0.00 0.00 877.63 260.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
||||
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n",
|
||||
"<title>architecture</title>\n",
|
||||
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 873.63,-256 873.63,4 -4,4\"/>\n",
|
||||
"<!-- input -->\n",
|
||||
"<g id=\"node1\" class=\"node\">\n",
|
||||
"<title>input</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"305.69\" cy=\"-162\" rx=\"79.09\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"305.69\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv3x3_e2</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_0 -->\n",
|
||||
"<g id=\"node2\" class=\"node\">\n",
|
||||
"<title>layer_0</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"413.69\" cy=\"-90\" rx=\"79.09\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"413.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv5x5_e2</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- input->layer_0 -->\n",
|
||||
"<g id=\"edge1\" class=\"edge\">\n",
|
||||
"<title>input->layer_0</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M331.29,-144.94C345.71,-135.32 363.91,-123.19 379.53,-112.77\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"381.53,-115.65 387.91,-107.19 377.65,-109.82 381.53,-115.65\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_1 -->\n",
|
||||
"<g id=\"node3\" class=\"node\">\n",
|
||||
"<title>layer_1</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"521.69\" cy=\"-18\" rx=\"79.09\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"521.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv3x3_e2</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- input->layer_1 -->\n",
|
||||
"<g id=\"edge2\" class=\"edge\">\n",
|
||||
"<title>input->layer_1</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M308.9,-143.96C312.47,-128.22 319.57,-105.51 332.8,-89.91 360.41,-57.36 404.42,-39.51 443.02,-29.73\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"444.3,-33.03 453.22,-27.3 442.68,-26.22 444.3,-33.03\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_2 -->\n",
|
||||
"<g id=\"node4\" class=\"node\">\n",
|
||||
"<title>layer_2</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"629.69\" cy=\"-90\" rx=\"79.09\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"629.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv5x5_e2</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- input->layer_2 -->\n",
|
||||
"<g id=\"edge5\" class=\"edge\">\n",
|
||||
"<title>input->layer_2</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M362.5,-149.38C418.43,-136.95 503.83,-117.97 562.96,-104.83\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"563.95,-108.2 572.95,-102.61 562.43,-101.36 563.95,-108.2\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_0->layer_2 -->\n",
|
||||
"<g id=\"edge4\" class=\"edge\">\n",
|
||||
"<title>layer_0->layer_2</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M493.2,-90C508.5,-90 524.61,-90 540.2,-90\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"540.35,-93.5 550.35,-90 540.35,-86.5 540.35,-93.5\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_1->layer_2 -->\n",
|
||||
"<g id=\"edge3\" class=\"edge\">\n",
|
||||
"<title>layer_1->layer_2</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M547.29,-35.06C561.71,-44.68 579.91,-56.81 595.53,-67.23\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"593.65,-70.18 603.91,-72.81 597.53,-64.35 593.65,-70.18\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- output -->\n",
|
||||
"<g id=\"node5\" class=\"node\">\n",
|
||||
"<title>output</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"737.69\" cy=\"-18\" rx=\"41.69\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"737.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">output</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- layer_2->output -->\n",
|
||||
"<g id=\"edge6\" class=\"edge\">\n",
|
||||
"<title>layer_2->output</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M655.29,-72.94C670.55,-62.76 690.04,-49.76 706.24,-38.97\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"708.47,-41.69 714.85,-33.23 704.59,-35.86 708.47,-41.69\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- upsample -->\n",
|
||||
"<g id=\"node9\" class=\"node\">\n",
|
||||
"<title>upsample</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"737.69\" cy=\"-234\" rx=\"131.88\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"737.69\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">Upsample + 3 x Conv 3x3</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- output->upsample -->\n",
|
||||
"<g id=\"edge7\" class=\"edge\">\n",
|
||||
"<title>output->upsample</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M737.69,-36.04C737.69,-73.61 737.69,-160.45 737.69,-205.59\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"734.19,-205.85 737.69,-215.85 741.19,-205.85 734.19,-205.85\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- scale-1 -->\n",
|
||||
"<g id=\"node6\" class=\"node\">\n",
|
||||
"<title>scale-1</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-162\" rx=\"83.69\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=2, ch=40</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- scale-2 -->\n",
|
||||
"<g id=\"node7\" class=\"node\">\n",
|
||||
"<title>scale-2</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-90\" rx=\"83.69\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=4, ch=72</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- scale-4 -->\n",
|
||||
"<g id=\"node8\" class=\"node\">\n",
|
||||
"<title>scale-4</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-18\" rx=\"89.88\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=8, ch=104</text>\n",
|
||||
"</g>\n",
|
||||
"</g>\n",
|
||||
"</svg>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<graphviz.graphs.Digraph at 0x7f5968ddc070>"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ss.mutate(m).arch.view()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8bc5d19f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evaluating models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "225a22a1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Objective functions are the main tool used to evaluate architectures in given criteria (task performance, speed, size, etc.). Objectives are optimized by search algorithms."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2e961df2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"There are two types of objective function abstractions:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"* Objectives [archai.discrete_search.Objective](https://microsoft.github.io/archai/reference/api/archai.discrete_search.api.html#module-archai.discrete_search.api.objective)\n",
|
||||
" * Must override `Objective.higher_is_better` (optimization direction)\n",
|
||||
" * Must implement `Objective.evaluate(model, dataset, budget)`\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"* Asynchronous Objectives [archai.discrete_search.AsyncObjective](https://microsoft.github.io/archai/reference/api/archai.discrete_search.api.html#archai.discrete_search.api.objective.AsyncObjective):\n",
|
||||
" * Must override `AsyncObjective.higher_is_better` (optimization direction)\n",
|
||||
" * Must implement `AsyncObjective.send(model, dataset, budget)` and `AsyncObjective.fetch_all()` \n",
|
||||
" \n",
|
||||
"Objective functions may optionally use a dataset (e.g task accuracy) or not (e.g architecture latency). Objective functions may also receive a `budget` value from some search algorithms. This `budget` value is used by some algorithms to sinalize how much budget the model evaluation should spend.\n",
|
||||
"\n",
|
||||
"Read more about them [here](https://microsoft.github.io/archai/reference/api/archai.discrete_search.api.html#module-archai.discrete_search.api.objective)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6f649761",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Built-in objective example (AvgOnnxLatency)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a438d949",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from archai.discrete_search.objectives.onnx_model import AvgOnnxLatency"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "e250c928",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
|
||||
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.0004309522919356823"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"onnx_latency_obj = AvgOnnxLatency(input_shape=(1, 3, 64, 64))\n",
|
||||
"onnx_latency_obj.evaluate(model=ss.random_sample(), dataset_provider=None, budget=None)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4bf6fa13",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"By default `AvgOnnxLatency` will be minimized"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "293f9b9a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"False"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"onnx_latency_obj.higher_is_better"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cbc58063",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Custom objective example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "cbff79f6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from archai.discrete_search import Objective, DatasetProvider"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "5e906f6b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class NumberOfModules(Objective):\n",
|
||||
" ''' Class that measures the size of a model by the number of torch modules '''\n",
|
||||
"\n",
|
||||
" higher_is_better: bool = False # Smaller models are better\n",
|
||||
" \n",
|
||||
" @overrides\n",
|
||||
" def evaluate(self, model: ArchaiModel, dataset: DatasetProvider,\n",
|
||||
" budget: Optional[float] = None):\n",
|
||||
" return len(list(model.arch.modules()))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "58aa1eba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"m = ss.random_sample()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "d6b5f1f0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"67"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"my_objective = NumberOfModules()\n",
|
||||
"my_objective.evaluate(m, None, None)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "81051526",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Some utility objectives"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ecb5ca78",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* [RayParallelObjective](https://microsoft.github.io/archai/reference/api/archai.discrete_search.objectives.html#module-archai.discrete_search.objectives.ray) - Wraps an existing objective and runs it using multiple Ray workers\n",
|
||||
"\n",
|
||||
"* [EvaluationFunction](https://microsoft.github.io/archai/reference/api/archai.discrete_search.objectives.html#module-archai.discrete_search.objectives.functional) - Wraps a function that takes a (model, dataset_provider, budget) triplet and creates an objective"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bcaf1433",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Parallelizing NumberOfModules**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "8f6666c0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from archai.discrete_search.objectives.ray import RayParallelObjective"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "c77984d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"my_objective_parallel = RayParallelObjective(\n",
|
||||
" NumberOfModules(), timeout=10, num_cpus=1.0\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "61099086",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`my_objective_parallel` is now an `AsyncObjective` object"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "41a38d5d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Dispatching job for 4aba6fbdb292e44d634daefa425ab1406684daed_64_64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2022-10-28 07:39:27,701\tINFO worker.py:1518 -- Started a local Ray instance.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Dispatching job for e0521c00e4b6dfa7f624d2d7560d9c220591864b_64_64\n",
|
||||
"Dispatching job for c60496d4923eaa0062de511eaab3b9cb4ec46a3e_64_64\n",
|
||||
"Dispatching job for d31e4ef0912834bc51336aaf55fd879606fbf4ca_64_64\n",
|
||||
"Dispatching job for 915ff7e0aca6e48bbae0def46d64b7300887fb80_64_64\n",
|
||||
"Dispatching job for 90da2af4f0a0aa0f24cafa1cd59032623ada1c23_64_64\n",
|
||||
"Dispatching job for fe6c11c85bbcbdaf6b716d9259f5415b7327192d_64_64\n",
|
||||
"Dispatching job for 65e92bee3ecc899c5c346be82961c331d9f18933_64_64\n",
|
||||
"Dispatching job for bdf6f69e2a8e08473e9e799ec2d7e627dd915d43_64_64\n",
|
||||
"Dispatching job for 9b0f792a6e6c37c4e40abde72b4fbd2cdca9ebae_64_64\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_list = [ss.random_sample() for _ in range(10)]\n",
|
||||
"\n",
|
||||
"for model in model_list:\n",
|
||||
" print(f'Dispatching job for {model.archid}')\n",
|
||||
" my_objective_parallel.send(model, dataset=None, budget=None)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "230bb8f1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[53, 29, 60, 31, 87, 49, 30, 83, 33, 61]"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"my_objective_parallel.fetch_all()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f86f894a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Example: Wrapping custom training code in an objective**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "bb78fec1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from archai.datasets.providers.mnist_provider import MnistProvider\n",
|
||||
"from archai.discrete_search.objectives.functional import EvaluationFunction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "2e5a0d51",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_provider = MnistProvider({'dataroot': '/home/pkauffmann/dataroot/'})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "854e656c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```python\n",
|
||||
" def custom_training_val_performance(model, dataset_provider, budget=None):\n",
|
||||
" tr_data, val_data = dataset_provider.get_datasets(True, True, False, False)\n",
|
||||
"\n",
|
||||
" tr_dl = torch.utils.data.DataLoader(tr_data, shuffle=True, batch_size=16)\n",
|
||||
" val_dl = torch.utils.data.DataLoader(tr_data, shuffle=True, batch_size=16)\n",
|
||||
"\n",
|
||||
" optimizer = torch.optim.Adam(model.arch.parameters(), lr=1e-3)\n",
|
||||
" ...\n",
|
||||
"\n",
|
||||
" for batch in tr_dl:\n",
|
||||
" ...\n",
|
||||
"\n",
|
||||
" for batch in val_dl:\n",
|
||||
" ...\n",
|
||||
"\n",
|
||||
" return validation_metric\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2db08cec",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```python\n",
|
||||
"\n",
|
||||
"training_objective = EvaluationFunction(custom_traininb_val_performance, higher_is_better=True)\n",
|
||||
"\n",
|
||||
"training_objective.evaluate(ss.random_sample(), dataset_provider, budget=None)\n",
|
||||
"\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ebb4c012",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"See the next notebook for a full example using a custom training objective"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,845 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c4bd875c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Discrete Search Spaces"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "5cdb0135",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from overrides import overrides\n",
|
||||
"import numpy as np\n",
|
||||
"from typing import Tuple, List, Optional\n",
|
||||
"from archai.discrete_search import ArchaiModel, DiscreteSearchSpace"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "8d100fef",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torch import nn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c9e2ca76",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Discrete search spaces in Archai are defined using the `DiscreteSearchSpace` abstract class:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"\n",
|
||||
"class DiscreteSearchSpace(EnforceOverrides):\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def random_sample(self) -> ArchaiModel:\n",
|
||||
" ...\n",
|
||||
" \n",
|
||||
" @abstractmethod\n",
|
||||
" def save_arch(self, model: ArchaiModel, path: str) -> None:\n",
|
||||
" ...\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def load_arch(self, path: str) -> ArchaiModel:\n",
|
||||
" ...\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def save_model_weights(self, model: ArchaiModel, path: str) -> None:\n",
|
||||
" ...\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def load_model_weights(self, model: ArchaiModel, path: str) -> None:\n",
|
||||
" ...\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "16149883",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### The `ArchaiModel` abstraction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fb4c0081",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `ArchaiModel` abstraction is used to wrap a model object with a given architecture id (`archid`) and optionally a metadata dictionary."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ca7f4686",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from archai.discrete_search import ArchaiModel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9f563253",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "ab5fe6fc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class DummyModel(nn.Module):\n",
|
||||
" def __init__(self, nb_layers: int = 2, kernel_size: int = 3):\n",
|
||||
" super().__init__()\n",
|
||||
" \n",
|
||||
" self.nb_layers = nb_layers\n",
|
||||
" self.kernel_size = kernel_size\n",
|
||||
" \n",
|
||||
" layers = []\n",
|
||||
" for i in range(nb_layers):\n",
|
||||
" input_dim = 3 if i == 0 else 16\n",
|
||||
" \n",
|
||||
" layers += [\n",
|
||||
" nn.Conv2d(input_dim, 16, kernel_size=kernel_size, padding='same'),\n",
|
||||
" nn.BatchNorm2d(16),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" ]\n",
|
||||
" \n",
|
||||
" self.layers = nn.Sequential(*layers)\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" return self.layers(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "0bb34bbf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_obj = DummyModel(nb_layers=2, kernel_size=3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e4dfe5ad",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's wrap model_obj into an `ArchaiModel`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "290e625a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = ArchaiModel(\n",
|
||||
" arch=model_obj,\n",
|
||||
" archid=f'L={model_obj.nb_layers}, K={model_obj.kernel_size}',\n",
|
||||
" metadata={'optional': {'metadata'}}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "a61e5264",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'L=2, K=3'"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.archid"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "e0ea6fb2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'optional': {'metadata'}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.metadata"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "2337eb37",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DummyModel(\n",
|
||||
" (layers): Sequential(\n",
|
||||
" (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
||||
" (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (2): ReLU()\n",
|
||||
" (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
||||
" (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (5): ReLU()\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.arch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5732d030",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Archid will be used to deduplicate seen architectures. It should only identify the architecture and not the model weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6b1d6f94",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### ConvNet Search Space Example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d9fb59c3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's start with a (really) simple search space for image classification"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "e3145b08",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MyModel(nn.Module):\n",
|
||||
" def __init__(self, nb_layers: int = 5, kernel_size: int = 3, hidden_dim: int = 32):\n",
|
||||
" super().__init__()\n",
|
||||
" \n",
|
||||
" self.nb_layers = nb_layers\n",
|
||||
" self.kernel_size = kernel_size\n",
|
||||
" self.hidden_dim = hidden_dim\n",
|
||||
" \n",
|
||||
" layer_list = []\n",
|
||||
"\n",
|
||||
" for i in range(nb_layers):\n",
|
||||
" in_ch = (3 if i == 0 else hidden_dim)\n",
|
||||
" \n",
|
||||
" layer_list += [\n",
|
||||
" nn.Conv2d(in_ch, hidden_dim, kernel_size=kernel_size, padding='same'),\n",
|
||||
" nn.BatchNorm2d(hidden_dim),\n",
|
||||
" nn.ReLU()\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" layer_list += [\n",
|
||||
" nn.Conv2d(hidden_dim, 1, kernel_size=1, padding='same'),\n",
|
||||
" nn.Sigmoid()\n",
|
||||
" ]\n",
|
||||
" \n",
|
||||
" self.model = nn.Sequential(*layer_list)\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" return self.model(x)\n",
|
||||
" \n",
|
||||
" def get_archid(self):\n",
|
||||
" return f'({self.nb_layers}, {self.kernel_size}, {self.hidden_dim})'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "0ccc4d41",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"MyModel(\n",
|
||||
" (model): Sequential(\n",
|
||||
" (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
||||
" (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (2): ReLU()\n",
|
||||
" (3): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
|
||||
" (4): Sigmoid()\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"m = MyModel(nb_layers=1)\n",
|
||||
"m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "e4bd63f2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'(1, 3, 32)'"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"m.get_archid()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fa42080d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's overide DiscreteSearchSpace"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "193ea617",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from typing import Tuple\n",
|
||||
"from random import Random\n",
|
||||
"\n",
|
||||
"class CNNSearchSpace(DiscreteSearchSpace):\n",
|
||||
" def __init__(self, min_layers: int = 1, max_layers: int = 12,\n",
|
||||
" kernel_list=(1, 3, 5, 7), hidden_list=(16, 32, 64, 128),\n",
|
||||
" seed: int = 1):\n",
|
||||
"\n",
|
||||
" self.min_layers = min_layers\n",
|
||||
" self.max_layers = max_layers\n",
|
||||
" self.kernel_list = kernel_list\n",
|
||||
" self.hidden_list = hidden_list\n",
|
||||
" \n",
|
||||
" self.rng = Random(seed)\n",
|
||||
" \n",
|
||||
" @overrides\n",
|
||||
" def random_sample(self) -> ArchaiModel:\n",
|
||||
" # Randomly chooses architecture parameters\n",
|
||||
" nb_layers = self.rng.randint(self.min_layers, self.max_layers)\n",
|
||||
" kernel_size = self.rng.choice(self.kernel_list)\n",
|
||||
" hidden_dim = self.rng.choice(self.hidden_list)\n",
|
||||
" \n",
|
||||
" model = MyModel(nb_layers, kernel_size, hidden_dim)\n",
|
||||
" \n",
|
||||
" # Wraps model into ArchaiModel\n",
|
||||
" return ArchaiModel(arch=model, archid=model.get_archid())\n",
|
||||
"\n",
|
||||
" @overrides\n",
|
||||
" def save_arch(self, model: ArchaiModel, file: str):\n",
|
||||
" with open(file, 'w') as fp:\n",
|
||||
" json.dump({\n",
|
||||
" 'nb_layers': model.arch.nb_layers,\n",
|
||||
" 'kernel_size': model.arch.kernel_size,\n",
|
||||
" 'hidden_dim': model.arch.hidden_dim\n",
|
||||
" }, fp)\n",
|
||||
"\n",
|
||||
" @overrides\n",
|
||||
" def load_arch(self, file: str):\n",
|
||||
" config = json.load(open(file))\n",
|
||||
" model = MyModel(**config)\n",
|
||||
" \n",
|
||||
" return ArchaiModel(arch=model, archid=model.get_archid())\n",
|
||||
"\n",
|
||||
" @overrides\n",
|
||||
" def save_model_weights(self, model: ArchaiModel, file: str):\n",
|
||||
" state_dict = model.arch.get_state_dict()\n",
|
||||
" torch.save(state_dict, file)\n",
|
||||
" \n",
|
||||
" @overrides\n",
|
||||
" def load_model_weights(self, model: ArchaiModel, file: str):\n",
|
||||
" model.arch.load_state_dict(torch.load(file))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "7db02619",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ss = CNNSearchSpace(hidden_list=[32, 64, 128])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6ce23725",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Sampling a model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "83c03fe1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ArchaiModel(\n",
|
||||
"\tarchid=(3, 1, 64), \n",
|
||||
"\tmetadata={}, \n",
|
||||
"\tarch=MyModel(\n",
|
||||
" (model): Sequential(\n",
|
||||
" (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
|
||||
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (2): ReLU()\n",
|
||||
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
|
||||
" (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (5): ReLU()\n",
|
||||
" (6): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
|
||||
" (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (8): ReLU()\n",
|
||||
" (9): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
|
||||
" (10): Sigmoid()\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"m = ss.random_sample()\n",
|
||||
"m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "619c4d9c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Saving an architecture"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "3dace5d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ss.save_arch(m, 'arch.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "f1d4dba3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{\"nb_layers\": 3, \"kernel_size\": 1, \"hidden_dim\": 64}"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!cat arch.json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "70813cc7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Loading an architecture (not the weights)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "863ef766",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ArchaiModel(\n",
|
||||
"\tarchid=(3, 1, 64), \n",
|
||||
"\tmetadata={}, \n",
|
||||
"\tarch=MyModel(\n",
|
||||
" (model): Sequential(\n",
|
||||
" (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
|
||||
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (2): ReLU()\n",
|
||||
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
|
||||
" (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (5): ReLU()\n",
|
||||
" (6): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
|
||||
" (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (8): ReLU()\n",
|
||||
" (9): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
|
||||
" (10): Sigmoid()\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ss.load_arch('arch.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f5c0b5a3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Making the search space compatible with different types of algorithms"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4083db69",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* Evolutionary-based algorithms:\n",
|
||||
" - User must subclass `EvolutionarySearchSpace` and implement `EvolutionarySearchSpace.mutate` and `EvolutionarySearchSpace.crossover`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"* BO-based algorithms:\n",
|
||||
" - User must subclass `BayesOptSearchSpace` and override `BayesOptSearchSpace.encode`\n",
|
||||
" - Encode should take an `ArchaiModel` and produce a fixed-length vector representation of that architecture. This numerical representation will be used to train surrogate models.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d294ab69",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "78b73a68",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from archai.discrete_search import EvolutionarySearchSpace, BayesOptSearchSpace"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "0e02255f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class CNNSearchSpaceExt(CNNSearchSpace, EvolutionarySearchSpace, BayesOptSearchSpace):\n",
|
||||
" ''' We are subclassing CNNSearchSpace just to save up space'''\n",
|
||||
" \n",
|
||||
" @overrides\n",
|
||||
" def mutate(self, model_1: ArchaiModel) -> ArchaiModel:\n",
|
||||
" config = {\n",
|
||||
" 'nb_layers': model_1.arch.nb_layers,\n",
|
||||
" 'kernel_size': model_1.arch.kernel_size,\n",
|
||||
" 'hidden_dim': model_1.arch.hidden_dim\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" if self.rng.random() < 0.2:\n",
|
||||
" config['nb_layers'] = self.rng.randint(self.min_layers, self.max_layers)\n",
|
||||
" \n",
|
||||
" if self.rng.random() < 0.2:\n",
|
||||
" config['kernel_size'] = self.rng.choice(self.kernel_list)\n",
|
||||
" \n",
|
||||
" if self.rng.random() < 0.2:\n",
|
||||
" config['hidden_dim'] = self.rng.choice(self.hidden_list)\n",
|
||||
" \n",
|
||||
" mutated_model = MyModel(**config)\n",
|
||||
" \n",
|
||||
" return ArchaiModel(\n",
|
||||
" arch=mutated_model, archid=mutated_model.get_archid()\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" @overrides\n",
|
||||
" def crossover(self, model_list: List[ArchaiModel]) -> ArchaiModel:\n",
|
||||
" model_1, model_2 = model_list[:2]\n",
|
||||
" \n",
|
||||
" new_config = {\n",
|
||||
" 'nb_layers': self.rng.choice([model_1.arch.nb_layers, model_2.arch.nb_layers]),\n",
|
||||
" 'kernel_size': self.rng.choice([model_1.arch.kernel_size, model_2.arch.kernel_size]),\n",
|
||||
" 'hidden_dim': self.rng.choice([model_1.arch.hidden_dim, model_2.arch.hidden_dim]),\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" crossover_model = MyModel(**new_config)\n",
|
||||
" \n",
|
||||
" return ArchaiModel(\n",
|
||||
" arch=crossover_model, archid=crossover_model.get_archid()\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" @overrides\n",
|
||||
" def encode(self, model: ArchaiModel) -> np.ndarray:\n",
|
||||
" return np.array([model.arch.nb_layers, model.arch.kernel_size, model.arch.hidden_dim])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "8f9b6ba7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ss = CNNSearchSpaceExt(hidden_list=[32, 64, 128])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7582b266",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"id": "d23e6373",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'(3, 1, 64)'"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"m = ss.random_sample()\n",
|
||||
"m.archid"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "6c695837",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'(8, 1, 64)'"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ss.mutate(m).archid"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "e0b99677",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([ 3, 1, 64])"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ss.encode(m)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1201e318",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now `CNNSearchSpaceExt` is compatible with Bayesian Optimization and Evolutionary based search algorithms!\n",
|
||||
"\n",
|
||||
"**To see a list of built-in search spaces, go to `archai/discrete_search/search_spaces`**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1fad0c69",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Example: "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "0ce94672",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ArchaiModel(\n",
|
||||
"\tarchid=74f66612a0d01c5b7d4702234756b0ee4ffa5abc_64_64, \n",
|
||||
"\tmetadata={'parent': '32fa5956ab3ce9e05bc42836599a8dc9dd53e847_64_64'}, \n",
|
||||
"\tarch=SegmentationDagModel(\n",
|
||||
" (edge_dict): ModuleDict(\n",
|
||||
" (input-output): Block(\n",
|
||||
" (op): Sequential(\n",
|
||||
" (0): NormalConvBlock(\n",
|
||||
" (conv): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
||||
" (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU()\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (stem_block): NormalConvBlock(\n",
|
||||
" (conv): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
|
||||
" (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU()\n",
|
||||
" )\n",
|
||||
" (up): Upsample(size=(64, 64), mode=nearest)\n",
|
||||
" (post_upsample): Sequential(\n",
|
||||
" (0): NormalConvBlock(\n",
|
||||
" (conv): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
||||
" (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU()\n",
|
||||
" )\n",
|
||||
" (1): NormalConvBlock(\n",
|
||||
" (conv): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
||||
" (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU()\n",
|
||||
" )\n",
|
||||
" (2): NormalConvBlock(\n",
|
||||
" (conv): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
||||
" (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU()\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (classifier): Conv2d(40, 1, kernel_size=(1, 1), stride=(1, 1))\n",
|
||||
")\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from archai.discrete_search.search_spaces.segmentation_dag.search_space import SegmentationDagSearchSpace\n",
|
||||
"\n",
|
||||
"ss = SegmentationDagSearchSpace(nb_classes=1, img_size=(64, 64), max_layers=3)\n",
|
||||
"ss.mutate(ss.random_sample())"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -26,7 +26,8 @@ extensions = [
|
|||
"sphinxcontrib.programoutput",
|
||||
"sphinxcontrib.mermaid",
|
||||
"sphinx_inline_tabs",
|
||||
"sphinx_git"
|
||||
"sphinx_git",
|
||||
"nbsphinx"
|
||||
]
|
||||
exclude_patterns = [
|
||||
"benchmarks/**",
|
||||
|
|
|
@ -45,6 +45,7 @@ If you use Archai in a scientific publication, please consider citing it:
|
|||
|
||||
30-Minute Tutorial <basic_guide/tutorial>
|
||||
Notebooks <basic_guide/notebooks>
|
||||
Discrete Search Tutorial <basic_guide/discrete_search>
|
||||
Examples & Scripts <basic_guide/examples_scripts>
|
||||
|
||||
.. toctree::
|
||||
|
|
Загрузка…
Ссылка в новой задаче