Removed all the deprecated python training code and related tests and utils (#18333)

### Description
Motivation for this PR is code cleanup.

1. Remove all deprecated python code related to orttrainer, old
checkpoint, related tests and utils
2. Cleanup orttraining_pybind_state.cc to remove all deprecated
bindings.
This commit is contained in:
Ashwini Khade 2023-11-17 18:19:21 -08:00 коммит произвёл GitHub
Родитель cbb85b4874
Коммит 02333293de
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
57 изменённых файлов: 21 добавлений и 16534 удалений

Просмотреть файл

@ -339,9 +339,6 @@ configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in
${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py)
if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/deprecated/*.py"
)
file(GLOB onnxruntime_python_root_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/*.py"
)
@ -419,10 +416,6 @@ if (onnxruntime_ENABLE_TRAINING)
"${ORTTRAINING_SOURCE_DIR}/python/training/onnxblock/optim/*"
)
endif()
else()
file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/training/*.py"
)
endif()
if (onnxruntime_BUILD_UNIT_TESTS)
@ -577,9 +570,6 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_capi_training_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/training/
COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:onnxruntime_pybind11_state>
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/
@ -750,9 +740,6 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils/data/
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils/hooks/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_capi_training_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/training/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_root_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/

Просмотреть файл

@ -61,7 +61,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa:
from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401
from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401
from onnxruntime.capi.training import * # noqa: F403
# TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
try: # noqa: SIM105

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,102 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import unittest
from numpy.testing import assert_allclose, assert_array_equal
from onnxruntime_test_ort_trainer import run_bert_training_test
class TestOrtTrainer(unittest.TestCase):
def test_bert_training_mixed_precision(self):
expected_losses = [
11.034248352050781,
11.125300407409668,
11.006105422973633,
11.047048568725586,
11.027417182922363,
11.015759468078613,
11.060905456542969,
10.971782684326172,
]
expected_all_finites = [True, True, True, True, True, True, True, True]
expected_eval_loss = [10.959012985229492]
actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test(
gradient_accumulation_steps=1,
use_mixed_precision=True,
allreduce_post_accumulation=False,
use_simple_model_desc=False,
)
rtol = 1e-02
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
assert_allclose(
expected_eval_loss,
actual_eval_loss,
rtol=rtol,
err_msg="evaluation loss mismatch",
)
def test_bert_training_mixed_precision_internal_loss_scale(self):
expected_losses = [
11.034248352050781,
11.125300407409668,
11.006105422973633,
11.047048568725586,
11.027417182922363,
11.015759468078613,
11.060905456542969,
10.971782684326172,
]
expected_eval_loss = [10.959012985229492]
actual_losses, actual_eval_loss = run_bert_training_test(
gradient_accumulation_steps=1,
use_mixed_precision=True,
allreduce_post_accumulation=False,
use_simple_model_desc=False,
use_internel_loss_scale=True,
)
rtol = 1e-02
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
assert_allclose(
expected_eval_loss,
actual_eval_loss,
rtol=rtol,
err_msg="evaluation loss mismatch",
)
def test_bert_training_gradient_accumulation_mixed_precision(self):
expected_losses = [
11.034248352050781,
11.125300407409668,
11.006077766418457,
11.047025680541992,
11.027434349060059,
11.0156831741333,
11.060973167419434,
10.971841812133789,
]
expected_all_finites = [True, True]
expected_eval_loss = [10.95903205871582]
actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test(
gradient_accumulation_steps=4,
use_mixed_precision=True,
allreduce_post_accumulation=False,
use_simple_model_desc=False,
)
rtol = 1e-02
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
assert_allclose(
expected_eval_loss,
actual_eval_loss,
rtol=rtol,
err_msg="evaluation loss mismatch",
)
if __name__ == "__main__":
unittest.main(module=__name__, buffer=True)

Просмотреть файл

@ -1,95 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import unittest
import torch
import torch.nn as nn
from numpy.testing import assert_allclose
from onnxruntime_test_ort_trainer import map_optimizer_attributes, ort_trainer_learning_rate_description
from onnxruntime_test_training_unittest_utils import process_dropout
import onnxruntime
from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer
class TestTrainingDropout(unittest.TestCase):
def setUp(self):
torch.manual_seed(1)
onnxruntime.set_seed(1)
@unittest.skip(
"Temporarily disable this test. The graph below will trigger ORT to "
"sort backward graph before forward graph which gives incorrect result. "
"https://github.com/microsoft/onnxruntime/issues/16801"
)
def test_training_and_eval_dropout(self):
class TwoDropoutNet(nn.Module):
def __init__(self, drop_prb_1, drop_prb_2, dim_size):
super().__init__()
self.drop_1 = nn.Dropout(drop_prb_1)
self.drop_2 = nn.Dropout(drop_prb_2)
self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32))
def forward(self, x):
x = x + self.weight_1
x = self.drop_1(x)
x = self.drop_2(x)
output = x
return output[0]
dim_size = 3
device = torch.device("cuda", 0)
# This will drop all values, therefore expecting all 0 in output tensor
model = TwoDropoutNet(0.999, 0.999, dim_size)
input_desc = IODescription("input", [dim_size], torch.float32)
output_desc = IODescription("output", [], torch.float32)
model_desc = ModelDescription([input_desc], [output_desc])
lr_desc = ort_trainer_learning_rate_description()
model = ORTTrainer(
model,
None,
model_desc,
"LambOptimizer",
map_optimizer_attributes,
lr_desc,
device,
postprocess_model=process_dropout,
world_rank=0,
world_size=1,
)
input = torch.ones(dim_size, dtype=torch.float32).to(device)
expected_training_output = [0.0]
expected_eval_output = [1.0]
learning_rate = torch.tensor([1.0000000e00]).to(device)
input_args = [input, learning_rate]
train_output = model.train_step(*input_args)
rtol = 1e-04
assert_allclose(
expected_training_output,
train_output.item(),
rtol=rtol,
err_msg="dropout training loss mismatch",
)
eval_output = model.eval_step(input)
assert_allclose(
expected_eval_output,
eval_output.item(),
rtol=rtol,
err_msg="dropout eval loss mismatch",
)
# Do another train step to make sure it's using original ratios
train_output_2 = model.train_step(*input_args)
assert_allclose(
expected_training_output,
train_output_2.item(),
rtol=rtol,
err_msg="dropout training loss 2 mismatch",
)
if __name__ == "__main__":
unittest.main(module=__name__, buffer=True)

Просмотреть файл

@ -1,56 +0,0 @@
import numpy as np
from onnx import numpy_helper
def get_node_index(model, node):
i = 0
while i < len(model.graph.node):
if model.graph.node[i] == node:
break
i += 1
return i if i < len(model.graph.node) else None
def add_const(model, name, output, t_value=None, f_value=None):
const_node = model.graph.node.add()
const_node.op_type = "Constant"
const_node.name = name
const_node.output.extend([output])
attr = const_node.attribute.add()
attr.name = "value"
if t_value is not None:
attr.type = 4
attr.t.CopyFrom(t_value)
else:
attr.type = 1
attr.f = f_value
return const_node
def process_dropout(model):
dropouts = []
index = 0
for node in model.graph.node:
if node.op_type == "Dropout":
new_dropout = model.graph.node.add()
new_dropout.op_type = "TrainableDropout"
new_dropout.name = "TrainableDropout_%d" % index
# make ratio node
ratio = np.asarray([node.attribute[0].f], dtype=np.float32)
print(ratio.shape)
ratio_value = numpy_helper.from_array(ratio)
ratio_node = add_const(
model,
"dropout_node_ratio_%d" % index,
"dropout_node_ratio_%d" % index,
t_value=ratio_value,
)
print(ratio_node)
new_dropout.input.extend([node.input[0], ratio_node.output[0]])
new_dropout.output.extend(node.output)
dropouts.append(get_node_index(model, node))
index += 1
dropouts.sort(reverse=True)
for d in dropouts:
del model.graph.node[d]
model.opset_import[0].version = 10

Просмотреть файл

@ -1,127 +0,0 @@
import os
import torch
def list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"):
ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)]
ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)]
ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names]
assert len(ckpt_file_names) > 0, 'No checkpoint files found with prefix "{}" in directory {}.'.format(
checkpoint_prefix, checkpoint_dir
)
return ckpt_file_names
def get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None):
SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806
MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806
if is_partitioned:
filename = MULTIPLE_CHECKPOINT_FILENAME.format(
prefix=prefix, world_rank=world_rank, world_size=(world_size - 1)
)
else:
filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix)
return filename
def _split_state_dict(state_dict):
optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"]
split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}}
for k, v in state_dict.items():
mode = "fp32_param"
for optim_key in optimizer_keys:
if k.startswith(optim_key):
mode = "optimizer"
break
if k.endswith("_fp16"):
mode = "fp16_param"
split_sd[mode][k] = v
return split_sd
class CombineZeroCheckpoint:
def __init__(self, checkpoint_files, clean_state_dict=None):
assert len(checkpoint_files) > 0, "No checkpoint files passed"
self.checkpoint_files = checkpoint_files
self.clean_state_dict = clean_state_dict
self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1
assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files"
self.weight_shape_map = dict()
self.sharded_params = set()
def _split_name(self, name: str):
name_split = name.split("_view_")
view_num = None
if len(name_split) > 1:
view_num = int(name_split[1])
optimizer_key = ""
mp_suffix = ""
if name_split[0].startswith("Moment_1"):
optimizer_key = "Moment_1_"
elif name_split[0].startswith("Moment_2"):
optimizer_key = "Moment_2_"
elif name_split[0].startswith("Update_Count"):
optimizer_key = "Update_Count_"
elif name_split[0].endswith("_fp16"):
mp_suffix = "_fp16"
param_name = name_split[0]
if optimizer_key:
param_name = param_name.split(optimizer_key)[1]
param_name = param_name.split("_fp16")[0]
return param_name, optimizer_key, view_num, mp_suffix
def _update_weight_statistics(self, name, value):
if name not in self.weight_shape_map:
self.weight_shape_map[name] = value.size() # original shape of tensor
def _reshape_tensor(self, key):
value = self.aggregate_state_dict[key]
weight_name, _, _, _ = self._split_name(key)
set_size = self.weight_shape_map[weight_name]
self.aggregate_state_dict[key] = value.reshape(set_size)
def _aggregate(self, param_dict):
for k, v in param_dict.items():
weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k)
if view_num is not None:
# parameter is sharded
param_name = optimizer_key + weight_name + mp_suffix
if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]:
self.sharded_params.add(param_name)
# Found a previous shard of the param, concatenate shards ordered by ranks
self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v))
else:
self.aggregate_state_dict[param_name] = v
else:
if k in self.aggregate_state_dict:
assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value"
else:
self.aggregate_state_dict[k] = v
self._update_weight_statistics(weight_name, v)
def aggregate_checkpoints(self):
checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0]
self.aggregate_state_dict = dict()
for i in range(self.world_size):
checkpoint_name = get_checkpoint_name(checkpoint_prefix, True, i, self.world_size)
rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu"))
if "model" in rank_state_dict:
rank_state_dict = rank_state_dict["model"]
if self.clean_state_dict:
rank_state_dict = self.clean_state_dict(rank_state_dict)
rank_state_dict = _split_state_dict(rank_state_dict)
self._aggregate(rank_state_dict["fp16_param"])
self._aggregate(rank_state_dict["fp32_param"])
self._aggregate(rank_state_dict["optimizer"])
for k in self.sharded_params:
self._reshape_tensor(k)
return self.aggregate_state_dict

Просмотреть файл

@ -1,6 +0,0 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from onnxruntime.capi._pybind_state import TrainingParameters # noqa: F401
from onnxruntime.capi.training.training_session import TrainingSession # noqa: F401

Просмотреть файл

@ -1,68 +0,0 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os # noqa: F401
import sys # noqa: F401
from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401
from onnxruntime.capi.onnxruntime_inference_collection import (
InferenceSession,
Session,
check_and_normalize_provider_args,
)
class TrainingSession(InferenceSession):
def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None):
Session.__init__(self)
if sess_options:
self._sess = C.TrainingSession(sess_options)
else:
self._sess = C.TrainingSession()
# providers needs to be passed explicitly as of ORT 1.10
# retain the pre-1.10 behavior by setting to the available providers.
if providers is None:
providers = C.get_available_providers()
providers, provider_options = check_and_normalize_provider_args(
providers, provider_options, C.get_available_providers()
)
if isinstance(path_or_bytes, str):
config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options)
elif isinstance(path_or_bytes, bytes):
config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options)
else:
raise TypeError(f"Unable to load from type '{type(path_or_bytes)}'")
self.loss_scale_input_name = config_result.loss_scale_input_name
self._inputs_meta = self._sess.inputs_meta
self._outputs_meta = self._sess.outputs_meta
def __del__(self):
if self._sess:
self._sess.finalize()
def get_state(self):
return self._sess.get_state()
def get_model_state(self, include_mixed_precision_weights=False):
return self._sess.get_model_state(include_mixed_precision_weights)
def get_optimizer_state(self):
return self._sess.get_optimizer_state()
def get_partition_info_map(self):
return self._sess.get_partition_info_map()
def load_state(self, dict, strict=False):
self._sess.load_state(dict, strict)
def is_output_fp32_node(self, output_name):
return self._sess.is_output_fp32_node(output_name)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -18,7 +18,6 @@
#include "core/session/environment.h"
#include "core/session/custom_ops.h"
#include "core/dlpack/dlpack_converter.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/agent/training_agent.h"
#include "orttraining/core/graph/gradient_config.h"
#include "orttraining/core/graph/optimizer_config.h"
@ -113,14 +112,11 @@ struct TrainingParameters {
std::unordered_set<std::string> weights_to_train;
std::unordered_set<std::string> weights_not_to_train;
onnxruntime::training::TrainingSession::ImmutableWeights immutable_weights;
// optimizer
std::string training_optimizer_name;
std::string lr_params_feed_name = "Learning_Rate";
std::unordered_map<std::string, std::unordered_map<std::string, float>> optimizer_attributes_map;
std::unordered_map<std::string, std::unordered_map<std::string, int64_t>> optimizer_int_attributes_map;
onnxruntime::training::TrainingSession::OptimizerState optimizer_initial_state;
std::unordered_map<std::string, std::vector<int>> sliced_schema;
std::unordered_map<std::string, int> sliced_axes;
std::vector<std::string> sliced_tensor_names;
@ -206,185 +202,6 @@ struct PyGradientGraphBuilderContext {
local_registries_(local_registries) {}
};
// TODO: this method does not handle parallel optimization.
TrainingConfigurationResult ConfigureSessionForTraining(
training::PipelineTrainingSession* sess, TrainingParameters& parameters) {
// TODO tix, refactor the mpi related code to populate all fields correctly by default.
ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size, "data_parallel_size: ", parameters.data_parallel_size, ", world_size: ", parameters.world_size);
ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size, "horizontal_parallel_size: ", parameters.horizontal_parallel_size, ", world_size: ", parameters.world_size);
ORT_ENFORCE(parameters.pipeline_parallel_size <= parameters.world_size, "pipeline_parallel_size: ", parameters.pipeline_parallel_size, ", world_size: ", parameters.world_size);
// When DxHxP != the total number of ranks, we try adjusting D so that DxHxP == the total number of ranks.
if (parameters.world_size != parameters.data_parallel_size * parameters.horizontal_parallel_size * parameters.pipeline_parallel_size) {
ORT_ENFORCE(parameters.world_size % parameters.horizontal_parallel_size * parameters.pipeline_parallel_size == 0,
"D, H, P sizes are incorrect. To enable automatic correction, total number of ranks must be a divisible by HxP.");
const auto new_data_parallel_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size);
parameters.data_parallel_size = new_data_parallel_size;
const std::string msg = "Cannot distribute " + std::to_string(parameters.world_size) + " ranks for distributed computation with D=" + std::to_string(parameters.data_parallel_size) +
", H=" + std::to_string(parameters.horizontal_parallel_size) + ", P=" + std::to_string(parameters.pipeline_parallel_size) + ", so D is automatically changed to " + std::to_string(new_data_parallel_size);
LOGS(*(sess->GetLogger()), WARNING) << msg;
}
training::PipelineTrainingSession::TrainingConfiguration config{};
config.weight_names_to_train = parameters.weights_to_train;
config.weight_names_to_not_train = parameters.weights_not_to_train;
config.immutable_weights = parameters.immutable_weights;
config.gradient_accumulation_steps = parameters.gradient_accumulation_steps;
config.distributed_config.world_rank = parameters.world_rank;
config.distributed_config.world_size = parameters.world_size;
config.distributed_config.local_rank = parameters.local_rank;
config.distributed_config.local_size = parameters.local_size;
config.distributed_config.data_parallel_size = parameters.data_parallel_size;
config.distributed_config.horizontal_parallel_size = parameters.horizontal_parallel_size;
config.distributed_config.pipeline_parallel_size = parameters.pipeline_parallel_size;
config.distributed_config.num_pipeline_micro_batches = parameters.num_pipeline_micro_batches;
config.distributed_config.sliced_schema = parameters.sliced_schema;
config.distributed_config.sliced_axes = parameters.sliced_axes;
config.distributed_config.sliced_tensor_names = parameters.sliced_tensor_names;
if (parameters.use_mixed_precision) {
training::PipelineTrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mp{};
mp.use_mixed_precision_initializers = true;
config.mixed_precision_config = mp;
}
if (config.distributed_config.pipeline_parallel_size > 1) {
training::PipelineTrainingSession::TrainingConfiguration::PipelineConfiguration pipeline_config;
// Currently don't support auto-partition. User needs to pass in cut information for pipeline
pipeline_config.do_partition = true;
assert(!parameters.pipeline_cut_info_string.empty());
auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) {
std::vector<std::string> result;
size_t pos = 0;
while ((pos = input_str.find(delimiter)) != std::string::npos) {
std::string token = input_str.substr(0, pos);
result.emplace_back(token);
input_str.erase(0, pos + delimiter.length());
}
// push the last split of substring into result.
result.emplace_back(input_str);
return result;
};
auto process_cut_info = [&](std::string& cut_info_string) {
std::vector<PipelineTrainingSession::TrainingConfiguration::CutInfo> cut_list;
const std::string group_delimiter = ",";
const std::string edge_delimiter = ":";
const std::string consumer_delimiter = "/";
const std::string producer_consumer_delimiter = "-";
auto cut_info_groups = process_with_delimiter(cut_info_string, group_delimiter);
for (auto& cut_info_group : cut_info_groups) {
PipelineTrainingSession::TrainingConfiguration::CutInfo cut_info;
auto cut_edges = process_with_delimiter(cut_info_group, edge_delimiter);
for (auto& cut_edge : cut_edges) {
auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter);
if (process_edge.size() == 1) {
PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]};
cut_info.emplace_back(edge);
} else {
ORT_ENFORCE(process_edge.size() == 2);
auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter);
PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list};
cut_info.emplace_back(edge);
}
}
cut_list.emplace_back(cut_info);
}
return cut_list;
};
pipeline_config.cut_list = process_cut_info(parameters.pipeline_cut_info_string);
config.pipeline_config = pipeline_config;
}
config.loss_name = parameters.loss_output_name;
if (!parameters.training_optimizer_name.empty()) {
training::PipelineTrainingSession::TrainingConfiguration::OptimizerConfiguration opt{};
opt.name = parameters.training_optimizer_name;
opt.learning_rate_input_name = parameters.lr_params_feed_name;
opt.weight_attributes_generator = [&parameters](const std::string& weight_name) {
const auto it = parameters.optimizer_attributes_map.find(weight_name);
ORT_ENFORCE(
it != parameters.optimizer_attributes_map.end(),
"Failed to find attribute map for weight ", weight_name);
return it->second;
};
opt.weight_int_attributes_generator = [&parameters](const std::string& weight_name) {
const auto it = parameters.optimizer_int_attributes_map.find(weight_name);
ORT_ENFORCE(
it != parameters.optimizer_int_attributes_map.end(),
"Failed to find int attribute map for weight ", weight_name);
return it->second;
};
opt.use_mixed_precision_moments = parameters.use_fp16_moments;
opt.do_all_reduce_in_mixed_precision_type = true;
// TODO: this mapping is temporary.
// For now, nccl allreduce kernel only implements for allreduce_post_accumulation
// hovorod allreduce kernel only implements for not allreduce_post_accumulation.
// eventually we will have one all reduce kernel and let opt to have
// an allreduce_post_accumulation option and remove the use_nccl option.
opt.use_nccl = parameters.allreduce_post_accumulation;
opt.deepspeed_zero = onnxruntime::training::ZeROConfig(parameters.deepspeed_zero_stage);
opt.enable_grad_norm_clip = parameters.enable_grad_norm_clip;
// TODO reduction types
if (parameters.enable_adasum) {
#ifdef USE_CUDA
opt.adasum_reduction_type = training::AdasumReductionType::GpuHierarchicalReduction;
#else
opt.adasum_reduction_type = training::AdasumReductionType::CpuReduction;
#endif
}
config.optimizer_config = opt;
}
if (!parameters.optimizer_initial_state.empty()) {
config.init_optimizer_states = parameters.optimizer_initial_state;
}
config.gradient_graph_config.use_memory_efficient_gradient = parameters.use_memory_efficient_gradient;
config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs;
config.graph_transformer_config.attn_dropout_recompute = parameters.attn_dropout_recompute;
config.graph_transformer_config.gelu_recompute = parameters.gelu_recompute;
config.graph_transformer_config.transformer_layer_recompute = parameters.transformer_layer_recompute;
config.graph_transformer_config.number_recompute_layers = parameters.number_recompute_layers;
config.graph_transformer_config.propagate_cast_ops_config.strategy = parameters.propagate_cast_ops_strategy;
config.graph_transformer_config.propagate_cast_ops_config.level = parameters.propagate_cast_ops_level;
config.graph_transformer_config.propagate_cast_ops_config.allow = parameters.propagate_cast_ops_allow;
if (!parameters.model_after_graph_transforms_path.empty()) {
config.model_after_graph_transforms_path = ToPathString(parameters.model_after_graph_transforms_path);
}
if (!parameters.model_with_gradient_graph_path.empty()) {
config.model_with_gradient_graph_path = ToPathString(parameters.model_with_gradient_graph_path);
}
if (!parameters.model_with_training_graph_path.empty()) {
config.model_with_training_graph_path = ToPathString(parameters.model_with_training_graph_path);
}
training::PipelineTrainingSession::TrainingConfigurationResult config_result{};
OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result));
TrainingConfigurationResult python_config_result{};
if (config_result.mixed_precision_config_result.has_value()) {
const auto& mp_config_result = config_result.mixed_precision_config_result.value();
python_config_result.loss_scale_input_name = mp_config_result.loss_scale_input_name;
}
return python_config_result;
}
#if defined(USE_MPI)
void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) {
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank();
@ -424,7 +241,7 @@ std::unordered_map<std::string, std::unordered_map<std::string, py::object>> Con
return py_tensor_state;
}
void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) {
void addObjectMethodsForTraining(py::module& m) {
py::class_<OrtValueCache, OrtValueCachePtr>(m, "OrtValueCache")
.def(py::init<>())
.def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) {
@ -451,7 +268,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
parameters.def(py::init())
.def_readwrite("loss_output_name", &TrainingParameters::loss_output_name)
.def_readwrite("immutable_weights", &TrainingParameters::immutable_weights)
.def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train)
.def_readwrite("weights_to_train", &TrainingParameters::weights_to_train)
.def_readwrite("sliced_tensor_names", &TrainingParameters::sliced_tensor_names)
@ -484,25 +300,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
.def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size)
.def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size)
.def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size)
.def("set_optimizer_initial_state",
[](TrainingParameters& parameters, const std::unordered_map<std::string, std::unordered_map<std::string, py::object>>& py_state) -> void {
onnxruntime::training::TrainingSession::OptimizerState optim_state;
for (const auto& weight_it : py_state) {
auto state = weight_it.second;
NameMLValMap state_tensors;
for (auto& initializer : state) {
OrtValue ml_value;
// InputDeflist is null because parameters havent been tied to session yet
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
CreateGenericMLValue(nullptr, GetAllocator(), "", initializer.second, &ml_value, true);
ThrowIfPyErrOccured();
state_tensors.emplace(initializer.first, ml_value);
}
optim_state.emplace(weight_it.first, state_tensors);
}
parameters.optimizer_initial_state = optim_state;
})
.def_readwrite("model_after_graph_transforms_path", &TrainingParameters::model_after_graph_transforms_path)
.def_readwrite("model_with_gradient_graph_path", &TrainingParameters::model_with_gradient_graph_path)
.def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path)
@ -611,130 +408,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
});
#endif
py::class_<TrainingConfigurationResult> config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc");
config_result.def(py::init())
.def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object {
if (result.loss_scale_input_name.has_value()) {
return py::str{result.loss_scale_input_name.value()};
}
return py::none();
});
// Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user
struct PyTrainingSession : public PyInferenceSession {
PyTrainingSession(std::shared_ptr<Environment> env, const PySessionOptions& so)
: PyInferenceSession(env, std::make_unique<PipelineTrainingSession>(so.value, *env)) {
}
~PyTrainingSession() = default;
};
py::class_<PyTrainingSession, PyInferenceSession> training_session(m, "TrainingSession");
training_session
.def(py::init([](const PySessionOptions& so) {
auto& training_env = GetTrainingEnv();
return std::make_unique<PyTrainingSession>(training_env.GetORTEnv(), so);
}))
.def(py::init([]() {
auto& training_env = GetTrainingEnv();
return std::make_unique<PyTrainingSession>(training_env.GetORTEnv(), GetDefaultCPUSessionOptions());
}))
.def("finalize", [](py::object) {
#if defined(USE_MPI)
#ifdef _WIN32
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
// shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction
// call shutdown_mpi() here instead.
MPIContext::shutdown_mpi();
#endif
#endif
})
.def("load_model", [ep_registration_fn](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path));
#if defined(USE_MPI)
bool use_nccl = parameters.allreduce_post_accumulation;
if (!use_nccl && parameters.world_size > 1)
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
ProviderOptionsVector merged_options;
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
return config_result;
})
.def("read_bytes", [ep_registration_fn](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
std::istringstream buffer(serialized_model);
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer));
#if defined(USE_MPI)
bool use_nccl = parameters.allreduce_post_accumulation;
if (!use_nccl && parameters.world_size > 1)
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
ProviderOptionsVector merged_options;
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
return config_result;
})
.def("get_state", [](PyTrainingSession* sess) {
NameMLValMap state_tensors;
ORT_THROW_IF_ERROR(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->GetStateTensors(state_tensors));
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
// convert to numpy array
std::map<std::string, py::object> rmap;
for (auto& kv : state_tensors) {
if (kv.second.IsTensor()) {
py::object obj;
const Tensor& rtensor = kv.second.Get<Tensor>();
GetPyObjFromTensor(rtensor, obj, &data_transfer_manager);
rmap.insert({kv.first, obj});
} else {
throw std::runtime_error("Non tensor type in session state tensors is not expected.");
}
}
return rmap;
})
.def("get_model_state", [](PyTrainingSession* sess, bool include_mixed_precision_weights) {
std::unordered_map<std::string, NameMLValMap> model_state_tensors;
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetModelState(model_state_tensors, include_mixed_precision_weights));
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
return ConvertORTTensorMapToNumpy(model_state_tensors, data_transfer_manager);
})
.def("get_optimizer_state", [](PyTrainingSession* sess) {
std::unordered_map<std::string, NameMLValMap> opt_state_tensors;
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetOptimizerState(opt_state_tensors));
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
return ConvertORTTensorMapToNumpy(opt_state_tensors, data_transfer_manager);
})
.def("get_partition_info_map", [](PyTrainingSession* sess) {
std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> part_info_map;
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetPartitionInfoMap(part_info_map));
return part_info_map;
})
.def("load_state", [](PyTrainingSession* sess, std::unordered_map<std::string, py::object>& state, bool strict) {
NameMLValMap state_tensors;
for (auto initializer : state) {
OrtValue ml_value;
auto px = sess->GetSessionHandle()->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
CreateGenericMLValue(px.second, GetAllocator(), initializer.first, initializer.second, &ml_value);
ThrowIfPyErrOccured();
state_tensors.insert(std::make_pair(initializer.first, ml_value));
}
ORT_THROW_IF_ERROR(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->SetStateTensors(state_tensors, strict));
})
.def("is_output_fp32_node", [](PyTrainingSession* sess, const std::string& output_name) {
return static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name);
});
py::class_<PartialGraphExecutionState>(m, "PartialGraphExecutionState")
.def(py::init([]() {
return std::make_unique<PartialGraphExecutionState>();

Просмотреть файл

@ -40,7 +40,7 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM*
void addGlobalMethods(py::module& m);
void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn);
void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn);
void addObjectMethodsForTraining(py::module& m);
void addObjectMethodsForEager(py::module& m);
#ifdef ENABLE_LAZY_TENSOR
void addObjectMethodsForLazyTensor(py::module& m);
@ -339,7 +339,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
}
#endif
addObjectMethodsForTraining(m, ORTTrainingRegisterExecutionProviders);
addObjectMethodsForTraining(m);
#ifdef ENABLE_LAZY_TENSOR
addObjectMethodsForLazyTensor(m);

Просмотреть файл

@ -8,26 +8,16 @@ from onnxruntime.capi._pybind_state import (
TrainingParameters,
is_ortmodule_available,
)
from onnxruntime.capi.training.training_session import TrainingSession
# Options need to be imported before `ORTTrainer`.
from .orttrainer_options import ORTTrainerOptions
from .orttrainer import ORTTrainer, TrainStepInfo
from . import amp, artifacts, checkpoint, model_desc_validation, optim
from . import amp, artifacts, optim
__all__ = [
"PropagateCastOpsStrategy",
"TrainingParameters",
"is_ortmodule_available",
"TrainingSession",
"ORTTrainerOptions",
"ORTTrainer",
"TrainStepInfo",
"amp",
"artifacts",
"checkpoint",
"model_desc_validation",
"optim",
]

Просмотреть файл

@ -1,107 +0,0 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import pickle
from collections.abc import Mapping
import h5py
def _dfs_save(group, save_obj):
"""Recursively go over each level in the save_obj dictionary and save values to a hdf5 group"""
for key, value in save_obj.items():
if isinstance(value, Mapping):
subgroup = group.create_group(key)
_dfs_save(subgroup, value)
else:
group[key] = value
def save(save_obj: dict, path):
"""Persists the input dictionary to a file specified by path.
Saves an hdf5 representation of the save_obj dictionary to a file or a file-like object specified by path.
Values are saved in a format supported by h5py. For example, a PyTorch tensor is saved and loaded as a
numpy object. So, user types may be converted from their original types to numpy equivalent types.
Args:
save_obj: dictionary that needs to be saved.
save_obj should consist of types supported by hdf5 file format.
if hdf5 does not recognize a type, an exception is raised.
if save_obj is not a dictionary, a ValueError is raised.
path: string representation to a file path or a python file-like object.
if file already exists at path, an exception is raised.
"""
if not isinstance(save_obj, Mapping):
raise ValueError("Object to be saved must be a dictionary")
with h5py.File(path, "w-") as f:
_dfs_save(f, save_obj)
def _dfs_load(group, load_obj):
"""Recursively go over each level in the hdf5 group and load the values into the given dictionary"""
for key in group:
if isinstance(group[key], h5py.Group):
load_obj[key] = {}
_dfs_load(group[key], load_obj[key])
else:
load_obj[key] = group[key][()]
def load(path, key=None):
"""Loads the data stored in the binary file specified at the given path into a dictionary and returns it.
Loads the data from an hdf5 file specified at the given path into a python dictionary.
Loaded dictionary contains numpy equivalents of python data types. For example:
PyTorch tensor -> saved as a numpy array and loaded as a numpy array.
bool -> saved as a numpy bool and loaded as a numpy bool
If a '/' separated key is provided, the value at that hierarchical level in the hdf5 group is returned.
Args:
path: string representation to a file path or a python file-like object.
if file does not already exist at path, an exception is raised.
key: '/' separated representation of the hierarchy level value that needs to be returned/
for example, if the saved binary file has structure {a: {b: x, c:y}} and the user would like
to query the value for c, the key provided should be 'a/c'.
the default value of None for key implies that the entire hdf5 file structure needs to be loaded into a dictionary and returned.
Returns:
a dictionary loaded from the specified binary hdf5 file.
"""
if not h5py.is_hdf5(path):
raise ValueError(f"{path} is not an hdf5 file or a python file-like object.")
load_obj = {}
with h5py.File(path, "r") as f:
if key:
f = f[key] # noqa: PLW2901
if isinstance(f, h5py.Dataset):
return f[()]
_dfs_load(f, load_obj)
return load_obj
def to_serialized_hex(user_dict):
"""Serialize the user_dict and convert the serialized bytes to a hex string and return"""
return pickle.dumps(user_dict).hex()
def from_serialized_hex(serialized_hex):
"""Convert serialized_hex to bytes and deserialize it and return"""
# serialized_hex can be either a regular string or a byte string.
# if it is a byte string, convert to regular string using decode()
# if it is a regular string, do nothing to it
try: # noqa: SIM105
serialized_hex = serialized_hex.decode()
except AttributeError:
pass
return pickle.loads(bytes.fromhex(serialized_hex))

Просмотреть файл

@ -6,11 +6,9 @@
import importlib.util
import os
import sys
from functools import wraps # noqa: F401
import numpy as np
import torch
from onnx import TensorProto # noqa: F401
from packaging.version import Version
@ -23,16 +21,6 @@ def get_device_index(device):
return 0 if device.index is None else device.index
def get_device_index_from_input(input):
"""Returns device index from a input PyTorch Tensor"""
if isinstance(input, (list, tuple)):
device_index = get_device_index(input[0].device)
else:
device_index = get_device_index(input.device)
return device_index
def get_device_str(device):
if isinstance(device, str):
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
@ -50,24 +38,6 @@ def get_device_str(device):
return device
def get_all_gradients_finite_name_from_session(session):
"""Find all_gradients_finite node on Session graph and return its name"""
nodes = [x for x in session._outputs_meta if "all_gradients_finite" in x.name]
if len(nodes) != 1:
raise RuntimeError("'all_gradients_finite' node not found within training session")
return nodes[0].name
def get_gradient_accumulation_name_from_session(session):
"""Find Group_Accumulated_Gradients node on Session graph and return its name"""
nodes = [x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name]
if len(nodes) != 1:
raise RuntimeError("'Group_Accumulated_Gradients' node not found within training session")
return nodes[0].name
def dtype_torch_to_numpy(torch_dtype):
"""Converts PyTorch types to Numpy types
@ -232,111 +202,3 @@ def import_module_from_file(file_path, module_name=None):
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def state_dict_model_key():
"""Returns the model key name in the state dictionary"""
return "model"
def state_dict_optimizer_key():
"""Returns the optimizer key name in the state dictionary"""
return "optimizer"
def state_dict_partition_info_key():
"""Returns the partition info key name in the state dictionary"""
return "partition_info"
def state_dict_trainer_options_key():
"""Returns the trainer options key name in the state dictionary"""
return "trainer_options"
def state_dict_full_precision_key():
"""Returns the full precision key name in the state dictionary"""
return "full_precision"
def state_dict_original_dimension_key():
"""Returns the original dimension key name in the state dictionary"""
return "original_dim"
def state_dict_sharded_optimizer_keys():
"""Returns the optimizer key names that can be sharded in the state dictionary"""
return {"Moment_1", "Moment_2"}
def state_dict_user_dict_key():
"""Returns the user dict key name in the state dictionary"""
return "user_dict"
def state_dict_trainer_options_mixed_precision_key():
"""Returns the trainer options mixed precision key name in the state dictionary"""
return "mixed_precision"
def state_dict_trainer_options_zero_stage_key():
"""Returns the trainer options zero_stage key name in the state dictionary"""
return "zero_stage"
def state_dict_trainer_options_world_rank_key():
"""Returns the trainer options world_rank key name in the state dictionary"""
return "world_rank"
def state_dict_trainer_options_world_size_key():
"""Returns the trainer options world_size key name in the state dictionary"""
return "world_size"
def state_dict_trainer_options_data_parallel_size_key():
"""Returns the trainer options data_parallel_size key name in the state dictionary"""
return "data_parallel_size"
def state_dict_trainer_options_horizontal_parallel_size_key():
"""Returns the trainer options horizontal_parallel_size key name in the state dictionary"""
return "horizontal_parallel_size"
def state_dict_trainer_options_optimizer_name_key():
"""Returns the trainer options optimizer_name key name in the state dictionary"""
return "optimizer_name"
def state_dict_train_step_info_key():
"""Returns the train step info key name in the state dictionary"""
return "train_step_info"
def state_dict_train_step_info_optimization_step_key():
"""Returns the train step info optimization step key name in the state dictionary"""
return "optimization_step"
def state_dict_train_step_info_step_key():
"""Returns the train step info step key name in the state dictionary"""
return "step"

Просмотреть файл

@ -1,748 +0,0 @@
import os
import tempfile
import warnings
from enum import Enum
import numpy as np
import onnx
import torch
from . import _checkpoint_storage, _utils
################################################################################
# Experimental Checkpoint APIs
################################################################################
def experimental_state_dict(ort_trainer, include_optimizer_state=True):
warnings.warn(
"experimental_state_dict() will be deprecated soon. Please use ORTTrainer.state_dict() instead.",
DeprecationWarning,
)
if not ort_trainer._training_session:
warnings.warn(
"ONNX Runtime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling state_dict()."
)
return ort_trainer._state_dict
# extract trained weights
session_state = ort_trainer._training_session.get_state()
torch_state = {}
for name in session_state:
torch_state[name] = torch.from_numpy(session_state[name])
# extract untrained weights and buffer
for n in ort_trainer._onnx_model.graph.initializer:
if n.name not in torch_state and n.name in ort_trainer.options.utils.frozen_weights:
torch_state[n.name] = torch.from_numpy(np.array(onnx.numpy_helper.to_array(n)))
# Need to remove redundant (optimizer) initializers to map back to original torch state names
if not include_optimizer_state and ort_trainer._torch_state_dict_keys:
return {key: torch_state[key] for key in ort_trainer._torch_state_dict_keys if key in torch_state}
return torch_state
def experimental_load_state_dict(ort_trainer, state_dict, strict=False):
warnings.warn(
"experimental_load_state_dict() will be deprecated soon. Please use ORTTrainer.load_state_dict() instead.",
DeprecationWarning,
)
# Note: It may happen ONNX model has not yet been initialized
# In this case we cache a reference to desired state and delay the restore until after initialization
# Unexpected behavior will result if the user changes the reference before initialization
if not ort_trainer._training_session:
ort_trainer._state_dict = state_dict
ort_trainer._load_state_dict_strict = strict
return
# Update onnx model from loaded state dict
cur_initializers_names = [n.name for n in ort_trainer._onnx_model.graph.initializer]
new_initializers = {}
for name in state_dict:
if name in cur_initializers_names:
new_initializers[name] = state_dict[name].numpy()
elif strict:
raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.")
ort_trainer._update_onnx_model_initializers(new_initializers)
# create new session based on updated onnx model
ort_trainer._state_dict = None
ort_trainer._init_session()
# load training state
session_state = {name: state_dict[name].numpy() for name in state_dict}
ort_trainer._training_session.load_state(session_state, strict)
def experimental_save_checkpoint(
ort_trainer,
checkpoint_dir,
checkpoint_prefix="ORT_checkpoint",
checkpoint_state_dict=None,
include_optimizer_state=True,
):
warnings.warn(
"experimental_save_checkpoint() will be deprecated soon. Please use ORTTrainer.save_checkpoint() instead.",
DeprecationWarning,
)
if checkpoint_state_dict is None:
checkpoint_state_dict = {"model": experimental_state_dict(ort_trainer, include_optimizer_state)}
else:
checkpoint_state_dict.update({"model": experimental_state_dict(ort_trainer, include_optimizer_state)})
assert os.path.exists(checkpoint_dir), f"checkpoint_dir ({checkpoint_dir}) directory doesn't exist"
checkpoint_name = _get_checkpoint_name(
checkpoint_prefix,
ort_trainer.options.distributed.deepspeed_zero_optimization.stage,
ort_trainer.options.distributed.world_rank,
ort_trainer.options.distributed.world_size,
)
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
if os.path.exists(checkpoint_file):
msg = f"{checkpoint_file} already exists, overwriting."
warnings.warn(msg)
torch.save(checkpoint_state_dict, checkpoint_file)
def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False):
warnings.warn(
"experimental_load_checkpoint() will be deprecated soon. Please use ORTTrainer.load_checkpoint() instead.",
DeprecationWarning,
)
checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
is_partitioned = False
if len(checkpoint_files) > 1:
msg = (
f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}."
" Attempting to load ZeRO checkpoint."
)
warnings.warn(msg)
is_partitioned = True
if (not ort_trainer.options.distributed.deepspeed_zero_optimization.stage) and is_partitioned:
return _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict)
else:
return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict)
class _AGGREGATION_MODE(Enum): # noqa: N801
Zero = 0
Megatron = 1
def _order_paths(paths, D_groups, H_groups):
"""Reorders the given paths in order of aggregation of ranks for D and H parallellism respectively
and returns the ordered dict"""
trainer_options_path_tuples = []
world_rank = _utils.state_dict_trainer_options_world_rank_key()
for path in paths:
trainer_options_path_tuples.append(
(_checkpoint_storage.load(path, key=_utils.state_dict_trainer_options_key()), path)
)
# sort paths according to rank
sorted_paths = [
path
for _, path in sorted(
trainer_options_path_tuples, key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank]
)
]
ordered_paths = dict()
ordered_paths["D"] = [[sorted_paths[i] for i in D_groups[group_id]] for group_id in range(len(D_groups))]
ordered_paths["H"] = [[sorted_paths[i] for i in H_groups[group_id]] for group_id in range(len(H_groups))]
return ordered_paths
def _add_or_update_sharded_key(
state_key, state_value, state_sub_dict, model_state_key, state_partition_info, sharded_states_original_dims, mode
):
"""Add or update the record for the sharded state_key in the state_sub_dict"""
# record the original dimension for this state
original_dim = _utils.state_dict_original_dimension_key()
sharded_states_original_dims[model_state_key] = state_partition_info[original_dim]
axis = 0
if mode == _AGGREGATION_MODE.Megatron and state_partition_info["megatron_row_partition"] == 0:
axis = -1
if state_key in state_sub_dict:
# state_dict already contains a record for this state
# since this state is sharded, concatenate the state value to
# the record in the state_dict
state_sub_dict[state_key] = np.concatenate((state_sub_dict[state_key], state_value), axis)
else:
# create a new entry for this state in the state_dict
state_sub_dict[state_key] = state_value
def _add_or_validate_unsharded_key(state_key, state_value, state_sub_dict, mismatch_error_string):
"""Add or validate the record for the unsharded state_key in the state_sub_dict"""
if state_key in state_sub_dict:
# state_dict already contains a record for this unsharded state.
# assert that all values are the same for this previously loaded state
assert (state_sub_dict[state_key] == state_value).all(), mismatch_error_string
else:
# create a new entry for this state in the state_sub_dict
state_sub_dict[state_key] = state_value
def _aggregate_model_states(
rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled, mode=_AGGREGATION_MODE.Zero
):
"""Aggregates all model states from the rank_state_dict into state_dict"""
model = _utils.state_dict_model_key()
full_precision = _utils.state_dict_full_precision_key()
partition_info = _utils.state_dict_partition_info_key()
# if there are no model states in the rank_state_dict, no model aggregation is needed
if model not in rank_state_dict:
return
if model not in state_dict:
state_dict[model] = {}
if full_precision not in state_dict[model]:
state_dict[model][full_precision] = {}
# iterate over all model state keys
for model_state_key, model_state_value in rank_state_dict[model][full_precision].items():
# ZERO: full precision model states are sharded only when they exist in the partition_info subdict and mixed
# precision training was enabled. for full precision training, full precision model states are not sharded
# MEGATRON : full precision model states are sharded when they exist in the partition_info subdict
if (model_state_key in rank_state_dict[partition_info]) and (
mode == _AGGREGATION_MODE.Megatron or mixed_precision_enabled
):
# this model state is sharded
_add_or_update_sharded_key(
model_state_key,
model_state_value,
state_dict[model][full_precision],
model_state_key,
rank_state_dict[partition_info][model_state_key],
sharded_states_original_dims,
mode,
)
else:
# this model state is not sharded since a record for it does not exist in the partition_info subdict
_add_or_validate_unsharded_key(
model_state_key,
model_state_value,
state_dict[model][full_precision],
f"Value mismatch for model state {model_state_key}",
)
def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode=_AGGREGATION_MODE.Zero):
"""Aggregates all optimizer states from the rank_state_dict into state_dict"""
optimizer = _utils.state_dict_optimizer_key()
partition_info = _utils.state_dict_partition_info_key()
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()
# if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed
if optimizer not in rank_state_dict:
return
if optimizer not in state_dict:
state_dict[optimizer] = {}
# iterate over all optimizer state keys
for model_state_key, optimizer_dict in rank_state_dict[optimizer].items():
for optimizer_key, optimizer_value in optimizer_dict.items():
if model_state_key not in state_dict[optimizer]:
state_dict[optimizer][model_state_key] = {}
if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]:
# this optimizer state is sharded since a record exists in the partition_info subdict
_add_or_update_sharded_key(
optimizer_key,
optimizer_value,
state_dict[optimizer][model_state_key],
model_state_key,
rank_state_dict[partition_info][model_state_key],
sharded_states_original_dims,
mode,
)
else:
# this optimizer state is not sharded since a record for it does not exist in the partition_info subdict
# or this optimizer key is not one of the sharded optimizer keys
_add_or_validate_unsharded_key(
optimizer_key,
optimizer_value,
state_dict[optimizer][model_state_key],
f"Value mismatch for model state {model_state_key} and optimizer state {optimizer_key}",
)
def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled):
"""Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims"""
model = _utils.state_dict_model_key()
full_precision = _utils.state_dict_full_precision_key()
optimizer = _utils.state_dict_optimizer_key()
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()
for sharded_state_key, original_dim in sharded_states_original_dims.items():
# reshape model states to original_dim only when mixed precision is enabled
if mixed_precision_enabled and (model in state_dict):
state_dict[model][full_precision][sharded_state_key] = state_dict[model][full_precision][
sharded_state_key
].reshape(original_dim)
# reshape optimizer states to original_dim
if optimizer in state_dict:
for optimizer_key, optimizer_value in state_dict[optimizer][sharded_state_key].items():
if optimizer_key in sharded_optimizer_keys:
state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim)
def _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation):
"""Extracts trainer options from rank_state_dict and loads them accordingly on state_dict"""
trainer_options = _utils.state_dict_trainer_options_key()
state_dict[trainer_options] = {}
mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
world_rank = _utils.state_dict_trainer_options_world_rank_key()
world_size = _utils.state_dict_trainer_options_world_size_key()
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()
D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806
H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806
state_dict[trainer_options][mixed_precision] = rank_state_dict[trainer_options][mixed_precision]
state_dict[trainer_options][zero_stage] = 0
state_dict[trainer_options][world_rank] = rank_state_dict[trainer_options][world_rank] if partial_aggregation else 0
state_dict[trainer_options][world_size] = 1
state_dict[trainer_options][optimizer_name] = rank_state_dict[trainer_options][optimizer_name]
state_dict[trainer_options][D_size] = 1
state_dict[trainer_options][H_size] = 1
def _aggregate_megatron_partition_info(rank_state_dict, state_dict):
"""Extracts partition_info from rank_state_dict and loads on state_dict for megatron-partitioned weights"""
partition_info = _utils.state_dict_partition_info_key()
if partition_info not in state_dict:
state_dict[partition_info] = {}
rank_partition_info = rank_state_dict[partition_info]
for model_state_key, partition_info_dict in rank_partition_info.items():
if model_state_key not in state_dict[partition_info]:
# add partition info only if weight is megatron partitioned
if partition_info_dict["megatron_row_partition"] >= 0:
state_dict[partition_info][model_state_key] = partition_info_dict
def _to_pytorch_format(state_dict):
"""Convert ORT state dictionary schema (hierarchical structure) to PyTorch state dictionary schema (flat structure)"""
pytorch_state_dict = {}
for model_state_key, model_state_value in state_dict[_utils.state_dict_model_key()][
_utils.state_dict_full_precision_key()
].items():
# convert numpy array to a torch tensor
pytorch_state_dict[model_state_key] = torch.tensor(model_state_value)
return pytorch_state_dict
def _get_parallellism_groups(data_parallel_size, horizontal_parallel_size, world_size):
"""Returns the D and H groups for the given sizes"""
num_data_groups = world_size // data_parallel_size
data_groups = []
for data_group_id in range(num_data_groups):
data_group_ranks = []
for r in range(data_parallel_size):
data_group_ranks.append(data_group_id + horizontal_parallel_size * r)
data_groups.append(data_group_ranks)
num_horizontal_groups = world_size // horizontal_parallel_size
horizontal_groups = []
for hori_group_id in range(num_horizontal_groups):
hori_group_ranks = []
for r in range(horizontal_parallel_size):
hori_group_ranks.append(hori_group_id * horizontal_parallel_size + r)
horizontal_groups.append(hori_group_ranks)
return data_groups, horizontal_groups
def _aggregate_over_ranks(
ordered_paths,
ranks,
sharded_states_original_dims=None,
mode=_AGGREGATION_MODE.Zero,
partial_aggregation=False,
pytorch_format=True,
):
"""Aggregate checkpoint files over set of ranks and return a single state dictionary
Args:
ordered_paths: list of paths in the order in which they must be aggregated
ranks: list of ranks that are to be aggregated
sharded_states_original_dims: dict containing the original dims for sharded states that are persisted over
multiple calls to _aggregate_over_ranks()
mode: mode of aggregation: Zero or Megatron
partial_aggregation: boolean flag to indicate whether to produce a partially
aggregated state which can be further aggregated over
pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict
Returns:
state_dict that can be loaded into an ORTTrainer or into a PyTorch model
"""
state_dict = {}
if sharded_states_original_dims is None:
sharded_states_original_dims = dict()
world_rank = _utils.state_dict_trainer_options_world_rank_key()
mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
world_size = _utils.state_dict_trainer_options_world_size_key()
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()
loaded_mixed_precision = None
loaded_world_size = None
loaded_zero_stage = None
loaded_optimizer_name = None
for i, path in enumerate(ordered_paths):
rank_state_dict = _checkpoint_storage.load(path)
assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info"
assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options"
assert (
ranks[i] == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank]
), "Unexpected rank in file at path {}. Expected {}, got {}".format(
path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] # noqa: F821
)
if loaded_mixed_precision is None:
loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
else:
assert (
loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
), f"Mixed precision state mismatch among checkpoint files. File: {path}"
if loaded_world_size is None:
loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size]
else:
assert (
loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size]
), f"World size state mismatch among checkpoint files. File: {path}"
if loaded_zero_stage is None:
loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage]
else:
assert (
loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage]
), f"Zero stage mismatch among checkpoint files. File: {path}"
if loaded_optimizer_name is None:
loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]
else:
assert (
loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]
), f"Optimizer name mismatch among checkpoint files. File: {path}"
# aggregate all model states
_aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision, mode)
if not pytorch_format:
# aggregate all optimizer states if pytorch_format is False
_aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode)
# for D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups
# to aggregate over Zero, and another pass to aggregate Megatron partitioned
# states. Preserve the relevant partition info only for weights that are megatron partitioned for
# a partial aggregation call
if partial_aggregation:
_aggregate_megatron_partition_info(rank_state_dict, state_dict)
# entry for trainer_options in the state_dict to perform other sanity checks
if _utils.state_dict_trainer_options_key() not in state_dict:
_aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation)
# entry for user_dict in the state_dict if not already present
if (
_utils.state_dict_user_dict_key() not in state_dict
and _utils.state_dict_user_dict_key() in rank_state_dict
):
state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()]
# for a partial aggregation scenario, we might not have the entire tensor aggregated yet, thus skip reshape
if not partial_aggregation:
# reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims
_reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision)
# return a flat structure for PyTorch model in case pytorch_format is True
# else return the hierarchical structure for ORTTrainer
return _to_pytorch_format(state_dict) if pytorch_format else state_dict
def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format): # noqa: N802
"""Aggregate checkpoint files and return a single state dictionary for the D+H
(Zero+Megatron) partitioning strategy.
For D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups
to aggregate over Zero, and another pass over the previously aggregated states
to aggregate Megatron partitioned states.
"""
sharded_states_original_dims = {}
aggregate_data_checkpoint_files = []
# combine for Zero over data groups and save to temp file
with tempfile.TemporaryDirectory() as save_dir:
for group_id, d_group in enumerate(D_groups):
aggregate_state_dict = _aggregate_over_ranks(
ordered_paths["D"][group_id],
d_group,
sharded_states_original_dims,
partial_aggregation=True,
pytorch_format=False,
)
filename = "ort.data_group." + str(group_id) + ".ort.pt"
filepath = os.path.join(save_dir, filename)
_checkpoint_storage.save(aggregate_state_dict, filepath)
aggregate_data_checkpoint_files.append(filepath)
assert len(aggregate_data_checkpoint_files) > 0
# combine for megatron:
aggregate_state = _aggregate_over_ranks(
aggregate_data_checkpoint_files,
H_groups[0],
sharded_states_original_dims,
mode=_AGGREGATION_MODE.Megatron,
pytorch_format=pytorch_format,
)
return aggregate_state
def aggregate_checkpoints(paths, pytorch_format=True):
"""Aggregate checkpoint files and return a single state dictionary
Aggregates checkpoint files specified by paths and loads them one at a time, merging
them into a single state dictionary.
The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function.
The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict()
Args:
paths: list of more than one file represented as strings where the checkpoint is saved
pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict
Returns:
state_dict that can be loaded into an ORTTrainer or into a PyTorch model
"""
loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key())
D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806
H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806
world_size = _utils.state_dict_trainer_options_world_size_key()
D_size = loaded_trainer_options[D_size] # noqa: N806
H_size = loaded_trainer_options[H_size] # noqa: N806
world_size = loaded_trainer_options[world_size]
D_groups, H_groups = _get_parallellism_groups(D_size, H_size, world_size) # noqa: N806
combine_zero = loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0
combine_megatron = len(H_groups[0]) > 1
# order the paths in the order of groups in which they must be aggregated according to
# data-parallel groups and H-parallel groups obtained
# eg: {'D': [[path_0, path_2],[path_1, path_3]], 'H': [[path_0, path_1],[path_2, path_3]]}
ordered_paths = _order_paths(paths, D_groups, H_groups)
aggregate_state = None
if combine_zero and combine_megatron:
aggregate_state = _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format)
elif combine_zero:
aggregate_state = _aggregate_over_ranks(
ordered_paths["D"][0], D_groups[0], mode=_AGGREGATION_MODE.Zero, pytorch_format=pytorch_format
)
elif combine_megatron:
aggregate_state = _aggregate_over_ranks(
ordered_paths["H"][0], H_groups[0], mode=_AGGREGATION_MODE.Megatron, pytorch_format=pytorch_format
)
return aggregate_state
################################################################################
# Helper functions
################################################################################
def _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict):
checkpoint_name = _get_checkpoint_name(
checkpoint_prefix,
is_partitioned,
ort_trainer.options.distributed.world_rank,
ort_trainer.options.distributed.world_size,
)
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
if is_partitioned:
assert_msg = (
f"Couldn't find checkpoint file {checkpoint_file}."
" Optimizer partitioning is enabled using ZeRO. Please make sure the checkpoint file exists "
f"for rank {ort_trainer.options.distributed.world_rank} of {ort_trainer.options.distributed.world_size}"
)
else:
assert_msg = f"Couldn't find checkpoint file {checkpoint_file}."
assert os.path.exists(checkpoint_file), assert_msg
checkpoint_state = torch.load(checkpoint_file, map_location="cpu")
experimental_load_state_dict(ort_trainer, checkpoint_state["model"], strict=strict)
del checkpoint_state["model"]
return checkpoint_state
def _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict):
checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
ckpt_agg = _CombineZeroCheckpoint(checkpoint_files)
aggregate_state_dict = ckpt_agg.aggregate_checkpoints()
experimental_load_state_dict(ort_trainer, aggregate_state_dict, strict=strict)
# aggregate other keys in the state_dict.
# Values will be overwritten for matching keys among workers
all_checkpoint_states = dict()
for checkpoint_file in checkpoint_files:
checkpoint_state = torch.load(checkpoint_file, map_location="cpu")
del checkpoint_state["model"]
all_checkpoint_states.update(checkpoint_state)
return all_checkpoint_states
def _list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"):
ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)]
ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)]
ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names]
assert len(ckpt_file_names) > 0, f"No checkpoint found with prefix '{checkpoint_prefix}' at '{checkpoint_dir}'"
return ckpt_file_names
def _get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None):
SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806
MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806
if is_partitioned:
filename = MULTIPLE_CHECKPOINT_FILENAME.format(
prefix=prefix, world_rank=world_rank, world_size=(world_size - 1)
)
else:
filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix)
return filename
def _split_state_dict(state_dict):
optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"]
split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}}
for k, v in state_dict.items():
mode = "fp32_param"
for optim_key in optimizer_keys:
if k.startswith(optim_key):
mode = "optimizer"
break
if k.endswith("_fp16"):
mode = "fp16_param"
split_sd[mode][k] = v
return split_sd
class _CombineZeroCheckpoint:
def __init__(self, checkpoint_files, clean_state_dict=None):
assert len(checkpoint_files) > 0, "No checkpoint files passed"
self.checkpoint_files = checkpoint_files
self.clean_state_dict = clean_state_dict
self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1
assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files"
self.weight_shape_map = {}
self.sharded_params = set()
def _split_name(self, name: str):
name_split = name.split("_view_")
view_num = None
if len(name_split) > 1:
view_num = int(name_split[1])
optimizer_key = ""
mp_suffix = ""
if name_split[0].startswith("Moment_1"):
optimizer_key = "Moment_1_"
elif name_split[0].startswith("Moment_2"):
optimizer_key = "Moment_2_"
elif name_split[0].startswith("Update_Count"):
optimizer_key = "Update_Count_"
elif name_split[0].endswith("_fp16"):
mp_suffix = "_fp16"
param_name = name_split[0]
if optimizer_key:
param_name = param_name.split(optimizer_key)[1]
param_name = param_name.split("_fp16")[0]
return param_name, optimizer_key, view_num, mp_suffix
def _update_weight_statistics(self, name, value):
if name not in self.weight_shape_map:
self.weight_shape_map[name] = value.size() # original shape of tensor
def _reshape_tensor(self, key):
value = self.aggregate_state_dict[key]
weight_name, _, _, _ = self._split_name(key)
set_size = self.weight_shape_map[weight_name]
self.aggregate_state_dict[key] = value.reshape(set_size)
def _aggregate(self, param_dict):
for k, v in param_dict.items():
weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k)
if view_num is not None:
# parameter is sharded
param_name = optimizer_key + weight_name + mp_suffix
if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]:
self.sharded_params.add(param_name)
# Found a previous shard of the param, concatenate shards ordered by ranks
self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v))
else:
self.aggregate_state_dict[param_name] = v
else:
if k in self.aggregate_state_dict:
assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value"
else:
self.aggregate_state_dict[k] = v
self._update_weight_statistics(weight_name, v)
def aggregate_checkpoints(self):
warnings.warn(
"_CombineZeroCheckpoint.aggregate_checkpoints() will be deprecated soon. "
"Please use aggregate_checkpoints() instead.",
DeprecationWarning,
)
checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0]
self.aggregate_state_dict = dict()
for i in range(self.world_size):
checkpoint_name = _get_checkpoint_name(checkpoint_prefix, True, i, self.world_size)
rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu"))
if "model" in rank_state_dict:
rank_state_dict = rank_state_dict["model"]
if self.clean_state_dict:
rank_state_dict = self.clean_state_dict(rank_state_dict)
rank_state_dict = _split_state_dict(rank_state_dict)
self._aggregate(rank_state_dict["fp16_param"])
self._aggregate(rank_state_dict["fp32_param"])
self._aggregate(rank_state_dict["optimizer"])
for k in self.sharded_params:
self._reshape_tensor(k)
return self.aggregate_state_dict

Просмотреть файл

@ -1,408 +0,0 @@
from collections import namedtuple
import cerberus
import torch
from ._utils import static_vars
LEARNING_RATE_IO_DESCRIPTION_NAME = "__learning_rate"
ALL_FINITE_IO_DESCRIPTION_NAME = "__all_finite"
LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME = "__loss_scale_input_name"
GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME = "__gradient_accumulation_name"
class _ORTTrainerModelDesc:
def __init__(self, model_desc):
# Keep a copy of original input for debug
self._original = dict(model_desc)
# Global counter used to validate occurrences of 'is_loss=True' whithin 'model_desc.outputs'
# A stateless validator is used for each tuple, but validation accross the whole list of tuple is needed
# because just one 'is_loss=True' is allowed withing 'model_desc.outputs' list of tuples
_model_desc_outputs_validation.loss_counter = 0
# Used for logging purposes
self._main_class_name = self.__class__.__name__
# Validates user input
self._validated = dict(self._original)
validator = cerberus.Validator(MODEL_DESC_SCHEMA)
self._validated = validator.validated(self._validated)
if self._validated is None:
raise ValueError(f"Invalid model_desc: {validator.errors}")
# Normalize inputs to a list of namedtuple(name, shape)
self._InputDescription = namedtuple("InputDescription", ["name", "shape"])
self._InputDescriptionTyped = namedtuple("InputDescriptionTyped", ["name", "shape", "dtype"])
for idx, input in enumerate(self._validated["inputs"]):
self._validated["inputs"][idx] = self._InputDescription(*input)
# Normalize outputs to a list of namedtuple(name, shape, is_loss)
self._OutputDescription = namedtuple("OutputDescription", ["name", "shape", "is_loss"])
self._OutputDescriptionTyped = namedtuple(
"OutputDescriptionTyped", ["name", "shape", "is_loss", "dtype", "dtype_amp"]
)
for idx, output in enumerate(self._validated["outputs"]):
if len(output) == 2:
self._validated["outputs"][idx] = self._OutputDescription(*output, False)
else:
self._validated["outputs"][idx] = self._OutputDescription(*output)
# Hard-code learning rate, all_finite descriptors
self.learning_rate = self._InputDescriptionTyped(LEARNING_RATE_IO_DESCRIPTION_NAME, [1], torch.float32)
# Convert dict in object
for k, v in self._validated.items():
setattr(self, k, self._wrap(v))
def __repr__(self):
"""Pretty representation for a model description class"""
pretty_msg = "Model description:\n"
# Inputs
inputs = []
for i_desc in self.inputs:
if isinstance(i_desc, self._InputDescription):
inputs.append(f"(name={i_desc.name}, shape={i_desc.shape})")
elif isinstance(i_desc, self._InputDescriptionTyped):
inputs.append(f"(name={i_desc.name}, shape={i_desc.shape}, dtype={i_desc.dtype})")
else:
raise ValueError(f"Unexpected type {type(i_desc)} for input description")
pretty_msg += "\nInputs:"
for idx, item in enumerate(inputs):
pretty_msg += f"\n\t{idx}: {item}"
# Outputs
outputs = []
for o_desc in self.outputs:
if isinstance(o_desc, self._OutputDescription):
outputs.append(f"(name={o_desc.name}, shape={o_desc.shape})")
elif isinstance(o_desc, self._OutputDescriptionTyped):
outputs.append(
f"(name={o_desc.name}, shape={o_desc.shape}, dtype={o_desc.dtype}, dtype_amp={o_desc.dtype_amp})"
)
else:
raise ValueError(f"Unexpected type {type(o_desc)} for output description")
pretty_msg += "\nOutputs:"
for idx, item in enumerate(outputs):
pretty_msg += f"\n\t{idx}: {item}"
# Learning rate
if self.learning_rate:
pretty_msg += "\nLearning rate: "
pretty_msg += (
f"(name={self.learning_rate.name}, shape={self.learning_rate.shape}, dtype={self.learning_rate.dtype})"
)
# Mixed precision
if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) or getattr(
self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None
):
pretty_msg += "\nMixed Precision:"
if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None):
pretty_msg += "\n\tis gradients finite: "
pretty_msg += (
f"(name={self.all_finite.name}, shape={self.all_finite.shape}, dtype={self.all_finite.dtype})"
)
if getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None):
pretty_msg += "\n\tloss scale input name: "
pretty_msg += f"(name={self.loss_scale_input.name}, shape={self.loss_scale_input.shape}, dtype={self.loss_scale_input.dtype})"
# Gradient Accumulation steps
if self.gradient_accumulation:
pretty_msg += "\nGradient Accumulation: "
pretty_msg += f"(name={self.gradient_accumulation.name}, shape={self.gradient_accumulation.shape}, dtype={self.gradient_accumulation.dtype})"
return pretty_msg
def add_type_to_input_description(self, index, dtype):
"""Updates an existing input description at position 'index' with 'dtype' type information
Args:
index (int): position within 'inputs' description
dtype (torch.dtype): input data type
"""
assert isinstance(index, int) and index >= 0, "input 'index' must be a positive int"
assert isinstance(dtype, torch.dtype), "input 'dtype' must be a torch.dtype type"
existing_values = (*self.inputs[index],)
if isinstance(self.inputs[index], self._InputDescriptionTyped):
existing_values = (*existing_values[:-1],)
self.inputs[index] = self._InputDescriptionTyped(*existing_values, dtype)
def add_type_to_output_description(self, index, dtype, dtype_amp=None):
"""Updates an existing output description at position 'index' with 'dtype' type information
Args:
index (int): position within 'inputs' description
dtype (torch.dtype): input data type
dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision
"""
assert isinstance(index, int) and index >= 0, "output 'index' must be a positive int"
assert isinstance(dtype, torch.dtype), "output 'dtype' must be a torch.dtype type"
assert dtype_amp is None or isinstance(
dtype_amp, torch.dtype
), "output 'dtype_amp' must be either None or torch.dtype type"
existing_values = (*self.outputs[index],)
if isinstance(self.outputs[index], self._OutputDescriptionTyped):
existing_values = (*existing_values[:-2],)
self.outputs[index] = self._OutputDescriptionTyped(*existing_values, dtype, dtype_amp)
@property
def gradient_accumulation(self):
return getattr(self, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, None)
@gradient_accumulation.setter
def gradient_accumulation(self, name):
self._add_output_description(
self, name, [1], False, torch.bool, None, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, ignore_duplicate=True
)
@property
def all_finite(self):
return getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None)
@all_finite.setter
def all_finite(self, name):
self._add_output_description(
self, name, [1], False, torch.bool, None, ALL_FINITE_IO_DESCRIPTION_NAME, ignore_duplicate=True
)
@property
def loss_scale_input(self):
return getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None)
@loss_scale_input.setter
def loss_scale_input(self, name):
self._add_input_description(
self, name, [], torch.float32, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, ignore_duplicate=True
)
def _add_input_description(self, node, name, shape, dtype=None, attr_name=None, ignore_duplicate=False):
"""Add a new input description into the node object
If 'dtype' is specified, a typed input description namedtuple(name, shape, dtype) is created.
Otherwise an untyped input description namedtuple(name, shape) is created instead.
Args:
node (list or object): node to append input description to. When 'node' is 'self.inputs',
a new input description is appended to the list.
Otherwise, a new input description is created as an attribute into 'node' with name 'attr_name'
name (str): name of input description
shape (list): shape of input description
dtype (torch.dtype): input data type
attr_name (str, default is None): friendly name to allow direct access to the output description
ignore_duplicate (bool, default is False): silently skips addition of duplicate inputs
"""
assert isinstance(name, str) and len(name) > 0, "'name' is an invalid input name"
not_found = True
if not ignore_duplicate:
if id(node) == id(self.inputs):
not_found = all([name not in i_desc.name for i_desc in node])
assert not_found, f"'name' {name} already exists in the inputs description"
else:
not_found = attr_name not in dir(self)
assert not_found, f"'attr_name' {attr_name} already exists in the 'node'"
elif not not_found:
return
assert isinstance(shape, list) and all(
[(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape]
), "'shape' must be a list of int or str with length at least 1"
assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type"
if dtype:
new_input_desc = self._InputDescriptionTyped(name, shape, dtype)
else:
new_input_desc = self._InputDescription(name, shape)
if id(node) == id(self.inputs):
self.inputs.append(new_input_desc)
else:
assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'"
setattr(node, attr_name, new_input_desc)
def _add_output_description(
self, node, name, shape, is_loss, dtype=None, dtype_amp=None, attr_name=None, ignore_duplicate=False
):
"""Add a new output description into the node object as a tuple
When (name, shape, is_loss, dtype) is specified, a typed output description is created
Otherwise an untyped output description (name, shape, is_loss) is created instead
Args:
node (list or object): node to append output description to. When 'node' is 'self.outputs',
a new output description is appended to the list.
Otherwise, a new output description is created as an attribute into 'node' with name 'attr_name'
name (str): name of output description
shape (list): shape of output description
is_loss (bool): specifies whether this output is a loss
dtype (torch.dtype): input data type
dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision.
attr_name (str, default is None): friendly name to allow direct access to the output description
ignore_duplicate (bool, default is False): silently skips addition of duplicate outputs
"""
assert isinstance(name, str) and len(name) > 0, "'name' is an invalid output name"
assert isinstance(shape, list) and all(
[(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape]
), "'shape' must be a list of int or str with length at least 1"
assert isinstance(is_loss, bool), "'is_loss' must be a bool"
not_found = True
if not ignore_duplicate:
if id(node) == id(self.outputs):
not_found = all([name not in o_desc.name for o_desc in node])
assert not_found, f"'name' {name} already exists in the outputs description"
assert (
all([not o_desc.is_loss for o_desc in node]) if is_loss else True
), "Only one 'is_loss' is supported at outputs description"
else:
not_found = attr_name not in dir(self)
assert not_found, f"'attr_name' {attr_name} already exists in the 'node'"
elif not not_found:
return
assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type"
if dtype:
new_output_desc = self._OutputDescriptionTyped(name, shape, is_loss, dtype, None)
else:
new_output_desc = self._OutputDescription(name, shape, is_loss)
if id(node) == id(self.outputs):
self.outputs.append(new_output_desc)
else:
assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'"
setattr(node, attr_name, new_output_desc)
def _wrap(self, v):
"""Add 'v' as self's attribute to allow direct access as self.v"""
if isinstance(v, (list)):
return type(v)([self._wrap(v) for v in v])
elif isinstance(
v,
(
self._InputDescription,
self._InputDescriptionTyped,
self._OutputDescription,
self._OutputDescriptionTyped,
),
):
return v
elif isinstance(v, (tuple)):
return type(v)([self._wrap(v) for v in v])
elif isinstance(v, (dict, int, float, bool, str)):
return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v
else:
raise ValueError(
f"Unsupported type for model_desc ({v})."
"Only int, float, bool, str, list, tuple and dict are supported"
)
class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc):
r"""Internal class used by ONNX Runtime training backend for input validation
NOTE: Users MUST NOT use this class in any way!
"""
def __init__(self, main_class_name, model_desc):
# Used for logging purposes
self._main_class_name = main_class_name
# Convert dict in object
for k, v in dict(model_desc).items():
setattr(self, k, self._wrap(v))
def _model_desc_inputs_validation(field, value, error):
r"""Cerberus custom check method for 'model_desc.inputs'
'model_desc.inputs' is a list of tuples.
The list has variable length, but each tuple has size 2
The first element of the tuple is a string which represents the input name
The second element is a list of shapes. Each shape must be either an int or string.
Empty list represents a scalar output
Validation is done within each tuple to enforce the schema described above.
Example:
.. code-block:: python
model_desc['inputs'] = [('input1', ['batch', 1024]),
('input2', [])
('input3', [512])]
"""
if not isinstance(value, tuple) or len(value) != 2:
error(field, "must be a tuple with size 2")
if not isinstance(value[0], str):
error(field, "the first element of the tuple (aka name) must be a string")
if not isinstance(value[1], list):
error(field, "the second element of the tuple (aka shape) must be a list")
else:
for shape in value[1]:
if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool):
error(field, "each shape must be either a string or integer")
@static_vars(loss_counter=0)
def _model_desc_outputs_validation(field, value, error):
r"""Cerberus custom check method for 'model_desc.outputs'
'model_desc.outputs' is a list of tuples with variable length.
The first element of the tuple is a string which represents the output name
The second element is a list of shapes. Each shape must be either an int or string.
Empty list represents a scalar output
The third element is optional and is a flag that signals whether the output is a loss value
Validation is done within each tuple to enforce the schema described above, but also
throughout the list of tuples to ensure a single 'is_loss=True' occurrence.
Example:
.. code-block:: python
model_desc['outputs'] = [('output1', ['batch', 1024], is_loss=True),
('output2', [], is_loss=False)
('output3', [512])]
"""
if not isinstance(value, tuple) or len(value) < 2 or len(value) > 3:
error(field, "must be a tuple with size 2 or 3")
if len(value) == 3 and not isinstance(value[2], bool):
error(field, "the third element of the tuple (aka is_loss) must be a boolean")
elif len(value) == 3:
if value[2]:
_model_desc_outputs_validation.loss_counter += 1
if _model_desc_outputs_validation.loss_counter > 1:
error(field, "only one is_loss can bet set to True")
if not isinstance(value[0], str):
error(field, "the first element of the tuple (aka name) must be a string")
if not isinstance(value[1], list):
error(field, "the second element of the tuple (aka shape) must be a list")
else:
for shape in value[1]:
if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool):
error(field, "each shape must be either a string or integer")
# Validation schema for model description dictionary
MODEL_DESC_SCHEMA = {
"inputs": {
"type": "list",
"required": True,
"minlength": 1,
"schema": {"check_with": _model_desc_inputs_validation},
},
"outputs": {
"type": "list",
"required": True,
"minlength": 1,
"schema": {"check_with": _model_desc_outputs_validation},
},
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,692 +0,0 @@
import cerberus
import onnxruntime as ort
from onnxruntime.capi._pybind_state import PropagateCastOpsStrategy
from .amp import loss_scaler
from .optim import lr_scheduler
class ORTTrainerOptions:
r"""Settings used by ONNX Runtime training backend
The parameters are hierarchically organized to facilitate configuration through semantic groups
that encompasses features, such as distributed training, etc.
Input validation is performed on the input dict during instantiation to ensure
that supported parameters and values are passed in. Invalid input results
in :py:obj:`ValueError` exception with details on it.
Args:
options (dict): contains all training options
_validate (bool, default is True): for internal use only
Supported schema for kwargs:
.. code-block:: python
schema = {
'batch' : {
'type' : 'dict',
'required': False,
'default' : {},
'schema' : {
'gradient_accumulation_steps' : {
'type' : 'integer',
'min' : 1,
'default' : 1
}
},
},
'device' : {
'type' : 'dict',
'required': False,
'default' : {},
'schema' : {
'id' : {
'type' : 'string',
'default' : 'cuda'
},
'mem_limit' : {
'type' : 'integer',
'min' : 0,
'default' : 0
}
}
},
'distributed': {
'type': 'dict',
'default': {},
'required': False,
'schema': {
'world_rank': {
'type': 'integer',
'min': 0,
'default': 0
},
'world_size': {
'type': 'integer',
'min': 1,
'default': 1
},
'local_rank': {
'type': 'integer',
'min': 0,
'default': 0
},
'data_parallel_size': {
'type': 'integer',
'min': 1,
'default': 1
},
'horizontal_parallel_size': {
'type': 'integer',
'min': 1,
'default': 1
},
'pipeline_parallel' : {
'type': 'dict',
'default': {},
'required': False,
'schema': {
'pipeline_parallel_size': {
'type': 'integer',
'min': 1,
'default': 1
},
'num_pipeline_micro_batches': {
'type': 'integer',
'min': 1,
'default': 1
},
'pipeline_cut_info_string': {
'type': 'string',
'default': ''
},
'sliced_schema': {
'type': 'dict',
'default': {},
'keysrules': {'type': 'string'},
'valuesrules': {
'type': 'list',
'schema': {'type': 'integer'}
}
},
'sliced_axes': {
'type': 'dict',
'default': {},
'keysrules': {'type': 'string'},
'valuesrules': {'type': 'integer'}
},
'sliced_tensor_names': {
'type': 'list',
'schema': {'type': 'string'},
'default': []
}
}
},
'allreduce_post_accumulation': {
'type': 'boolean',
'default': False
},
'deepspeed_zero_optimization': {
'type': 'dict',
'default': {},
'required': False,
'schema': {
'stage': {
'type': 'integer',
'min': 0,
'max': 1,
'default': 0
},
}
},
'enable_adasum': {
'type': 'boolean',
'default': False
}
}
},
'lr_scheduler' : {
'type' : 'optim.lr_scheduler',
'nullable' : True,
'default' : None
},
'mixed_precision' : {
'type' : 'dict',
'required': False,
'default' : {},
'schema' : {
'enabled' : {
'type' : 'boolean',
'default' : False
},
'loss_scaler' : {
'type' : 'amp.loss_scaler',
'nullable' : True,
'default' : None
}
}
},
'graph_transformer': {
'type': 'dict',
'required': False,
'default': {},
'schema': {
'attn_dropout_recompute': {
'type': 'boolean',
'default': False
},
'gelu_recompute': {
'type': 'boolean',
'default': False
},
'transformer_layer_recompute': {
'type': 'boolean',
'default': False
},
'number_recompute_layers': {
'type': 'integer',
'min': 0,
'default': 0
},
'propagate_cast_ops_config': {
'type': 'dict',
'required': False,
'default': {},
'schema': {
'propagate_cast_ops_strategy': {
'type': 'onnxruntime.training.PropagateCastOpsStrategy',
'default': PropagateCastOpsStrategy.FLOOD_FILL
},
'propagate_cast_ops_level': {
'type': 'integer',
'default': 1
},
'propagate_cast_ops_allow': {
'type': 'list',
'schema': {'type': 'string'},
'default': []
}
}
}
}
},
'utils' : {
'type' : 'dict',
'required': False,
'default' : {},
'schema' : {
'frozen_weights' : {
'type' : 'list',
'default' : []
},
'grad_norm_clip' : {
'type' : 'boolean',
'default' : True
},
'memory_efficient_gradient' : {
'type' : 'boolean',
'default' : False
},
'run_symbolic_shape_infer' : {
'type' : 'boolean',
'default' : False
}
}
},
'debug' : {
'type' : 'dict',
'required': False,
'default' : {},
'schema' : {
'deterministic_compute' : {
'type' : 'boolean',
'default' : False
},
'check_model_export' : {
'type' : 'boolean',
'default' : False
},
'graph_save_paths' : {
'type' : 'dict',
'default': {},
'required': False,
'schema': {
'model_after_graph_transforms_path': {
'type': 'string',
'default': ''
},
'model_with_gradient_graph_path':{
'type': 'string',
'default': ''
},
'model_with_training_graph_path': {
'type': 'string',
'default': ''
},
'model_with_training_graph_after_optimization_path': {
'type': 'string',
'default': ''
},
}
},
}
},
'_internal_use' : {
'type' : 'dict',
'required': False,
'default' : {},
'schema' : {
'enable_internal_postprocess' : {
'type' : 'boolean',
'default' : True
},
'extra_postprocess' : {
'type' : 'callable',
'nullable' : True,
'default' : None
},
'onnx_opset_version': {
'type': 'integer',
'min' : 12,
'max' :14,
'default': 14
},
'enable_onnx_contrib_ops' : {
'type' : 'boolean',
'default' : True
}
}
},
'provider_options':{
'type': 'dict',
'default': {},
'required': False,
'schema': {}
},
'session_options': {
'type': 'SessionOptions',
'nullable': True,
'default': None
},
}
Keyword arguments:
batch (dict):
batch related settings
batch.gradient_accumulation_steps (int, default is 1):
number of steps to accumulate before do collective gradient reduction
device (dict):
compute device related settings
device.id (string, default is 'cuda'):
device to run training
device.mem_limit (int):
maximum memory size (in bytes) used by device.id
distributed (dict):
distributed training options.
distributed.world_rank (int, default is 0):
rank ID used for data/horizontal parallelism
distributed.world_size (int, default is 1):
number of ranks participating in parallelism
distributed.data_parallel_size (int, default is 1):
number of ranks participating in data parallelism
distributed.horizontal_parallel_size (int, default is 1):
number of ranks participating in horizontal parallelism
distributed.pipeline_parallel (dict):
Options which are only useful to pipeline parallel.
distributed.pipeline_parallel.pipeline_parallel_size (int, default is 1):
number of ranks participating in pipeline parallelism
distributed.pipeline_parallel.num_pipeline_micro_batches (int, default is 1):
number of micro-batches. We divide input batch into micro-batches and run the graph.
distributed.pipeline_parallel.pipeline_cut_info_string (string, default is ''):
string of cutting ids for pipeline partition.
distributed.allreduce_post_accumulation (bool, default is False):
True enables overlap of AllReduce with computation, while False,
postpone AllReduce until all gradients are ready
distributed.deepspeed_zero_optimization:
DeepSpeed ZeRO options.
distributed.deepspeed_zero_optimization.stage (int, default is 0):
select which stage of DeepSpeed ZeRO to use. Stage 0 means disabled.
distributed.enable_adasum (bool, default is False):
enable `Adasum <https://arxiv.org/abs/2006.02924>`_
algorithm for AllReduce
lr_scheduler (optim._LRScheduler, default is None):
specifies learning rate scheduler
mixed_precision (dict):
mixed precision training options
mixed_precision.enabled (bool, default is False):
enable mixed precision (fp16)
mixed_precision.loss_scaler (amp.LossScaler, default is None):
specifies a loss scaler to be used for fp16. If not specified,
:py:class:`.DynamicLossScaler` is used with default values.
Users can also instantiate :py:class:`.DynamicLossScaler` and
override its parameters. Lastly, a completely new implementation
can be specified by extending :py:class:`.LossScaler` class from scratch
graph_transformer (dict):
graph transformer related configurations
graph_transformer.attn_dropout_recompute(bool, default False)
graph_transformer.gelu_recompute(bool, default False)
graph_transformer.transformer_layer_recompute(bool, default False)
graph_transformer.number_recompute_layers(bool, default False)
graph_transformer.propagate_cast_ops_config (dict):
graph_transformer.propagate_cast_ops_config.strategy(PropagateCastOpsStrategy, default FLOOD_FILL)
Specify the choice of the cast propagation optimization strategy, either, NONE, INSERT_AND_REDUCE or FLOOD_FILL.
NONE strategy does not perform any cast propagation transformation on the graph, although other optimizations
locally change cast operations, for example, in order to fuse Transpose and MatMul nodes, the TransposeMatMulFunsion optimization could
interchange Transpose and Cast if the Cast node exists between Transpose and MatMul.
INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes.
FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike
INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region.
graph_transformer.propagate_cast_ops_config.level(integer, default 1)
Optimize by moving Cast operations if propagate_cast_ops_level is non-negative.
Use predetermined list of opcodes considered safe to move before/after cast operation
if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise.
graph_transformer.propagate_cast_ops_config.allow(list of str, [])
List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
attn_dropout_recompute (bool, default is False):
enable recomputing attention dropout to save memory
gelu_recompute (bool, default is False):
enable recomputing Gelu activation output to save memory
transformer_layer_recompute (bool, default is False):
enable recomputing transformer layerwise to save memory
number_recompute_layers (int, default is 0)
number of layers to apply transformer_layer_recompute, by default system will
apply recompute to all the layers, except for the last one
utils (dict):
miscellaneous options
utils.frozen_weights (list of str, []):
list of model parameter names to skip training (weights don't change)
utils.grad_norm_clip (bool, default is True):
enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer'
utils.memory_efficient_gradient (bool, default is False):
enables use of memory aware gradient builder.
utils.run_symbolic_shape_infer (bool, default is False):
runs symbolic shape inference on the model
debug (dict):
debug options
debug.deterministic_compute (bool, default is False)
forces compute to be deterministic accross runs
debug.check_model_export (bool, default is False)
compares PyTorch model outputs with ONNX model outputs in inference before the first
train step to ensure successful model export
debug.graph_save_paths (dict):
paths used for dumping ONNX graphs for debugging purposes
debug.graph_save_paths.model_after_graph_transforms_path (str, default is "")
path to export the ONNX graph after training-related graph transforms have been applied.
No output when it is empty.
debug.graph_save_paths.model_with_gradient_graph_path (str, default is "")
path to export the ONNX graph with the gradient graph added. No output when it is empty.
debug.graph_save_paths.model_with_training_graph_path (str, default is "")
path to export the training ONNX graph with forward, gradient and optimizer nodes.
No output when it is empty.
debug.graph_save_paths.model_with_training_graph_after_optimization_path (str, default is "")
outputs the optimized training graph to the path if nonempty.
_internal_use (dict):
internal options, possibly undocumented, that might be removed without notice
_internal_use.enable_internal_postprocess (bool, default is True):
enable internal internal post processing of the ONNX model
_internal_use.extra_postprocess (callable, default is None)
a functor to postprocess the ONNX model and return a new ONNX model.
It does not override :py:attr:`._internal_use.enable_internal_postprocess`, but complement it
_internal_use.onnx_opset_version (int, default is 14):
ONNX opset version used during model exporting.
_internal_use.enable_onnx_contrib_ops (bool, default is True)
enable PyTorch to export nodes as contrib ops in ONNX.
This flag may be removed anytime in the future.
session_options (onnxruntime.SessionOptions):
The SessionOptions instance that TrainingSession will use.
provider_options (dict):
The provider_options for customized execution providers. it is dict map from EP name to
a key-value pairs, like {'EP1' : {'key1' : 'val1'}, ....}
Example:
.. code-block:: python
opts = ORTTrainerOptions({
'batch' : {
'gradient_accumulation_steps' : 128
},
'device' : {
'id' : 'cuda:0',
'mem_limit' : 2*1024*1024*1024,
},
'lr_scheduler' : optim.lr_scheduler.LinearWarmupLRScheduler(),
'mixed_precision' : {
'enabled': True,
'loss_scaler': amp.LossScaler(loss_scale=float(1 << 16))
}
})
fp16_enabled = opts.mixed_precision.enabled
"""
def __init__(self, options={}): # noqa: B006
# Keep a copy of original input for debug
self._original_opts = dict(options)
# Used for logging purposes
self._main_class_name = self.__class__.__name__
# Validates user input
self._validated_opts = dict(self._original_opts)
validator = ORTTrainerOptionsValidator(_ORTTRAINER_OPTIONS_SCHEMA)
self._validated_opts = validator.validated(self._validated_opts)
if self._validated_opts is None:
raise ValueError(f"Invalid options: {validator.errors}")
# Convert dict in object
for k, v in self._validated_opts.items():
setattr(self, k, self._wrap(v))
def __repr__(self):
return "{%s}" % str(
", ".join(
f"'{k}': {v!r}"
for (k, v) in self.__dict__.items()
if k not in ["_original_opts", "_validated_opts", "_main_class_name"]
)
)
def _wrap(self, v):
if isinstance(v, (tuple, list, set, frozenset)):
return type(v)([self._wrap(i) for i in v])
else:
return _ORTTrainerOptionsInternal(self._main_class_name, v) if isinstance(v, dict) else v
class _ORTTrainerOptionsInternal(ORTTrainerOptions):
r"""Internal class used by ONNX Runtime training backend for input validation
NOTE: Users MUST NOT use this class in any way!
"""
def __init__(self, main_class_name, options):
# Used for logging purposes
self._main_class_name = main_class_name
# We don't call super().__init__(options) here but still called it "_validated_opts"
# instead of "_original_opts" because it has been validated in the top-level
# ORTTrainerOptions's constructor.
self._validated_opts = dict(options)
# Convert dict in object
for k, v in dict(options).items():
setattr(self, k, self._wrap(v))
class ORTTrainerOptionsValidator(cerberus.Validator):
_LR_SCHEDULER = cerberus.TypeDefinition("lr_scheduler", (lr_scheduler._LRScheduler,), ())
_LOSS_SCALER = cerberus.TypeDefinition("loss_scaler", (loss_scaler.LossScaler,), ())
_SESSION_OPTIONS = cerberus.TypeDefinition("session_options", (ort.SessionOptions,), ())
_PROPAGATE_CAST_OPS_STRATEGY = cerberus.TypeDefinition(
"propagate_cast_ops_strategy", (PropagateCastOpsStrategy,), ()
)
types_mapping = cerberus.Validator.types_mapping.copy()
types_mapping["lr_scheduler"] = _LR_SCHEDULER
types_mapping["loss_scaler"] = _LOSS_SCALER
types_mapping["session_options"] = _SESSION_OPTIONS
types_mapping["propagate_cast_ops_strategy"] = _PROPAGATE_CAST_OPS_STRATEGY
def _check_is_callable(field, value, error):
result = False
try:
# Python 3
result = value is None or callable(value)
except Exception:
# Python 3 but < 3.2
if hasattr(value, "__call__"): # noqa: B004
result = True
if not result:
error(field, "Must be callable or None")
_ORTTRAINER_OPTIONS_SCHEMA = {
"batch": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {"gradient_accumulation_steps": {"type": "integer", "min": 1, "default": 1}},
},
"device": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"id": {"type": "string", "default": "cuda"},
"mem_limit": {"type": "integer", "min": 0, "default": 0},
},
},
"distributed": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"world_rank": {"type": "integer", "min": 0, "default": 0},
"world_size": {"type": "integer", "min": 1, "default": 1},
"local_rank": {"type": "integer", "min": 0, "default": 0},
"data_parallel_size": {"type": "integer", "min": 1, "default": 1},
"horizontal_parallel_size": {"type": "integer", "min": 1, "default": 1},
"pipeline_parallel": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"pipeline_parallel_size": {"type": "integer", "min": 1, "default": 1},
"num_pipeline_micro_batches": {"type": "integer", "min": 1, "default": 1},
"pipeline_cut_info_string": {"type": "string", "default": ""},
"sliced_schema": {
"type": "dict",
"default_setter": lambda _: {},
"keysrules": {"type": "string"},
"valuesrules": {"type": "list", "schema": {"type": "integer"}},
},
"sliced_axes": {
"type": "dict",
"default_setter": lambda _: {},
"keysrules": {"type": "string"},
"valuesrules": {"type": "integer"},
},
"sliced_tensor_names": {"type": "list", "schema": {"type": "string"}, "default": []},
},
},
"allreduce_post_accumulation": {"type": "boolean", "default": False},
"deepspeed_zero_optimization": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"stage": {"type": "integer", "min": 0, "max": 1, "default": 0},
},
},
"enable_adasum": {"type": "boolean", "default": False},
},
},
"lr_scheduler": {"type": "lr_scheduler", "nullable": True, "default": None},
"mixed_precision": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"enabled": {"type": "boolean", "default": False},
"loss_scaler": {"type": "loss_scaler", "nullable": True, "default": None},
},
},
"graph_transformer": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"attn_dropout_recompute": {"type": "boolean", "default": False},
"gelu_recompute": {"type": "boolean", "default": False},
"transformer_layer_recompute": {"type": "boolean", "default": False},
"number_recompute_layers": {"type": "integer", "min": 0, "default": 0},
"propagate_cast_ops_config": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"strategy": {
"type": "propagate_cast_ops_strategy",
"nullable": True,
"default": PropagateCastOpsStrategy.FLOOD_FILL,
},
"level": {"type": "integer", "min": -1, "default": 1},
"allow": {"type": "list", "schema": {"type": "string"}, "default": []},
},
},
},
},
"utils": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"frozen_weights": {"type": "list", "default": []},
"grad_norm_clip": {"type": "boolean", "default": True},
"memory_efficient_gradient": {"type": "boolean", "default": False},
"run_symbolic_shape_infer": {"type": "boolean", "default": False},
},
},
"debug": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"deterministic_compute": {"type": "boolean", "default": False},
"check_model_export": {"type": "boolean", "default": False},
"graph_save_paths": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"model_after_graph_transforms_path": {"type": "string", "default": ""},
"model_with_gradient_graph_path": {"type": "string", "default": ""},
"model_with_training_graph_path": {"type": "string", "default": ""},
"model_with_training_graph_after_optimization_path": {"type": "string", "default": ""},
},
},
},
},
"_internal_use": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"schema": {
"enable_internal_postprocess": {"type": "boolean", "default": True},
"extra_postprocess": {"check_with": _check_is_callable, "nullable": True, "default": None},
"onnx_opset_version": {"type": "integer", "min": 12, "max": 14, "default": 14},
"enable_onnx_contrib_ops": {"type": "boolean", "default": True},
},
},
"provider_options": {
"type": "dict",
"default_setter": lambda _: {},
"required": False,
"allow_unknown": True,
"schema": {},
},
"session_options": {"type": "session_options", "nullable": True, "default": None},
}

Просмотреть файл

@ -1,478 +0,0 @@
import os.path # noqa: F401
import struct
import sys # noqa: F401
import numpy as np # noqa: F401
import onnx
from onnx import * # noqa: F403
from onnx import helper, numpy_helper # noqa: F401
def run_postprocess(model):
# this post pass is not required for pytorch >= 1.5
# where add_node_name in torch.onnx.export is default to True
model = add_name(model)
# this post pass is not required for pytorch > 1.6
model = fuse_softmaxNLL_to_softmaxCE(model)
model = fix_expand_shape(model)
model = fix_expand_shape_pt_1_5(model)
return model
def find_input_node(model, arg):
result = []
for node in model.graph.node:
for output in node.output:
if output == arg:
result.append(node)
return result[0] if len(result) == 1 else None
def find_output_node(model, arg):
result = []
for node in model.graph.node:
for input in node.input:
if input == arg:
result.append(node)
return result[0] if len(result) == 1 else result
def add_name(model):
i = 0
for node in model.graph.node:
node.name = "%s_%d" % (node.op_type, i)
i += 1
return model
# Expand Shape PostProcess
def fix_expand_shape(model):
expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"]
model_inputs_names = [i.name for i in model.graph.input]
for expand_node in expand_nodes:
shape = find_input_node(model, expand_node.input[1])
if shape.op_type == "Shape":
# an expand subgraph
# Input Input2
# | |
# | Shape
# | |
# |__ __|
# | |
# Expand
# |
# output
#
# Only if Input2 is one of the model inputs, assign Input2's shape to output of expand.
shape_input_name = shape.input[0]
if shape_input_name in model_inputs_names:
index = model_inputs_names.index(shape_input_name)
expand_out = model.graph.value_info.add()
expand_out.name = expand_node.output[0]
expand_out.type.CopyFrom(model.graph.input[index].type)
return model
def fix_expand_shape_pt_1_5(model):
# expand subgraph
# Constant
# +
# ConstantOfShape
# | + |
# | + |
# (Reshape subgraph) Mul |
# |___ _________| |
# + | | |
# + Equal |
# +++++|++++++++++++++|++
# |____________ | +
# | | +
# (subgraph) Where
# | |
# |_____ ___________|
# | |
# Expand
# |
# output
#
# where the Reshape subgraph is
#
# Input
# | |
# | |___________________
# | |
# Shape Constant Shape Constant
# | ______| | ______|
# | | | |
# Gather Gather
# | |
# Unsqueeze Unsqueeze
# | |
# | ..Number of dims.. |
# | _________________|
# |...|
# Concat Constant
# | |
# |______ __________________|
# | |
# Reshape
# |
# output
#
# This pass will copy Input's shape to the output of Expand.
expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"]
model_inputs_names = [i.name for i in model.graph.input]
for expand_node in expand_nodes:
n_where = find_input_node(model, expand_node.input[1])
if n_where.op_type != "Where":
continue
n_equal = find_input_node(model, n_where.input[0])
n_cos = find_input_node(model, n_where.input[1])
n_reshape = find_input_node(model, n_where.input[2])
if n_equal.op_type != "Equal" or n_cos.op_type != "ConstantOfShape" or n_reshape.op_type != "Reshape":
continue
n_reshape_e = find_input_node(model, n_equal.input[0])
n_mul = find_input_node(model, n_equal.input[1])
if n_reshape_e != n_reshape or n_mul.op_type != "Mul":
continue
n_cos_m = find_input_node(model, n_mul.input[0])
n_constant = find_input_node(model, n_mul.input[1])
if n_cos_m != n_cos or n_constant.op_type != "Constant":
continue
n_concat = find_input_node(model, n_reshape.input[0])
n_constant_r = find_input_node(model, n_reshape.input[1])
if n_concat.op_type != "Concat" or n_constant_r.op_type != "Constant":
continue
n_input_candidates = []
for concat_in in n_concat.input:
n_unsqueeze = find_input_node(model, concat_in)
if n_unsqueeze.op_type != "Unsqueeze":
break
n_gather = find_input_node(model, n_unsqueeze.input[0])
if n_gather.op_type != "Gather":
break
n_shape = find_input_node(model, n_gather.input[0])
n_constant_g = find_input_node(model, n_gather.input[1])
if n_shape.op_type != "Shape" or n_constant_g.op_type != "Constant":
break
n_input = n_shape.input[0]
if n_input not in model_inputs_names:
break
n_input_candidates.append(n_input)
if not n_input_candidates or not all(elem == n_input_candidates[0] for elem in n_input_candidates):
continue
index = model_inputs_names.index(n_input_candidates[0])
expand_out = model.graph.value_info.add()
expand_out.name = expand_node.output[0]
expand_out.type.CopyFrom(model.graph.input[index].type)
return model
# LayerNorm PostProcess
def find_nodes(graph, op_type):
nodes = []
for node in graph.node:
if node.op_type == op_type:
nodes.append(node)
return nodes
def is_type(node, op_type):
if node is None or isinstance(node, list):
return False
return node.op_type == op_type
def add_const(model, name, output, t_value=None, f_value=None):
const_node = model.graph.node.add()
const_node.op_type = "Constant"
const_node.name = name
const_node.output.extend([output])
attr = const_node.attribute.add()
attr.name = "value"
if t_value is not None:
attr.type = 4
attr.t.CopyFrom(t_value)
else:
attr.type = 1
attr.f = f_value
return const_node
def layer_norm_transform(model):
# DEPRECATED: This pass is no longer needed as the transform is handled at the backend.
# Converting below subgraph
#
# input
# |
# ReduceMean
# |
# Sub Constant
# _||_____ |
# | | |
# | | |
# | (optional) Cast (optional) Cast
# | | |
# | | ____________________|
# | | |
# | Pow
# | |
# | ReduceMean
# | |
# | Add
# | |
# |__ __Sqrt
# | |
# Div (weight)
# | |
# | _____|
# | |
# Mul (bias)
# | |
# | _____|
# | |
# Add
# |
# output
#
# to the below subgraph
#
# input (weight) (bias)
# | | |
# | _______| |
# | | ________________|
# | | |
# LayerNormalization
# |
# output
graph = model.graph
nodes_ReduceMean = find_nodes(graph, "ReduceMean") # noqa: N806
id = 0
layer_norm_nodes = []
remove_nodes = []
for reduce_mean in nodes_ReduceMean:
# check that reduce_mean output is Sub
sub = find_output_node(model, reduce_mean.output[0])
if not is_type(sub, "Sub"):
continue
# check that sub output[0] is Div and output[1] is Pow
pow, div = find_output_node(model, sub.output[0])
if is_type(pow, "Cast"):
# During an update in PyTorch, Cast nodes are inserted between Sub and Pow.
remove_nodes += [pow]
pow = find_output_node(model, pow.output[0])
if not is_type(pow, "Pow"):
continue
cast_pow = find_input_node(model, pow.input[1])
if not is_type(cast_pow, "Cast"):
continue
remove_nodes += [cast_pow]
if not is_type(div, "Div") or not is_type(pow, "Pow"):
continue
# check that pow ouput is ReduceMean
reduce_mean2 = find_output_node(model, pow.output[0])
if not is_type(reduce_mean2, "ReduceMean"):
continue
# check that reduce_mean2 output is Add
add = find_output_node(model, reduce_mean2.output[0])
if not is_type(add, "Add"):
continue
# check that add output is Sqrt
sqrt = find_output_node(model, add.output[0])
if not is_type(sqrt, "Sqrt"):
continue
# check that sqrt output is div
if div != find_output_node(model, sqrt.output[0]):
continue
# check if div output is Mul
optional_mul = find_output_node(model, div.output[0])
if not is_type(optional_mul, "Mul"):
optional_mul = None
continue # default bias and weight not supported
# check if mul output is Add
if optional_mul is not None:
optional_add = find_output_node(model, optional_mul.output[0])
else:
optional_add = find_output_node(model, div.output[0])
if not is_type(optional_add, "Add"):
optional_add = None
continue # default bias and weight not supported
# add nodes to remove_nodes
remove_nodes.extend([reduce_mean, sub, div, pow, reduce_mean2, add, sqrt])
# create LayerNorm node
layer_norm_input = []
layer_norm_output = []
layer_norm_input.append(reduce_mean.input[0])
if optional_mul is not None:
remove_nodes.append(optional_mul)
weight = optional_mul.input[1]
layer_norm_input.append(weight)
if optional_add is not None:
remove_nodes.append(optional_add)
bias = optional_add.input[1]
layer_norm_input.append(bias)
if optional_add is not None:
layer_norm_output.append(optional_add.output[0])
elif optional_mul is not None:
layer_norm_output.append(optional_mul.output[0])
else:
layer_norm_output.append(div.output[0])
layer_norm_output.append("saved_mean_" + str(id))
layer_norm_output.append("saved_inv_std_var_" + str(id))
epsilon_node = find_input_node(model, add.input[1])
epsilon = epsilon_node.attribute[0].t.raw_data
epsilon = struct.unpack("f", epsilon)[0]
layer_norm = helper.make_node(
"LayerNormalization",
layer_norm_input,
layer_norm_output,
"LayerNormalization_" + str(id),
None,
axis=reduce_mean.attribute[0].ints[0],
epsilon=epsilon,
)
layer_norm_nodes.append(layer_norm)
id += 1
# remove orphan constant nodes
for constant in graph.node:
if constant.op_type == "Constant" and constant not in remove_nodes:
is_orphan = True
for out_name in constant.output:
out = find_output_node(model, out_name)
if out not in remove_nodes:
is_orphan = False
if is_orphan:
remove_nodes.append(constant)
all_nodes = []
for node in graph.node:
if node not in remove_nodes:
all_nodes.append(node)
for node in layer_norm_nodes:
all_nodes.append(node) # noqa: PERF402
graph.ClearField("node")
graph.node.extend(all_nodes)
return model
# Fuse SoftmaxCrossEntropy
def fuse_softmaxNLL_to_softmaxCE(onnx_model): # noqa: N802
# Converting below subgraph
#
# (subgraph)
# |
# LogSoftmax (target) (optional weight)
# | | |
# nll_loss/NegativeLogLikelihoodLoss
# |
# output
#
# to the following
#
# (subgraph) (target) (optional weight)
# | | _____|
# | | |
# SparseSoftmaxCrossEntropy
# |
# output
nll_count = 0
while True:
nll_count = nll_count + 1
nll_loss_node = None
nll_loss_node_index = 0
for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007
if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss":
nll_loss_node = node
break
if nll_loss_node is None:
break
softmax_node = None
softmax_node_index = 0
label_input_name = None
weight_input_name = None
for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007
if node.op_type == "LogSoftmax":
# has to be connected to nll_loss
if len(nll_loss_node.input) > 2:
weight_input_name = nll_loss_node.input[2]
if node.output[0] == nll_loss_node.input[0]:
softmax_node = node
label_input_name = nll_loss_node.input[1]
break
elif node.output[0] == nll_loss_node.input[1]:
softmax_node = node
label_input_name = nll_loss_node.input[0]
break
else:
if softmax_node is not None:
break
if softmax_node is None:
break
# delete nll_loss and LogSoftmax nodes in order
if nll_loss_node_index < softmax_node_index:
del onnx_model.graph.node[softmax_node_index]
del onnx_model.graph.node[nll_loss_node_index]
else:
del onnx_model.graph.node[nll_loss_node_index]
del onnx_model.graph.node[softmax_node_index]
probability_output_name = softmax_node.output[0]
node = onnx_model.graph.node.add()
inputs = (
[softmax_node.input[0], label_input_name, weight_input_name]
if weight_input_name
else [softmax_node.input[0], label_input_name]
)
node.CopyFrom(
onnx.helper.make_node(
"SparseSoftmaxCrossEntropy",
inputs,
[nll_loss_node.output[0], probability_output_name],
"nll_loss_node_" + str(nll_count),
)
)
return onnx_model

Просмотреть файл

@ -1,144 +0,0 @@
import sys
import threading
import time
class OutputGrabber:
"""
Class used to grab standard output or another stream.
"""
escape_char = "\b"
def __init__(self, stream=None, threaded=False):
self.origstream = stream
self.threaded = threaded
if self.origstream is None:
self.origstream = sys.stdout
self.origstreamfd = self.origstream.fileno()
self.capturedtext = ""
# Create a pipe so the stream can be captured:
self.pipe_out, self.pipe_in = os.pipe()
def __enter__(self):
self.start()
return self
def __exit__(self, type, value, traceback):
self.stop()
def start(self):
"""
Start capturing the stream data.
"""
self.capturedtext = ""
# Save a copy of the stream:
self.streamfd = os.dup(self.origstreamfd)
# Replace the original stream with our write pipe:
os.dup2(self.pipe_in, self.origstreamfd)
if self.threaded:
# Start thread that will read the stream:
self.workerThread = threading.Thread(target=self.readOutput)
self.workerThread.start()
# Make sure that the thread is running and os.read() has executed:
time.sleep(0.01)
def stop(self):
"""
Stop capturing the stream data and save the text in `capturedtext`.
"""
# Print the escape character to make the readOutput method stop:
self.origstream.write(self.escape_char)
# Flush the stream to make sure all our data goes in before
# the escape character:
self.origstream.flush()
if self.threaded:
# wait until the thread finishes so we are sure that
# we have until the last character:
self.workerThread.join()
else:
self.readOutput()
# Close the pipe:
os.close(self.pipe_in)
os.close(self.pipe_out)
# Restore the original stream:
os.dup2(self.streamfd, self.origstreamfd)
# Close the duplicate stream:
os.close(self.streamfd)
def readOutput(self):
"""
Read the stream data (one byte at a time)
and save the text in `capturedtext`.
"""
while True:
char = os.read(self.pipe_out, 1).decode(self.origstream.encoding)
if not char or self.escape_char in char:
break
self.capturedtext += char
import os # noqa: E402
import unittest # noqa: E402
import numpy as np # noqa: E402, F401
import torch # noqa: E402
import torch.nn as nn # noqa: E402
import torch.nn.functional as F # noqa: E402
from onnxruntime.capi import _pybind_state as torch_ort_eager # noqa: E402, F401
from onnxruntime.training import optim, orttrainer, orttrainer_options # noqa: E402, F401
def my_loss(x, target):
return F.nll_loss(F.log_softmax(x, dim=1), target)
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x, target):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return my_loss(out, target)
class OrtEPTests(unittest.TestCase):
def test_external_graph_transformer_triggering(self):
input_size = 784
hidden_size = 500
num_classes = 10
batch_size = 128
model = NeuralNet(input_size, hidden_size, num_classes)
model_desc = {
"inputs": [
("x", [batch_size, input_size]),
(
"target",
[
batch_size,
],
),
],
"outputs": [("loss", [], True)],
}
optim_config = optim.SGDConfig()
opts = orttrainer.ORTTrainerOptions({"device": {"id": "cpu"}})
model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
# because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer
data = torch.rand(batch_size, input_size)
target = torch.randint(0, 10, (batch_size,))
with OutputGrabber() as out:
model.train_step(data, target)
assert "******************Trigger Customized Graph Transformer: MyGraphTransformer!" in out.capturedtext
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -1,35 +0,0 @@
#include "core/optimizer/rewrite_rule.h"
#include "orttraining/core/optimizer/graph_transformer_registry.h"
#include "onnx/defs/schema.h"
#include <memory>
#include <iostream>
namespace onnxruntime {
namespace training {
class MyRewriteRule : public RewriteRule {
public:
MyRewriteRule() noexcept
: RewriteRule("MyRewriteRule") {
}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {};
}
private:
bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override {
return true;
}
Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/, const logging::Logger& /*logger*/) const override {
std::cout << "******************Trigger Customized Graph Transformer: MyGraphTransformer!" << std::endl;
return Status::OK();
}
};
void RegisterTrainingExternalTransformers() {
ONNX_REGISTER_EXTERNAL_REWRITE_RULE(MyRewriteRule, Level1, true);
}
} // namespace training
} // namespace onnxruntime

