archai/tests/zero_model_test.py

30 строки
1.1 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from archai.nas.model import Model
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.common.common import common_init
def test_darts_zero_model():
conf = common_init(config_filepath='confs/algos/darts.yaml')
conf_search = conf['nas']['search']
model_desc = conf_search['model_desc']
model_desc_builder = ModelDescBuilder()
model_desc = model_desc_builder.build(model_desc)
m = Model(model_desc, False, True)
y, aux = m(torch.rand((1, 3, 32, 32)))
assert isinstance(y, torch.Tensor) and y.shape==(1,10) and aux is None
def test_petridish_zero_model():
conf = common_init(config_filepath='confs/petridish_cifar.yaml')
conf_search = conf['nas']['search']
model_desc = conf_search['model_desc']
model_desc_builder = ModelDescBuilder()
model_desc = model_desc_builder.build(model_desc)
m = Model(model_desc, False, True)
y, aux = m(torch.rand((1, 3, 32, 32)))
assert isinstance(y, torch.Tensor) and y.shape==(1,10) and aux is None