зеркало из https://github.com/microsoft/nni.git
NAS benchmark (stage 1) - move files (#5379)
This commit is contained in:
Родитель
7a25f7d78d
Коммит
55d7ffb133
|
@ -87,7 +87,7 @@ autodoc_mock_imports = [
|
|||
# Some of our modules cannot generate summary
|
||||
autosummary_mock_imports = [
|
||||
'nni.retiarii.codegen.tensorflow',
|
||||
'nni.nas.benchmarks.nasbench101.db_gen',
|
||||
'nni.nas.benchmark.nasbench101.db_gen',
|
||||
'nni.tools.jupyter_extension.management',
|
||||
] + autodoc_mock_imports
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ Data Preparation
|
|||
Option 1 (Recommended)
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
You can download the preprocessed benchmark files via ``python -m nni.nas.benchmarks.download <benchmark_name>``, where ``<benchmark_name>`` can be ``nasbench101``, ``nasbench201``, and etc. Add ``--help`` to the command for supported command line arguments.
|
||||
You can download the preprocessed benchmark files via ``python -m nni.nas.benchmark.download <benchmark_name>``, where ``<benchmark_name>`` can be ``nasbench101``, ``nasbench201``, and etc. Add ``--help`` to the command for supported command line arguments.
|
||||
|
||||
Option 2
|
||||
^^^^^^^^
|
||||
|
@ -79,7 +79,7 @@ Instead of storing results obtained with different configurations in separate fi
|
|||
|
||||
Here is a list of available operators used in NDS.
|
||||
|
||||
.. automodule:: nni.nas.benchmarks.nds.constants
|
||||
.. automodule:: nni.nas.benchmark.nds.constants
|
||||
:noindex:
|
||||
|
||||
See :doc:`example usages </tutorials/nasbench_as_dataset>` and :ref:`API references <nds-reference>`.
|
||||
|
|
|
@ -15,5 +15,5 @@ fi
|
|||
echo "Generating database..."
|
||||
rm -f ${NASBENCHMARK_DIR}/nasbench101.db ${NASBENCHMARK_DIR}/nasbench101.db-journal
|
||||
mkdir -p ${NASBENCHMARK_DIR}
|
||||
python3 -m nni.nas.benchmarks.nasbench101.db_gen nasbench_full.tfrecord
|
||||
python3 -m nni.nas.benchmark.nasbench101.db_gen nasbench_full.tfrecord
|
||||
rm -f nasbench_full.tfrecord
|
||||
|
|
|
@ -15,5 +15,5 @@ fi
|
|||
echo "Generating database..."
|
||||
rm -f ${NASBENCHMARK_DIR}/nasbench201.db ${NASBENCHMARK_DIR}/nasbench201.db-journal
|
||||
mkdir -p ${NASBENCHMARK_DIR}
|
||||
python3 -m nni.nas.benchmarks.nasbench201.db_gen a.pth
|
||||
python3 -m nni.nas.benchmark.nasbench201.db_gen a.pth
|
||||
rm -f a.pth
|
||||
|
|
|
@ -16,5 +16,5 @@ unzip data.zip
|
|||
echo "Generating database..."
|
||||
rm -f ${NASBENCHMARK_DIR}/nds.db ${NASBENCHMARK_DIR}/nds.db-journal
|
||||
mkdir -p ${NASBENCHMARK_DIR}
|
||||
python3 -m nni.nas.benchmarks.nds.db_gen nds_data
|
||||
python3 -m nni.nas.benchmark.nds.db_gen nds_data
|
||||
rm -rf data.zip nds_data
|
||||
|
|
|
@ -35,5 +35,5 @@ cd ..
|
|||
echo "Generating database..."
|
||||
rm -f ${NASBENCHMARK_DIR}/nlp.db ${NASBENCHMARK_DIR}/nlp.db-journal
|
||||
mkdir -p ${NASBENCHMARK_DIR}
|
||||
python3 -m nni.nas.benchmarks.nlp.db_gen nlp_data
|
||||
python3 -m nni.nas.benchmark.nlp.db_gen nlp_data
|
||||
rm -rf nlp_data
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import warnings
|
||||
try:
|
||||
import peewee
|
||||
except ImportError:
|
||||
warnings.warn('peewee is not installed. Please install it to use NAS benchmarks.')
|
||||
|
||||
# from .evaluator import *
|
||||
# from .space import *
|
||||
from .utils import load_benchmark, download_benchmark
|
|
@ -2,5 +2,5 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from .constants import INPUT, OUTPUT, CONV3X3_BN_RELU, CONV1X1_BN_RELU, MAXPOOL3X3
|
||||
from .model import Nb101TrialStats, Nb101IntermediateStats, Nb101TrialConfig
|
||||
from .schema import Nb101TrialStats, Nb101IntermediateStats, Nb101TrialConfig
|
||||
from .query import query_nb101_trial_stats
|
|
@ -6,8 +6,8 @@ import argparse
|
|||
from tqdm import tqdm
|
||||
from nasbench import api # pylint: disable=import-error
|
||||
|
||||
from nni.nas.benchmarks.utils import load_benchmark
|
||||
from .model import Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats
|
||||
from nni.nas.benchmark.utils import load_benchmark
|
||||
from .schema import Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats
|
||||
from .graph_util import nasbench_format_to_architecture_repr, hash_module
|
||||
|
||||
|
|
@ -6,8 +6,8 @@ import functools
|
|||
from peewee import fn
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
from nni.nas.benchmarks.utils import load_benchmark
|
||||
from .model import Nb101TrialStats, Nb101TrialConfig, proxy
|
||||
from nni.nas.benchmark.utils import load_benchmark
|
||||
from .schema import Nb101TrialStats, Nb101TrialConfig, proxy
|
||||
from .graph_util import hash_module, infer_num_vertices
|
||||
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
|
||||
from playhouse.sqlite_ext import JSONField
|
||||
|
||||
from nni.nas.benchmarks.utils import json_dumps
|
||||
from nni.nas.benchmark.utils import json_dumps
|
||||
|
||||
proxy = Proxy()
|
||||
|
|
@ -2,5 +2,5 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3
|
||||
from .model import Nb201TrialStats, Nb201IntermediateStats, Nb201TrialConfig
|
||||
from .schema import Nb201TrialStats, Nb201IntermediateStats, Nb201TrialConfig
|
||||
from .query import query_nb201_trial_stats
|
|
@ -7,9 +7,9 @@ import re
|
|||
import tqdm
|
||||
import torch
|
||||
|
||||
from nni.nas.benchmarks.utils import load_benchmark
|
||||
from nni.nas.benchmark.utils import load_benchmark
|
||||
from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3
|
||||
from .model import Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats
|
||||
from .schema import Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats
|
||||
|
||||
|
||||
def parse_arch_str(arch_str):
|
|
@ -6,8 +6,8 @@ import functools
|
|||
from peewee import fn
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
from nni.nas.benchmarks.utils import load_benchmark
|
||||
from .model import Nb201TrialStats, Nb201TrialConfig, proxy
|
||||
from nni.nas.benchmark.utils import load_benchmark
|
||||
from .schema import Nb201TrialStats, Nb201TrialConfig, proxy
|
||||
|
||||
|
||||
def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False):
|
|
@ -4,7 +4,7 @@
|
|||
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
|
||||
from playhouse.sqlite_ext import JSONField
|
||||
|
||||
from nni.nas.benchmarks.utils import json_dumps
|
||||
from nni.nas.benchmark.utils import json_dumps
|
||||
|
||||
proxy = Proxy()
|
||||
|
|
@ -2,5 +2,5 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from .constants import *
|
||||
from .model import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
|
||||
from .schema import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
|
||||
from .query import query_nds_trial_stats
|
|
@ -8,8 +8,8 @@ import os
|
|||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from nni.nas.benchmarks.utils import load_benchmark
|
||||
from .model import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
|
||||
from nni.nas.benchmark.utils import load_benchmark
|
||||
from .schema import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
|
||||
|
||||
|
||||
def inject_item(db, item, proposer, dataset, generator):
|
|
@ -6,8 +6,8 @@ import functools
|
|||
from peewee import fn
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
from nni.nas.benchmarks.utils import load_benchmark
|
||||
from .model import NdsTrialStats, NdsTrialConfig, proxy
|
||||
from nni.nas.benchmark.utils import load_benchmark
|
||||
from .schema import NdsTrialStats, NdsTrialConfig, proxy
|
||||
|
||||
|
||||
def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset,
|
|
@ -4,7 +4,7 @@
|
|||
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
|
||||
from playhouse.sqlite_ext import JSONField
|
||||
|
||||
from nni.nas.benchmarks.utils import json_dumps
|
||||
from nni.nas.benchmark.utils import json_dumps
|
||||
|
||||
proxy = Proxy()
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .model import NlpTrialStats, NlpIntermediateStats, NlpTrialConfig
|
||||
from .schema import NlpTrialStats, NlpIntermediateStats, NlpTrialConfig
|
||||
from .query import query_nlp_trial_stats
|
|
@ -6,7 +6,7 @@ import os
|
|||
import argparse
|
||||
import tqdm
|
||||
|
||||
from .model import db, NlpTrialConfig, NlpTrialStats, NlpIntermediateStats
|
||||
from .schema import db, NlpTrialConfig, NlpTrialStats, NlpIntermediateStats
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
|
@ -5,7 +5,7 @@ import functools
|
|||
|
||||
from peewee import fn
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
from .model import NlpTrialStats, NlpTrialConfig
|
||||
from .schema import NlpTrialStats, NlpTrialConfig
|
||||
|
||||
def query_nlp_trial_stats(arch, dataset, reduction=None, include_intermediates=False):
|
||||
"""
|
|
@ -6,8 +6,8 @@ import os
|
|||
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
|
||||
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
|
||||
|
||||
from nni.nas.benchmarks.utils import json_dumps
|
||||
from nni.nas.benchmarks.constants import DATABASE_DIR
|
||||
from nni.nas.benchmark.utils import json_dumps
|
||||
from nni.nas.benchmark.constants import DATABASE_DIR
|
||||
|
||||
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nlp.db'), autoconnect=True)
|
||||
|
|
@ -35,7 +35,7 @@ def load_benchmark(benchmark: str) -> SqliteExtDatabase:
|
|||
load_or_download_file(local_path, url)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(
|
||||
f'Please use `nni.nas.benchmarks.download_benchmark("{benchmark}")` to setup the benchmark first before using it.'
|
||||
f'Please use `nni.nas.benchmark.download_benchmark("{benchmark}")` to setup the benchmark first before using it.'
|
||||
)
|
||||
_loaded_benchmarks[benchmark] = SqliteExtDatabase(local_path, autoconnect=True)
|
||||
return _loaded_benchmarks[benchmark]
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .utils import load_benchmark, download_benchmark
|
Загрузка…
Ссылка в новой задаче