Просмотреть файл

@ -1,26 +1,7 @@
import copy
import math
import os
import subprocess
import sys
import numpy as np
import onnx
import torch
from numpy.testing import assert_allclose
import onnxruntime
from onnxruntime.training import _utils, optim
def _single_run(execution_file, scenario, checkopint_dir=None):
cmd = [sys.executable, execution_file]
if scenario:
cmd += ["--scenario", scenario]
if checkopint_dir:
cmd += ["--checkpoint_dir", checkopint_dir]
assert subprocess.call(cmd) == 0
def is_windows():
return sys.platform.startswith("win")
@ -46,197 +27,3 @@ def run_subprocess(args, cwd=None, capture=False, dll_path=None, shell=False, en
if log:
log.debug("Subprocess completed. Return code=" + str(completed_process.returncode))
return completed_process
def legacy_constant_lr_scheduler(global_step, initial_lr, total_steps, warmup):
num_warmup_steps = warmup * total_steps
if global_step < num_warmup_steps:
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
else:
new_lr = initial_lr
return new_lr
def legacy_cosine_lr_scheduler(global_step, initial_lr, total_steps, warmup, cycles):
num_warmup_steps = warmup * total_steps
if global_step < num_warmup_steps:
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
else:
progress = float(global_step - num_warmup_steps) / float(max(1, total_steps - num_warmup_steps))
new_lr = initial_lr * max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(cycles) * 2.0 * progress)))
return new_lr
def legacy_linear_lr_scheduler(global_step, initial_lr, total_steps, warmup):
num_warmup_steps = warmup * total_steps
if global_step < num_warmup_steps:
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
else:
new_lr = initial_lr * max(0.0, float(total_steps - global_step) / float(max(1, total_steps - num_warmup_steps)))
return new_lr
def legacy_poly_lr_scheduler(global_step, initial_lr, total_steps, warmup, power, lr_end):
num_warmup_steps = warmup * total_steps
if global_step < num_warmup_steps:
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
elif global_step > total_steps:
new_lr = lr_end
else:
lr_range = initial_lr - lr_end
decay_steps = total_steps - num_warmup_steps
pct_remaining = 1 - (global_step - num_warmup_steps) / decay_steps
decay = lr_range * pct_remaining**power + lr_end
new_lr = decay
return new_lr
def generate_dummy_optim_state(model, optimizer):
np.random.seed(0)
if not (isinstance(optimizer, (optim.AdamConfig, optim.LambConfig))):
return dict()
moment_keys = ["Moment_1", "Moment_2"]
uc_key = "Update_Count"
step_key = "Step"
shared_state_key = "shared_optimizer_state"
optim_state = dict()
weight_shape_map = dict()
if isinstance(model, torch.nn.Module):
weight_shape_map = {name: param.size() for name, param in model.named_parameters()}
elif isinstance(model, onnx.ModelProto):
weight_shape_map = {n.name: n.dims for n in model.graph.initializer}
else:
raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'")
for weight_name, weight_shape in weight_shape_map.items():
per_weight_state = dict()
for moment in moment_keys:
per_weight_state[moment] = np.random.uniform(-2, 2, weight_shape).astype(np.float32)
if isinstance(optimizer, optim.AdamConfig):
per_weight_state[uc_key] = np.full([1], 5, dtype=np.int64)
optim_state[weight_name] = copy.deepcopy(per_weight_state)
if isinstance(optimizer, optim.LambConfig):
step_val = np.full([1], 5, dtype=np.int64)
optim_state[shared_state_key] = {step_key: step_val}
return {"optimizer": optim_state, "trainer_options": {"optimizer_name": optimizer.name}}
def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False, data_dir=None):
# Loads external Pytorch TransformerModel into utils
root = "samples"
if not os.path.exists(root):
root = os.path.normpath(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "..", "samples")
)
if not os.path.exists(root):
raise FileNotFoundError("Unable to find folder 'samples', tried %r." % root)
pytorch_transformer_path = os.path.join(root, "python", "training", "orttrainer", "pytorch_transformer")
pt_model_path = os.path.join(pytorch_transformer_path, "pt_model.py")
pt_model = _utils.import_module_from_file(pt_model_path)
ort_utils_path = os.path.join(pytorch_transformer_path, "ort_utils.py")
ort_utils = _utils.import_module_from_file(ort_utils_path)
utils_path = os.path.join(pytorch_transformer_path, "utils.py")
utils = _utils.import_module_from_file(utils_path)
# Modeling
model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
my_loss = ort_utils.my_loss
if legacy_api:
if dynamic_axes:
model_desc = ort_utils.legacy_transformer_model_description_dynamic_axes()
else:
model_desc = ort_utils.legacy_transformer_model_description()
else:
if dynamic_axes:
model_desc = ort_utils.transformer_model_description_dynamic_axes()
else:
model_desc = ort_utils.transformer_model_description()
# Preparing data
train_data, val_data, test_data = utils.prepare_data(device, 20, 20, data_dir)
return model, model_desc, my_loss, utils.get_batch, train_data, val_data, test_data
def generate_random_input_from_bart_model_desc(desc, seed=1, device="cuda:0"):
"""Generates a sample input for the BART model using the model desc"""
torch.manual_seed(seed)
onnxruntime.set_seed(seed)
dtype = torch.int64
vocab_size = 30528
sample_input = []
for _index, input in enumerate(desc["inputs"]):
size = []
for s in input[1]:
if isinstance(s, (int)):
size.append(s)
else:
size.append(1)
sample_input.append(torch.randint(0, vocab_size, tuple(size), dtype=dtype).to(device))
return sample_input
def _load_bart_model():
bart_onnx_model_path = os.path.join("testdata", "bart_tiny.onnx")
model = onnx.load(bart_onnx_model_path)
batch = 2
seq_len = 1024
model_desc = {
"inputs": [
(
"src_tokens",
[batch, seq_len],
),
(
"prev_output_tokens",
[batch, seq_len],
),
(
"target",
[batch * seq_len],
),
],
"outputs": [("loss", [], True)],
}
return model, model_desc
def assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint, reshape_states=False):
"""Assert that the two ORTTrainer (hierarchical) state dictionaries are very close for all states"""
assert ("model" in state_dict_pre_checkpoint) == ("model" in state_dict_post_checkpoint)
assert ("optimizer" in state_dict_pre_checkpoint) == ("optimizer" in state_dict_post_checkpoint)
if "model" in state_dict_pre_checkpoint:
for model_state_key in state_dict_pre_checkpoint["model"]["full_precision"]:
if reshape_states:
assert_allclose(
state_dict_pre_checkpoint["model"]["full_precision"][model_state_key],
state_dict_post_checkpoint["model"]["full_precision"][model_state_key].reshape(
state_dict_pre_checkpoint["model"]["full_precision"][model_state_key].shape
),
)
else:
assert_allclose(
state_dict_pre_checkpoint["model"]["full_precision"][model_state_key],
state_dict_post_checkpoint["model"]["full_precision"][model_state_key],
)
if "optimizer" in state_dict_pre_checkpoint:
for model_state_key in state_dict_pre_checkpoint["optimizer"]:
for optimizer_state_key in state_dict_pre_checkpoint["optimizer"][model_state_key]:
if reshape_states:
assert_allclose(
state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key],
state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key].reshape(
state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key].shape
),
)
else:
assert_allclose(
state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key],
state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key],
)

