Executor - Add stdout logging util module and enable real-time logging flushing in executor (#445)

**Description**
Add stdout logging util module and enable real-time logging flushing in executor

**Major Revision**
- Add stdout logging util module to redirect stdout into file log
- enable stdout logging in executor to write benchmark output into both stdout and file `sb-bench.log`
- enable real-time log flushing in run_command of microbenchmarks through config `log_flushing`

**Minor Revision**
- add log_n_step args to enable regular step time log in model benchmarks 
- udpate related docs
This commit is contained in:
Yuting Jiang 2022-12-30 17:40:28 +08:00 коммит произвёл GitHub
Родитель f2634d8608
Коммит 9dfefce350
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
18 изменённых файлов: 216 добавлений и 12 удалений

Просмотреть файл

@ -68,6 +68,7 @@ Here're the details about work directory structure for SuperBench Runner.
│ │ └── rank-0 # output for each rank in each benchmark │ │ └── rank-0 # output for each rank in each benchmark
│ │ ├── results.json # raw results │ │ ├── results.json # raw results
| | └── monitor.jsonl # monitor results (optional) | | └── monitor.jsonl # monitor results (optional)
| ├── sb-bench.log # SuperBench benchmarks' runtime log for debugging
│ └── sb-exec.log # collected SuperBench Executor log │ └── sb-exec.log # collected SuperBench Executor log
├── sb-run.log # SuperBench Runner log ├── sb-run.log # SuperBench Runner log
├── sb.config.yaml # SuperBench configuration snapshot ├── sb.config.yaml # SuperBench configuration snapshot
@ -98,6 +99,7 @@ The `/root` directory is mounted from `$HOME/sb-workspace` on the host path.
│ ├── results.json # raw results │ ├── results.json # raw results
│ └── monitor.jsonl # monitor results (optional) │ └── monitor.jsonl # monitor results (optional)
├── sb.config.yaml # SuperBench configuration snapshot ├── sb.config.yaml # SuperBench configuration snapshot
├── sb-bench.log # SuperBench benchmarks' runtime log for debugging
└── sb.env # SuperBench runtime environment variables └── sb.env # SuperBench runtime environment variables
``` ```

Просмотреть файл

@ -336,10 +336,11 @@ A list of models to run, only supported in model-benchmark.
Parameters for benchmark to use, varying for different benchmarks. Parameters for benchmark to use, varying for different benchmarks.
There have three common parameters for all benchmarks: There have four common parameters for all benchmarks:
* run_count: how many times do user want to run this benchmark, default value is 1. * run_count: how many times do user want to run this benchmark, default value is 1.
* duration: the elapsed time of benchmark in seconds. It can work for all model-benchmark. But for micro-benchmark, benchmark authors should consume it by themselves. * duration: the elapsed time of benchmark in seconds. It can work for all model-benchmark. But for micro-benchmark, benchmark authors should consume it by themselves.
* log_raw_data: log raw data into file instead of saving it into result object, default value is `False`. Benchmarks who have large raw output may want to set it as `True`, such as `nccl-bw`/`rccl-bw`. * log_raw_data: log raw data into file instead of saving it into result object, default value is `False`. Benchmarks who have large raw output may want to set it as `True`, such as `nccl-bw`/`rccl-bw`.
* log_flushing: real-time log flushing, default value is `False`.
For Model-Benchmark, there have some parameters that can control the elapsed time. For Model-Benchmark, there have some parameters that can control the elapsed time.
* duration: the elapsed time of benchmark in seconds. * duration: the elapsed time of benchmark in seconds.

Просмотреть файл

@ -74,6 +74,12 @@ class Benchmark(ABC):
default=False, default=False,
help='Log raw data into file instead of saving it into result object.', help='Log raw data into file instead of saving it into result object.',
) )
self._parser.add_argument(
'--log_flushing',
action='store_true',
default=False,
help='Real-time log flushing.',
)
def get_configurable_settings(self): def get_configurable_settings(self):
"""Get all the configurable settings. """Get all the configurable settings.

Просмотреть файл

