CLI - Support SKU auto detect if running on Azure VM (#365)

Support SKU auto detect and using corresponding benchmark config if running on Azure VM.
This commit is contained in:
Yifan Xiong 2022-07-05 10:52:39 +08:00 коммит произвёл GitHub
Родитель 620192a242
Коммит a94ead34b0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 104 добавлений и 5 удалений

Просмотреть файл

@ -157,6 +157,7 @@ setup(
'pyyaml>=5.3',
'seaborn>=0.11.2',
'tcping>=0.1.1rc1',
'urllib3>=1.26.9',
'xlrd>=2.0.1',
'xlsxwriter>=1.3.8',
'xmltodict>=0.12.0',

Просмотреть файл

@ -3,6 +3,7 @@
"""Exposes the interface of SuperBench common utilities."""
from superbench.common.utils.azure import get_vm_size
from superbench.common.utils.logging import SuperBenchLogger, logger
from superbench.common.utils.file_handler import rotate_dir, create_sb_output_dir, get_sb_config
from superbench.common.utils.lazy_import import LazyImport
@ -14,10 +15,11 @@ __all__ = [
'LazyImport',
'SuperBenchLogger',
'create_sb_output_dir',
'device_manager',
'get_sb_config',
'get_vm_size',
'logger',
'network',
'device_manager',
'rotate_dir',
'run_command',
]

Просмотреть файл

@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Utilities for Azure services."""
import urllib3
def get_azure_imds(path, format='text'):
"""Get metadata from Azure Instance Metadata Service.
Args:
path (str): URL path for Azure Instance Metadata Service.
format (str, optional): Response format, text or json. Defaults to 'text'.
Returns:
str: Metadata in response. Defaults to '' if timeout or error occurs.
"""
http = urllib3.PoolManager(
headers={'Metadata': 'true'},
timeout=urllib3.Timeout(connect=1.0, read=1.0),
retries=urllib3.Retry(total=3, connect=0, backoff_factor=1.0),
)
try:
r = http.request('GET', f'http://169.254.169.254/metadata/{path}?api-version=2020-06-01&format={format}')
return r.data.decode('ascii')
except Exception:
return ''
def get_vm_size():
"""Get Azure VM SKU.
Returns:
str: Azure VM SKU.
"""
return get_azure_imds('instance/compute/vmSize')

Просмотреть файл

@ -10,7 +10,7 @@ from datetime import datetime
import yaml
from omegaconf import OmegaConf
from superbench.common.utils import logger
from superbench.common.utils import logger, get_vm_size
def rotate_dir(target_dir):
@ -57,7 +57,7 @@ def create_sb_output_dir(output_dir=None):
def get_sb_config(config_file):
"""Read SuperBench config yaml.
Read config file, use default config if None is provided.
Read config file, detect Azure SKU and use corresponding config if None is provided.
Args:
config_file (str): config file path.
@ -65,8 +65,18 @@ def get_sb_config(config_file):
Returns:
OmegaConf: Config object, None if file does not exist.
"""
default_config_file = Path(__file__).parent / '../../config/default.yaml'
p = Path(config_file) if config_file else default_config_file
p = Path(str(config_file))
if not config_file:
config_path = (Path(__file__).parent / '../../config').resolve()
p = config_path / 'default.yaml'
vm_size = get_vm_size().lower()
if vm_size:
logger.info('Detected Azure SKU %s.', vm_size)
for config in (config_path / 'azure').glob('**/*'):
if config.name.startswith(vm_size):
p = config
break
logger.info('No benchmark config provided, using config file %s.', str(p))
if not p.is_file():
return None
with p.open() as fp:

Просмотреть файл

@ -0,0 +1,49 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for Azure services utilities."""
import unittest
from unittest import mock
from pathlib import Path
import yaml
from omegaconf import OmegaConf
from superbench.common.utils import get_sb_config
class FileHandlerUtilsTestCase(unittest.TestCase):
"""A class for file_handler test cases."""
@mock.patch('superbench.common.utils.azure.get_azure_imds')
def test_get_sb_config_default(self, mock_get_azure_imds):
"""Test get_sb_config when no SKU detected, should use default config.
Args:
mock_get_azure_imds (function): Mock get_azure_imds function.
"""
mock_get_azure_imds.return_value = ''
with (Path.cwd() / 'superbench/config/default.yaml').open() as fp:
self.assertEqual(get_sb_config(None), OmegaConf.create(yaml.load(fp, Loader=yaml.SafeLoader)))
@mock.patch('superbench.common.utils.azure.get_azure_imds')
def test_get_sb_config_sku(self, mock_get_azure_imds):
"""Test get_sb_config when SKU detected and config exists, should use corresponding config.
Args:
mock_get_azure_imds (function): Mock get_azure_imds function.
"""
mock_get_azure_imds.return_value = 'Standard_NC96ads_A100_v4'
with (Path.cwd() / 'superbench/config/azure/inference/standard_nc96ads_a100_v4.yaml').open() as fp:
self.assertEqual(get_sb_config(None), OmegaConf.create(yaml.load(fp, Loader=yaml.SafeLoader)))
@mock.patch('superbench.common.utils.azure.get_azure_imds')
def test_get_sb_config_sku_nonexist(self, mock_get_azure_imds):
"""Test get_sb_config when SKU detected and no config exists, should use default config.
Args:
mock_get_azure_imds (function): Mock get_azure_imds function.
"""
mock_get_azure_imds.return_value = 'Standard_Nonexist_A100_v4'
with (Path.cwd() / 'superbench/config/default.yaml').open() as fp:
self.assertEqual(get_sb_config(None), OmegaConf.create(yaml.load(fp, Loader=yaml.SafeLoader)))