Просмотреть файл

@ -1,30 +1,11 @@
import copy
import os
import numpy as np
import torch
from numpy.testing import assert_allclose
from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer
from onnxruntime.training import orttrainer
try:
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
from onnxruntime.training.ortmodule._graph_execution_manager_factory import ( # noqa: F401
GraphExecutionManagerFactory,
)
except ImportError:
# Some pipelines do not contain ORTModule
pass
except Exception as e:
from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
if isinstance(e, ORTModuleInitException):
# ORTModule is present but not ready to run
# That is OK because this file is also used by ORTTrainer tests
pass
raise
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory # noqa: F401
def is_all_or_nothing_fallback_enabled(model, policy=None):
@ -66,103 +47,6 @@ def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0):
assert_allclose(output_a, output_b, rtol=rtol, atol=atol, err_msg="Model output value mismatch")
def assert_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0):
r"""Asserts whether weight difference between models a and b differences are within specified tolerance
Compares the weights of two different ONNX models (model_a and model_b)
and raises AssertError when they diverge by more than atol or rtol
Args:
model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure
verbose (bool, default is False): if True, prints absolute difference for each weight
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""
assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer)
state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state()
assert len(state_dict_a.items()) == len(state_dict_b.items())
_assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol)
def assert_legacy_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0):
r"""Asserts whether weight difference between models a and b differences are within specified tolerance
Compares the weights of a legacy model model_a and experimental model_b model
and raises AssertError when they diverge by more than atol or rtol.
Args:
model_a (ORTTrainer): Instance of legacy ORTTrainer
model_b (ORTTrainer): Instance of experimental ORTTrainer
verbose (bool, default is False): if True, prints absolute difference for each weight.
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""
assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, Legacy_ORTTrainer)
state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b.session.get_state()
assert len(state_dict_a.items()) == len(state_dict_b.items())
_assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol)
def _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol):
r"""Asserts whether dicts a and b value differences are within specified tolerance
Compares the weights of two model's state_dict dicts and raises AssertError
when they diverge by more than atol or rtol
Args:
model_a (ORTTrainer): Instance of legacy ORTTrainer
model_b (ORTTrainer): Instance of experimental ORTTrainer
verbose (bool, default is False): if True, prints absolute difference for each weight.
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""
for (a_name, a_val), (_b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()):
np_a_vals = np.array(a_val).flatten()
np_b_vals = np.array(b_val).flatten()
assert np_a_vals.shape == np_b_vals.shape
if verbose:
print(f"Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}")
assert_allclose(a_val, b_val, rtol=rtol, atol=atol, err_msg=f"Weight mismatch for {a_name}")
def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0):
r"""Asserts whether optimizer state differences are within specified tolerance
Compares the expected and actual optimizer states of dicts and raises AssertError
when they diverge by more than atol or rtol.
The optimizer dict is of the form:
model_weight_name:
{
"Moment_1": moment1_tensor,
"Moment_2": moment2_tensor,
"Update_Count": update_tensor # if optimizer is adam, absent otherwise
},
...
"shared_optimizer_state": # if optimizer is shared, absent otherwise.
So far, only lamb optimizer uses this.
{
"step": step_tensor # int array of size 1
}
Args:
expected_state (dict(dict())): Expected optimizer state
actual_state (dict(dict())): Actual optimizer state
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 0): Max absolute difference
"""
assert expected_state.keys() == actual_state.keys()
for param_name, a_state in actual_state.items():
for k, v in a_state.items():
assert_allclose(
v,
expected_state[param_name][k],
rtol=rtol,
atol=atol,
err_msg=f"Optimizer state mismatch for param {param_name}, key {k}",
)
def is_dynamic_axes(model):
# Check inputs
for inp in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.input:

Просмотреть файл

@ -1,325 +0,0 @@
import os
import unittest
import torch
import torch.nn as nn
from orttraining_test_bert_postprocess import postprocess_model
from orttraining_test_data_loader import create_ort_test_dataloader
from orttraining_test_transformers import BertForPreTraining, BertModelTest
from orttraining_test_utils import map_optimizer_attributes
import onnxruntime
from onnxruntime.capi.ort_trainer import ( # noqa: F401
IODescription,
LossScaler,
ModelDescription,
ORTTrainer,
generate_sample,
)
torch.manual_seed(1)
onnxruntime.set_seed(1)
class Test_PostPasses(unittest.TestCase): # noqa: N801
def get_onnx_model(
self, model, model_desc, inputs, device, _enable_internal_postprocess=True, _extra_postprocess=None
):
lr_desc = IODescription(
"Learning_Rate",
[
1,
],
torch.float32,
)
model = ORTTrainer(
model,
None,
model_desc,
"LambOptimizer",
map_optimizer_attributes,
lr_desc,
device,
world_rank=0,
world_size=1,
_opset_version=14,
_enable_internal_postprocess=_enable_internal_postprocess,
_extra_postprocess=_extra_postprocess,
)
model.train_step(*inputs)
return model.onnx_model_
def count_all_nodes(self, model):
return len(model.graph.node)
def count_nodes(self, model, node_type):
count = 0
for node in model.graph.node:
if node.op_type == node_type:
count += 1
return count
def find_nodes(self, model, node_type):
nodes = []
for node in model.graph.node:
if node.op_type == node_type:
nodes.append(node)
return nodes
def get_name(self, name):
if os.path.exists(name):
return name
rel = os.path.join("testdata", name)
if os.path.exists(rel):
return rel
this = os.path.dirname(__file__)
data = os.path.join(this, "..", "..", "..", "..", "onnxruntime", "test", "testdata")
res = os.path.join(data, name)
if os.path.exists(res):
return res
raise FileNotFoundError(f"Unable to find '{name}' or '{rel}' or '{res}'")
def test_layer_norm(self):
class LayerNormNet(nn.Module):
def __init__(self, target):
super().__init__()
self.ln_1 = nn.LayerNorm(10)
self.loss = nn.CrossEntropyLoss()
self.target = target
def forward(self, x):
output1 = self.ln_1(x)
loss = self.loss(output1, self.target)
return loss, output1
device = torch.device("cpu")
target = torch.ones(20, 10, 10, dtype=torch.int64).to(device)
model = LayerNormNet(target)
input = torch.randn(20, 5, 10, 10, dtype=torch.float32).to(device)
input_desc = IODescription("input", [], "float32")
output0_desc = IODescription("output0", [], "float32")
output1_desc = IODescription("output1", [20, 5, 10, 10], "float32")
model_desc = ModelDescription([input_desc], [output0_desc, output1_desc])
learning_rate = torch.tensor([1.0000000e00]).to(device)
input_args = [input, learning_rate]
onnx_model = self.get_onnx_model(model, model_desc, input_args, device)
count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization")
count_nodes = self.count_all_nodes(onnx_model)
assert count_layer_norm == 0
assert count_nodes == 3
def test_expand(self):
class ExpandNet(nn.Module):
def __init__(self, target):
super().__init__()
self.loss = nn.CrossEntropyLoss()
self.target = target
self.linear = torch.nn.Linear(2, 2)
def forward(self, x, x1):
output = x.expand_as(x1)
output = self.linear(output)
output = output + output
loss = self.loss(output, self.target)
return loss, output
device = torch.device("cpu")
target = torch.ones(5, 5, 2, dtype=torch.int64).to(device)
model = ExpandNet(target).to(device)
x = torch.randn(5, 3, 1, 2, dtype=torch.float32).to(device)
x1 = torch.randn(5, 3, 5, 2, dtype=torch.float32).to(device)
input0_desc = IODescription("x", [5, 3, 1, 2], "float32")
input1_desc = IODescription("x1", [5, 3, 5, 2], "float32")
output0_desc = IODescription("output0", [], "float32")
output1_desc = IODescription("output1", [5, 3, 5, 2], "float32")
model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc])
learning_rate = torch.tensor([1.0000000e00]).to(device)
input_args = [x, x1, learning_rate]
onnx_model = self.get_onnx_model(model, model_desc, input_args, device)
# check that expand output has shape
expand_nodes = self.find_nodes(onnx_model, "Expand")
assert len(expand_nodes) == 1
model_info = onnx_model.graph.value_info
assert model_info[0].name == expand_nodes[0].output[0]
assert model_info[0].type == onnx_model.graph.input[1].type
def test_bert(self):
device = torch.device("cpu")
model_tester = BertModelTest.BertModelTester(self)
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = model_tester.prepare_config_and_inputs()
model = BertForPreTraining(config=config)
model.eval()
loss, prediction_scores, seq_relationship_score = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
masked_lm_labels=token_labels,
next_sentence_label=sequence_labels,
)
model_desc = ModelDescription(
[
model_tester.input_ids_desc,
model_tester.attention_mask_desc,
model_tester.token_type_ids_desc,
model_tester.masked_lm_labels_desc,
model_tester.next_sentence_label_desc,
],
[model_tester.loss_desc, model_tester.prediction_scores_desc, model_tester.seq_relationship_scores_desc],
)
from collections import namedtuple
MyArgs = namedtuple(
"MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len"
)
args = MyArgs(
local_rank=0,
world_size=1,
max_steps=100,
learning_rate=0.00001,
warmup_proportion=0.01,
batch_size=13,
seq_len=7,
)
dataset_len = 100
dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device)
learning_rate = torch.tensor(1.0e0, dtype=torch.float32).to(device)
for b in dataloader:
batch = b
break
learning_rate = torch.tensor([1.00e00]).to(device)
inputs = [*batch, learning_rate]
onnx_model = self.get_onnx_model(model, model_desc, inputs, device, _extra_postprocess=postprocess_model)
self._bert_helper(onnx_model)
def _bert_helper(self, onnx_model):
# count layer_norm
count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization")
assert count_layer_norm == 0
# get expand node and check output shape
expand_nodes = self.find_nodes(onnx_model, "Expand")
assert len(expand_nodes) == 1
model_info = onnx_model.graph.value_info
assert model_info[0].name == expand_nodes[0].output[0]
assert model_info[0].type == onnx_model.graph.input[0].type
def test_extra_postpass(self):
def postpass_replace_first_add_with_sub(model):
# this post pass replaces the first Add node with Sub in the model.
# Previous graph
# (subgraph 1) (subgraph 2)
# | |
# | |
# |________ ________|
# | |
# Add
# |
# (subgraph 3)
#
# Post graph
# (subgraph 1) (subgraph 2)
# | |
# | |
# |________ ________|
# | |
# Sub
# |
# (subgraph 3)
add_nodes = [n for n in model.graph.node if n.op_type == "Add"]
add_nodes[0].op_type = "Sub"
class MultiAdd(nn.Module):
def __init__(self, target):
super().__init__()
self.loss = nn.CrossEntropyLoss()
self.target = target
self.linear = torch.nn.Linear(2, 2, bias=False)
def forward(self, x, x1):
output = x + x1
output = output + x
output = output + x1
output = self.linear(output)
loss = self.loss(output, self.target)
return loss, output
device = torch.device("cpu")
target = torch.ones(5, 2, dtype=torch.int64).to(device)
model = MultiAdd(target).to(device)
x = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
input0_desc = IODescription("x", [5, 5, 2], "float32")
input1_desc = IODescription("x1", [5, 5, 2], "float32")
output0_desc = IODescription("output0", [], "float32")
output1_desc = IODescription("output1", [5, 5, 2], "float32")
model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc])
learning_rate = torch.tensor([1.0000000e00]).to(device)
input_args = [x, x1, learning_rate]
onnx_model = self.get_onnx_model(
model, model_desc, input_args, device, _extra_postprocess=postpass_replace_first_add_with_sub
)
# check that extra postpass is called, and called only once.
add_nodes = self.find_nodes(onnx_model, "Add")
sub_nodes = self.find_nodes(onnx_model, "Sub")
assert len(add_nodes) == 2
assert len(sub_nodes) == 1
unprocessed_onnx_model = self.get_onnx_model(
model, model_desc, input_args, device, _extra_postprocess=None, _enable_internal_postprocess=False
)
# check that the model is unchanged.
add_nodes = self.find_nodes(unprocessed_onnx_model, "Add")
sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub")
assert len(add_nodes) == 3
assert len(sub_nodes) == 0
processed_onnx_model = self.get_onnx_model(
unprocessed_onnx_model,
model_desc,
input_args,
device,
_extra_postprocess=postpass_replace_first_add_with_sub,
)
# check that extra postpass is called, and called only once.
add_nodes = self.find_nodes(processed_onnx_model, "Add")
sub_nodes = self.find_nodes(processed_onnx_model, "Sub")
assert len(add_nodes) == 2
assert len(sub_nodes) == 1
if __name__ == "__main__":
unittest.main(module=__name__, buffer=True)

Просмотреть файл

@ -43,7 +43,7 @@ def run_ortmodule_ops_tests(cwd, log, transformers_cache):
env = get_env_with_transformers_cache(transformers_cache)
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnx_ops_ortmodule.py"]
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_onnx_ops.py"]
run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode()
@ -146,7 +146,7 @@ def run_data_sampler_tests(cwd, log):
def run_hooks_tests(cwd, log):
log.debug("Running: Data hooks tests")
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_hooks.py"]
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_hooks.py"]
run_subprocess(command, cwd=cwd, log=log).check_returncode()

Просмотреть файл

@ -1,801 +0,0 @@
# ==================
import dataclasses
import datetime
import glob
import json
import logging
import os
import random
import shutil
import unittest
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
import h5py
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import BertConfig, BertForPreTraining, HfArgumentParser
import onnxruntime as ort
# need to override torch.onnx.symbolic_opset12.nll_loss to handle ignore_index == -100 cases.
# the fix for ignore_index == -100 cases is already in pytorch master.
# however to use current torch master is causing computation changes in many tests.
# eventually we will use pytorch with fixed nll_loss once computation
# issues are understood and solved.
import onnxruntime.capi.pt_patch
from onnxruntime.training import amp, optim, orttrainer
from onnxruntime.training.checkpoint import aggregate_checkpoints
from onnxruntime.training.optim import LinearWarmupLRScheduler, PolyWarmupLRScheduler # noqa: F401
# we cannot make full convergence run in nightly pipeling because of its timeout limit,
# max_steps is still needed to calculate learning rate. force_to_stop_max_steps is used to
# terminate the training before the pipeline run hit its timeout.
force_to_stop_max_steps = 2500
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
logger = logging.getLogger(__name__)
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def is_main_process(args):
if hasattr(args, "world_rank"):
return args.world_rank in [-1, 0]
else:
return get_rank() == 0
def bert_model_description(config):
vocab_size = config.vocab_size
new_model_desc = {
"inputs": [
(
"input_ids",
["batch", "max_seq_len_in_batch"],
),
(
"attention_mask",
["batch", "max_seq_len_in_batch"],
),
(
"token_type_ids",
["batch", "max_seq_len_in_batch"],
),
(
"masked_lm_labels",
["batch", "max_seq_len_in_batch"],
),
(
"next_sentence_label",
[
"batch",
],
),
],
"outputs": [
("loss", [], True),
(
"prediction_scores",
["batch", "max_seq_len_in_batch", vocab_size],
),
(
"seq_relationship_scores",
["batch", 2],
),
],
}
return new_model_desc
def create_pretraining_dataset(input_file, max_pred_length, args):
train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(
train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=0, pin_memory=True
)
return train_dataloader, input_file
class pretraining_dataset(Dataset): # noqa: N801
def __init__(self, input_file, max_pred_length):
logger.info("pretraining_dataset: %s, max_pred_length: %d", input_file, max_pred_length)
self.input_file = input_file
self.max_pred_length = max_pred_length
f = h5py.File(input_file, "r")
keys = [
"input_ids",
"input_mask",
"segment_ids",
"masked_lm_positions",
"masked_lm_ids",
"next_sentence_labels",
]
self.inputs = [np.asarray(f[key][:]) for key in keys]
f.close()
def __len__(self):
"Denotes the total number of samples"
return len(self.inputs[0])
def __getitem__(self, index):
[input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [
torch.from_numpy(input[index].astype(np.int64))
if indice < 5
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
for indice, input in enumerate(self.inputs)
]
# HF model use default ignore_index value (-100) for CrossEntropyLoss
masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -100
index = self.max_pred_length
# store number of masked tokens in index
padded_mask_indices = (masked_lm_positions == 0).nonzero()
if len(padded_mask_indices) != 0:
index = padded_mask_indices[0].item()
masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index]
return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels]
import argparse # noqa: E402
def parse_arguments():
parser = argparse.ArgumentParser()
# batch size test config parameters
parser.add_argument(
"--enable_mixed_precision",
default=False,
action="store_true",
help="Whether to use 16-bit float precision instead of 32-bit",
)
parser.add_argument(
"--sequence_length",
default=512,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.",
)
parser.add_argument(
"--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence"
)
parser.add_argument("--max_batch_size", default=32, type=int, help="Total batch size for training.")
parser.add_argument("--gelu_recompute", default=False, action="store_true")
parser.add_argument("--attn_dropout_recompute", default=False, action="store_true")
parser.add_argument("--transformer_layer_recompute", default=False, action="store_true")
args = parser.parse_args()
return args
@dataclass
class PretrainArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
input_dir: str = field(
default=None, metadata={"help": "The input data dir. Should contain .hdf5 files for the task"}
)
bert_model: str = field(
default=None,
metadata={
"help": "Bert pre-trained model selected in the list: bert-base-uncased, \
bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
},
)
output_dir: str = field(
default=None, metadata={"help": "The output directory where the model checkpoints will be written."}
)
cache_dir: str = field(
default="/tmp/bert_pretrain/",
metadata={"help": "The output directory where the model checkpoints will be written."},
)
max_seq_length: Optional[int] = field(
default=512,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer \
than this will be truncated, sequences shorter will be padded."
},
)
max_predictions_per_seq: Optional[int] = field(
default=80, metadata={"help": "The maximum total of masked tokens in input sequence."}
)
train_batch_size: Optional[int] = field(default=32, metadata={"help": "Batch size for training."})
learning_rate: Optional[float] = field(default=5e-5, metadata={"help": "The initial learning rate for Lamb."})
num_train_epochs: Optional[float] = field(
default=3.0, metadata={"help": "Total number of training epochs to perform."}
)
max_steps: Optional[float] = field(default=1000, metadata={"help": "Total number of training steps to perform."})
warmup_proportion: Optional[float] = field(
default=0.01,
metadata={
"help": "Proportion of training to perform linear learning rate warmup for. \
E.g., 0.1 = 10%% of training."
},
)
local_rank: Optional[int] = field(default=-1, metadata={"help": "local_rank for distributed training on gpus."})
world_rank: Optional[int] = field(default=-1)
world_size: Optional[int] = field(default=1)
seed: Optional[int] = field(default=42, metadata={"help": "random seed for initialization."})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "Number of updates steps to accumualte before performing a backward/update pass."}
)
fp16: bool = field(default=False, metadata={"help": "Whether to use 16-bit float precision instead of 32-bit."})
gelu_recompute: bool = field(
default=False, metadata={"help": "Whether to enable recomputing Gelu activation output to save memory."}
)
attn_dropout_recompute: bool = field(
default=False, metadata={"help": "Whether to enable recomputing attention dropout to save memory."}
)
transformer_layer_recompute: bool = field(
default=False, metadata={"help": "Whether to enable recomputing transformer layerwise to save memory."}
)
loss_scale: Optional[float] = field(
default=0.0, metadata={"help": "Loss scaling, positive power of 2 values can improve fp16 convergence."}
)
deepspeed_zero_stage: Optional[int] = field(default=0, metadata={"help": "Deepspeed Zero Stage. 0 => disabled"})
log_freq: Optional[float] = field(default=1.0, metadata={"help": "frequency of logging loss."})
checkpoint_activations: bool = field(default=False, metadata={"help": "Whether to use gradient checkpointing."})
resume_from_checkpoint: bool = field(
default=False, metadata={"help": "Whether to resume training from checkpoint."}
)
resume_step: Optional[int] = field(default=-1, metadata={"help": "Step to resume training from."})
num_steps_per_checkpoint: Optional[int] = field(
default=100, metadata={"help": "Number of update steps until a model checkpoint is saved to disk."}
)
save_checkpoint: Optional[bool] = field(
default=False, metadata={"help": "Enable for saving a model checkpoint to disk."}
)
init_state_dict: Optional[dict] = field(default=None, metadata={"help": "State to load before training."})
phase2: bool = field(default=False, metadata={"help": "Whether to train with seq len 512."})
allreduce_post_accumulation: bool = field(
default=False, metadata={"help": "Whether to do allreduces during gradient accumulation steps."}
)
allreduce_post_accumulation_fp16: bool = field(
default=False, metadata={"help": "Whether to do fp16 allreduce post accumulation."}
)
accumulate_into_fp16: bool = field(default=False, metadata={"help": "Whether to use fp16 gradient accumulators."})
phase1_end_step: Optional[int] = field(
default=7038, metadata={"help": "Whether to use fp16 gradient accumulators."}
)
tensorboard_dir: Optional[str] = field(
default=None,
)
schedule: Optional[str] = field(
default="warmup_poly",
)
# this argument is test specific. to run a full bert model will take too long to run. instead, we reduce
# number of hidden layers so that it can show convergence to an extend to help detect any regression.
force_num_hidden_layers: Optional[int] = field(
default=None, metadata={"help": "Whether to use fp16 gradient accumulators."}
)
def to_json_string(self):
"""
Serializes this instance to a JSON string.
"""
return json.dumps(dataclasses.asdict(self), indent=2)
def to_sanitized_dict(self) -> Dict[str, Any]:
"""
Sanitized serialization to use with TensorBoard`s hparams
"""
d = dataclasses.asdict(self)
valid_types = [bool, int, float, str, torch.Tensor]
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
def setup_training(args):
assert torch.cuda.is_available()
if args.local_rank == -1:
args.local_rank = 0
args.world_rank = 0
print("args.local_rank: ", args.local_rank)
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
args.n_gpu = 1
if args.gradient_accumulation_steps < 1:
raise ValueError(
f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1"
)
if args.train_batch_size % args.gradient_accumulation_steps != 0:
raise ValueError(
"Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format(
args.gradient_accumulation_steps, args.train_batch_size
)
)
# args.train_batch_size is per global step (optimization step) batch size
# now make it a per gpu batch size
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
args.train_batch_size = args.train_batch_size // args.world_size
logger.info("setup_training: args.train_batch_size = %d", args.train_batch_size)
return device, args
def setup_torch_distributed(world_rank, world_size):
os.environ["RANK"] = str(world_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=world_rank)
return
def prepare_model(args, device):
config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir)
# config.num_hidden_layers = 12
if args.force_num_hidden_layers:
logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers)
config.num_hidden_layers = args.force_num_hidden_layers
model = BertForPreTraining(config)
if args.init_state_dict is not None:
model.load_state_dict(args.init_state_dict)
model_desc = bert_model_description(config)
lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion)
loss_scaler = amp.DynamicLossScaler() if args.fp16 else None
options = orttrainer.ORTTrainerOptions(
{
"batch": {"gradient_accumulation_steps": args.gradient_accumulation_steps},
"device": {"id": str(device)},
"mixed_precision": {"enabled": args.fp16, "loss_scaler": loss_scaler},
"graph_transformer": {
"attn_dropout_recompute": args.attn_dropout_recompute,
"gelu_recompute": args.gelu_recompute,
"transformer_layer_recompute": args.transformer_layer_recompute,
},
"debug": {
"deterministic_compute": True,
},
"utils": {"grad_norm_clip": True},
"distributed": {
"world_rank": max(0, args.local_rank),
"world_size": args.world_size,
"local_rank": max(0, args.local_rank),
"allreduce_post_accumulation": args.allreduce_post_accumulation,
"deepspeed_zero_optimization": {"stage": args.deepspeed_zero_stage},
"enable_adasum": False,
},
"lr_scheduler": lr_scheduler,
}
)
param_optimizer = list(model.named_parameters())
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
params = [
{
"params": [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)],
"alpha": 0.9,
"beta": 0.999,
"lambda": 0.0,
"epsilon": 1e-6,
},
{
"params": [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)],
"alpha": 0.9,
"beta": 0.999,
"lambda": 0.0,
"epsilon": 1e-6,
},
]
optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options)
return model
def get_data_file(f_id, world_rank, world_size, files):
num_files = len(files)
if world_size > num_files:
remainder = world_size % num_files
return files[(f_id * world_size + world_rank + remainder * f_id) % num_files]
elif world_size > 1:
return files[(f_id * world_size + world_rank) % num_files]
else:
return files[f_id % num_files]
def main():
parser = HfArgumentParser(PretrainArguments)
args = parser.parse_args_into_dataclasses()[0]
do_pretrain(args)
def do_pretrain(args):
if is_main_process(args) and args.tensorboard_dir:
tb_writer = SummaryWriter(log_dir=args.tensorboard_dir)
tb_writer.add_text("args", args.to_json_string())
tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
else:
tb_writer = None
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
ort.set_seed(args.seed)
device, args = setup_training(args)
model = prepare_model(args, device)
logger.info("Running training: Batch size = %d, initial LR = %f", args.train_batch_size, args.learning_rate)
average_loss = 0.0
epoch = 0
training_steps = 0
pool = ProcessPoolExecutor(1)
while True:
files = [
os.path.join(args.input_dir, f)
for f in os.listdir(args.input_dir)
if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in f
]
files.sort()
random.shuffle(files)
f_id = 0
train_dataloader, data_file = create_pretraining_dataset(
get_data_file(f_id, args.world_rank, args.world_size, files), args.max_predictions_per_seq, args
)
for f_id in range(1, len(files)):
logger.info("data file %s" % (data_file))
dataset_future = pool.submit(
create_pretraining_dataset,
get_data_file(f_id, args.world_rank, args.world_size, files),
args.max_predictions_per_seq,
args,
)
train_iter = tqdm(train_dataloader, desc="Iteration") if is_main_process(args) else train_dataloader
for _step, batch in enumerate(train_iter):
training_steps += 1
batch = [t.to(device) for t in batch] # noqa: PLW2901
input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
loss, _, _ = model.train_step(
input_ids, input_mask, segment_ids, masked_lm_labels, next_sentence_labels
)
average_loss += loss.item()
global_step = model._train_step_info.optimization_step
if training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
if is_main_process(args):
divisor = args.log_freq * args.gradient_accumulation_steps
if tb_writer:
lr = model.options.lr_scheduler.get_last_lr()[0]
tb_writer.add_scalar("train/summary/scalar/Learning_Rate", lr, global_step)
if args.fp16:
tb_writer.add_scalar("train/summary/scalar/loss_scale_25", loss, global_step)
# TODO: ORTTrainer to expose all_finite
# tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step)
tb_writer.add_scalar("train/summary/total_loss", average_loss / divisor, global_step)
print(f"Step:{global_step} Average Loss = {average_loss / divisor}")
if global_step >= args.max_steps or global_step >= force_to_stop_max_steps:
if tb_writer:
tb_writer.close()
if global_step >= args.max_steps:
if args.save_checkpoint:
model.save_checkpoint(os.path.join(args.output_dir, f"checkpoint-{args.world_rank}.ortcp"))
final_loss = average_loss / (args.log_freq * args.gradient_accumulation_steps)
return final_loss
average_loss = 0
del train_dataloader
train_dataloader, data_file = dataset_future.result(timeout=None)
epoch += 1
def generate_tensorboard_logdir(root_dir):
current_date_time = datetime.datetime.today()
dt_string = current_date_time.strftime("BERT_pretrain_%y_%m_%d_%I_%M_%S")
return os.path.join(root_dir, dt_string)
class ORTBertPretrainTest(unittest.TestCase):
def setUp(self):
self.output_dir = "/bert_data/hf_data/test_out/bert_pretrain_results"
self.bert_model = "bert-base-uncased"
self.local_rank = -1
self.world_rank = -1
self.world_size = 1
self.max_steps = 300000
self.learning_rate = 5e-4
self.max_seq_length = 512
self.max_predictions_per_seq = 20
self.input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train"
self.train_batch_size = 4096
self.gradient_accumulation_steps = 64
self.fp16 = True
self.allreduce_post_accumulation = True
self.tensorboard_dir = "/bert_data/hf_data/test_out"
def test_pretrain_throughput(self, process_args=None):
if process_args.sequence_length == 128:
input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train"
else:
input_dir = "/bert_data/hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train"
print("process_args.enable_mixed_precision: ", process_args.enable_mixed_precision)
print("process_args.sequence_length: ", process_args.sequence_length)
print("process_args.max_batch_size: ", process_args.max_batch_size)
print("process_args.max_predictions_per_seq: ", process_args.max_predictions_per_seq)
print("process_args.gelu_recompute: ", process_args.gelu_recompute)
print("process_args.attn_dropout_recompute: ", process_args.attn_dropout_recompute)
print("process_args.transformer_layer_recompute: ", process_args.transformer_layer_recompute)
args = PretrainArguments(
input_dir=input_dir,
output_dir="/bert_data/hf_data/test_out/bert_pretrain_results",
bert_model="bert-large-uncased",
local_rank=self.local_rank,
world_rank=self.world_rank,
world_size=self.world_size,
max_steps=10,
learning_rate=5e-4,
max_seq_length=process_args.sequence_length,
max_predictions_per_seq=process_args.max_predictions_per_seq,
train_batch_size=process_args.max_batch_size,
gradient_accumulation_steps=1,
fp16=process_args.enable_mixed_precision,
gelu_recompute=process_args.gelu_recompute,
attn_dropout_recompute=process_args.attn_dropout_recompute,
transformer_layer_recompute=process_args.transformer_layer_recompute,
allreduce_post_accumulation=True,
# TODO: remove
force_num_hidden_layers=2,
)
do_pretrain(args)
def test_pretrain_convergence(self):
args = PretrainArguments(
output_dir=self.output_dir,
bert_model=self.bert_model,
local_rank=self.local_rank,
world_rank=self.world_rank,
world_size=self.world_size,
max_steps=self.max_steps,
learning_rate=self.learning_rate,
max_seq_length=self.max_seq_length,
max_predictions_per_seq=self.max_predictions_per_seq,
train_batch_size=self.train_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
input_dir=self.input_dir,
fp16=self.fp16,
allreduce_post_accumulation=self.allreduce_post_accumulation,
force_num_hidden_layers=self.force_num_hidden_layers,
tensorboard_dir=generate_tensorboard_logdir("/bert_data/hf_data/test_out/"),
)
final_loss = do_pretrain(args)
return final_loss
def test_pretrain_zero(self):
assert self.world_size > 0, "ZeRO test requires a distributed run."
setup_torch_distributed(self.world_rank, self.world_size)
per_gpu_batch_size = 32
optimization_batch_size = per_gpu_batch_size * self.world_size # set to disable grad accumulation
self.train_batch_size = optimization_batch_size
self.gradient_accumulation_steps = 1
self.deepspeed_zero_stage = 1
self.force_num_hidden_layers = 2
self.max_seq_length = 32
self.output_dir = "./bert_pretrain_ckpt"
if self.world_rank == 0:
if os.path.isdir(self.output_dir):
shutil.rmtree(self.output_dir)
os.makedirs(self.output_dir, exist_ok=True)
torch.distributed.barrier()
assert os.path.exists(self.output_dir)
# run a few optimization steps
self.max_steps = 200
args = PretrainArguments(
output_dir=self.output_dir,
bert_model=self.bert_model,
local_rank=self.local_rank,
world_rank=self.world_rank,
world_size=self.world_size,
max_steps=self.max_steps,
learning_rate=self.learning_rate,
max_seq_length=self.max_seq_length,
max_predictions_per_seq=self.max_predictions_per_seq,
train_batch_size=self.train_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
input_dir=self.input_dir,
fp16=self.fp16,
allreduce_post_accumulation=self.allreduce_post_accumulation,
force_num_hidden_layers=self.force_num_hidden_layers,
deepspeed_zero_stage=self.deepspeed_zero_stage,
save_checkpoint=True,
)
do_pretrain(args)
# ensure all workers reach this point before loading the checkpointed state
torch.distributed.barrier()
# on rank 0, load the trained state
if args.world_rank == 0:
checkpoint_files = glob.glob(os.path.join(self.output_dir, "checkpoint*.ortcp"))
args.init_state_dict = aggregate_checkpoints(checkpoint_files, pytorch_format=True)
torch.distributed.barrier()
# run a single step to get the loss, on rank 0 should be lesser than starting loss
args.save_checkpoint = False
args.max_steps = 1
args.deepspeed_zero_stage = 0
final_loss = do_pretrain(args)
return final_loss
if __name__ == "__main__":
import sys
logger.warning("sys.argv: %s", sys.argv)
# usage:
# data parallel training
# mpirun -n 4 python orttraining_run_bert_pretrain.py
#
# single gpu:
# python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput
# [batch size test arguments]
# python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_convergence
#
# pytorch.distributed.launch will not work because ort backend requires MPI to broadcast ncclUniqueId
# calling unpublished get_mpi_context_xxx to get rank/size numbers.
try:
# In case ORT is not built with MPI/NCCL, there are no get_mpi_context_xxx internal apis.
from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401
from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401
from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size
has_get_mpi_context_internal_api = True
except ImportError:
has_get_mpi_context_internal_api = False
pass
if has_get_mpi_context_internal_api and get_mpi_context_world_size() > 1:
world_size = get_mpi_context_world_size()
print("get_mpi_context_world_size(): ", world_size)
local_rank = get_mpi_context_local_rank()
if local_rank == 0:
print("================================================================> os.getpid() = ", os.getpid())
test = ORTBertPretrainTest()
test.setUp()
test.local_rank = local_rank
test.world_rank = local_rank
test.world_size = world_size
if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_zero":
logger.info("running ORTBertPretrainTest.test_pretrain_zero()...")
final_loss = test.test_pretrain_zero()
logger.info("ORTBertPretrainTest.test_pretrain_zero() rank = %i final loss = %f", local_rank, final_loss)
if local_rank == 0:
test.assertLess(final_loss, 10.2)
else:
test.assertGreater(final_loss, 11.0)
logger.info("ORTBertPretrainTest.test_pretrain_zero() passed")
elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence":
logger.info("running ORTBertPretrainTest.test_pretrain_convergence()...")
test.max_steps = 200
test.force_num_hidden_layers = 8
final_loss = test.test_pretrain_convergence()
logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss)
test.assertLess(final_loss, 8.5)
logger.info("ORTBertPretrainTest.test_pretrain_convergence() passed")
else:
# https://microsoft.sharepoint.com/teams/ONNX2/_layouts/15/Doc.aspx?sourcedoc={170774be-e1c6-4f8b-a3ae-984f211fe410}&action=edit&wd=target%28ONNX%20Training.one%7C8176133b-c7cb-4ef2-aa9d-3fdad5344c40%2FGitHub%20Master%20Merge%20Schedule%7Cb67f0db1-e3a0-4add-80a6-621d67fd8107%2F%29
# to make equivalent args for cpp convergence test
test.max_seq_length = 128
test.max_predictions_per_seq = 20
test.gradient_accumulation_steps = 16
# cpp_batch_size (=64) * grad_acc * world_size
test.train_batch_size = 64 * test.gradient_accumulation_steps * test.world_size
test.max_steps = 300000
test.force_num_hidden_layers = None
# already using Adam (e.g. AdamConfig)
test.learning_rate = 5e-4
test.warmup_proportion = 0.1
final_loss = test.test_pretrain_convergence()
logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss)
else:
# unittest does not accept user defined arguments
# we need to run this script with user defined arguments
if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_throughput":
run_test_pretrain_throughput, run_test_pretrain_convergence = True, False
sys.argv.remove("ORTBertPretrainTest.test_pretrain_throughput")
elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence":
run_test_pretrain_throughput, run_test_pretrain_convergence = False, True
sys.argv.remove("ORTBertPretrainTest.test_pretrain_convergence")
else:
run_test_pretrain_throughput, run_test_pretrain_convergence = True, True
process_args = parse_arguments()
test = ORTBertPretrainTest()
test.setUp()
if run_test_pretrain_throughput:
logger.info("running single GPU ORTBertPretrainTest.test_pretrain_throughput()...")
test.test_pretrain_throughput(process_args)
logger.info("single GPU ORTBertPretrainTest.test_pretrain_throughput() passed")
# unittest.main()

Просмотреть файл

@ -1,67 +0,0 @@
import collections
import subprocess
import sys
Config = collections.namedtuple(
"Config",
[
"enable_mixed_precision",
"sequence_length",
"max_batch_size",
"max_predictions_per_seq",
"gelu_recompute",
"attn_dropout_recompute",
"transformer_layer_recompute",
],
)
configs = [
Config(True, 128, 46, 20, False, False, False),
Config(True, 512, 8, 80, False, False, False),
Config(False, 128, 26, 20, False, False, False),
Config(False, 512, 4, 80, False, False, False),
Config(True, 128, 50, 20, True, False, False),
Config(True, 128, 50, 20, False, True, False),
Config(True, 128, 76, 20, False, False, True),
Config(True, 512, 8, 80, True, False, False),
Config(True, 512, 9, 80, False, True, False),
Config(True, 512, 15, 80, False, False, True),
]
def run_with_config(config):
print(
"##### testing name - {}-{} #####".format(
"fp16" if config.enable_mixed_precision else "fp32", config.sequence_length
)
)
print("gelu_recompute: ", config.gelu_recompute)
print("attn_dropout_recompute: ", config.attn_dropout_recompute)
print("transformer_layer_recompute: ", config.transformer_layer_recompute)
cmds = [
sys.executable,
"orttraining_run_bert_pretrain.py",
"ORTBertPretrainTest.test_pretrain_throughput",
"--sequence_length",
str(config.sequence_length),
"--max_batch_size",
str(config.max_batch_size),
"--max_predictions_per_seq",
str(config.max_predictions_per_seq),
]
if config.enable_mixed_precision:
cmds.append("--enable_mixed_precision")
if config.gelu_recompute:
cmds.append("--gelu_recompute")
if config.attn_dropout_recompute:
cmds.append("--attn_dropout_recompute")
if config.transformer_layer_recompute:
cmds.append("--transformer_layer_recompute")
# access to azure storage shared disk is much slower so we need a longer timeout.
subprocess.run(cmds, timeout=1200).check_returncode() # noqa: PLW1510
for config in configs:
run_with_config(config)

Просмотреть файл

@ -1,323 +0,0 @@
# adapted from run_glue.py of huggingface transformers
import dataclasses # noqa: F401
import logging
import os
import unittest
from dataclasses import dataclass, field
from typing import Dict, Optional
import numpy as np
from numpy.testing import assert_allclose
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
EvalPrediction,
GlueDataset,
GlueDataTrainingArguments,
TrainingArguments,
glue_compute_metrics,
glue_output_modes,
glue_tasks_num_labels,
set_seed,
)
import onnxruntime
from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401
try:
from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401
from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401
from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size
has_get_mpi_context_internal_api = True
except ImportError:
has_get_mpi_context_internal_api = False
pass
import torch # noqa: F401
from orttraining_transformer_trainer import ORTTransformerTrainer
logger = logging.getLogger(__name__)
def verify_old_and_new_api_are_equal(results_per_api):
new_api_results = results_per_api[True]
old_api_results = results_per_api[False]
for key in new_api_results:
assert_allclose(new_api_results[key], old_api_results[key])
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"})
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
class ORTGlueTest(unittest.TestCase):
def setUp(self):
# configurations not to be changed accoss tests
self.max_seq_length = 128
self.train_batch_size = 8
self.learning_rate = 2e-5
self.num_train_epochs = 3.0
self.local_rank = -1
self.world_size = 1
self.overwrite_output_dir = True
self.gradient_accumulation_steps = 1
self.data_dir = "/bert_data/hf_data/glue_data/"
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "glue_test_output/")
self.cache_dir = "/tmp/glue/"
self.logging_steps = 10
def test_roberta_with_mrpc(self):
expected_acc = 0.85
expected_f1 = 0.88
expected_loss = 0.35
results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=False)
assert results["acc"] >= expected_acc
assert results["f1"] >= expected_f1
assert results["loss"] <= expected_loss
def test_roberta_fp16_with_mrpc(self):
expected_acc = 0.87
expected_f1 = 0.90
expected_loss = 0.33
results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=True)
assert results["acc"] >= expected_acc
assert results["f1"] >= expected_f1
assert results["loss"] <= expected_loss
def test_bert_with_mrpc(self):
if self.local_rank == -1:
expected_acc = 0.83
expected_f1 = 0.88
expected_loss = 0.44
elif self.local_rank == 0:
expected_acc = 0.81
expected_f1 = 0.86
expected_loss = 0.44
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False)
if self.local_rank in [-1, 0]:
assert results["acc"] >= expected_acc
assert results["f1"] >= expected_f1
assert results["loss"] <= expected_loss
def test_bert_fp16_with_mrpc(self):
expected_acc = 0.84
expected_f1 = 0.88
expected_loss = 0.46
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=True)
assert results["acc"] >= expected_acc
assert results["f1"] >= expected_f1
assert results["loss"] <= expected_loss
def model_to_desc(self, model_name, model):
if model_name.startswith("bert") or model_name.startswith("xlnet"):
model_desc = {
"inputs": [
(
"input_ids",
["batch", "max_seq_len_in_batch"],
),
(
"attention_mask",
["batch", "max_seq_len_in_batch"],
),
(
"token_type_ids",
["batch", "max_seq_len_in_batch"],
),
(
"labels",
[
"batch",
],
),
],
"outputs": [("loss", [], True), ("logits", ["batch", 2])],
}
elif model_name.startswith("roberta"):
model_desc = {
"inputs": [
(
"input_ids",
["batch", "max_seq_len_in_batch"],
),
(
"attention_mask",
["batch", "max_seq_len_in_batch"],
),
(
"labels",
[
"batch",
],
),
],
"outputs": [("loss", [], True), ("logits", ["batch", 2])],
}
else:
raise RuntimeError(f"unsupported base model name {model_name}.")
return model_desc
def run_glue(self, model_name, task_name, fp16):
model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir)
data_args = GlueDataTrainingArguments(
task_name=task_name, data_dir=os.path.join(self.data_dir, task_name), max_seq_length=self.max_seq_length
)
training_args = TrainingArguments(
output_dir=os.path.join(self.output_dir, task_name),
do_train=True,
do_eval=True,
per_gpu_train_batch_size=self.train_batch_size,
learning_rate=self.learning_rate,
num_train_epochs=self.num_train_epochs,
local_rank=self.local_rank,
overwrite_output_dir=self.overwrite_output_dir,
gradient_accumulation_steps=self.gradient_accumulation_steps,
fp16=fp16,
logging_steps=self.logging_steps,
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
set_seed(training_args.seed)
onnxruntime.set_seed(training_args.seed)
try:
num_labels = glue_tasks_num_labels[data_args.task_name]
output_mode = glue_output_modes[data_args.task_name]
except KeyError:
raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None
def compute_metrics(p: EvalPrediction) -> Dict:
if output_mode == "classification":
preds = np.argmax(p.predictions, axis=1)
elif output_mode == "regression":
preds = np.squeeze(p.predictions)
return glue_compute_metrics(data_args.task_name, preds, p.label_ids)
model_desc = self.model_to_desc(model_name, model)
# Initialize the ORTTrainer within ORTTransformerTrainer
trainer = ORTTransformerTrainer(
model=model,
model_desc=model_desc,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
world_size=self.world_size,
)
# Training
if training_args.do_train:
trainer.train()
trainer.save_model()
# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
logger.info("*** Evaluate ***")
result = trainer.evaluate()
logger.info(f"***** Eval results {data_args.task_name} *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
results.update(result)
return results
if __name__ == "__main__":
if has_get_mpi_context_internal_api:
local_rank = get_mpi_context_local_rank()
world_size = get_mpi_context_world_size()
else:
local_rank = -1
world_size = 1
if world_size > 1:
# mpi launch
logger.warning("mpirun launch, local_rank / world_size: %s : % s", local_rank, world_size)
# TrainingArguments._setup_devices will call torch.distributed.init_process_group(backend="nccl")
# pytorch expects following environment settings (which would be set if launched with torch.distributed.launch).
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
from onnxruntime.capi._pybind_state import set_cuda_device_id
set_cuda_device_id(local_rank)
test = ORTGlueTest()
test.setUp()
test.local_rank = local_rank
test.world_size = world_size
test.test_bert_with_mrpc()
else:
unittest.main()

Просмотреть файл

@ -1,281 +0,0 @@
# adapted from run_multiple_choice.py of huggingface transformers
# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/run_multiple_choice.py
import dataclasses # noqa: F401
import logging
import os
import unittest
from dataclasses import dataclass, field
from typing import Dict, Optional
import numpy as np
import torch # noqa: F401
from numpy.testing import assert_allclose # noqa: F401
from orttraining_run_glue import verify_old_and_new_api_are_equal # noqa: F401
from orttraining_transformer_trainer import ORTTransformerTrainer
from transformers import HfArgumentParser # noqa: F401
from transformers import Trainer # noqa: F401
from transformers import (
AutoConfig,
AutoModelForMultipleChoice,
AutoTokenizer,
EvalPrediction,
TrainingArguments,
set_seed,
)
from utils_multiple_choice import MultipleChoiceDataset, Split, SwagProcessor
import onnxruntime
from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401
logger = logging.getLogger(__name__)
def simple_accuracy(preds, labels):
return (preds == labels).mean()
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"})
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
task_name: str = field(metadata={"help": "The name of the task to train on."})
data_dir: str = field(metadata={"help": "Should contain the data files for the task."})
max_seq_length: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"})
class ORTMultipleChoiceTest(unittest.TestCase):
def setUp(self):
# configurations not to be changed accoss tests
self.max_seq_length = 80
self.train_batch_size = 16
self.eval_batch_size = 2
self.learning_rate = 2e-5
self.num_train_epochs = 1.0
self.local_rank = -1
self.overwrite_output_dir = True
self.gradient_accumulation_steps = 8
self.data_dir = "/bert_data/hf_data/swag/swagaf/data"
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "multiple_choice_test_output/")
self.cache_dir = "/tmp/multiple_choice/"
self.logging_steps = 10
self.rtol = 2e-01
def test_bert_with_swag(self):
expected_acc = 0.75
expected_loss = 0.64
results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=False)
assert results["acc"] >= expected_acc
assert results["loss"] <= expected_loss
def test_bert_fp16_with_swag(self):
# larger batch can be handled with mixed precision
self.train_batch_size = 32
expected_acc = 0.73
expected_loss = 0.68
results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=True)
assert results["acc"] >= expected_acc
assert results["loss"] <= expected_loss
def run_multiple_choice(self, model_name, task_name, fp16):
model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir)
data_args = DataTrainingArguments(
task_name=task_name, data_dir=self.data_dir, max_seq_length=self.max_seq_length
)
training_args = TrainingArguments(
output_dir=os.path.join(self.output_dir, task_name),
do_train=True,
do_eval=True,
per_gpu_train_batch_size=self.train_batch_size,
per_gpu_eval_batch_size=self.eval_batch_size,
learning_rate=self.learning_rate,
num_train_epochs=self.num_train_epochs,
local_rank=self.local_rank,
overwrite_output_dir=self.overwrite_output_dir,
gradient_accumulation_steps=self.gradient_accumulation_steps,
fp16=fp16,
logging_steps=self.logging_steps,
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
set_seed(training_args.seed)
onnxruntime.set_seed(training_args.seed)
try:
processor = SwagProcessor()
label_list = processor.get_labels()
num_labels = len(label_list)
except KeyError:
raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
model = AutoModelForMultipleChoice.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
# Get datasets
train_dataset = (
MultipleChoiceDataset(
data_dir=data_args.data_dir,
tokenizer=tokenizer,
task=data_args.task_name,
processor=processor,
max_seq_length=data_args.max_seq_length,
overwrite_cache=data_args.overwrite_cache,
mode=Split.train,
)
if training_args.do_train
else None
)
eval_dataset = (
MultipleChoiceDataset(
data_dir=data_args.data_dir,
tokenizer=tokenizer,
task=data_args.task_name,
processor=processor,
max_seq_length=data_args.max_seq_length,
overwrite_cache=data_args.overwrite_cache,
mode=Split.dev,
)
if training_args.do_eval
else None
)
def compute_metrics(p: EvalPrediction) -> Dict:
preds = np.argmax(p.predictions, axis=1)
return {"acc": simple_accuracy(preds, p.label_ids)}
if model_name.startswith("bert"):
model_desc = {
"inputs": [
(
"input_ids",
["batch", num_labels, "max_seq_len_in_batch"],
),
(
"attention_mask",
["batch", num_labels, "max_seq_len_in_batch"],
),
(
"token_type_ids",
["batch", num_labels, "max_seq_len_in_batch"],
),
(
"labels",
["batch", num_labels],
),
],
"outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])],
}
else:
model_desc = {
"inputs": [
(
"input_ids",
["batch", num_labels, "max_seq_len_in_batch"],
),
(
"attention_mask",
["batch", num_labels, "max_seq_len_in_batch"],
),
(
"labels",
["batch", num_labels],
),
],
"outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])],
}
# Initialize the ORTTrainer within ORTTransformerTrainer
trainer = ORTTransformerTrainer(
model=model,
model_desc=model_desc,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
)
# Training
if training_args.do_train:
trainer.train()
trainer.save_model()
# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
logger.info("*** Evaluate ***")
result = trainer.evaluate()
logger.info(f"***** Eval results {data_args.task_name} *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
results.update(result)
return results
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -1,6 +0,0 @@
from orttraining_test_layer_norm_transform import layer_norm_transform # noqa: F401
from orttraining_test_model_transform import add_expand_shape, add_name, fix_transpose # noqa: F401
def postprocess_model(model):
add_name(model)

Просмотреть файл

@ -1,257 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# orttraining_test_checkpoint_storage.py
import os
import pickle
import shutil
import numpy as np
import pytest
import torch
from onnxruntime.training import _checkpoint_storage
# Helper functions
def _equals(a, b):
"""Checks recursively if two dictionaries are equal"""
if isinstance(a, dict):
return all(not (key not in b or not _equals(a[key], b[key])) for key in a)
else:
if isinstance(a, bytes):
a = a.decode()
if isinstance(b, bytes):
b = b.decode()
are_equal = a == b
return are_equal if isinstance(are_equal, bool) else are_equal.all()
return False
def _numpy_types(obj_value):
"""Return a bool indicating whether or not the input obj_value is a numpy type object
Recursively checks if the obj_value (could be a dictionary) is a numpy type object.
Exceptions are str and bytes.
Returns true if object is numpy type, str, or bytes
False if any other type
"""
if not isinstance(obj_value, dict):
return isinstance(obj_value, (str, bytes)) or type(obj_value).__module__ == np.__name__
return all(_numpy_types(value) for _, value in obj_value.items())
def _get_dict(separated_key):
"""Create dummy dictionary with different datatypes
Returns the tuple of the entire dummy dictionary created, key argument as a dictionary for _checkpoint_storage.load
function and the value for that key in the original dictionary
For example the complete dictionary is represented by:
{
'int1':1,
'int2': 2,
'int_list': [1,2,3,5,6],
'dict1': {
'np_array': np.arange(100),
'dict2': {'int3': 3, 'int4': 4},
'str1': "onnxruntime"
},
'bool1': bool(True),
'int5': 5,
'float1': 2.345,
'np_array_float': np.array([1.234, 2.345, 3.456]),
'np_array_float_3_dim': np.array([[[1,2],[3,4]], [[5,6],[7,8]]])
}
if the input key is ['dict1', 'str1'], then the key argument returned is 'dict1/str1'
and the value corresponding to that is "onnxruntime"
so, for the above example, the returned tuple is:
(original_dict, {'key': 'dict1/str1', "onnxruntime")
"""
test_dict = {
"int1": 1,
"int2": 2,
"int_list": [1, 2, 3, 5, 6],
"dict1": {"np_array": np.arange(100), "dict2": {"int3": 3, "int4": 4}, "str1": "onnxruntime"},
"bool1": True,
"int5": 5,
"float1": 2.345,
"np_array_float": np.array([1.234, 2.345, 3.456]),
"np_array_float_3_dim": np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]),
}
key = ""
expected_val = test_dict
for single_key in separated_key:
key += single_key + "/"
expected_val = expected_val[single_key]
return test_dict, {"key": key} if len(separated_key) > 0 else dict(), expected_val
class _CustomClass:
"""Custom object that encpsulates dummy values for loss, epoch and train_step"""
def __init__(self):
self._loss = 1.23
self._epoch = 12000
self._train_step = 25
def __eq__(self, other):
if isinstance(other, _CustomClass):
return self._loss == other._loss and self._epoch == other._epoch and self._train_step == other._train_step
# Test fixtures
@pytest.yield_fixture(scope="function")
def checkpoint_storage_test_setup():
checkpoint_dir = os.path.abspath("checkpoint_dir/")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
pytest.checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ortcp")
yield "checkpoint_storage_test_setup"
shutil.rmtree(checkpoint_dir)
@pytest.yield_fixture(scope="function")
def checkpoint_storage_test_parameterized_setup(request, checkpoint_storage_test_setup):
yield request.param
# Tests
@pytest.mark.parametrize(
"checkpoint_storage_test_parameterized_setup",
[
_get_dict([]),
_get_dict(["int1"]),
_get_dict(["dict1"]),
_get_dict(["dict1", "dict2"]),
_get_dict(["dict1", "dict2", "int4"]),
_get_dict(["dict1", "str1"]),
_get_dict(["bool1"]),
_get_dict(["float1"]),
_get_dict(["np_array_float"]),
],
indirect=True,
)
def test_checkpoint_storage_saved_dict_matches_loaded(checkpoint_storage_test_parameterized_setup):
to_save = checkpoint_storage_test_parameterized_setup[0]
key_arg = checkpoint_storage_test_parameterized_setup[1]
expected = checkpoint_storage_test_parameterized_setup[2]
_checkpoint_storage.save(to_save, pytest.checkpoint_path)
loaded = _checkpoint_storage.load(pytest.checkpoint_path, **key_arg)
assert _equals(loaded, expected)
assert _numpy_types(loaded)
@pytest.mark.parametrize(
"checkpoint_storage_test_parameterized_setup",
[{"int_set": {1, 2, 3, 4, 5}}, {"str_set": {"one", "two"}}, [1, 2, 3], 2.352],
indirect=True,
)
def test_checkpoint_storage_saving_non_supported_types_fails(checkpoint_storage_test_parameterized_setup):
to_save = checkpoint_storage_test_parameterized_setup
with pytest.raises(Exception): # noqa: B017
_checkpoint_storage.save(to_save, pytest.checkpoint_path)
@pytest.mark.parametrize(
"checkpoint_storage_test_parameterized_setup",
[
({"int64_tensor": torch.tensor(np.arange(100))}, "int64_tensor", torch.int64, np.int64),
({"int32_tensor": torch.tensor(np.arange(100), dtype=torch.int32)}, "int32_tensor", torch.int32, np.int32),
({"int16_tensor": torch.tensor(np.arange(100), dtype=torch.int16)}, "int16_tensor", torch.int16, np.int16),
({"int8_tensor": torch.tensor(np.arange(100), dtype=torch.int8)}, "int8_tensor", torch.int8, np.int8),
({"float64_tensor": torch.tensor(np.array([1.0, 2.0]))}, "float64_tensor", torch.float64, np.float64),
(
{"float32_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32)},
"float32_tensor",
torch.float32,
np.float32,
),
(
{"float16_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float16)},
"float16_tensor",
torch.float16,
np.float16,
),
],
indirect=True,
)
def test_checkpoint_storage_saving_tensor_datatype(checkpoint_storage_test_parameterized_setup):
tensor_dict = checkpoint_storage_test_parameterized_setup[0]
tensor_name = checkpoint_storage_test_parameterized_setup[1]
tensor_dtype = checkpoint_storage_test_parameterized_setup[2]
np_dtype = checkpoint_storage_test_parameterized_setup[3]
_checkpoint_storage.save(tensor_dict, pytest.checkpoint_path)
loaded = _checkpoint_storage.load(pytest.checkpoint_path)
assert isinstance(loaded[tensor_name], np.ndarray)
assert tensor_dict[tensor_name].dtype == tensor_dtype
assert loaded[tensor_name].dtype == np_dtype
assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all()
@pytest.mark.parametrize(
"checkpoint_storage_test_parameterized_setup",
[
({"two_dim": torch.ones([2, 4], dtype=torch.float64)}, "two_dim"),
({"three_dim": torch.ones([2, 4, 6], dtype=torch.float64)}, "three_dim"),
({"four_dim": torch.ones([2, 4, 6, 8], dtype=torch.float64)}, "four_dim"),
],
indirect=True,
)
def test_checkpoint_storage_saving_multiple_dimension_tensors(checkpoint_storage_test_parameterized_setup):
tensor_dict = checkpoint_storage_test_parameterized_setup[0]
tensor_name = checkpoint_storage_test_parameterized_setup[1]
_checkpoint_storage.save(tensor_dict, pytest.checkpoint_path)
loaded = _checkpoint_storage.load(pytest.checkpoint_path)
assert isinstance(loaded[tensor_name], np.ndarray)
assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all()
@pytest.mark.parametrize(
"checkpoint_storage_test_parameterized_setup", [{}, {"a": {}}, {"a": {"b": {}}}], indirect=True
)
def test_checkpoint_storage_saving_and_loading_empty_dictionaries_succeeds(checkpoint_storage_test_parameterized_setup):
saved = checkpoint_storage_test_parameterized_setup
_checkpoint_storage.save(saved, pytest.checkpoint_path)
loaded = _checkpoint_storage.load(pytest.checkpoint_path)
assert _equals(saved, loaded)
def test_checkpoint_storage_load_file_that_does_not_exist_fails(checkpoint_storage_test_setup):
with pytest.raises(Exception): # noqa: B017
_checkpoint_storage.load(pytest.checkpoint_path)
def test_checkpoint_storage_for_custom_user_dict_succeeds(checkpoint_storage_test_setup):
custom_class = _CustomClass()
user_dict = {"tensor1": torch.tensor(np.arange(100), dtype=torch.float32), "custom_class": custom_class}
pickled_bytes = pickle.dumps(user_dict).hex()
to_save = {"a": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32), "user_dict": pickled_bytes}
_checkpoint_storage.save(to_save, pytest.checkpoint_path)
loaded_dict = _checkpoint_storage.load(pytest.checkpoint_path)
assert (loaded_dict["a"] == to_save["a"].numpy()).all()
try: # noqa: SIM105
loaded_dict["user_dict"] = loaded_dict["user_dict"].decode()
except AttributeError:
pass
loaded_obj = pickle.loads(bytes.fromhex(loaded_dict["user_dict"]))
assert torch.all(loaded_obj["tensor1"].eq(user_dict["tensor1"]))
assert loaded_obj["custom_class"] == custom_class

Просмотреть файл

@ -4,8 +4,6 @@ from enum import Enum
import torch
from torch.utils.data import DataLoader, Dataset
from onnxruntime.capi.ort_trainer import generate_sample
global_rng = random.Random()
@ -41,6 +39,16 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
def generate_sample(desc, device=None):
"""Generate a sample based on the description"""
# symbolic dimensions are described with strings. set symbolic dimensions to be 1
size = [s if isinstance(s, (int)) else 1 for s in desc.shape_]
if desc.num_classes_:
return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device)
else:
return torch.randn(size, dtype=desc.dtype_).to(device)
class OrtTestDataset(Dataset):
def __init__(self, input_desc, seq_len, dataset_len, device):
import copy

Просмотреть файл

@ -1,40 +0,0 @@
import pytest
import torch
from _test_commons import _load_pytorch_transformer_model
from onnxruntime import set_seed
from onnxruntime.training import optim, orttrainer
###############################################################################
# Testing starts here #########################################################
###############################################################################
@pytest.mark.parametrize(
"seed, device",
[
(24, "cuda"),
],
)
def testORTTransformerModelExport(seed, device):
# Common setup
optim_config = optim.LambConfig()
opts = orttrainer.ORTTrainerOptions(
{
"debug": {
"check_model_export": True,
},
"device": {
"id": device,
},
}
)
# Setup for the first ORTTRainer run
torch.manual_seed(seed)
set_seed(seed)
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
data, targets = batcher_fn(train_data, 0)
_ = first_trainer.train_step(data, targets)
assert first_trainer._onnx_model is not None

Просмотреть файл

@ -27,7 +27,7 @@ def run_training_apis_python_api_tests(cwd, log):
log.debug("Running: ort training api tests")
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_python_bindings.py"]
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_py_bindings.py"]
run_subprocess(command, cwd=cwd, log=log).check_returncode()
@ -37,7 +37,7 @@ def run_onnxblock_tests(cwd, log):
log.debug("Running: onnxblock tests")
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnxblock.py"]
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_onnxblock.py"]
run_subprocess(command, cwd=cwd, log=log).check_returncode()

Просмотреть файл

@ -11,7 +11,7 @@ import numpy as np
import onnx
import pytest
import torch
from orttraining_test_onnxblock import _get_models
from orttraining_test_ort_apis_onnxblock import _get_models
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime import OrtValue, SessionOptions

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,722 +0,0 @@
from unittest.mock import Mock, patch
import numpy as np
import onnx
import pytest
import torch
from _test_commons import _load_pytorch_transformer_model
from onnxruntime.training import _checkpoint_storage, amp, checkpoint, optim, orttrainer # noqa: F401
# Helper functions
def _create_trainer(zero_enabled=False):
"""Cerates a simple ORTTrainer for ORTTrainer functional tests"""
device = "cuda"
optim_config = optim.LambConfig(lr=0.1)
opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}}
if zero_enabled:
opts["distributed"] = {
"world_rank": 0,
"world_size": 1,
"horizontal_parallel_size": 1,
"data_parallel_size": 1,
"allreduce_post_accumulation": True,
"deepspeed_zero_optimization": {"stage": 1},
}
model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
trainer = orttrainer.ORTTrainer(
model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts)
)
return trainer
class _training_session_mock: # noqa: N801
"""Mock object for the ORTTrainer _training_session member"""
def __init__(self, model_states, optimizer_states, partition_info):
self.model_states = model_states
self.optimizer_states = optimizer_states
self.partition_info = partition_info
def get_model_state(self, include_mixed_precision_weights=False):
return self.model_states
def get_optimizer_state(self):
return self.optimizer_states
def get_partition_info_map(self):
return self.partition_info
def _get_load_state_dict_strict_error_arguments():
"""Return a list of tuples that can be used as parameters for test_load_state_dict_errors_when_model_key_missing
Construct a list of tuples (training_session_state_dict, input_state_dict, error_arguments)
The load_state_dict function will compare the two state dicts (training_session_state_dict, input_state_dict) and
throw a runtime error with the missing/unexpected keys. The error arguments capture these missing/unexpected keys.
"""
training_session_state_dict = {
"model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}},
"optimizer": {
"a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
"shared_optimizer_state": {"step": np.arange(5)},
},
}
# input state dictionaries
precision_key_missing = {"model": {}, "optimizer": {}}
precision_key_unexpected = {"model": {"full_precision": {}, "mixed_precision": {}}, "optimizer": {}}
model_state_key_missing = {"model": {"full_precision": {}}, "optimizer": {}}
model_state_key_unexpected = {"model": {"full_precision": {"a": 2, "b": 3, "c": 4}}, "optimizer": {}}
optimizer_model_state_key_missing = {"model": {"full_precision": {"a": 2, "b": 3}}, "optimizer": {}}
optimizer_model_state_key_unexpected = {
"model": {"full_precision": {"a": 2, "b": 3}},
"optimizer": {"a": {}, "shared_optimizer_state": {}, "b": {}},
}
optimizer_state_key_missing = {
"model": {"full_precision": {"a": 2, "b": 3}},
"optimizer": {"a": {}, "shared_optimizer_state": {"step": np.arange(5)}},
}
optimizer_state_key_unexpected = {
"model": {"full_precision": {"a": 2, "b": 3}},
"optimizer": {
"a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
"shared_optimizer_state": {"step": np.arange(5), "another_step": np.arange(1)},
},
}
input_arguments = [
(training_session_state_dict, precision_key_missing, ["full_precision"]),
(training_session_state_dict, precision_key_unexpected, ["mixed_precision"]),
(training_session_state_dict, model_state_key_missing, ["a", "b"]),
(training_session_state_dict, model_state_key_unexpected, ["c"]),
(training_session_state_dict, optimizer_model_state_key_missing, ["a", "shared_optimizer_state"]),
(training_session_state_dict, optimizer_model_state_key_unexpected, ["b"]),
(training_session_state_dict, optimizer_state_key_missing, ["Moment_1", "Moment_2"]),
(training_session_state_dict, optimizer_state_key_unexpected, ["another_step"]),
]
return input_arguments
# Tests
def test_empty_state_dict_when_training_session_uninitialized():
trainer = _create_trainer()
with pytest.warns(UserWarning) as user_warning:
state_dict = trainer.state_dict()
assert len(state_dict.keys()) == 0
assert (
user_warning[0].message.args[0] == "ONNX Runtime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling ORTTrainer.state_dict()."
)
@patch("onnx.ModelProto")
def test_training_session_provides_empty_model_states(onnx_model_mock):
trainer = _create_trainer()
training_session_mock = _training_session_mock({}, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert len(state_dict["model"].keys()) == 0
@patch("onnx.ModelProto")
def test_training_session_provides_model_states(onnx_model_mock):
trainer = _create_trainer()
model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}
training_session_mock = _training_session_mock(model_states, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all()
assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all()
@patch("onnx.ModelProto")
def test_training_session_provides_model_states_pytorch_format(onnx_model_mock):
trainer = _create_trainer()
model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}
training_session_mock = _training_session_mock(model_states, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict(pytorch_format=True)
assert torch.all(torch.eq(state_dict["a"], torch.tensor(np.arange(5))))
assert torch.all(torch.eq(state_dict["b"], torch.tensor(np.arange(7))))
@patch("onnx.ModelProto")
def test_onnx_graph_provides_frozen_model_states(onnx_model_mock):
trainer = _create_trainer()
model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}
training_session_mock = _training_session_mock(model_states, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
trainer.options.utils.frozen_weights = ["a_frozen_weight", "a_float16_weight"]
trainer._onnx_model.graph.initializer = [
onnx.numpy_helper.from_array(np.array([1, 2, 3], dtype=np.float32), "a_frozen_weight"),
onnx.numpy_helper.from_array(np.array([4, 5, 6], dtype=np.float32), "a_non_fronzen_weight"),
onnx.numpy_helper.from_array(np.array([7, 8, 9], dtype=np.float16), "a_float16_weight"),
]
state_dict = trainer.state_dict()
assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all()
assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all()
assert (state_dict["model"]["full_precision"]["a_frozen_weight"] == np.array([1, 2, 3], dtype=np.float32)).all()
assert "a_non_fronzen_weight" not in state_dict["model"]["full_precision"]
assert (state_dict["model"]["full_precision"]["a_float16_weight"] == np.array([7, 8, 9], dtype=np.float32)).all()
@patch("onnx.ModelProto")
def test_training_session_provides_empty_optimizer_states(onnx_model_mock):
trainer = _create_trainer()
training_session_mock = _training_session_mock({}, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert len(state_dict["optimizer"].keys()) == 0
@patch("onnx.ModelProto")
def test_training_session_provides_optimizer_states(onnx_model_mock):
trainer = _create_trainer()
optimizer_states = {
"model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
"shared_optimizer_state": {"step": np.arange(1)},
}
training_session_mock = _training_session_mock({}, optimizer_states, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all()
assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all()
assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all()
@patch("onnx.ModelProto")
def test_training_session_provides_optimizer_states_pytorch_format(onnx_model_mock):
trainer = _create_trainer()
model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}
optimizer_states = {
"model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
"shared_optimizer_state": {"step": np.arange(1)},
}
training_session_mock = _training_session_mock(model_states, optimizer_states, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict(pytorch_format=True)
assert "optimizer" not in state_dict
@patch("onnx.ModelProto")
def test_training_session_provides_empty_partition_info_map(onnx_model_mock):
trainer = _create_trainer(zero_enabled=True)
training_session_mock = _training_session_mock({}, {}, {})
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert len(state_dict["partition_info"].keys()) == 0
@patch("onnx.ModelProto")
def test_training_session_provides_partition_info_map(onnx_model_mock):
trainer = _create_trainer(zero_enabled=True)
partition_info = {"a": {"original_dim": [1, 2, 3]}}
training_session_mock = _training_session_mock({}, {}, partition_info)
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3]
@patch("onnx.ModelProto")
def test_training_session_provides_all_states(onnx_model_mock):
trainer = _create_trainer(zero_enabled=True)
model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}
optimizer_states = {
"model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
"shared_optimizer_state": {"step": np.arange(1)},
}
partition_info = {"a": {"original_dim": [1, 2, 3]}}
training_session_mock = _training_session_mock(model_states, optimizer_states, partition_info)
trainer._training_session = training_session_mock
trainer._onnx_model = onnx_model_mock()
state_dict = trainer.state_dict()
assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all()
assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all()
assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all()
assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all()
assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all()
assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3]
def test_load_state_dict_holds_when_training_session_not_initialized():
trainer = _create_trainer()
state_dict = {
"model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}},
"optimizer": {
"a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
"shared_optimizer_state": {"step": np.arange(5)},
},
}
assert not trainer._load_state_dict
state_dict = trainer.load_state_dict(state_dict)
assert trainer._load_state_dict
@pytest.mark.parametrize(
"state_dict, input_state_dict, error_key",
[
(
{"model": {}, "optimizer": {}},
{"model": {}, "optimizer": {}, "trainer_options": {"optimizer_name": "LambOptimizer"}},
"train_step_info",
),
(
{"optimizer": {}, "train_step_info": {"optimization_step": 0, "step": 0}},
{
"optimizer": {},
"trainer_options": {"optimizer_name": "LambOptimizer"},
"train_step_info": {"optimization_step": 0, "step": 0},
},
"model",
),
(
{"model": {}, "train_step_info": {"optimization_step": 0, "step": 0}},
{
"model": {},
"trainer_options": {"optimizer_name": "LambOptimizer"},
"train_step_info": {"optimization_step": 0, "step": 0},
},
"optimizer",
),
],
)
def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, input_state_dict, error_key):
trainer = _create_trainer()
trainer._training_session = _training_session_mock({}, {}, {})
trainer.state_dict = Mock(return_value=state_dict)
trainer._update_onnx_model_initializers = Mock()
trainer._init_session = Mock()
with patch("onnx.ModelProto") as onnx_model_mock:
trainer._onnx_model = onnx_model_mock()
trainer._onnx_model.graph.initializer = []
with pytest.warns(UserWarning) as user_warning:
trainer.load_state_dict(input_state_dict)
assert user_warning[0].message.args[0] == f"Missing key: {error_key} in state_dict"
@pytest.mark.parametrize("state_dict, input_state_dict, error_keys", _get_load_state_dict_strict_error_arguments())
def test_load_state_dict_errors_when_state_dict_mismatch(state_dict, input_state_dict, error_keys):
trainer = _create_trainer()
trainer._training_session = _training_session_mock({}, {}, {})
trainer.state_dict = Mock(return_value=state_dict)
with pytest.raises(RuntimeError) as runtime_error:
trainer.load_state_dict(input_state_dict)
assert any(key in str(runtime_error.value) for key in error_keys)
@patch("onnx.ModelProto")
def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_mock):
trainer = _create_trainer()
training_session_state_dict = {
"model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}},
"optimizer": {
"a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
"shared_optimizer_state": {"step": np.arange(1)},
},
}
input_state_dict = {
"model": {"full_precision": {"a": np.array([1, 2]), "b": np.array([3, 4])}},
"optimizer": {
"a": {"Moment_1": np.array([5, 6]), "Moment_2": np.array([7, 8])},
"shared_optimizer_state": {"step": np.array([9])},
},
"trainer_options": {"optimizer_name": "LambOptimizer"},
}
trainer._training_session = _training_session_mock({}, {}, {})
trainer.state_dict = Mock(return_value=training_session_state_dict)
trainer._onnx_model = onnx_model_mock()
trainer._onnx_model.graph.initializer = [
onnx.numpy_helper.from_array(np.arange(20, dtype=np.float32), "a"),
onnx.numpy_helper.from_array(np.arange(25, dtype=np.float32), "b"),
]
trainer._update_onnx_model_initializers = Mock()
trainer._init_session = Mock()
trainer.load_state_dict(input_state_dict)
loaded_initializers, _ = trainer._update_onnx_model_initializers.call_args
state_dict_to_load, _ = trainer._init_session.call_args
assert "a" in loaded_initializers[0]
assert (loaded_initializers[0]["a"] == np.array([1, 2])).all()
assert "b" in loaded_initializers[0]
assert (loaded_initializers[0]["b"] == np.array([3, 4])).all()
assert (state_dict_to_load[0]["a"]["Moment_1"] == np.array([5, 6])).all()
assert (state_dict_to_load[0]["a"]["Moment_2"] == np.array([7, 8])).all()
assert (state_dict_to_load[0]["shared_optimizer_state"]["step"] == np.array([9])).all()
@patch("onnxruntime.training._checkpoint_storage.save")
def test_save_checkpoint_calls_checkpoint_storage_save(save_mock):
trainer = _create_trainer()
state_dict = {"model": {}, "optimizer": {}}
trainer.state_dict = Mock(return_value=state_dict)
trainer.save_checkpoint("abc")
save_args, _ = save_mock.call_args
assert "model" in save_args[0]
assert not bool(save_args[0]["model"])
assert "optimizer" in save_args[0]
assert not bool(save_args[0]["optimizer"])
assert save_args[1] == "abc"
@patch("onnxruntime.training._checkpoint_storage.save")
def test_save_checkpoint_exclude_optimizer_states(save_mock):
trainer = _create_trainer()
state_dict = {"model": {}, "optimizer": {}}
trainer.state_dict = Mock(return_value=state_dict)
trainer.save_checkpoint("abc", include_optimizer_states=False)
save_args, _ = save_mock.call_args
assert "model" in save_args[0]
assert not bool(save_args[0]["model"])
assert "optimizer" not in save_args[0]
assert save_args[1] == "abc"
@patch("onnxruntime.training._checkpoint_storage.save")
def test_save_checkpoint_user_dict(save_mock):
trainer = _create_trainer()
state_dict = {"model": {}, "optimizer": {}}
trainer.state_dict = Mock(return_value=state_dict)
trainer.save_checkpoint("abc", user_dict={"abc": np.arange(4)})
save_args, _ = save_mock.call_args
assert "user_dict" in save_args[0]
assert save_args[0]["user_dict"] == _checkpoint_storage.to_serialized_hex({"abc": np.arange(4)})
@patch("onnxruntime.training._checkpoint_storage.load")
@patch("onnxruntime.training.checkpoint.aggregate_checkpoints")
def test_load_checkpoint(aggregate_checkpoints_mock, load_mock):
trainer = _create_trainer()
trainer_options = {
"mixed_precision": np.bool_(False),
"world_rank": np.int64(0),
"world_size": np.int64(1),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(1),
"zero_stage": np.int64(0),
}
state_dict = {
"model": {},
"optimizer": {},
"trainer_options": {
"mixed_precision": np.bool_(False),
"world_rank": np.int64(0),
"world_size": np.int64(1),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(1),
"zero_stage": np.int64(0),
},
}
trainer.load_state_dict = Mock()
load_mock.side_effect = [trainer_options, state_dict]
trainer.load_checkpoint("abc")
args_list = load_mock.call_args_list
load_args, load_kwargs = args_list[0]
assert load_args[0] == "abc"
assert load_kwargs["key"] == "trainer_options"
load_args, load_kwargs = args_list[1]
assert load_args[0] == "abc"
assert "key" not in load_kwargs
assert not aggregate_checkpoints_mock.called
@patch("onnxruntime.training._checkpoint_storage.load")
@patch("onnxruntime.training.checkpoint.aggregate_checkpoints")
@pytest.mark.parametrize(
"trainer_options",
[
{
"mixed_precision": np.bool_(False),
"world_rank": np.int64(0),
"world_size": np.int64(4),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(4),
"zero_stage": np.int64(1),
},
{
"mixed_precision": np.bool_(True),
"world_rank": np.int64(0),
"world_size": np.int64(1),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(1),
"zero_stage": np.int64(1),
},
{
"mixed_precision": np.bool_(True),
"world_rank": np.int64(0),
"world_size": np.int64(1),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(1),
"zero_stage": np.int64(1),
},
],
)
def test_load_checkpoint_aggregation_required_zero_enabled(aggregate_checkpoints_mock, load_mock, trainer_options):
trainer = _create_trainer()
trainer.load_state_dict = Mock()
load_mock.side_effect = [trainer_options]
trainer.load_checkpoint("abc")
args_list = load_mock.call_args_list
load_args, load_kwargs = args_list[0]
assert load_args[0] == "abc"
assert load_kwargs["key"] == "trainer_options"
assert aggregate_checkpoints_mock.called
call_args, _ = aggregate_checkpoints_mock.call_args
assert call_args[0] == tuple(["abc"])
@patch("onnxruntime.training._checkpoint_storage.load")
@patch("onnxruntime.training.checkpoint.aggregate_checkpoints")
def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock):
trainer = _create_trainer()
trainer_options = {
"mixed_precision": np.bool_(False),
"world_rank": np.int64(0),
"world_size": np.int64(1),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(1),
"zero_stage": np.int64(0),
}
state_dict = {
"model": {},
"optimizer": {},
"trainer_options": {
"mixed_precision": np.bool_(False),
"world_rank": np.int64(0),
"world_size": np.int64(1),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(1),
"zero_stage": np.int64(0),
},
"user_dict": _checkpoint_storage.to_serialized_hex({"array": torch.tensor(np.arange(5))}),
}
trainer.load_state_dict = Mock()
load_mock.side_effect = [trainer_options, state_dict]
user_dict = trainer.load_checkpoint("abc")
assert torch.all(torch.eq(user_dict["array"], torch.tensor(np.arange(5))))
@patch("onnxruntime.training._checkpoint_storage.load")
def test_checkpoint_aggregation(load_mock):
trainer_options1 = {
"mixed_precision": np.bool_(False),
"world_rank": np.int64(0),
"world_size": np.int64(2),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(2),
"zero_stage": np.int64(1),
"optimizer_name": b"Adam",
}
trainer_options2 = {
"mixed_precision": np.bool_(False),
"world_rank": np.int64(1),
"world_size": np.int64(2),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(2),
"zero_stage": np.int64(1),
"optimizer_name": b"Adam",
}
state_dict1 = {
"model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}},
"optimizer": {
"optimizer_sharded": {
"Moment_1": np.array([9, 8, 7]),
"Moment_2": np.array([99, 88, 77]),
"Step": np.array([5]),
},
"non_sharded": {
"Moment_1": np.array([666, 555, 444]),
"Moment_2": np.array([6666, 5555, 4444]),
"Step": np.array([55]),
},
},
"trainer_options": {
"mixed_precision": np.bool_(False),
"world_rank": np.int64(0),
"world_size": np.int64(1),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(1),
"zero_stage": np.int64(0),
"optimizer_name": b"Adam",
},
"partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}},
}
state_dict2 = {
"model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}},
"optimizer": {
"optimizer_sharded": {
"Moment_1": np.array([6, 5, 4]),
"Moment_2": np.array([66, 55, 44]),
"Step": np.array([5]),
},
"non_sharded": {
"Moment_1": np.array([666, 555, 444]),
"Moment_2": np.array([6666, 5555, 4444]),
"Step": np.array([55]),
},
},
"trainer_options": {
"mixed_precision": np.bool_(False),
"world_rank": np.int64(1),
"world_size": np.int64(1),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(1),
"zero_stage": np.int64(0),
"optimizer_name": b"Adam",
},
"partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}},
}
load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2]
state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False)
assert (state_dict["model"]["full_precision"]["optimizer_sharded"] == np.array([1, 2, 3])).all()
assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all()
assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all()
assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all()
assert (state_dict["optimizer"]["optimizer_sharded"]["Step"] == np.array([5])).all()
assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all()
assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all()
assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all()
assert state_dict["trainer_options"]["mixed_precision"] is False
assert state_dict["trainer_options"]["world_rank"] == 0
assert state_dict["trainer_options"]["world_size"] == 1
assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1
assert state_dict["trainer_options"]["data_parallel_size"] == 1
assert state_dict["trainer_options"]["zero_stage"] == 0
assert state_dict["trainer_options"]["optimizer_name"] == b"Adam"
@patch("onnxruntime.training._checkpoint_storage.load")
def test_checkpoint_aggregation_mixed_precision(load_mock):
trainer_options1 = {
"mixed_precision": np.bool_(True),
"world_rank": np.int64(0),
"world_size": np.int64(2),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(2),
"zero_stage": np.int64(1),
"optimizer_name": b"Adam",
}
trainer_options2 = {
"mixed_precision": np.bool_(True),
"world_rank": np.int64(1),
"world_size": np.int64(2),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(2),
"zero_stage": np.int64(1),
"optimizer_name": b"Adam",
}
state_dict1 = {
"model": {"full_precision": {"sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}},
"optimizer": {
"sharded": {"Moment_1": np.array([9, 8, 7]), "Moment_2": np.array([99, 88, 77]), "Step": np.array([5])},
"non_sharded": {
"Moment_1": np.array([666, 555, 444]),
"Moment_2": np.array([6666, 5555, 4444]),
"Step": np.array([55]),
},
},
"trainer_options": {
"mixed_precision": np.bool_(True),
"world_rank": np.int64(0),
"world_size": np.int64(1),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(1),
"zero_stage": np.int64(0),
"optimizer_name": b"Adam",
},
"partition_info": {"sharded": {"original_dim": np.array([2, 3])}},
}
state_dict2 = {
"model": {"full_precision": {"sharded": np.array([4, 5, 6]), "non_sharded": np.array([11, 22, 33])}},
"optimizer": {
"sharded": {"Moment_1": np.array([6, 5, 4]), "Moment_2": np.array([66, 55, 44]), "Step": np.array([5])},
"non_sharded": {
"Moment_1": np.array([666, 555, 444]),
"Moment_2": np.array([6666, 5555, 4444]),
"Step": np.array([55]),
},
},
"trainer_options": {
"mixed_precision": np.bool_(True),
"world_rank": np.int64(1),
"world_size": np.int64(1),
"horizontal_parallel_size": np.int64(1),
"data_parallel_size": np.int64(1),
"zero_stage": np.int64(0),
"optimizer_name": b"Adam",
},
"partition_info": {"sharded": {"original_dim": np.array([2, 3])}},
}
load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2]
state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False)
assert (state_dict["model"]["full_precision"]["sharded"] == np.array([[1, 2, 3], [4, 5, 6]])).all()
assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all()
assert (state_dict["optimizer"]["sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all()
assert (state_dict["optimizer"]["sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all()
assert (state_dict["optimizer"]["sharded"]["Step"] == np.array([5])).all()
assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all()
assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all()
assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all()
assert state_dict["trainer_options"]["mixed_precision"] is True
assert state_dict["trainer_options"]["world_rank"] == 0
assert state_dict["trainer_options"]["world_size"] == 1
assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1
assert state_dict["trainer_options"]["data_parallel_size"] == 1
assert state_dict["trainer_options"]["zero_stage"] == 0
assert state_dict["trainer_options"]["optimizer_name"] == b"Adam"

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,480 +0,0 @@
import random
import unittest
import numpy as np
import torch
from numpy.testing import assert_allclose
from orttraining_test_data_loader import BatchArgsOption, ids_tensor
from orttraining_test_utils import get_lr, run_test
from transformers import BertConfig, BertForPreTraining
import onnxruntime
from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401
class BertModelTest(unittest.TestCase):
class BertModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
device="cpu",
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.device = device
# 1. superset of bert input/output descs
# see BertPreTrainedModel doc
self.input_ids_desc = IODescription(
"input_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size
)
self.attention_mask_desc = IODescription(
"attention_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2
)
self.token_type_ids_desc = IODescription(
"token_type_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2
)
self.position_ids_desc = IODescription(
"position_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.max_position_embeddings
)
self.head_mask_desc = IODescription(
"head_mask", [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2
)
self.inputs_embeds_desc = IODescription(
"inputs_embeds", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32
)
self.encoder_hidden_states_desc = IODescription(
"encoder_hidden_states", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32
)
self.encoder_attention_mask_desc = IODescription(
"encoder_attention_mask", ["batch", "max_seq_len_in_batch"], torch.float32
)
# see BertForPreTraining doc
self.masked_lm_labels_desc = IODescription(
"masked_lm_labels", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size
)
self.next_sentence_label_desc = IODescription(
"next_sentence_label",
[
"batch",
],
torch.int64,
num_classes=2,
)
# outputs
self.loss_desc = IODescription(
"loss",
[
1,
],
torch.float32,
)
self.prediction_scores_desc = IODescription(
"prediction_scores", ["batch", "max_seq_len_in_batch", self.vocab_size], torch.float32
)
self.seq_relationship_scores_desc = IODescription(
"seq_relationship_scores", ["batch", 2], torch.float32
) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32)
self.hidden_states_desc = IODescription(
"hidden_states",
[self.num_hidden_layers, "batch", "max_seq_len_in_batch", self.hidden_size],
torch.float32,
)
self.attentions_desc = IODescription(
"attentions",
[
self.num_hidden_layers,
"batch",
self.num_attention_heads,
"max_seq_len_in_batch",
"max_seq_len_in_batch",
],
torch.float32,
)
self.last_hidden_state_desc = IODescription(
"last_hidden_state", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32
)
self.pooler_output_desc = IODescription("pooler_output", ["batch", self.hidden_size], torch.float32)
def BertForPreTraining_descs(self):
return ModelDescription(
[
self.input_ids_desc,
self.attention_mask_desc,
self.token_type_ids_desc,
self.masked_lm_labels_desc,
self.next_sentence_label_desc,
],
# returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided
# hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states
[
self.loss_desc,
self.prediction_scores_desc,
self.seq_relationship_scores_desc,
# hidden_states_desc, attentions_desc
],
)
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device)
choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device)
config = BertConfig(
vocab_size=self.vocab_size,
vocab_size_or_config_json_file=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_bert_for_pretraining(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
option_fp16,
option_allreduce_post_accumulation,
option_gradient_accumulation_steps,
option_split_batch,
option_use_internal_get_lr_this_step=[True], # noqa: B006
option_use_internal_loss_scaler=[True], # noqa: B006
):
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
onnxruntime.set_seed(seed)
model = BertForPreTraining(config=config)
model.eval()
loss, prediction_scores, seq_relationship_score = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
masked_lm_labels=token_labels,
next_sentence_label=sequence_labels,
)
model_desc = ModelDescription(
[
self.input_ids_desc,
self.attention_mask_desc,
self.token_type_ids_desc,
self.masked_lm_labels_desc,
self.next_sentence_label_desc,
],
[self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc],
)
from collections import namedtuple
MyArgs = namedtuple(
"MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len"
)
dataset_len = 100
epochs = 8
max_steps = epochs * dataset_len
args = MyArgs(
local_rank=0,
world_size=1,
max_steps=max_steps,
learning_rate=0.00001,
warmup_proportion=0.01,
batch_size=13,
seq_len=7,
)
def get_lr_this_step(global_step):
return get_lr(args, global_step)
loss_scaler = LossScaler("loss_scale_input_name", True, up_scale_window=2000)
for fp16 in option_fp16:
for allreduce_post_accumulation in option_allreduce_post_accumulation:
for gradient_accumulation_steps in option_gradient_accumulation_steps:
for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step:
for use_internal_loss_scaler in option_use_internal_loss_scaler:
for split_batch in option_split_batch:
print("gradient_accumulation_steps:", gradient_accumulation_steps)
print("split_batch:", split_batch)
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
onnxruntime.set_seed(seed)
(
old_api_loss_ort,
old_api_prediction_scores_ort,
old_api_seq_relationship_score_ort,
) = run_test(
model,
model_desc,
self.device,
args,
gradient_accumulation_steps,
fp16,
allreduce_post_accumulation,
get_lr_this_step,
use_internal_get_lr_this_step,
loss_scaler,
use_internal_loss_scaler,
split_batch,
dataset_len,
epochs,
use_new_api=False,
)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
onnxruntime.set_seed(seed)
if use_internal_get_lr_this_step and use_internal_loss_scaler:
(
new_api_loss_ort,
new_api_prediction_scores_ort,
new_api_seq_relationship_score_ort,
) = run_test(
model,
model_desc,
self.device,
args,
gradient_accumulation_steps,
fp16,
allreduce_post_accumulation,
get_lr_this_step,
use_internal_get_lr_this_step,
loss_scaler,
use_internal_loss_scaler,
split_batch,
dataset_len,
epochs,
use_new_api=True,
)
assert_allclose(old_api_loss_ort, new_api_loss_ort)
assert_allclose(old_api_prediction_scores_ort, new_api_prediction_scores_ort)
assert_allclose(
old_api_seq_relationship_score_ort, new_api_seq_relationship_score_ort
)
def setUp(self):
self.model_tester = BertModelTest.BertModelTester(self)
def test_for_pretraining_mixed_precision(self):
# It would be better to test both with/without mixed precision and allreduce_post_accumulation.
# However, stress test of all the 4 cases is not stable at least on the test machine.
# There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases.
option_fp16 = [True]
option_allreduce_post_accumulation = [True]
option_gradient_accumulation_steps = [1]
option_split_batch = [BatchArgsOption.ListAndDict]
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(
*config_and_inputs,
option_fp16,
option_allreduce_post_accumulation,
option_gradient_accumulation_steps,
option_split_batch,
)
def test_for_pretraining_mixed_precision_with_gradient_accumulation(self):
# It would be better to test both with/without mixed precision and allreduce_post_accumulation.
# However, stress test of all the 4 cases is not stable at least on the test machine.
# There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases.
option_fp16 = [True]
option_allreduce_post_accumulation = [True]
option_gradient_accumulation_steps = [8]
option_split_batch = [BatchArgsOption.ListAndDict]
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(
*config_and_inputs,
option_fp16,
option_allreduce_post_accumulation,
option_gradient_accumulation_steps,
option_split_batch,
)
def test_for_pretraining_full_precision_all(self):
# This test is not stable because it create and run ORTSession multiple times.
# It occasionally gets seg fault at ~MemoryPattern()
# when releasing patterns_. In order not to block PR merging CI test,
# this test is broke into following individual tests.
option_fp16 = [False]
option_allreduce_post_accumulation = [True]
option_gradient_accumulation_steps = [1, 8]
option_split_batch = [BatchArgsOption.List, BatchArgsOption.Dict, BatchArgsOption.ListAndDict]
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(
*config_and_inputs,
option_fp16,
option_allreduce_post_accumulation,
option_gradient_accumulation_steps,
option_split_batch,
)
def test_for_pretraining_full_precision_list_input(self):
option_fp16 = [False]
option_allreduce_post_accumulation = [True]
option_gradient_accumulation_steps = [1]
option_split_batch = [BatchArgsOption.List]
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(
*config_and_inputs,
option_fp16,
option_allreduce_post_accumulation,
option_gradient_accumulation_steps,
option_split_batch,
)
def test_for_pretraining_full_precision_dict_input(self):
option_fp16 = [False]
option_allreduce_post_accumulation = [True]
option_gradient_accumulation_steps = [1]
option_split_batch = [BatchArgsOption.Dict]
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(
*config_and_inputs,
option_fp16,
option_allreduce_post_accumulation,
option_gradient_accumulation_steps,
option_split_batch,
)
def test_for_pretraining_full_precision_list_and_dict_input(self):
option_fp16 = [False]
option_allreduce_post_accumulation = [True]
option_gradient_accumulation_steps = [1]
option_split_batch = [BatchArgsOption.ListAndDict]
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(
*config_and_inputs,
option_fp16,
option_allreduce_post_accumulation,
option_gradient_accumulation_steps,
option_split_batch,
)
def test_for_pretraining_full_precision_grad_accumulation_list_input(self):
option_fp16 = [False]
option_allreduce_post_accumulation = [True]
option_gradient_accumulation_steps = [8]
option_split_batch = [BatchArgsOption.List]
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(
*config_and_inputs,
option_fp16,
option_allreduce_post_accumulation,
option_gradient_accumulation_steps,
option_split_batch,
)
def test_for_pretraining_full_precision_grad_accumulation_dict_input(self):
option_fp16 = [False]
option_allreduce_post_accumulation = [True]
option_gradient_accumulation_steps = [8]
option_split_batch = [BatchArgsOption.Dict]
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(
*config_and_inputs,
option_fp16,
option_allreduce_post_accumulation,
option_gradient_accumulation_steps,
option_split_batch,
)
def test_for_pretraining_full_precision_grad_accumulation_list_and_dict_input(self):
option_fp16 = [False]
option_allreduce_post_accumulation = [True]
option_gradient_accumulation_steps = [8]
option_split_batch = [BatchArgsOption.ListAndDict]
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(
*config_and_inputs,
option_fp16,
option_allreduce_post_accumulation,
option_gradient_accumulation_steps,
option_split_batch,
)
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -1,246 +0,0 @@
import math
import torch
from orttraining_test_data_loader import BatchArgsOption, create_ort_test_dataloader, split_batch
from onnxruntime.capi.ort_trainer import IODescription, ORTTrainer
from onnxruntime.training import amp, optim, orttrainer
from onnxruntime.training.optim import _LRScheduler
def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x / warmup
return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002):
if x < warmup:
return x / warmup
return 1.0
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x / warmup
return max((x - 1.0) / (warmup - 1.0), 0.0)
def warmup_poly(x, warmup=0.002, degree=0.5):
if x < warmup:
return x / warmup
return (1.0 - x) ** degree
SCHEDULES = {
"warmup_cosine": warmup_cosine,
"warmup_constant": warmup_constant,
"warmup_linear": warmup_linear,
"warmup_poly": warmup_poly,
}
def get_lr(args, training_steps, schedule="warmup_poly"):
if args.max_steps == -1:
return args.learning_rate
schedule_fct = SCHEDULES[schedule]
return args.learning_rate * schedule_fct(training_steps / args.max_steps, args.warmup_proportion)
def map_optimizer_attributes(name):
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys)
if no_decay:
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
else:
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
class WrapLRScheduler(_LRScheduler):
def __init__(self, get_lr_this_step):
super().__init__()
self.get_lr_this_step = get_lr_this_step
def get_lr(self, train_step_info):
return [self.get_lr_this_step(train_step_info.optimization_step)]
def run_test(
model,
model_desc,
device,
args,
gradient_accumulation_steps,
fp16,
allreduce_post_accumulation,
get_lr_this_step,
use_internal_get_lr_this_step,
loss_scaler,
use_internal_loss_scaler,
batch_args_option,
dataset_len,
epochs,
use_new_api,
):
dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device)
if use_new_api:
assert use_internal_loss_scaler, "new api should always use internal loss scaler"
new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step)
new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None
options = orttrainer.ORTTrainerOptions(
{
"batch": {"gradient_accumulation_steps": gradient_accumulation_steps},
"device": {"id": device},
"mixed_precision": {"enabled": fp16, "loss_scaler": new_api_loss_scaler},
"debug": {
"deterministic_compute": True,
},
"utils": {"grad_norm_clip": True},
"distributed": {"allreduce_post_accumulation": True},
"lr_scheduler": new_api_lr_scheduler,
}
)
param_optimizer = list(model.named_parameters())
params = [
{
"params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n],
"alpha": 0.9,
"beta": 0.999,
"lambda": 0.0,
"epsilon": 1e-6,
},
{
"params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)],
"alpha": 0.9,
"beta": 0.999,
"lambda": 0.0,
"epsilon": 1e-6,
},
]
vocab_size = 99
new_model_desc = {
"inputs": [
(
"input_ids",
["batch", "max_seq_len_in_batch"],
),
(
"attention_mask",
["batch", "max_seq_len_in_batch"],
),
(
"token_type_ids",
["batch", "max_seq_len_in_batch"],
),
(
"masked_lm_labels",
["batch", "max_seq_len_in_batch"],
),
(
"next_sentence_label",
[
"batch",
],
),
],
"outputs": [
(
"loss",
[
1,
],
True,
),
("prediction_scores", ["batch", "max_seq_len_in_batch", vocab_size]),
("seq_relationship_scores", ["batch", 2]),
],
}
optim_config = optim.LambConfig(params=params, lr=2e-5)
model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options)
print("running with new frontend API")
else:
model = ORTTrainer(
model,
None,
model_desc,
"LambOptimizer",
map_optimizer_attributes=map_optimizer_attributes,
learning_rate_description=IODescription(
"Learning_Rate",
[
1,
],
torch.float32,
),
device=device,
_enable_internal_postprocess=True,
gradient_accumulation_steps=gradient_accumulation_steps,
# BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
world_rank=args.local_rank,
world_size=args.world_size,
use_mixed_precision=fp16,
allreduce_post_accumulation=allreduce_post_accumulation,
get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None,
loss_scaler=loss_scaler if use_internal_loss_scaler else None,
_opset_version=14,
_use_deterministic_compute=True,
)
print("running with old frontend API")
# training loop
eval_batch = None
if not use_new_api:
model.train()
for _epoch in range(epochs):
for step, batch in enumerate(dataloader):
if eval_batch is None:
eval_batch = batch
if not use_internal_get_lr_this_step:
lr = get_lr_this_step(step)
learning_rate = torch.tensor([lr])
if not use_internal_loss_scaler and fp16:
loss_scale = torch.tensor([loss_scaler.loss_scale_])
if batch_args_option == BatchArgsOption.List:
if not use_internal_get_lr_this_step:
batch = [*batch, learning_rate] # noqa: PLW2901
if not use_internal_loss_scaler and fp16:
batch = [*batch, loss_scale] # noqa: PLW2901
outputs = model.train_step(*batch)
elif batch_args_option == BatchArgsOption.Dict:
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
if not use_internal_get_lr_this_step:
kwargs["Learning_Rate"] = learning_rate
if not use_internal_loss_scaler and fp16:
kwargs[model.loss_scale_input_name] = loss_scale
outputs = model.train_step(*args, **kwargs)
else:
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
if not use_internal_get_lr_this_step:
kwargs["Learning_Rate"] = learning_rate
if not use_internal_loss_scaler and fp16:
kwargs[model.loss_scale_input_name] = loss_scale
outputs = model.train_step(*args, **kwargs)
# eval
if batch_args_option == BatchArgsOption.List:
outputs = model.eval_step(*batch)
elif batch_args_option == BatchArgsOption.Dict:
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
outputs = model.eval_step(*args, **kwargs)
else:
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
outputs = model.eval_step(*args, **kwargs)
return (output.cpu().numpy() for output in outputs)

Просмотреть файл

@ -1,357 +0,0 @@
# adapted from Trainer.py of huggingface transformers
import json
import logging
import os
import random
from typing import Callable, Dict, List, NamedTuple, Optional
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler
from tqdm import tqdm, trange
from transformers.data.data_collator import DefaultDataCollator
from transformers.modeling_utils import PreTrainedModel
from transformers.training_args import TrainingArguments
import onnxruntime
from onnxruntime.training import amp, optim, orttrainer
try:
from torch.utils.tensorboard import SummaryWriter
_has_tensorboard = True
except ImportError:
try:
from tensorboardX import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
_has_tensorboard = False
def is_tensorboard_available():
return _has_tensorboard
logger = logging.getLogger(__name__)
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
onnxruntime.set_seed(seed)
class EvalPrediction(NamedTuple):
predictions: np.ndarray
label_ids: np.ndarray
class PredictionOutput(NamedTuple):
predictions: np.ndarray
label_ids: Optional[np.ndarray]
metrics: Optional[Dict[str, float]]
class TrainOutput(NamedTuple):
global_step: int
training_loss: float
def get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps, base_lr):
def lr_lambda_linear(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
def lambda_lr_get_lr(current_global_step):
# LambdaLR increment self.last_epoch at evert sept()
return base_lr * lr_lambda_linear(current_global_step)
return lambda_lr_get_lr
class ORTTransformerTrainer:
""" """
model: PreTrainedModel
args: TrainingArguments
train_dataset: Dataset
eval_dataset: Dataset
compute_metrics: Callable[[EvalPrediction], Dict]
def __init__(
self,
model: PreTrainedModel,
model_desc: dict,
args: TrainingArguments,
train_dataset: Dataset,
eval_dataset: Dataset,
compute_metrics: Callable[[EvalPrediction], Dict],
world_size: Optional[int] = 1,
):
""" """
self.model = model
self.model_desc = model_desc
self.args = args
self.world_size = world_size
self.data_collator = DefaultDataCollator()
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics
set_seed(self.args.seed)
# Create output directory if needed
if self.args.local_rank in [-1, 0]:
os.makedirs(self.args.output_dir, exist_ok=True)
def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = (
SequentialSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
return DataLoader(
self.train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator.collate_batch,
)
def get_eval_dataloader(self) -> DataLoader:
return DataLoader(
self.eval_dataset,
batch_size=self.args.eval_batch_size,
shuffle=False,
collate_fn=self.data_collator.collate_batch,
)
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
# We use the same batch_size as for eval.
return DataLoader(
test_dataset,
batch_size=self.args.eval_batch_size,
shuffle=False,
collate_fn=self.data_collator.collate_batch,
)
def train(self):
"""
Main training entry point.
"""
train_dataloader = self.get_train_dataloader()
if self.args.max_steps > 0:
t_total = self.args.max_steps
num_train_epochs = (
self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
)
else:
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs
lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps / float(t_total))
loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None
device = self.args.device.type
device = f"{device}:{self.args.device.index}" if self.args.device.index else f"{device}:0"
options = orttrainer.ORTTrainerOptions(
{
"batch": {"gradient_accumulation_steps": self.args.gradient_accumulation_steps},
"device": {"id": device},
"mixed_precision": {"enabled": self.args.fp16, "loss_scaler": loss_scaler},
"debug": {
"deterministic_compute": True,
},
"utils": {"grad_norm_clip": False},
"distributed": {
# we are running single node multi gpu test. thus world_rank = local_rank
# and world_size = self.args.n_gpu
"world_rank": max(0, self.args.local_rank),
"world_size": int(self.world_size),
"local_rank": max(0, self.args.local_rank),
"allreduce_post_accumulation": True,
},
"lr_scheduler": lr_scheduler,
}
)
param_optimizer = list(self.model.named_parameters())
params = [
{
"params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n],
"weight_decay_mode": 1,
},
{
"params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)],
"weight_decay_mode": 1,
},
]
optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
self.model = orttrainer.ORTTrainer(self.model, self.model_desc, optim_config, options=options)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataloader.dataset))
logger.info(" Num Epochs = %d", num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
self.args.train_batch_size
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
tr_loss = 0.0
logging_loss = 0.0
train_iterator = trange(
epochs_trained,
int(num_train_epochs),
desc="Epoch",
disable=self.args.local_rank not in [-1, 0],
)
for _epoch in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0])
for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
tr_loss += self._training_step(self.model, inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)
):
global_step += 1
if self.args.local_rank in [-1, 0]:
if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (
global_step == 1 and self.args.logging_first_step
):
logs = {}
if self.args.evaluate_during_training:
results = self.evaluate()
for key, value in results.items():
eval_key = f"eval_{key}"
logs[eval_key] = value
loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps
logs["loss"] = loss_scalar
logging_loss = tr_loss
epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))
if self.args.max_steps > 0 and global_step > self.args.max_steps:
epoch_iterator.close()
break
if self.args.max_steps > 0 and global_step > self.args.max_steps:
train_iterator.close()
break
logger.info("\n\nTraining completed. \n\n")
return TrainOutput(global_step, tr_loss / global_step)
def _training_step(self, model, inputs: Dict[str, torch.Tensor]) -> float:
for k, v in inputs.items():
inputs[k] = v.to(self.args.device)
outputs = model.train_step(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
return loss.item()
def save_model(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
self.model.save_as_onnx(os.path.join(output_dir, "transformer.onnx"))
def evaluate(self) -> Dict[str, float]:
"""
Run evaluation and return metrics.
Returns:
A dict containing:
- the eval loss
- the potential metrics computed from the predictions
"""
eval_dataloader = self.get_eval_dataloader()
output = self._prediction_loop(eval_dataloader, description="Evaluation")
return output.metrics
def predict(self, test_dataset: Dataset) -> PredictionOutput:
"""
Run prediction and return predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels.
In that case, this method will also return metrics, like in evaluate().
"""
test_dataloader = self.get_test_dataloader(test_dataset)
return self._prediction_loop(test_dataloader, description="Prediction")
def _prediction_loop(self, dataloader: DataLoader, description: str) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
Works both with or without labels.
"""
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", len(dataloader.dataset))
logger.info(" Batch size = %d", dataloader.batch_size)
eval_losses: List[float] = []
preds: np.ndarray = None
label_ids: np.ndarray = None
for inputs in tqdm(dataloader, desc=description):
has_labels = any(inputs.get(k) is not None for k in ["labels", "masked_lm_labels"])
for k, v in inputs.items():
inputs[k] = v.to(self.args.device)
with torch.no_grad():
outputs = self.model.eval_step(**inputs)
if has_labels:
step_eval_loss, logits = outputs[:2]
eval_losses += [step_eval_loss.mean().item()]
else:
logits = outputs[0]
if preds is None:
preds = logits.detach().cpu().numpy()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
if inputs.get("labels") is not None:
if label_ids is None:
label_ids = inputs["labels"].detach().cpu().numpy()
else:
label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else:
metrics = {}
if len(eval_losses) > 0:
metrics["loss"] = np.mean(eval_losses)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

Просмотреть файл

@ -1,269 +0,0 @@
# adapted from run_multiple_choice.py of huggingface transformers
# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/utils_multiple_choice.py
import csv
import glob # noqa: F401
import json # noqa: F401
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
import torch
import tqdm
from filelock import FileLock
from torch.utils.data.dataset import Dataset
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available # noqa: F401
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class InputExample:
"""
A single training/test example for multiple choice
Args:
example_id: Unique id for the example.
question: string. The untokenized text of the second sequence (question).
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
example_id: str
question: str
contexts: List[str]
endings: List[str]
label: Optional[str]
@dataclass(frozen=True)
class InputFeatures:
"""
A single set of features of data.
Property names are the same names as the corresponding inputs to a model.
"""
example_id: str
input_ids: List[List[int]]
attention_mask: Optional[List[List[int]]]
token_type_ids: Optional[List[List[int]]]
label: Optional[int]
class Split(Enum):
train = "train"
dev = "dev"
test = "test"
class DataProcessor:
"""Base class for data converters for multiple choice data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the test set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
class MultipleChoiceDataset(Dataset):
"""
This will be superseded by a framework-agnostic approach
soon.
"""
features: List[InputFeatures]
def __init__(
self,
data_dir: str,
tokenizer: PreTrainedTokenizer,
task: str,
processor: DataProcessor,
max_seq_length: Optional[int] = None,
overwrite_cache=False,
mode: Split = Split.train,
):
cached_features_file = os.path.join(
data_dir,
"cached_{}_{}_{}_{}".format(
mode.value,
tokenizer.__class__.__name__,
str(max_seq_length),
task,
),
)
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not overwrite_cache:
logger.info(f"Loading features from cached file {cached_features_file}")
self.features = torch.load(cached_features_file)
else:
logger.info(f"Creating features from dataset file at {data_dir}")
label_list = processor.get_labels()
if mode == Split.dev:
examples = processor.get_dev_examples(data_dir)
elif mode == Split.test:
examples = processor.get_test_examples(data_dir)
else:
examples = processor.get_train_examples(data_dir)
logger.info("Training examples: %s", len(examples))
# TODO clean up all this to leverage built-in features of tokenizers
self.features = convert_examples_to_features(
examples,
label_list,
max_seq_length,
tokenizer,
pad_on_left=bool(tokenizer.padding_side == "left"),
pad_token=tokenizer.pad_token_id,
pad_token_segment_id=tokenizer.pad_token_type_id,
)
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(self.features, cached_features_file)
def __len__(self):
return len(self.features)
def __getitem__(self, i) -> InputFeatures:
return self.features[i]
class SwagProcessor(DataProcessor):
"""Processor for the SWAG data set."""
def get_train_examples(self, data_dir):
"""See base class."""
logger.info(f"LOOKING AT {data_dir} train")
return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
logger.info(f"LOOKING AT {data_dir} dev")
return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
logger.info(f"LOOKING AT {data_dir} dev")
raise ValueError(
"For swag testing, the input file does not contain a label column. It can not be tested in current code"
"setting!"
)
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1", "2", "3"]
def _read_csv(self, input_file):
with open(input_file, encoding="utf-8") as f:
return list(csv.reader(f))
def _create_examples(self, lines: List[List[str]], type: str):
"""Creates examples for the training and dev sets."""
if type == "train" and lines[0][-1] != "label":
raise ValueError("For training, the input file must contain a label column.")
examples = [
InputExample(
example_id=line[2],
question=line[5], # in the swag dataset, the
# common beginning of each
# choice is stored in "sent2".
contexts=[line[4], line[4], line[4], line[4]],
endings=[line[7], line[8], line[9], line[10]],
label=line[11],
)
for line in lines[1:] # we skip the line with the column names
]
return examples
def convert_examples_to_features(
examples: List[InputExample],
label_list: List[str],
max_length: int,
tokenizer: PreTrainedTokenizer,
pad_token_segment_id=0,
pad_on_left=False,
pad_token=0,
mask_padding_with_zero=True,
) -> List[InputFeatures]:
"""
Loads a data file into a list of `InputFeatures`
"""
label_map = {label: i for i, label in enumerate(label_list)}
features = []
for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
choices_inputs = []
for _ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
text_a = context
if example.question.find("_") != -1:
# this is for cloze question
text_b = example.question.replace("_", ending)
else:
text_b = example.question + " " + ending
inputs = tokenizer.encode_plus(
text_a,
text_b,
add_special_tokens=True,
max_length=max_length,
pad_to_max_length=True,
return_overflowing_tokens=True,
)
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
logger.info(
"Attention! you are cropping tokens (swag task is ok). "
"If you are training ARC and RACE and you are poping question + options,"
"you need to try to use a bigger max seq length!"
)
choices_inputs.append(inputs)
label = label_map[example.label]
input_ids = [x["input_ids"] for x in choices_inputs]
attention_mask = (
[x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None
)
token_type_ids = (
[x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None
)
features.append(
InputFeatures(
example_id=example.example_id,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
label=label,
)
)
for f in features[:2]:
logger.info("*** Example ***")
logger.info("feature: %s" % f)
return features

Просмотреть файл

@ -1,200 +0,0 @@
## This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py
## with modification to do training using onnxruntime as backend on cuda device.
## A private PyTorch build from https://aiinfra.visualstudio.com/Lotus/_git/pytorch (ORTTraining branch) is needed to run the demo.
## Model testing is not complete.
import argparse
import os
import numpy as np # noqa: F401
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim # noqa: F401
from mpi4py import MPI
from torchvision import datasets, transforms
from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer
try: # noqa: SIM105
from onnxruntime.capi._pybind_state import set_cuda_device_id
except ImportError:
pass
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
def my_loss(x, target):
return F.nll_loss(F.log_softmax(x, dim=1), target)
def train_with_trainer(args, trainer, device, train_loader, epoch):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device) # noqa: PLW2901
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
learning_rate = torch.tensor([args.lr])
loss = trainer.train_step(data, target, learning_rate)
# Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss.
if batch_idx % args.log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss[0],
)
)
# TODO: comple this once ORT training can do evaluation.
def test_with_trainer(args, trainer, device, test_loader):
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device) # noqa: PLW2901
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
output = F.log_softmax(trainer.eval_step(data, fetches=["probability"]), dim=1)
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
)
)
def mnist_model_description():
input_desc = IODescription("input1", ["batch", 784], torch.float32)
label_desc = IODescription(
"label",
[
"batch",
],
torch.int64,
num_classes=10,
)
loss_desc = IODescription("loss", [], torch.float32)
probability_desc = IODescription("probability", ["batch", 10], torch.float32)
return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc])
def main():
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
)
parser.add_argument(
"--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)"
)
parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)")
parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)")
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
kwargs = {"num_workers": 0, "pin_memory": True}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
),
batch_size=args.batch_size,
shuffle=True,
**kwargs,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=False,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs,
)
comm = MPI.COMM_WORLD
args.local_rank = (
int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) if ("OMPI_COMM_WORLD_LOCAL_RANK" in os.environ) else 0
)
args.world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) if ("OMPI_COMM_WORLD_RANK" in os.environ) else 0
args.world_size = comm.Get_size()
if use_cuda:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
args.n_gpu = 1
set_cuda_device_id(args.local_rank)
else:
device = torch.device("cpu")
input_size = 784
hidden_size = 500
num_classes = 10
model = NeuralNet(input_size, hidden_size, num_classes)
model_desc = mnist_model_description()
# use log_interval as gradient accumulate steps
trainer = ORTTrainer(
model,
my_loss,
model_desc,
"SGDOptimizer",
None,
IODescription(
"Learning_Rate",
[
1,
],
torch.float32,
),
device,
1,
args.world_rank,
args.world_size,
use_mixed_precision=False,
allreduce_post_accumulation=True,
)
print("\nBuild ort model done.")
for epoch in range(1, args.epochs + 1):
train_with_trainer(args, trainer, device, train_loader, epoch)
test_with_trainer(args, trainer, device, test_loader)
if __name__ == "__main__":
main()

Двоичный файл не отображается.

Просмотреть файл

@ -1,174 +0,0 @@
# This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py
# with modification to do training using onnxruntime as backend on cuda device.
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import onnxruntime
from onnxruntime.training import ORTTrainer, ORTTrainerOptions, optim
# Pytorch model
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = self.fc2(out)
return out
# ONNX Runtime training
def mnist_model_description():
return {
"inputs": [("input1", ["batch", 784]), ("label", ["batch"])],
"outputs": [("loss", [], True), ("probability", ["batch", 10])],
}
def my_loss(x, target):
return F.nll_loss(F.log_softmax(x, dim=1), target)
# Helpers
def train(log_interval, trainer, device, train_loader, epoch, train_steps):
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx == train_steps:
break
# Fetch data
data, target = data.to(device), target.to(device) # noqa: PLW2901
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
# Train step
loss, prob = trainer.train_step(data, target)
# Stats
if batch_idx % log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss
)
)
def test(trainer, device, test_loader):
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device) # noqa: PLW2901
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
# Using fetches around without eval_step to not pass 'target' as input
trainer._train_step_info.fetches = ["probability"]
output = F.log_softmax(trainer.eval_step(data), dim=1)
trainer._train_step_info.fetches = []
# Stats
test_loss += F.nll_loss(output, target, reduction="sum").item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
)
)
def main():
# Training settings
parser = argparse.ArgumentParser(description="ONNX Runtime MNIST Example")
parser.add_argument(
"--train-steps",
type=int,
default=-1,
metavar="N",
help="number of steps to train. Set -1 to run through whole dataset (default: -1)",
)
parser.add_argument(
"--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)"
)
parser.add_argument(
"--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)"
)
parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)")
parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)")
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model state")
# Basic setup
args = parser.parse_args()
if not args.no_cuda and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
torch.manual_seed(args.seed)
onnxruntime.set_seed(args.seed)
# Data loader
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"./data",
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
),
batch_size=args.batch_size,
shuffle=True,
)
if args.test_batch_size > 0:
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"./data",
train=False,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
),
batch_size=args.test_batch_size,
shuffle=True,
)
# Modeling
model = NeuralNet(784, 500, 10)
model_desc = mnist_model_description()
optim_config = optim.SGDConfig(lr=args.lr)
opts = {"device": {"id": device}}
opts = ORTTrainerOptions(opts)
trainer = ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
# Train loop
for epoch in range(1, args.epochs + 1):
train(args.log_interval, trainer, device, train_loader, epoch, args.train_steps)
if args.test_batch_size > 0:
test(trainer, device, test_loader)
# Save model
if args.save_path:
torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt"))
if __name__ == "__main__":
main()

Просмотреть файл

@ -1,157 +0,0 @@
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
# Pytorch model
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = self.fc2(out)
return out
def my_loss(x, target, is_train=True):
if is_train:
return F.nll_loss(F.log_softmax(x, dim=1), target)
else:
return F.nll_loss(F.log_softmax(x, dim=1), target, reduction="sum")
# Helpers
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx == args.train_steps:
break
data, target = data.to(device), target.to(device) # noqa: PLW2901
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
optimizer.zero_grad()
output = model(data)
loss = my_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device) # noqa: PLW2901
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
output = model(data)
# Stats
test_loss += my_loss(output, target, False).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
)
)
def main():
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--train-steps",
type=int,
default=-1,
metavar="N",
help="number of steps to train. Set -1 to run through whole dataset (default: -1)",
)
parser.add_argument(
"--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)"
)
parser.add_argument(
"--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)"
)
parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)")
parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)")
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model")
# Basic setup
args = parser.parse_args()
if not args.no_cuda and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
torch.manual_seed(args.seed)
# Data loader
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"./data",
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
),
batch_size=args.batch_size,
shuffle=True,
)
if args.test_batch_size > 0:
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"./data",
train=False,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
),
batch_size=args.test_batch_size,
shuffle=True,
)
# Modeling
model = NeuralNet(784, 500, 10).to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr)
# Train loop
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
if args.test_batch_size > 0:
test(model, device, test_loader)
# Save model
if args.save_path:
torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt"))
if __name__ == "__main__":
main()

Просмотреть файл

@ -1,33 +0,0 @@
# TransformerModel example
This example was adapted from Pytorch's [Sequence-to-Sequence Modeling with nn.Transformer and TorchText](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) tutorial
## Requirements
* PyTorch 1.6+
* TorchText 0.6+
* ONNX Runtime 1.5+
## Running PyTorch version
```bash
python pt_train.py
```
## Running ONNX Runtime version
```bash
python ort_train.py
```
## Optional arguments
| Argument | Description | Default |
| :---------------- | :-----------------------------------------------------: | --------: |
| --batch-size | input batch size for training | 20 |
| --test-batch-size | input batch size for testing | 20 |
| --epochs | number of epochs to train | 2 |
| --lr | learning rate | 0.001 |
| --no-cuda | disables CUDA training | False |
| --seed | random seed | 1 |
| --log-interval | how many batches to wait before logging training status | 200 |

Просмотреть файл

@ -1,89 +0,0 @@
import argparse
import torch
from ort_utils import my_loss, transformer_model_description_dynamic_axes
from pt_model import TransformerModel
from utils import get_batch, prepare_data
import onnxruntime
def train(trainer, data_source, device, epoch, args, bptt=35):
total_loss = 0.0
for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)):
data, targets = get_batch(data_source, i)
loss, pred = trainer.train_step(data, targets)
total_loss += loss.item()
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss / args.log_interval
print(
"epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format(
epoch, batch, len(data_source) // bptt, cur_loss
)
)
total_loss = 0
def evaluate(trainer, data_source, bptt=35):
total_loss = 0.0
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i)
loss, pred = trainer.eval_step(data, targets)
total_loss += len(data) * loss.item()
return total_loss / (len(data_source) - 1)
if __name__ == "__main__":
# Training settings
parser = argparse.ArgumentParser(description="PyTorch TransformerModel example")
parser.add_argument(
"--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)"
)
parser.add_argument(
"--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)"
)
parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)")
parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)")
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=200,
metavar="N",
help="how many batches to wait before logging training status (default: 200)",
)
# Basic setup
args = parser.parse_args()
if not args.no_cuda and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
torch.manual_seed(args.seed)
onnxruntime.set_seed(args.seed)
# Model
optim_config = onnxruntime.training.optim.SGDConfig(lr=args.lr)
model_desc = transformer_model_description_dynamic_axes()
model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
# Preparing data
train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size)
trainer = onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss)
# Train
for epoch in range(1, args.epochs + 1):
train(trainer, train_data, device, epoch, args)
val_loss = evaluate(trainer, val_data)
print("-" * 89)
print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ")
print("-" * 89)
# Evaluate
test_loss = evaluate(trainer, test_data)
print("=" * 89)
print(f"| End of training | test loss {test_loss:5.2f}")
print("=" * 89)

Просмотреть файл

@ -1,47 +0,0 @@
import torch
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription
from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription
def my_loss(x, target):
x = x.view(-1, 28785)
return torch.nn.CrossEntropyLoss()(x, target)
def transformer_model_description(bptt=35, batch_size=20, ntokens=28785):
model_desc = {
"inputs": [("input1", [bptt, batch_size]), ("label", [bptt * batch_size])],
"outputs": [("loss", [], True), ("predictions", [bptt, batch_size, ntokens])],
}
return model_desc
def transformer_model_description_dynamic_axes(ntokens=28785):
model_desc = {
"inputs": [("input1", ["bptt", "batch_size"]), ("label", ["bptt_x_batch_size"])],
"outputs": [("loss", [], True), ("predictions", ["bptt", "batch_size", ntokens])],
}
return model_desc
def legacy_transformer_model_description(bptt=35, batch_size=20, ntokens=28785):
input_desc = Legacy_IODescription("input1", [bptt, batch_size])
label_desc = Legacy_IODescription("label", [bptt * batch_size])
loss_desc = Legacy_IODescription("loss", [])
predictions_desc = Legacy_IODescription("predictions", [bptt, batch_size, ntokens])
return (
Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]),
Legacy_IODescription("__learning_rate", [1]),
)
def legacy_transformer_model_description_dynamic_axes(ntokens=28785):
input_desc = Legacy_IODescription("input1", ["bptt", "batch_size"])
label_desc = Legacy_IODescription("label", ["bptt_x_batch_size"])
loss_desc = Legacy_IODescription("loss", [])
predictions_desc = Legacy_IODescription("predictions", ["bptt", "batch_size", ntokens])
return (
Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]),
Legacy_IODescription("__learning_rate", [1]),
)

Просмотреть файл

@ -1,62 +0,0 @@
import math
import torch
import torch.nn as nn
class TransformerModel(nn.Module):
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super().__init__()
from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.model_type = "Transformer"
self.input1_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)
self.init_weights()
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0)
return mask
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, input1):
if self.input1_mask is None or self.input1_mask.size(0) != input1.size(0):
device = input1.device
mask = self._generate_square_subsequent_mask(input1.size(0)).to(device)
self.input1_mask = mask
input1 = self.encoder(input1) * math.sqrt(self.ninp)
input1 = self.pos_encoder(input1)
output = self.transformer_encoder(input1, self.input1_mask)
output = self.decoder(output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)

Просмотреть файл

@ -1,94 +0,0 @@
import argparse
import torch
import torch.nn as nn
from pt_model import TransformerModel
from utils import get_batch, prepare_data
def train(model, data_source, device, epoch, args, bptt=35):
total_loss = 0.0
model.train()
for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)):
data, targets = get_batch(data_source, i)
optimizer.zero_grad()
output = model(data)
loss = criterion(output.view(-1, 28785), targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss / args.log_interval
print(
"epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format(
epoch, batch, len(data_source) // bptt, cur_loss
)
)
total_loss = 0
def evaluate(model, data_source, criterion, bptt=35):
total_loss = 0.0
model.eval()
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i)
output = model(data)
output_flat = output.view(-1, 28785)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source) - 1)
if __name__ == "__main__":
# Training settings
parser = argparse.ArgumentParser(description="PyTorch TransformerModel example")
parser.add_argument(
"--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)"
)
parser.add_argument(
"--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)"
)
parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)")
parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)")
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=200,
metavar="N",
help="how many batches to wait before logging training status (default: 200)",
)
# Basic setup
args = parser.parse_args()
if not args.no_cuda and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
torch.manual_seed(args.seed)
# Model
criterion = nn.CrossEntropyLoss()
lr = 0.001
model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# Preparing data
train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size)
# Train
for epoch in range(1, args.epochs + 1):
train(model, train_data, device, epoch, args)
val_loss = evaluate(model, val_data, criterion)
print("-" * 89)
print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ")
print("-" * 89)
# Evaluate
test_loss = evaluate(model, test_data, criterion)
print("=" * 89)
print(f"| End of training | test loss {test_loss:5.2f}")
print("=" * 89)

Просмотреть файл

@ -1,59 +0,0 @@
import os
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive
from torchtext.vocab import build_vocab_from_iterator
def batchify(data, bsz, device):
# Divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
def get_batch(source, i, bptt=35):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i : i + seq_len]
target = source[i + 1 : i + 1 + seq_len].view(-1)
return data, target
def prepare_data(device="cpu", train_batch_size=20, eval_batch_size=20, data_dir=None):
url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
download_path = ".data_wikitext_2_v1"
extract_path = None
if data_dir:
download_path = os.path.join(data_dir, "download")
os.makedirs(download_path, exist_ok=True)
download_path = os.path.join(download_path, "wikitext-2-v1.zip")
extract_path = os.path.join(data_dir, "extracted")
os.makedirs(extract_path, exist_ok=True)
test_filepath, valid_filepath, train_filepath = extract_archive(
download_from_url(url, root=download_path), to_path=extract_path
)
tokenizer = get_tokenizer("basic_english")
vocab = build_vocab_from_iterator(map(tokenizer, iter(open(train_filepath, encoding="utf8")))) # noqa: SIM115
def data_process(raw_text_iter):
data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter]
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
train_data = data_process(iter(open(train_filepath, encoding="utf8"))) # noqa: SIM115
val_data = data_process(iter(open(valid_filepath, encoding="utf8"))) # noqa: SIM115
test_data = data_process(iter(open(test_filepath, encoding="utf8"))) # noqa: SIM115
device = torch.device(device)
train_data = batchify(train_data, train_batch_size, device)
val_data = batchify(val_data, eval_batch_size, device)
test_data = batchify(test_data, eval_batch_size, device)
return train_data, val_data, test_data

Просмотреть файл

@ -398,7 +398,6 @@ packages = [
"onnxruntime",
"onnxruntime.backend",
"onnxruntime.capi",
"onnxruntime.capi.training",
"onnxruntime.datasets",
"onnxruntime.tools",
"onnxruntime.tools.mobile_helpers",