@ -122,7 +122,7 @@ class DockerBenchmark(Benchmark):
self._curr_run_index, self._name, self._commands[cmd_idx] self._curr_run_index, self._name, self._commands[cmd_idx]
) )
) )
output = run_command(self._commands[cmd_idx]) output = run_command(self._commands[cmd_idx], flush_output=self._args.log_flushing)
if output.returncode != 0: if output.returncode != 0:
self._result.set_return_code(ReturnCode.DOCKERBENCHMARK_EXECUTION_FAILURE) self._result.set_return_code(ReturnCode.DOCKERBENCHMARK_EXECUTION_FAILURE)
logger.error( logger.error(

Просмотреть файл

@ -173,7 +173,7 @@ class MicroBenchmarkWithInvoke(MicroBenchmark):
) )
) )
output = run_command(self._commands[cmd_idx]) output = run_command(self._commands[cmd_idx], flush_output=self._args.log_flushing)
if output.returncode != 0: if output.returncode != 0:
self._result.set_return_code(ReturnCode.MICROBENCHMARK_EXECUTION_FAILURE) self._result.set_return_code(ReturnCode.MICROBENCHMARK_EXECUTION_FAILURE)
logger.error( logger.error(

Просмотреть файл

@ -8,7 +8,7 @@ import time
import statistics import statistics
from abc import abstractmethod from abc import abstractmethod
from superbench.common.utils import logger from superbench.common.utils import logger, stdout_logger
from superbench.benchmarks import Precision, ModelAction, DistributedImpl, DistributedBackend, BenchmarkType, ReturnCode from superbench.benchmarks import Precision, ModelAction, DistributedImpl, DistributedBackend, BenchmarkType, ReturnCode
from superbench.benchmarks.base import Benchmark from superbench.benchmarks.base import Benchmark
from superbench.benchmarks.context import Enum from superbench.benchmarks.context import Enum
@ -131,6 +131,14 @@ class ModelBenchmark(Benchmark):
help='Enable option to use full float32 precision.', help='Enable option to use full float32 precision.',
) )
self._parser.add_argument(
'--log_n_steps',
type=int,
default=0,
required=False,
help='Real-time log every n steps.',
)
@abstractmethod @abstractmethod
def _judge_gpu_availability(self): def _judge_gpu_availability(self):
"""Judge GPUs' availability according to arguments and running environment.""" """Judge GPUs' availability according to arguments and running environment."""
@ -435,3 +443,16 @@ class ModelBenchmark(Benchmark):
"""Print environments or dependencies information.""" """Print environments or dependencies information."""
# TODO: will implement it when add real benchmarks in the future. # TODO: will implement it when add real benchmarks in the future.
pass pass
def _log_step_time(self, curr_step, precision, duration):
"""Log step time into stdout regularly.
Args:
curr_step (int): the index of current step
precision (Precision): precision of model and input data, such as float32, float16.
duration (list): the durations of all steps
"""
if self._args.log_n_steps and curr_step % self._args.log_n_steps == 0:
step_time = statistics.mean(duration) if len(duration) < self._args.log_n_steps \
else statistics.mean(duration[-self._args.log_n_steps:])
stdout_logger.log(f'{self._name} - {precision.value}: step {curr_step}, step time {step_time}\n')

Просмотреть файл

@ -151,6 +151,7 @@ class PytorchBERT(PytorchBase):
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end, check_frequency): if self._is_finished(curr_step, end, check_frequency):
return duration return duration
@ -179,6 +180,7 @@ class PytorchBERT(PytorchBase):
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end): if self._is_finished(curr_step, end):
return duration return duration

Просмотреть файл

@ -114,6 +114,7 @@ class PytorchCNN(PytorchBase):
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end, check_frequency): if self._is_finished(curr_step, end, check_frequency):
return duration return duration
@ -143,6 +144,7 @@ class PytorchCNN(PytorchBase):
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end): if self._is_finished(curr_step, end):
return duration return duration

Просмотреть файл

@ -145,6 +145,7 @@ class PytorchGPT2(PytorchBase):
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end, check_frequency): if self._is_finished(curr_step, end, check_frequency):
return duration return duration
@ -173,6 +174,7 @@ class PytorchGPT2(PytorchBase):
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end): if self._is_finished(curr_step, end):
return duration return duration

Просмотреть файл

@ -154,6 +154,7 @@ class PytorchLSTM(PytorchBase):
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end, check_frequency): if self._is_finished(curr_step, end, check_frequency):
return duration return duration
@ -183,6 +184,7 @@ class PytorchLSTM(PytorchBase):
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end): if self._is_finished(curr_step, end):
return duration return duration

Просмотреть файл

