docs(ds): example notebooks in nbsphinx docs

This commit is contained in:
piero2c 2022-12-12 21:51:09 -03:00 коммит произвёл Gustavo Rosa
Родитель 64075df028
Коммит a6b8d81627
7 изменённых файлов: 5231 добавлений и 1 удалений

Просмотреть файл

@ -0,0 +1,14 @@
Discrete Search Tutorials
==============================
This page is under development
.. toctree::
:hidden:
:maxdepth: 2
Search Spaces <discrete_search/search_space.ipynb>
Objectives <discrete_search/objectives.ipynb>
Search Algorithms <discrete_search/search_algos.ipynb>
Config Search <discrete_search/config_search.ipynb>

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,731 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d6fa8b24",
"metadata": {},
"source": [
"## Objective functions"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1b103993",
"metadata": {},
"outputs": [],
"source": [
"from overrides import overrides\n",
"import torch\n",
"from typing import Tuple, List, Optional\n",
"from archai.discrete_search import ArchaiModel"
]
},
{
"cell_type": "markdown",
"id": "f81488a6",
"metadata": {},
"source": [
"We will use SegmentationDag search space for this example"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e7f475e2",
"metadata": {},
"outputs": [],
"source": [
"from archai.discrete_search.search_spaces.segmentation_dag.search_space import SegmentationDagSearchSpace"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "338ecd2c",
"metadata": {},
"outputs": [],
"source": [
"ss = SegmentationDagSearchSpace(nb_classes=1, img_size=(64, 64), max_layers=5, seed=11)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8390ddcb",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.43.0 (0)\n",
" -->\n",
"<!-- Title: architecture Pages: 1 -->\n",
"<svg width=\"878pt\" height=\"260pt\"\n",
" viewBox=\"0.00 0.00 877.63 260.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n",
"<title>architecture</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 873.63,-256 873.63,4 -4,4\"/>\n",
"<!-- input -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>input</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"305.69\" cy=\"-162\" rx=\"79.09\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"305.69\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv3x3_e2</text>\n",
"</g>\n",
"<!-- layer_0 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>layer_0</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"413.69\" cy=\"-90\" rx=\"79.09\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"413.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv5x5_e2</text>\n",
"</g>\n",
"<!-- input&#45;&gt;layer_0 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>input&#45;&gt;layer_0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M331.29,-144.94C345.71,-135.32 363.91,-123.19 379.53,-112.77\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"381.53,-115.65 387.91,-107.19 377.65,-109.82 381.53,-115.65\"/>\n",
"</g>\n",
"<!-- layer_2 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>layer_2</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"629.69\" cy=\"-90\" rx=\"79.09\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"629.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv5x5_e2</text>\n",
"</g>\n",
"<!-- input&#45;&gt;layer_2 -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>input&#45;&gt;layer_2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M362.5,-149.38C418.43,-136.95 503.83,-117.97 562.96,-104.83\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"563.95,-108.2 572.95,-102.61 562.43,-101.36 563.95,-108.2\"/>\n",
"</g>\n",
"<!-- layer_1 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>layer_1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"521.69\" cy=\"-18\" rx=\"49.29\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"521.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">conv5x5</text>\n",
"</g>\n",
"<!-- layer_0&#45;&gt;layer_1 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>layer_0&#45;&gt;layer_1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M439.29,-72.94C454.23,-62.97 473.24,-50.3 489.23,-39.64\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"491.38,-42.41 497.76,-33.95 487.5,-36.59 491.38,-42.41\"/>\n",
"</g>\n",
"<!-- layer_0&#45;&gt;layer_2 -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>layer_0&#45;&gt;layer_2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M493.2,-90C508.5,-90 524.61,-90 540.2,-90\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"540.35,-93.5 550.35,-90 540.35,-86.5 540.35,-93.5\"/>\n",
"</g>\n",
"<!-- layer_1&#45;&gt;layer_2 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>layer_1&#45;&gt;layer_2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M545.39,-33.8C559.96,-43.51 578.87,-56.12 595.09,-66.93\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"593.52,-70.09 603.78,-72.72 597.4,-64.27 593.52,-70.09\"/>\n",
"</g>\n",
"<!-- output -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>output</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"737.69\" cy=\"-18\" rx=\"41.69\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"737.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">output</text>\n",
"</g>\n",
"<!-- layer_2&#45;&gt;output -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>layer_2&#45;&gt;output</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M655.29,-72.94C670.55,-62.76 690.04,-49.76 706.24,-38.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"708.47,-41.69 714.85,-33.23 704.59,-35.86 708.47,-41.69\"/>\n",
"</g>\n",
"<!-- upsample -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>upsample</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"737.69\" cy=\"-234\" rx=\"131.88\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"737.69\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">Upsample + 2 x Conv 3x3</text>\n",
"</g>\n",
"<!-- output&#45;&gt;upsample -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>output&#45;&gt;upsample</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M737.69,-36.04C737.69,-73.61 737.69,-160.45 737.69,-205.59\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"734.19,-205.85 737.69,-215.85 741.19,-205.85 734.19,-205.85\"/>\n",
"</g>\n",
"<!-- scale&#45;1 -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>scale&#45;1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-162\" rx=\"83.69\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=2, ch=40</text>\n",
"</g>\n",
"<!-- scale&#45;2 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>scale&#45;2</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-90\" rx=\"83.69\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=4, ch=72</text>\n",
"</g>\n",
"<!-- scale&#45;4 -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>scale&#45;4</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-18\" rx=\"89.88\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=8, ch=104</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7f577fda8f10>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = ss.random_sample()\n",
"m.arch.view()"
]
},
{
"cell_type": "markdown",
"id": "de3d9ea3",
"metadata": {},
"source": [
"`SegmentationDagSearchSpace` is a subclass of `EvolutionarySearchSpace`, so `mutate` and `crossover` methods are already implemented"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dab02f97",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.43.0 (0)\n",
" -->\n",
"<!-- Title: architecture Pages: 1 -->\n",
"<svg width=\"878pt\" height=\"260pt\"\n",
" viewBox=\"0.00 0.00 877.63 260.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n",
"<title>architecture</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 873.63,-256 873.63,4 -4,4\"/>\n",
"<!-- input -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>input</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"305.69\" cy=\"-162\" rx=\"79.09\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"305.69\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv3x3_e2</text>\n",
"</g>\n",
"<!-- layer_0 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>layer_0</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"413.69\" cy=\"-90\" rx=\"79.09\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"413.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv5x5_e2</text>\n",
"</g>\n",
"<!-- input&#45;&gt;layer_0 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>input&#45;&gt;layer_0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M331.29,-144.94C345.71,-135.32 363.91,-123.19 379.53,-112.77\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"381.53,-115.65 387.91,-107.19 377.65,-109.82 381.53,-115.65\"/>\n",
"</g>\n",
"<!-- layer_1 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>layer_1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"521.69\" cy=\"-18\" rx=\"79.09\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"521.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv3x3_e2</text>\n",
"</g>\n",
"<!-- input&#45;&gt;layer_1 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>input&#45;&gt;layer_1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M308.9,-143.96C312.47,-128.22 319.57,-105.51 332.8,-89.91 360.41,-57.36 404.42,-39.51 443.02,-29.73\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"444.3,-33.03 453.22,-27.3 442.68,-26.22 444.3,-33.03\"/>\n",
"</g>\n",
"<!-- layer_2 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>layer_2</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"629.69\" cy=\"-90\" rx=\"79.09\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"629.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">mbconv5x5_e2</text>\n",
"</g>\n",
"<!-- input&#45;&gt;layer_2 -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>input&#45;&gt;layer_2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M362.5,-149.38C418.43,-136.95 503.83,-117.97 562.96,-104.83\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"563.95,-108.2 572.95,-102.61 562.43,-101.36 563.95,-108.2\"/>\n",
"</g>\n",
"<!-- layer_0&#45;&gt;layer_2 -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>layer_0&#45;&gt;layer_2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M493.2,-90C508.5,-90 524.61,-90 540.2,-90\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"540.35,-93.5 550.35,-90 540.35,-86.5 540.35,-93.5\"/>\n",
"</g>\n",
"<!-- layer_1&#45;&gt;layer_2 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>layer_1&#45;&gt;layer_2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M547.29,-35.06C561.71,-44.68 579.91,-56.81 595.53,-67.23\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"593.65,-70.18 603.91,-72.81 597.53,-64.35 593.65,-70.18\"/>\n",
"</g>\n",
"<!-- output -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>output</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"737.69\" cy=\"-18\" rx=\"41.69\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"737.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">output</text>\n",
"</g>\n",
"<!-- layer_2&#45;&gt;output -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>layer_2&#45;&gt;output</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M655.29,-72.94C670.55,-62.76 690.04,-49.76 706.24,-38.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"708.47,-41.69 714.85,-33.23 704.59,-35.86 708.47,-41.69\"/>\n",
"</g>\n",
"<!-- upsample -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>upsample</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"737.69\" cy=\"-234\" rx=\"131.88\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"737.69\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">Upsample + 3 x Conv 3x3</text>\n",
"</g>\n",
"<!-- output&#45;&gt;upsample -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>output&#45;&gt;upsample</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M737.69,-36.04C737.69,-73.61 737.69,-160.45 737.69,-205.59\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"734.19,-205.85 737.69,-215.85 741.19,-205.85 734.19,-205.85\"/>\n",
"</g>\n",
"<!-- scale&#45;1 -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>scale&#45;1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-162\" rx=\"83.69\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=2, ch=40</text>\n",
"</g>\n",
"<!-- scale&#45;2 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>scale&#45;2</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-90\" rx=\"83.69\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=4, ch=72</text>\n",
"</g>\n",
"<!-- scale&#45;4 -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>scale&#45;4</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.69\" cy=\"-18\" rx=\"89.88\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"89.69\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">scale=8, ch=104</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7f5968ddc070>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ss.mutate(m).arch.view()"
]
},
{
"cell_type": "markdown",
"id": "8bc5d19f",
"metadata": {},
"source": [
"## Evaluating models"
]
},
{
"cell_type": "markdown",
"id": "225a22a1",
"metadata": {},
"source": [
"Objective functions are the main tool used to evaluate architectures in given criteria (task performance, speed, size, etc.). Objectives are optimized by search algorithms."
]
},
{
"cell_type": "markdown",
"id": "2e961df2",
"metadata": {},
"source": [
"There are two types of objective function abstractions:\n",
"\n",
"\n",
"* Objectives [archai.discrete_search.Objective](https://microsoft.github.io/archai/reference/api/archai.discrete_search.api.html#module-archai.discrete_search.api.objective)\n",
" * Must override `Objective.higher_is_better` (optimization direction)\n",
" * Must implement `Objective.evaluate(model, dataset, budget)`\n",
" \n",
"\n",
"* Asynchronous Objectives [archai.discrete_search.AsyncObjective](https://microsoft.github.io/archai/reference/api/archai.discrete_search.api.html#archai.discrete_search.api.objective.AsyncObjective):\n",
" * Must override `AsyncObjective.higher_is_better` (optimization direction)\n",
" * Must implement `AsyncObjective.send(model, dataset, budget)` and `AsyncObjective.fetch_all()` \n",
" \n",
"Objective functions may optionally use a dataset (e.g task accuracy) or not (e.g architecture latency). Objective functions may also receive a `budget` value from some search algorithms. This `budget` value is used by some algorithms to sinalize how much budget the model evaluation should spend.\n",
"\n",
"Read more about them [here](https://microsoft.github.io/archai/reference/api/archai.discrete_search.api.html#module-archai.discrete_search.api.objective)"
]
},
{
"cell_type": "markdown",
"id": "6f649761",
"metadata": {},
"source": [
"### Built-in objective example (AvgOnnxLatency)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a438d949",
"metadata": {},
"outputs": [],
"source": [
"from archai.discrete_search.objectives.onnx_model import AvgOnnxLatency"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e250c928",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
"WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n"
]
},
{
"data": {
"text/plain": [
"0.0004309522919356823"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"onnx_latency_obj = AvgOnnxLatency(input_shape=(1, 3, 64, 64))\n",
"onnx_latency_obj.evaluate(model=ss.random_sample(), dataset_provider=None, budget=None)"
]
},
{
"cell_type": "markdown",
"id": "4bf6fa13",
"metadata": {},
"source": [
"By default `AvgOnnxLatency` will be minimized"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "293f9b9a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"onnx_latency_obj.higher_is_better"
]
},
{
"cell_type": "markdown",
"id": "cbc58063",
"metadata": {},
"source": [
"### Custom objective example"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "cbff79f6",
"metadata": {},
"outputs": [],
"source": [
"from archai.discrete_search import Objective, DatasetProvider"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5e906f6b",
"metadata": {},
"outputs": [],
"source": [
"class NumberOfModules(Objective):\n",
" ''' Class that measures the size of a model by the number of torch modules '''\n",
"\n",
" higher_is_better: bool = False # Smaller models are better\n",
" \n",
" @overrides\n",
" def evaluate(self, model: ArchaiModel, dataset: DatasetProvider,\n",
" budget: Optional[float] = None):\n",
" return len(list(model.arch.modules()))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "58aa1eba",
"metadata": {},
"outputs": [],
"source": [
"m = ss.random_sample()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "d6b5f1f0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"67"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_objective = NumberOfModules()\n",
"my_objective.evaluate(m, None, None)"
]
},
{
"cell_type": "markdown",
"id": "81051526",
"metadata": {},
"source": [
"### Some utility objectives"
]
},
{
"cell_type": "markdown",
"id": "ecb5ca78",
"metadata": {},
"source": [
"* [RayParallelObjective](https://microsoft.github.io/archai/reference/api/archai.discrete_search.objectives.html#module-archai.discrete_search.objectives.ray) - Wraps an existing objective and runs it using multiple Ray workers\n",
"\n",
"* [EvaluationFunction](https://microsoft.github.io/archai/reference/api/archai.discrete_search.objectives.html#module-archai.discrete_search.objectives.functional) - Wraps a function that takes a (model, dataset_provider, budget) triplet and creates an objective"
]
},
{
"cell_type": "markdown",
"id": "bcaf1433",
"metadata": {},
"source": [
"**Parallelizing NumberOfModules**"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8f6666c0",
"metadata": {},
"outputs": [],
"source": [
"from archai.discrete_search.objectives.ray import RayParallelObjective"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "c77984d2",
"metadata": {},
"outputs": [],
"source": [
"my_objective_parallel = RayParallelObjective(\n",
" NumberOfModules(), timeout=10, num_cpus=1.0\n",
")"
]
},
{
"cell_type": "markdown",
"id": "61099086",
"metadata": {},
"source": [
"`my_objective_parallel` is now an `AsyncObjective` object"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "41a38d5d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dispatching job for 4aba6fbdb292e44d634daefa425ab1406684daed_64_64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-10-28 07:39:27,701\tINFO worker.py:1518 -- Started a local Ray instance.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dispatching job for e0521c00e4b6dfa7f624d2d7560d9c220591864b_64_64\n",
"Dispatching job for c60496d4923eaa0062de511eaab3b9cb4ec46a3e_64_64\n",
"Dispatching job for d31e4ef0912834bc51336aaf55fd879606fbf4ca_64_64\n",
"Dispatching job for 915ff7e0aca6e48bbae0def46d64b7300887fb80_64_64\n",
"Dispatching job for 90da2af4f0a0aa0f24cafa1cd59032623ada1c23_64_64\n",
"Dispatching job for fe6c11c85bbcbdaf6b716d9259f5415b7327192d_64_64\n",
"Dispatching job for 65e92bee3ecc899c5c346be82961c331d9f18933_64_64\n",
"Dispatching job for bdf6f69e2a8e08473e9e799ec2d7e627dd915d43_64_64\n",
"Dispatching job for 9b0f792a6e6c37c4e40abde72b4fbd2cdca9ebae_64_64\n"
]
}
],
"source": [
"model_list = [ss.random_sample() for _ in range(10)]\n",
"\n",
"for model in model_list:\n",
" print(f'Dispatching job for {model.archid}')\n",
" my_objective_parallel.send(model, dataset=None, budget=None)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "230bb8f1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[53, 29, 60, 31, 87, 49, 30, 83, 33, 61]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_objective_parallel.fetch_all()"
]
},
{
"cell_type": "markdown",
"id": "f86f894a",
"metadata": {},
"source": [
"**Example: Wrapping custom training code in an objective**"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "bb78fec1",
"metadata": {},
"outputs": [],
"source": [
"from archai.datasets.providers.mnist_provider import MnistProvider\n",
"from archai.discrete_search.objectives.functional import EvaluationFunction"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "2e5a0d51",
"metadata": {},
"outputs": [],
"source": [
"dataset_provider = MnistProvider({'dataroot': '/home/pkauffmann/dataroot/'})"
]
},
{
"cell_type": "markdown",
"id": "854e656c",
"metadata": {},
"source": [
"```python\n",
" def custom_training_val_performance(model, dataset_provider, budget=None):\n",
" tr_data, val_data = dataset_provider.get_datasets(True, True, False, False)\n",
"\n",
" tr_dl = torch.utils.data.DataLoader(tr_data, shuffle=True, batch_size=16)\n",
" val_dl = torch.utils.data.DataLoader(tr_data, shuffle=True, batch_size=16)\n",
"\n",
" optimizer = torch.optim.Adam(model.arch.parameters(), lr=1e-3)\n",
" ...\n",
"\n",
" for batch in tr_dl:\n",
" ...\n",
"\n",
" for batch in val_dl:\n",
" ...\n",
"\n",
" return validation_metric\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "2db08cec",
"metadata": {},
"source": [
"```python\n",
"\n",
"training_objective = EvaluationFunction(custom_traininb_val_performance, higher_is_better=True)\n",
"\n",
"training_objective.evaluate(ss.random_sample(), dataset_provider, budget=None)\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "ebb4c012",
"metadata": {},
"source": [
"See the next notebook for a full example using a custom training objective"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,845 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c4bd875c",
"metadata": {},
"source": [
"## Discrete Search Spaces"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5cdb0135",
"metadata": {},
"outputs": [],
"source": [
"from overrides import overrides\n",
"import numpy as np\n",
"from typing import Tuple, List, Optional\n",
"from archai.discrete_search import ArchaiModel, DiscreteSearchSpace"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8d100fef",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn"
]
},
{
"cell_type": "markdown",
"id": "c9e2ca76",
"metadata": {},
"source": [
"Discrete search spaces in Archai are defined using the `DiscreteSearchSpace` abstract class:\n",
"\n",
"```python\n",
"\n",
"class DiscreteSearchSpace(EnforceOverrides):\n",
"\n",
" @abstractmethod\n",
" def random_sample(self) -> ArchaiModel:\n",
" ...\n",
" \n",
" @abstractmethod\n",
" def save_arch(self, model: ArchaiModel, path: str) -> None:\n",
" ...\n",
"\n",
" @abstractmethod\n",
" def load_arch(self, path: str) -> ArchaiModel:\n",
" ...\n",
"\n",
" @abstractmethod\n",
" def save_model_weights(self, model: ArchaiModel, path: str) -> None:\n",
" ...\n",
"\n",
" @abstractmethod\n",
" def load_model_weights(self, model: ArchaiModel, path: str) -> None:\n",
" ...\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "16149883",
"metadata": {},
"source": [
"#### The `ArchaiModel` abstraction"
]
},
{
"cell_type": "markdown",
"id": "fb4c0081",
"metadata": {},
"source": [
"The `ArchaiModel` abstraction is used to wrap a model object with a given architecture id (`archid`) and optionally a metadata dictionary."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ca7f4686",
"metadata": {},
"outputs": [],
"source": [
"from archai.discrete_search import ArchaiModel"
]
},
{
"cell_type": "markdown",
"id": "9f563253",
"metadata": {},
"source": [
"Example"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ab5fe6fc",
"metadata": {},
"outputs": [],
"source": [
"class DummyModel(nn.Module):\n",
" def __init__(self, nb_layers: int = 2, kernel_size: int = 3):\n",
" super().__init__()\n",
" \n",
" self.nb_layers = nb_layers\n",
" self.kernel_size = kernel_size\n",
" \n",
" layers = []\n",
" for i in range(nb_layers):\n",
" input_dim = 3 if i == 0 else 16\n",
" \n",
" layers += [\n",
" nn.Conv2d(input_dim, 16, kernel_size=kernel_size, padding='same'),\n",
" nn.BatchNorm2d(16),\n",
" nn.ReLU(),\n",
" ]\n",
" \n",
" self.layers = nn.Sequential(*layers)\n",
" \n",
" def forward(self, x):\n",
" return self.layers(x)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0bb34bbf",
"metadata": {},
"outputs": [],
"source": [
"model_obj = DummyModel(nb_layers=2, kernel_size=3)"
]
},
{
"cell_type": "markdown",
"id": "e4dfe5ad",
"metadata": {},
"source": [
"Let's wrap model_obj into an `ArchaiModel`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "290e625a",
"metadata": {},
"outputs": [],
"source": [
"model = ArchaiModel(\n",
" arch=model_obj,\n",
" archid=f'L={model_obj.nb_layers}, K={model_obj.kernel_size}',\n",
" metadata={'optional': {'metadata'}}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a61e5264",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'L=2, K=3'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.archid"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e0ea6fb2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'optional': {'metadata'}}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.metadata"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2337eb37",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DummyModel(\n",
" (layers): Sequential(\n",
" (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): ReLU()\n",
" )\n",
")"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.arch"
]
},
{
"cell_type": "markdown",
"id": "5732d030",
"metadata": {},
"source": [
"Archid will be used to deduplicate seen architectures. It should only identify the architecture and not the model weights"
]
},
{
"cell_type": "markdown",
"id": "6b1d6f94",
"metadata": {},
"source": [
"### ConvNet Search Space Example"
]
},
{
"cell_type": "markdown",
"id": "d9fb59c3",
"metadata": {},
"source": [
"Let's start with a (really) simple search space for image classification"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e3145b08",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self, nb_layers: int = 5, kernel_size: int = 3, hidden_dim: int = 32):\n",
" super().__init__()\n",
" \n",
" self.nb_layers = nb_layers\n",
" self.kernel_size = kernel_size\n",
" self.hidden_dim = hidden_dim\n",
" \n",
" layer_list = []\n",
"\n",
" for i in range(nb_layers):\n",
" in_ch = (3 if i == 0 else hidden_dim)\n",
" \n",
" layer_list += [\n",
" nn.Conv2d(in_ch, hidden_dim, kernel_size=kernel_size, padding='same'),\n",
" nn.BatchNorm2d(hidden_dim),\n",
" nn.ReLU()\n",
" ]\n",
"\n",
" layer_list += [\n",
" nn.Conv2d(hidden_dim, 1, kernel_size=1, padding='same'),\n",
" nn.Sigmoid()\n",
" ]\n",
" \n",
" self.model = nn.Sequential(*layer_list)\n",
" \n",
" def forward(self, x):\n",
" return self.model(x)\n",
" \n",
" def get_archid(self):\n",
" return f'({self.nb_layers}, {self.kernel_size}, {self.hidden_dim})'"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0ccc4d41",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MyModel(\n",
" (model): Sequential(\n",
" (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" (3): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (4): Sigmoid()\n",
" )\n",
")"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = MyModel(nb_layers=1)\n",
"m"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e4bd63f2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'(1, 3, 32)'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m.get_archid()"
]
},
{
"cell_type": "markdown",
"id": "fa42080d",
"metadata": {},
"source": [
"Let's overide DiscreteSearchSpace"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "193ea617",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from typing import Tuple\n",
"from random import Random\n",
"\n",
"class CNNSearchSpace(DiscreteSearchSpace):\n",
" def __init__(self, min_layers: int = 1, max_layers: int = 12,\n",
" kernel_list=(1, 3, 5, 7), hidden_list=(16, 32, 64, 128),\n",
" seed: int = 1):\n",
"\n",
" self.min_layers = min_layers\n",
" self.max_layers = max_layers\n",
" self.kernel_list = kernel_list\n",
" self.hidden_list = hidden_list\n",
" \n",
" self.rng = Random(seed)\n",
" \n",
" @overrides\n",
" def random_sample(self) -> ArchaiModel:\n",
" # Randomly chooses architecture parameters\n",
" nb_layers = self.rng.randint(self.min_layers, self.max_layers)\n",
" kernel_size = self.rng.choice(self.kernel_list)\n",
" hidden_dim = self.rng.choice(self.hidden_list)\n",
" \n",
" model = MyModel(nb_layers, kernel_size, hidden_dim)\n",
" \n",
" # Wraps model into ArchaiModel\n",
" return ArchaiModel(arch=model, archid=model.get_archid())\n",
"\n",
" @overrides\n",
" def save_arch(self, model: ArchaiModel, file: str):\n",
" with open(file, 'w') as fp:\n",
" json.dump({\n",
" 'nb_layers': model.arch.nb_layers,\n",
" 'kernel_size': model.arch.kernel_size,\n",
" 'hidden_dim': model.arch.hidden_dim\n",
" }, fp)\n",
"\n",
" @overrides\n",
" def load_arch(self, file: str):\n",
" config = json.load(open(file))\n",
" model = MyModel(**config)\n",
" \n",
" return ArchaiModel(arch=model, archid=model.get_archid())\n",
"\n",
" @overrides\n",
" def save_model_weights(self, model: ArchaiModel, file: str):\n",
" state_dict = model.arch.get_state_dict()\n",
" torch.save(state_dict, file)\n",
" \n",
" @overrides\n",
" def load_model_weights(self, model: ArchaiModel, file: str):\n",
" model.arch.load_state_dict(torch.load(file))\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "7db02619",
"metadata": {},
"outputs": [],
"source": [
"ss = CNNSearchSpace(hidden_list=[32, 64, 128])"
]
},
{
"cell_type": "markdown",
"id": "6ce23725",
"metadata": {},
"source": [
"Sampling a model"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "83c03fe1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ArchaiModel(\n",
"\tarchid=(3, 1, 64), \n",
"\tmetadata={}, \n",
"\tarch=MyModel(\n",
" (model): Sequential(\n",
" (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): ReLU()\n",
" (6): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (8): ReLU()\n",
" (9): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (10): Sigmoid()\n",
" )\n",
")\n",
")"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = ss.random_sample()\n",
"m"
]
},
{
"cell_type": "markdown",
"id": "619c4d9c",
"metadata": {},
"source": [
"Saving an architecture"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "3dace5d2",
"metadata": {},
"outputs": [],
"source": [
"ss.save_arch(m, 'arch.json')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f1d4dba3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\"nb_layers\": 3, \"kernel_size\": 1, \"hidden_dim\": 64}"
]
}
],
"source": [
"!cat arch.json"
]
},
{
"cell_type": "markdown",
"id": "70813cc7",
"metadata": {},
"source": [
"Loading an architecture (not the weights)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "863ef766",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ArchaiModel(\n",
"\tarchid=(3, 1, 64), \n",
"\tmetadata={}, \n",
"\tarch=MyModel(\n",
" (model): Sequential(\n",
" (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): ReLU()\n",
" (6): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (8): ReLU()\n",
" (9): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (10): Sigmoid()\n",
" )\n",
")\n",
")"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ss.load_arch('arch.json')"
]
},
{
"cell_type": "markdown",
"id": "f5c0b5a3",
"metadata": {},
"source": [
"### Making the search space compatible with different types of algorithms"
]
},
{
"cell_type": "markdown",
"id": "4083db69",
"metadata": {},
"source": [
"* Evolutionary-based algorithms:\n",
" - User must subclass `EvolutionarySearchSpace` and implement `EvolutionarySearchSpace.mutate` and `EvolutionarySearchSpace.crossover`\n",
"\n",
"\n",
"* BO-based algorithms:\n",
" - User must subclass `BayesOptSearchSpace` and override `BayesOptSearchSpace.encode`\n",
" - Encode should take an `ArchaiModel` and produce a fixed-length vector representation of that architecture. This numerical representation will be used to train surrogate models.\n"
]
},
{
"cell_type": "markdown",
"id": "d294ab69",
"metadata": {},
"source": [
"#### Example"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "78b73a68",
"metadata": {},
"outputs": [],
"source": [
"from archai.discrete_search import EvolutionarySearchSpace, BayesOptSearchSpace"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "0e02255f",
"metadata": {},
"outputs": [],
"source": [
"class CNNSearchSpaceExt(CNNSearchSpace, EvolutionarySearchSpace, BayesOptSearchSpace):\n",
" ''' We are subclassing CNNSearchSpace just to save up space'''\n",
" \n",
" @overrides\n",
" def mutate(self, model_1: ArchaiModel) -> ArchaiModel:\n",
" config = {\n",
" 'nb_layers': model_1.arch.nb_layers,\n",
" 'kernel_size': model_1.arch.kernel_size,\n",
" 'hidden_dim': model_1.arch.hidden_dim\n",
" }\n",
" \n",
" if self.rng.random() < 0.2:\n",
" config['nb_layers'] = self.rng.randint(self.min_layers, self.max_layers)\n",
" \n",
" if self.rng.random() < 0.2:\n",
" config['kernel_size'] = self.rng.choice(self.kernel_list)\n",
" \n",
" if self.rng.random() < 0.2:\n",
" config['hidden_dim'] = self.rng.choice(self.hidden_list)\n",
" \n",
" mutated_model = MyModel(**config)\n",
" \n",
" return ArchaiModel(\n",
" arch=mutated_model, archid=mutated_model.get_archid()\n",
" )\n",
" \n",
" @overrides\n",
" def crossover(self, model_list: List[ArchaiModel]) -> ArchaiModel:\n",
" model_1, model_2 = model_list[:2]\n",
" \n",
" new_config = {\n",
" 'nb_layers': self.rng.choice([model_1.arch.nb_layers, model_2.arch.nb_layers]),\n",
" 'kernel_size': self.rng.choice([model_1.arch.kernel_size, model_2.arch.kernel_size]),\n",
" 'hidden_dim': self.rng.choice([model_1.arch.hidden_dim, model_2.arch.hidden_dim]),\n",
" }\n",
" \n",
" crossover_model = MyModel(**new_config)\n",
" \n",
" return ArchaiModel(\n",
" arch=crossover_model, archid=crossover_model.get_archid()\n",
" )\n",
" \n",
" @overrides\n",
" def encode(self, model: ArchaiModel) -> np.ndarray:\n",
" return np.array([model.arch.nb_layers, model.arch.kernel_size, model.arch.hidden_dim])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "8f9b6ba7",
"metadata": {},
"outputs": [],
"source": [
"ss = CNNSearchSpaceExt(hidden_list=[32, 64, 128])"
]
},
{
"cell_type": "markdown",
"id": "7582b266",
"metadata": {},
"source": [
"Example"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "d23e6373",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'(3, 1, 64)'"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = ss.random_sample()\n",
"m.archid"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "6c695837",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'(8, 1, 64)'"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ss.mutate(m).archid"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "e0b99677",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 3, 1, 64])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ss.encode(m)"
]
},
{
"cell_type": "markdown",
"id": "1201e318",
"metadata": {},
"source": [
"Now `CNNSearchSpaceExt` is compatible with Bayesian Optimization and Evolutionary based search algorithms!\n",
"\n",
"**To see a list of built-in search spaces, go to `archai/discrete_search/search_spaces`**"
]
},
{
"cell_type": "markdown",
"id": "1fad0c69",
"metadata": {},
"source": [
"Example: "
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "0ce94672",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ArchaiModel(\n",
"\tarchid=74f66612a0d01c5b7d4702234756b0ee4ffa5abc_64_64, \n",
"\tmetadata={'parent': '32fa5956ab3ce9e05bc42836599a8dc9dd53e847_64_64'}, \n",
"\tarch=SegmentationDagModel(\n",
" (edge_dict): ModuleDict(\n",
" (input-output): Block(\n",
" (op): Sequential(\n",
" (0): NormalConvBlock(\n",
" (conv): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU()\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (stem_block): NormalConvBlock(\n",
" (conv): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU()\n",
" )\n",
" (up): Upsample(size=(64, 64), mode=nearest)\n",
" (post_upsample): Sequential(\n",
" (0): NormalConvBlock(\n",
" (conv): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU()\n",
" )\n",
" (1): NormalConvBlock(\n",
" (conv): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU()\n",
" )\n",
" (2): NormalConvBlock(\n",
" (conv): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU()\n",
" )\n",
" )\n",
" (classifier): Conv2d(40, 1, kernel_size=(1, 1), stride=(1, 1))\n",
")\n",
")"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from archai.discrete_search.search_spaces.segmentation_dag.search_space import SegmentationDagSearchSpace\n",
"\n",
"ss = SegmentationDagSearchSpace(nb_classes=1, img_size=(64, 64), max_layers=3)\n",
"ss.mutate(ss.random_sample())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Просмотреть файл

@ -26,7 +26,8 @@ extensions = [
"sphinxcontrib.programoutput",
"sphinxcontrib.mermaid",
"sphinx_inline_tabs",
"sphinx_git"
"sphinx_git",
"nbsphinx"
]
exclude_patterns = [
"benchmarks/**",

Просмотреть файл

@ -45,6 +45,7 @@ If you use Archai in a scientific publication, please consider citing it:
30-Minute Tutorial <basic_guide/tutorial>
Notebooks <basic_guide/notebooks>
Discrete Search Tutorial <basic_guide/discrete_search>
Examples & Scripts <basic_guide/examples_scripts>
.. toctree::