@ -5,6 +5,7 @@
from superbench.common.utils.azure import get_vm_size from superbench.common.utils.azure import get_vm_size
from superbench.common.utils.logging import SuperBenchLogger, logger from superbench.common.utils.logging import SuperBenchLogger, logger
from superbench.common.utils.stdout_logging import StdLogger, stdout_logger
from superbench.common.utils.file_handler import rotate_dir, create_sb_output_dir, get_sb_config from superbench.common.utils.file_handler import rotate_dir, create_sb_output_dir, get_sb_config
from superbench.common.utils.lazy_import import LazyImport from superbench.common.utils.lazy_import import LazyImport
from superbench.common.utils.process import run_command from superbench.common.utils.process import run_command
@ -16,6 +17,8 @@ device_manager = LazyImport('superbench.common.utils.device_manager')
__all__ = [ __all__ = [
'LazyImport', 'LazyImport',
'SuperBenchLogger', 'SuperBenchLogger',
'StdLogger',
'stdout_logger',
'create_sb_output_dir', 'create_sb_output_dir',
'device_manager', 'device_manager',
'get_sb_config', 'get_sb_config',

Просмотреть файл

@ -129,7 +129,7 @@ class DeviceManager:
Return: Return:
remapped_metrics (dict): the row remapped information, None means failed to get the data. remapped_metrics (dict): the row remapped information, None means failed to get the data.
""" """
output = process.run_command('nvidia-smi -i {} -q'.format(idx)) output = process.run_command('nvidia-smi -i {} -q'.format(idx), quite=True)
if output.returncode == 0: if output.returncode == 0:
begin = output.stdout.find('Remapped Rows') begin = output.stdout.find('Remapped Rows')
end = output.stdout.find('Temperature', begin) end = output.stdout.find('Temperature', begin)

Просмотреть файл

@ -4,19 +4,47 @@
"""Process Utility.""" """Process Utility."""
import subprocess import subprocess
import os
import shlex
from superbench.common.utils import stdout_logger
def run_command(command): def run_command(command, quite=False, flush_output=False):
"""Run command in string format, return the result with stdout and stderr. """Run command in string format, return the result with stdout and stderr.
Args: Args:
command (str): command to run. command (str): command to run.
quite (bool): no stdout display of the command if quite is True.
flush_output (bool): enable real-time output flush or not when running the command.
Return: Return:
result (subprocess.CompletedProcess): The return value from subprocess.run(). result (subprocess.CompletedProcess): The return value from subprocess.run().
""" """
result = subprocess.run( if flush_output:
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True, check=False, universal_newlines=True process = None
) try:
args = shlex.split(command)
return result process = subprocess.Popen(
args, cwd=os.getcwd(), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True
)
output = ''
for line in process.stdout:
output += line
if not quite:
stdout_logger.log(line)
process.wait()
retcode = process.poll()
return subprocess.CompletedProcess(args=args, returncode=retcode, stdout=output)
except Exception as e:
if process:
process.kill()
process.wait()
return subprocess.CompletedProcess(args=args, returncode=-1, stdout=str(e))
else:
result = subprocess.run(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True, check=False, universal_newlines=True
)
if not quite:
stdout_logger.log(result.stdout)
return result

Просмотреть файл

@ -0,0 +1,94 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""SuperBench stdout logging module."""
import sys
class StdLogger:
"""Logger class to enable or disable to redirect STDOUT and STDERR to file."""
class StdoutLoggerStream:
"""StdoutLoggerStream class which redirect the sys.stdout to file."""
def __init__(self, filename, rank):
"""Init the class with filename.
Args:
filename (str): the path of the file to save the log
rank (int): the rank id
"""
self._terminal = sys.stdout
self._rank = rank
self._log_file_handler = open(filename, 'a')
def __getattr__(self, attr):
"""Override __getattr__.
Args:
attr (str): Attribute name.
Returns:
Any: Attribute value.
"""
return getattr(self._terminal, attr)
def write(self, message):
"""Write the message to the stream.
Args:
message (str): the message to log.
"""
message = f'[{self._rank}]: {message}'
self._terminal.write(message)
self._log_file_handler.write(message)
self._log_file_handler.flush()
def flush(self):
"""Override flush."""
pass
def restore(self):
"""Restore sys.stdout and close the file."""
self._log_file_handler.close()
sys.stdout = self._terminal
def add_file_handler(self, filename):
"""Init the class with filename.
Args:
filename (str): the path of file to save the log
"""
self.filename = filename
def __init__(self):
"""Init the logger."""
self.logger_stream = None
def start(self, rank):
"""Start the logger to redirect the sys.stdout to file.
Args:
rank (int): the rank id
"""
self.logger_stream = self.StdoutLoggerStream(self.filename, rank)
sys.stdout = self.logger_stream
sys.stderr = sys.stdout
def stop(self):
"""Restore the sys.stdout to termital."""
if self.logger_stream is not None:
self.logger_stream.restore()
def log(self, message):
"""Write the message into the logger.
Args:
message (str): the message to log.
"""
if self.logger_stream:
self.logger_stream.write(message)
else:
sys.stdout.write(message)
stdout_logger = StdLogger()

Просмотреть файл

@ -10,7 +10,7 @@ from pathlib import Path
from omegaconf import ListConfig from omegaconf import ListConfig
from superbench.benchmarks import Platform, Framework, BenchmarkRegistry from superbench.benchmarks import Platform, Framework, BenchmarkRegistry
from superbench.common.utils import SuperBenchLogger, logger, rotate_dir from superbench.common.utils import SuperBenchLogger, logger, rotate_dir, stdout_logger
from superbench.common.devices import GPU from superbench.common.devices import GPU
from superbench.monitor import Monitor from superbench.monitor import Monitor
@ -29,6 +29,7 @@ class SuperBenchExecutor():
self._output_path = Path(sb_output_dir).expanduser().resolve() self._output_path = Path(sb_output_dir).expanduser().resolve()
self.__set_logger('sb-exec.log') self.__set_logger('sb-exec.log')
self.__set_stdout_logger(self._output_path / 'sb-bench.log')
logger.debug('Executor uses config: %s.', self._sb_config) logger.debug('Executor uses config: %s.', self._sb_config)
logger.debug('Executor writes to: %s.', str(self._output_path)) logger.debug('Executor writes to: %s.', str(self._output_path))
@ -46,6 +47,16 @@ class SuperBenchExecutor():
""" """
SuperBenchLogger.add_handler(logger.logger, filename=str(self._output_path / filename)) SuperBenchLogger.add_handler(logger.logger, filename=str(self._output_path / filename))
def __set_stdout_logger(self, filename):
"""Set stdout logger and redirect logs and stdout into the file.
Args:
filename (str): Log file name.
"""
stdout_logger.add_file_handler(filename)
stdout_logger.start(self.__get_rank_id())
SuperBenchLogger.add_handler(logger.logger, filename=filename)
def __validate_sb_config(self): def __validate_sb_config(self):
"""Validate SuperBench config object. """Validate SuperBench config object.
@ -244,5 +255,6 @@ class SuperBenchExecutor():
if monitor: if monitor:
monitor.stop() monitor.stop()
stdout_logger.stop()
self.__write_benchmark_results(benchmark_name, benchmark_results) self.__write_benchmark_results(benchmark_name, benchmark_results)
os.chdir(cwd) os.chdir(cwd)

Просмотреть файл

@ -158,6 +158,8 @@ def test_arguments_related_interfaces():
--duration int The elapsed time of benchmark in seconds. --duration int The elapsed time of benchmark in seconds.
--force_fp32 Enable option to use full float32 precision. --force_fp32 Enable option to use full float32 precision.
--hidden_size int Hidden size. --hidden_size int Hidden size.
--log_flushing Real-time log flushing.
--log_n_steps int Real-time log every n steps.
--log_raw_data Log raw data into file instead of saving it into --log_raw_data Log raw data into file instead of saving it into
result object. result object.
--model_action ModelAction [ModelAction ...] --model_action ModelAction [ModelAction ...]
@ -194,6 +196,8 @@ def test_preprocess():
--duration int The elapsed time of benchmark in seconds. --duration int The elapsed time of benchmark in seconds.
--force_fp32 Enable option to use full float32 precision. --force_fp32 Enable option to use full float32 precision.
--hidden_size int Hidden size. --hidden_size int Hidden size.
--log_flushing Real-time log flushing.
--log_n_steps int Real-time log every n steps.
--log_raw_data Log raw data into file instead of saving it into --log_raw_data Log raw data into file instead of saving it into
result object. result object.
--model_action ModelAction [ModelAction ...] --model_action ModelAction [ModelAction ...]

Просмотреть файл

@ -114,6 +114,7 @@ def test_get_benchmark_configurable_settings():
expected = """optional arguments: expected = """optional arguments:
--duration int The elapsed time of benchmark in seconds. --duration int The elapsed time of benchmark in seconds.
--log_flushing Real-time log flushing.
--log_raw_data Log raw data into file instead of saving it into result --log_raw_data Log raw data into file instead of saving it into result
object. object.
--lower_bound int The lower bound for accumulation. --lower_bound int The lower bound for accumulation.

Просмотреть файл

@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Test common.util.process."""
from superbench.common.utils import run_command
def test_run_command():
"""Test run_command."""
command = 'echo 123'
expected_output = '123\n'
output = run_command(command)
assert (output.stdout == expected_output)
assert (output.returncode == 0)
output = run_command(command, flush_output=True)
assert (output.stdout == expected_output)
assert (output.returncode == 0)
command = 'abb'
output = run_command(command)
assert (output.returncode != 0)
output = run_command(command, flush_output=True)
assert (output.returncode != 0)