Removed all the deprecated python training code and related tests and utils (#18333)
### Description Motivation for this PR is code cleanup. 1. Remove all deprecated python code related to orttrainer, old checkpoint, related tests and utils 2. Cleanup orttraining_pybind_state.cc to remove all deprecated bindings.
This commit is contained in:
Родитель
cbb85b4874
Коммит
02333293de
|
@ -339,9 +339,6 @@ configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in
|
|||
${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py)
|
||||
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/deprecated/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_root_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/*.py"
|
||||
)
|
||||
|
@ -419,10 +416,6 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
"${ORTTRAINING_SOURCE_DIR}/python/training/onnxblock/optim/*"
|
||||
)
|
||||
endif()
|
||||
else()
|
||||
file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/training/*.py"
|
||||
)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_BUILD_UNIT_TESTS)
|
||||
|
@ -577,9 +570,6 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_capi_training_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/training/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
$<TARGET_FILE:onnxruntime_pybind11_state>
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/
|
||||
|
@ -750,9 +740,6 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils/data/
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils/hooks/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_capi_training_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/training/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_root_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/
|
||||
|
|
|
@ -61,7 +61,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa:
|
|||
from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401
|
||||
from onnxruntime.capi.training import * # noqa: F403
|
||||
|
||||
# TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
|
||||
try: # noqa: SIM105
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,102 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
|
||||
from numpy.testing import assert_allclose, assert_array_equal
|
||||
from onnxruntime_test_ort_trainer import run_bert_training_test
|
||||
|
||||
|
||||
class TestOrtTrainer(unittest.TestCase):
|
||||
def test_bert_training_mixed_precision(self):
|
||||
expected_losses = [
|
||||
11.034248352050781,
|
||||
11.125300407409668,
|
||||
11.006105422973633,
|
||||
11.047048568725586,
|
||||
11.027417182922363,
|
||||
11.015759468078613,
|
||||
11.060905456542969,
|
||||
10.971782684326172,
|
||||
]
|
||||
expected_all_finites = [True, True, True, True, True, True, True, True]
|
||||
expected_eval_loss = [10.959012985229492]
|
||||
actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test(
|
||||
gradient_accumulation_steps=1,
|
||||
use_mixed_precision=True,
|
||||
allreduce_post_accumulation=False,
|
||||
use_simple_model_desc=False,
|
||||
)
|
||||
|
||||
rtol = 1e-02
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
|
||||
assert_allclose(
|
||||
expected_eval_loss,
|
||||
actual_eval_loss,
|
||||
rtol=rtol,
|
||||
err_msg="evaluation loss mismatch",
|
||||
)
|
||||
|
||||
def test_bert_training_mixed_precision_internal_loss_scale(self):
|
||||
expected_losses = [
|
||||
11.034248352050781,
|
||||
11.125300407409668,
|
||||
11.006105422973633,
|
||||
11.047048568725586,
|
||||
11.027417182922363,
|
||||
11.015759468078613,
|
||||
11.060905456542969,
|
||||
10.971782684326172,
|
||||
]
|
||||
expected_eval_loss = [10.959012985229492]
|
||||
actual_losses, actual_eval_loss = run_bert_training_test(
|
||||
gradient_accumulation_steps=1,
|
||||
use_mixed_precision=True,
|
||||
allreduce_post_accumulation=False,
|
||||
use_simple_model_desc=False,
|
||||
use_internel_loss_scale=True,
|
||||
)
|
||||
|
||||
rtol = 1e-02
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_allclose(
|
||||
expected_eval_loss,
|
||||
actual_eval_loss,
|
||||
rtol=rtol,
|
||||
err_msg="evaluation loss mismatch",
|
||||
)
|
||||
|
||||
def test_bert_training_gradient_accumulation_mixed_precision(self):
|
||||
expected_losses = [
|
||||
11.034248352050781,
|
||||
11.125300407409668,
|
||||
11.006077766418457,
|
||||
11.047025680541992,
|
||||
11.027434349060059,
|
||||
11.0156831741333,
|
||||
11.060973167419434,
|
||||
10.971841812133789,
|
||||
]
|
||||
expected_all_finites = [True, True]
|
||||
expected_eval_loss = [10.95903205871582]
|
||||
actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test(
|
||||
gradient_accumulation_steps=4,
|
||||
use_mixed_precision=True,
|
||||
allreduce_post_accumulation=False,
|
||||
use_simple_model_desc=False,
|
||||
)
|
||||
|
||||
rtol = 1e-02
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
|
||||
assert_allclose(
|
||||
expected_eval_loss,
|
||||
actual_eval_loss,
|
||||
rtol=rtol,
|
||||
err_msg="evaluation loss mismatch",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(module=__name__, buffer=True)
|
|
@ -1,95 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from numpy.testing import assert_allclose
|
||||
from onnxruntime_test_ort_trainer import map_optimizer_attributes, ort_trainer_learning_rate_description
|
||||
from onnxruntime_test_training_unittest_utils import process_dropout
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer
|
||||
|
||||
|
||||
class TestTrainingDropout(unittest.TestCase):
|
||||
def setUp(self):
|
||||
torch.manual_seed(1)
|
||||
onnxruntime.set_seed(1)
|
||||
|
||||
@unittest.skip(
|
||||
"Temporarily disable this test. The graph below will trigger ORT to "
|
||||
"sort backward graph before forward graph which gives incorrect result. "
|
||||
"https://github.com/microsoft/onnxruntime/issues/16801"
|
||||
)
|
||||
def test_training_and_eval_dropout(self):
|
||||
class TwoDropoutNet(nn.Module):
|
||||
def __init__(self, drop_prb_1, drop_prb_2, dim_size):
|
||||
super().__init__()
|
||||
self.drop_1 = nn.Dropout(drop_prb_1)
|
||||
self.drop_2 = nn.Dropout(drop_prb_2)
|
||||
self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32))
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.weight_1
|
||||
x = self.drop_1(x)
|
||||
x = self.drop_2(x)
|
||||
output = x
|
||||
return output[0]
|
||||
|
||||
dim_size = 3
|
||||
device = torch.device("cuda", 0)
|
||||
# This will drop all values, therefore expecting all 0 in output tensor
|
||||
model = TwoDropoutNet(0.999, 0.999, dim_size)
|
||||
input_desc = IODescription("input", [dim_size], torch.float32)
|
||||
output_desc = IODescription("output", [], torch.float32)
|
||||
model_desc = ModelDescription([input_desc], [output_desc])
|
||||
lr_desc = ort_trainer_learning_rate_description()
|
||||
model = ORTTrainer(
|
||||
model,
|
||||
None,
|
||||
model_desc,
|
||||
"LambOptimizer",
|
||||
map_optimizer_attributes,
|
||||
lr_desc,
|
||||
device,
|
||||
postprocess_model=process_dropout,
|
||||
world_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
input = torch.ones(dim_size, dtype=torch.float32).to(device)
|
||||
expected_training_output = [0.0]
|
||||
expected_eval_output = [1.0]
|
||||
learning_rate = torch.tensor([1.0000000e00]).to(device)
|
||||
input_args = [input, learning_rate]
|
||||
train_output = model.train_step(*input_args)
|
||||
|
||||
rtol = 1e-04
|
||||
assert_allclose(
|
||||
expected_training_output,
|
||||
train_output.item(),
|
||||
rtol=rtol,
|
||||
err_msg="dropout training loss mismatch",
|
||||
)
|
||||
|
||||
eval_output = model.eval_step(input)
|
||||
assert_allclose(
|
||||
expected_eval_output,
|
||||
eval_output.item(),
|
||||
rtol=rtol,
|
||||
err_msg="dropout eval loss mismatch",
|
||||
)
|
||||
|
||||
# Do another train step to make sure it's using original ratios
|
||||
train_output_2 = model.train_step(*input_args)
|
||||
assert_allclose(
|
||||
expected_training_output,
|
||||
train_output_2.item(),
|
||||
rtol=rtol,
|
||||
err_msg="dropout training loss 2 mismatch",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(module=__name__, buffer=True)
|
|
@ -1,56 +0,0 @@
|
|||
import numpy as np
|
||||
from onnx import numpy_helper
|
||||
|
||||
|
||||
def get_node_index(model, node):
|
||||
i = 0
|
||||
while i < len(model.graph.node):
|
||||
if model.graph.node[i] == node:
|
||||
break
|
||||
i += 1
|
||||
return i if i < len(model.graph.node) else None
|
||||
|
||||
|
||||
def add_const(model, name, output, t_value=None, f_value=None):
|
||||
const_node = model.graph.node.add()
|
||||
const_node.op_type = "Constant"
|
||||
const_node.name = name
|
||||
const_node.output.extend([output])
|
||||
attr = const_node.attribute.add()
|
||||
attr.name = "value"
|
||||
if t_value is not None:
|
||||
attr.type = 4
|
||||
attr.t.CopyFrom(t_value)
|
||||
else:
|
||||
attr.type = 1
|
||||
attr.f = f_value
|
||||
return const_node
|
||||
|
||||
|
||||
def process_dropout(model):
|
||||
dropouts = []
|
||||
index = 0
|
||||
for node in model.graph.node:
|
||||
if node.op_type == "Dropout":
|
||||
new_dropout = model.graph.node.add()
|
||||
new_dropout.op_type = "TrainableDropout"
|
||||
new_dropout.name = "TrainableDropout_%d" % index
|
||||
# make ratio node
|
||||
ratio = np.asarray([node.attribute[0].f], dtype=np.float32)
|
||||
print(ratio.shape)
|
||||
ratio_value = numpy_helper.from_array(ratio)
|
||||
ratio_node = add_const(
|
||||
model,
|
||||
"dropout_node_ratio_%d" % index,
|
||||
"dropout_node_ratio_%d" % index,
|
||||
t_value=ratio_value,
|
||||
)
|
||||
print(ratio_node)
|
||||
new_dropout.input.extend([node.input[0], ratio_node.output[0]])
|
||||
new_dropout.output.extend(node.output)
|
||||
dropouts.append(get_node_index(model, node))
|
||||
index += 1
|
||||
dropouts.sort(reverse=True)
|
||||
for d in dropouts:
|
||||
del model.graph.node[d]
|
||||
model.opset_import[0].version = 10
|
|
@ -1,127 +0,0 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"):
|
||||
ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)]
|
||||
ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)]
|
||||
ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names]
|
||||
|
||||
assert len(ckpt_file_names) > 0, 'No checkpoint files found with prefix "{}" in directory {}.'.format(
|
||||
checkpoint_prefix, checkpoint_dir
|
||||
)
|
||||
return ckpt_file_names
|
||||
|
||||
|
||||
def get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None):
|
||||
SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806
|
||||
MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806
|
||||
|
||||
if is_partitioned:
|
||||
filename = MULTIPLE_CHECKPOINT_FILENAME.format(
|
||||
prefix=prefix, world_rank=world_rank, world_size=(world_size - 1)
|
||||
)
|
||||
else:
|
||||
filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix)
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def _split_state_dict(state_dict):
|
||||
optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"]
|
||||
split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}}
|
||||
for k, v in state_dict.items():
|
||||
mode = "fp32_param"
|
||||
for optim_key in optimizer_keys:
|
||||
if k.startswith(optim_key):
|
||||
mode = "optimizer"
|
||||
break
|
||||
if k.endswith("_fp16"):
|
||||
mode = "fp16_param"
|
||||
split_sd[mode][k] = v
|
||||
return split_sd
|
||||
|
||||
|
||||
class CombineZeroCheckpoint:
|
||||
def __init__(self, checkpoint_files, clean_state_dict=None):
|
||||
assert len(checkpoint_files) > 0, "No checkpoint files passed"
|
||||
self.checkpoint_files = checkpoint_files
|
||||
self.clean_state_dict = clean_state_dict
|
||||
self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1
|
||||
assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files"
|
||||
self.weight_shape_map = dict()
|
||||
self.sharded_params = set()
|
||||
|
||||
def _split_name(self, name: str):
|
||||
name_split = name.split("_view_")
|
||||
view_num = None
|
||||
if len(name_split) > 1:
|
||||
view_num = int(name_split[1])
|
||||
optimizer_key = ""
|
||||
mp_suffix = ""
|
||||
if name_split[0].startswith("Moment_1"):
|
||||
optimizer_key = "Moment_1_"
|
||||
elif name_split[0].startswith("Moment_2"):
|
||||
optimizer_key = "Moment_2_"
|
||||
elif name_split[0].startswith("Update_Count"):
|
||||
optimizer_key = "Update_Count_"
|
||||
elif name_split[0].endswith("_fp16"):
|
||||
mp_suffix = "_fp16"
|
||||
param_name = name_split[0]
|
||||
if optimizer_key:
|
||||
param_name = param_name.split(optimizer_key)[1]
|
||||
param_name = param_name.split("_fp16")[0]
|
||||
return param_name, optimizer_key, view_num, mp_suffix
|
||||
|
||||
def _update_weight_statistics(self, name, value):
|
||||
if name not in self.weight_shape_map:
|
||||
self.weight_shape_map[name] = value.size() # original shape of tensor
|
||||
|
||||
def _reshape_tensor(self, key):
|
||||
value = self.aggregate_state_dict[key]
|
||||
weight_name, _, _, _ = self._split_name(key)
|
||||
set_size = self.weight_shape_map[weight_name]
|
||||
self.aggregate_state_dict[key] = value.reshape(set_size)
|
||||
|
||||
def _aggregate(self, param_dict):
|
||||
for k, v in param_dict.items():
|
||||
weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k)
|
||||
if view_num is not None:
|
||||
# parameter is sharded
|
||||
param_name = optimizer_key + weight_name + mp_suffix
|
||||
|
||||
if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]:
|
||||
self.sharded_params.add(param_name)
|
||||
# Found a previous shard of the param, concatenate shards ordered by ranks
|
||||
self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v))
|
||||
else:
|
||||
self.aggregate_state_dict[param_name] = v
|
||||
else:
|
||||
if k in self.aggregate_state_dict:
|
||||
assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value"
|
||||
else:
|
||||
self.aggregate_state_dict[k] = v
|
||||
self._update_weight_statistics(weight_name, v)
|
||||
|
||||
def aggregate_checkpoints(self):
|
||||
checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0]
|
||||
self.aggregate_state_dict = dict()
|
||||
|
||||
for i in range(self.world_size):
|
||||
checkpoint_name = get_checkpoint_name(checkpoint_prefix, True, i, self.world_size)
|
||||
rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu"))
|
||||
if "model" in rank_state_dict:
|
||||
rank_state_dict = rank_state_dict["model"]
|
||||
|
||||
if self.clean_state_dict:
|
||||
rank_state_dict = self.clean_state_dict(rank_state_dict)
|
||||
|
||||
rank_state_dict = _split_state_dict(rank_state_dict)
|
||||
self._aggregate(rank_state_dict["fp16_param"])
|
||||
self._aggregate(rank_state_dict["fp32_param"])
|
||||
self._aggregate(rank_state_dict["optimizer"])
|
||||
|
||||
for k in self.sharded_params:
|
||||
self._reshape_tensor(k)
|
||||
return self.aggregate_state_dict
|
|
@ -1,6 +0,0 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from onnxruntime.capi._pybind_state import TrainingParameters # noqa: F401
|
||||
from onnxruntime.capi.training.training_session import TrainingSession # noqa: F401
|
|
@ -1,68 +0,0 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import os # noqa: F401
|
||||
import sys # noqa: F401
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import (
|
||||
InferenceSession,
|
||||
Session,
|
||||
check_and_normalize_provider_args,
|
||||
)
|
||||
|
||||
|
||||
class TrainingSession(InferenceSession):
|
||||
def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None):
|
||||
Session.__init__(self)
|
||||
|
||||
if sess_options:
|
||||
self._sess = C.TrainingSession(sess_options)
|
||||
else:
|
||||
self._sess = C.TrainingSession()
|
||||
|
||||
# providers needs to be passed explicitly as of ORT 1.10
|
||||
# retain the pre-1.10 behavior by setting to the available providers.
|
||||
if providers is None:
|
||||
providers = C.get_available_providers()
|
||||
|
||||
providers, provider_options = check_and_normalize_provider_args(
|
||||
providers, provider_options, C.get_available_providers()
|
||||
)
|
||||
|
||||
if isinstance(path_or_bytes, str):
|
||||
config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options)
|
||||
elif isinstance(path_or_bytes, bytes):
|
||||
config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options)
|
||||
else:
|
||||
raise TypeError(f"Unable to load from type '{type(path_or_bytes)}'")
|
||||
|
||||
self.loss_scale_input_name = config_result.loss_scale_input_name
|
||||
|
||||
self._inputs_meta = self._sess.inputs_meta
|
||||
self._outputs_meta = self._sess.outputs_meta
|
||||
|
||||
def __del__(self):
|
||||
if self._sess:
|
||||
self._sess.finalize()
|
||||
|
||||
def get_state(self):
|
||||
return self._sess.get_state()
|
||||
|
||||
def get_model_state(self, include_mixed_precision_weights=False):
|
||||
return self._sess.get_model_state(include_mixed_precision_weights)
|
||||
|
||||
def get_optimizer_state(self):
|
||||
return self._sess.get_optimizer_state()
|
||||
|
||||
def get_partition_info_map(self):
|
||||
return self._sess.get_partition_info_map()
|
||||
|
||||
def load_state(self, dict, strict=False):
|
||||
self._sess.load_state(dict, strict)
|
||||
|
||||
def is_output_fp32_node(self, output_name):
|
||||
return self._sess.is_output_fp32_node(output_name)
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -18,7 +18,6 @@
|
|||
#include "core/session/environment.h"
|
||||
#include "core/session/custom_ops.h"
|
||||
#include "core/dlpack/dlpack_converter.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/agent/training_agent.h"
|
||||
#include "orttraining/core/graph/gradient_config.h"
|
||||
#include "orttraining/core/graph/optimizer_config.h"
|
||||
|
@ -113,14 +112,11 @@ struct TrainingParameters {
|
|||
std::unordered_set<std::string> weights_to_train;
|
||||
std::unordered_set<std::string> weights_not_to_train;
|
||||
|
||||
onnxruntime::training::TrainingSession::ImmutableWeights immutable_weights;
|
||||
|
||||
// optimizer
|
||||
std::string training_optimizer_name;
|
||||
std::string lr_params_feed_name = "Learning_Rate";
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, float>> optimizer_attributes_map;
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, int64_t>> optimizer_int_attributes_map;
|
||||
onnxruntime::training::TrainingSession::OptimizerState optimizer_initial_state;
|
||||
std::unordered_map<std::string, std::vector<int>> sliced_schema;
|
||||
std::unordered_map<std::string, int> sliced_axes;
|
||||
std::vector<std::string> sliced_tensor_names;
|
||||
|
@ -206,185 +202,6 @@ struct PyGradientGraphBuilderContext {
|
|||
local_registries_(local_registries) {}
|
||||
};
|
||||
|
||||
// TODO: this method does not handle parallel optimization.
|
||||
TrainingConfigurationResult ConfigureSessionForTraining(
|
||||
training::PipelineTrainingSession* sess, TrainingParameters& parameters) {
|
||||
// TODO tix, refactor the mpi related code to populate all fields correctly by default.
|
||||
ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size, "data_parallel_size: ", parameters.data_parallel_size, ", world_size: ", parameters.world_size);
|
||||
ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size, "horizontal_parallel_size: ", parameters.horizontal_parallel_size, ", world_size: ", parameters.world_size);
|
||||
ORT_ENFORCE(parameters.pipeline_parallel_size <= parameters.world_size, "pipeline_parallel_size: ", parameters.pipeline_parallel_size, ", world_size: ", parameters.world_size);
|
||||
|
||||
// When DxHxP != the total number of ranks, we try adjusting D so that DxHxP == the total number of ranks.
|
||||
if (parameters.world_size != parameters.data_parallel_size * parameters.horizontal_parallel_size * parameters.pipeline_parallel_size) {
|
||||
ORT_ENFORCE(parameters.world_size % parameters.horizontal_parallel_size * parameters.pipeline_parallel_size == 0,
|
||||
"D, H, P sizes are incorrect. To enable automatic correction, total number of ranks must be a divisible by HxP.");
|
||||
|
||||
const auto new_data_parallel_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size);
|
||||
parameters.data_parallel_size = new_data_parallel_size;
|
||||
|
||||
const std::string msg = "Cannot distribute " + std::to_string(parameters.world_size) + " ranks for distributed computation with D=" + std::to_string(parameters.data_parallel_size) +
|
||||
", H=" + std::to_string(parameters.horizontal_parallel_size) + ", P=" + std::to_string(parameters.pipeline_parallel_size) + ", so D is automatically changed to " + std::to_string(new_data_parallel_size);
|
||||
LOGS(*(sess->GetLogger()), WARNING) << msg;
|
||||
}
|
||||
|
||||
training::PipelineTrainingSession::TrainingConfiguration config{};
|
||||
config.weight_names_to_train = parameters.weights_to_train;
|
||||
config.weight_names_to_not_train = parameters.weights_not_to_train;
|
||||
config.immutable_weights = parameters.immutable_weights;
|
||||
config.gradient_accumulation_steps = parameters.gradient_accumulation_steps;
|
||||
|
||||
config.distributed_config.world_rank = parameters.world_rank;
|
||||
config.distributed_config.world_size = parameters.world_size;
|
||||
config.distributed_config.local_rank = parameters.local_rank;
|
||||
config.distributed_config.local_size = parameters.local_size;
|
||||
config.distributed_config.data_parallel_size = parameters.data_parallel_size;
|
||||
config.distributed_config.horizontal_parallel_size = parameters.horizontal_parallel_size;
|
||||
config.distributed_config.pipeline_parallel_size = parameters.pipeline_parallel_size;
|
||||
config.distributed_config.num_pipeline_micro_batches = parameters.num_pipeline_micro_batches;
|
||||
config.distributed_config.sliced_schema = parameters.sliced_schema;
|
||||
config.distributed_config.sliced_axes = parameters.sliced_axes;
|
||||
config.distributed_config.sliced_tensor_names = parameters.sliced_tensor_names;
|
||||
|
||||
if (parameters.use_mixed_precision) {
|
||||
training::PipelineTrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mp{};
|
||||
mp.use_mixed_precision_initializers = true;
|
||||
|
||||
config.mixed_precision_config = mp;
|
||||
}
|
||||
|
||||
if (config.distributed_config.pipeline_parallel_size > 1) {
|
||||
training::PipelineTrainingSession::TrainingConfiguration::PipelineConfiguration pipeline_config;
|
||||
|
||||
// Currently don't support auto-partition. User needs to pass in cut information for pipeline
|
||||
pipeline_config.do_partition = true;
|
||||
assert(!parameters.pipeline_cut_info_string.empty());
|
||||
|
||||
auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) {
|
||||
std::vector<std::string> result;
|
||||
size_t pos = 0;
|
||||
while ((pos = input_str.find(delimiter)) != std::string::npos) {
|
||||
std::string token = input_str.substr(0, pos);
|
||||
result.emplace_back(token);
|
||||
input_str.erase(0, pos + delimiter.length());
|
||||
}
|
||||
// push the last split of substring into result.
|
||||
result.emplace_back(input_str);
|
||||
return result;
|
||||
};
|
||||
|
||||
auto process_cut_info = [&](std::string& cut_info_string) {
|
||||
std::vector<PipelineTrainingSession::TrainingConfiguration::CutInfo> cut_list;
|
||||
const std::string group_delimiter = ",";
|
||||
const std::string edge_delimiter = ":";
|
||||
const std::string consumer_delimiter = "/";
|
||||
const std::string producer_consumer_delimiter = "-";
|
||||
|
||||
auto cut_info_groups = process_with_delimiter(cut_info_string, group_delimiter);
|
||||
for (auto& cut_info_group : cut_info_groups) {
|
||||
PipelineTrainingSession::TrainingConfiguration::CutInfo cut_info;
|
||||
auto cut_edges = process_with_delimiter(cut_info_group, edge_delimiter);
|
||||
for (auto& cut_edge : cut_edges) {
|
||||
auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter);
|
||||
if (process_edge.size() == 1) {
|
||||
PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]};
|
||||
cut_info.emplace_back(edge);
|
||||
} else {
|
||||
ORT_ENFORCE(process_edge.size() == 2);
|
||||
auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter);
|
||||
|
||||
PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list};
|
||||
cut_info.emplace_back(edge);
|
||||
}
|
||||
}
|
||||
cut_list.emplace_back(cut_info);
|
||||
}
|
||||
return cut_list;
|
||||
};
|
||||
|
||||
pipeline_config.cut_list = process_cut_info(parameters.pipeline_cut_info_string);
|
||||
config.pipeline_config = pipeline_config;
|
||||
}
|
||||
config.loss_name = parameters.loss_output_name;
|
||||
|
||||
if (!parameters.training_optimizer_name.empty()) {
|
||||
training::PipelineTrainingSession::TrainingConfiguration::OptimizerConfiguration opt{};
|
||||
opt.name = parameters.training_optimizer_name;
|
||||
opt.learning_rate_input_name = parameters.lr_params_feed_name;
|
||||
opt.weight_attributes_generator = [¶meters](const std::string& weight_name) {
|
||||
const auto it = parameters.optimizer_attributes_map.find(weight_name);
|
||||
ORT_ENFORCE(
|
||||
it != parameters.optimizer_attributes_map.end(),
|
||||
"Failed to find attribute map for weight ", weight_name);
|
||||
return it->second;
|
||||
};
|
||||
opt.weight_int_attributes_generator = [¶meters](const std::string& weight_name) {
|
||||
const auto it = parameters.optimizer_int_attributes_map.find(weight_name);
|
||||
ORT_ENFORCE(
|
||||
it != parameters.optimizer_int_attributes_map.end(),
|
||||
"Failed to find int attribute map for weight ", weight_name);
|
||||
return it->second;
|
||||
};
|
||||
opt.use_mixed_precision_moments = parameters.use_fp16_moments;
|
||||
opt.do_all_reduce_in_mixed_precision_type = true;
|
||||
// TODO: this mapping is temporary.
|
||||
// For now, nccl allreduce kernel only implements for allreduce_post_accumulation
|
||||
// hovorod allreduce kernel only implements for not allreduce_post_accumulation.
|
||||
// eventually we will have one all reduce kernel and let opt to have
|
||||
// an allreduce_post_accumulation option and remove the use_nccl option.
|
||||
opt.use_nccl = parameters.allreduce_post_accumulation;
|
||||
opt.deepspeed_zero = onnxruntime::training::ZeROConfig(parameters.deepspeed_zero_stage);
|
||||
opt.enable_grad_norm_clip = parameters.enable_grad_norm_clip;
|
||||
|
||||
// TODO reduction types
|
||||
if (parameters.enable_adasum) {
|
||||
#ifdef USE_CUDA
|
||||
opt.adasum_reduction_type = training::AdasumReductionType::GpuHierarchicalReduction;
|
||||
#else
|
||||
opt.adasum_reduction_type = training::AdasumReductionType::CpuReduction;
|
||||
#endif
|
||||
}
|
||||
|
||||
config.optimizer_config = opt;
|
||||
}
|
||||
|
||||
if (!parameters.optimizer_initial_state.empty()) {
|
||||
config.init_optimizer_states = parameters.optimizer_initial_state;
|
||||
}
|
||||
|
||||
config.gradient_graph_config.use_memory_efficient_gradient = parameters.use_memory_efficient_gradient;
|
||||
config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs;
|
||||
|
||||
config.graph_transformer_config.attn_dropout_recompute = parameters.attn_dropout_recompute;
|
||||
config.graph_transformer_config.gelu_recompute = parameters.gelu_recompute;
|
||||
config.graph_transformer_config.transformer_layer_recompute = parameters.transformer_layer_recompute;
|
||||
config.graph_transformer_config.number_recompute_layers = parameters.number_recompute_layers;
|
||||
config.graph_transformer_config.propagate_cast_ops_config.strategy = parameters.propagate_cast_ops_strategy;
|
||||
config.graph_transformer_config.propagate_cast_ops_config.level = parameters.propagate_cast_ops_level;
|
||||
config.graph_transformer_config.propagate_cast_ops_config.allow = parameters.propagate_cast_ops_allow;
|
||||
|
||||
if (!parameters.model_after_graph_transforms_path.empty()) {
|
||||
config.model_after_graph_transforms_path = ToPathString(parameters.model_after_graph_transforms_path);
|
||||
}
|
||||
if (!parameters.model_with_gradient_graph_path.empty()) {
|
||||
config.model_with_gradient_graph_path = ToPathString(parameters.model_with_gradient_graph_path);
|
||||
}
|
||||
if (!parameters.model_with_training_graph_path.empty()) {
|
||||
config.model_with_training_graph_path = ToPathString(parameters.model_with_training_graph_path);
|
||||
}
|
||||
|
||||
training::PipelineTrainingSession::TrainingConfigurationResult config_result{};
|
||||
|
||||
OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result));
|
||||
|
||||
TrainingConfigurationResult python_config_result{};
|
||||
if (config_result.mixed_precision_config_result.has_value()) {
|
||||
const auto& mp_config_result = config_result.mixed_precision_config_result.value();
|
||||
python_config_result.loss_scale_input_name = mp_config_result.loss_scale_input_name;
|
||||
}
|
||||
|
||||
return python_config_result;
|
||||
}
|
||||
|
||||
#if defined(USE_MPI)
|
||||
void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) {
|
||||
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank();
|
||||
|
@ -424,7 +241,7 @@ std::unordered_map<std::string, std::unordered_map<std::string, py::object>> Con
|
|||
return py_tensor_state;
|
||||
}
|
||||
|
||||
void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) {
|
||||
void addObjectMethodsForTraining(py::module& m) {
|
||||
py::class_<OrtValueCache, OrtValueCachePtr>(m, "OrtValueCache")
|
||||
.def(py::init<>())
|
||||
.def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) {
|
||||
|
@ -451,7 +268,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
|
||||
parameters.def(py::init())
|
||||
.def_readwrite("loss_output_name", &TrainingParameters::loss_output_name)
|
||||
.def_readwrite("immutable_weights", &TrainingParameters::immutable_weights)
|
||||
.def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train)
|
||||
.def_readwrite("weights_to_train", &TrainingParameters::weights_to_train)
|
||||
.def_readwrite("sliced_tensor_names", &TrainingParameters::sliced_tensor_names)
|
||||
|
@ -484,25 +300,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
.def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size)
|
||||
.def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size)
|
||||
.def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size)
|
||||
.def("set_optimizer_initial_state",
|
||||
[](TrainingParameters& parameters, const std::unordered_map<std::string, std::unordered_map<std::string, py::object>>& py_state) -> void {
|
||||
onnxruntime::training::TrainingSession::OptimizerState optim_state;
|
||||
for (const auto& weight_it : py_state) {
|
||||
auto state = weight_it.second;
|
||||
NameMLValMap state_tensors;
|
||||
for (auto& initializer : state) {
|
||||
OrtValue ml_value;
|
||||
|
||||
// InputDeflist is null because parameters havent been tied to session yet
|
||||
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
|
||||
CreateGenericMLValue(nullptr, GetAllocator(), "", initializer.second, &ml_value, true);
|
||||
ThrowIfPyErrOccured();
|
||||
state_tensors.emplace(initializer.first, ml_value);
|
||||
}
|
||||
optim_state.emplace(weight_it.first, state_tensors);
|
||||
}
|
||||
parameters.optimizer_initial_state = optim_state;
|
||||
})
|
||||
.def_readwrite("model_after_graph_transforms_path", &TrainingParameters::model_after_graph_transforms_path)
|
||||
.def_readwrite("model_with_gradient_graph_path", &TrainingParameters::model_with_gradient_graph_path)
|
||||
.def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path)
|
||||
|
@ -611,130 +408,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
});
|
||||
#endif
|
||||
|
||||
py::class_<TrainingConfigurationResult> config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc");
|
||||
config_result.def(py::init())
|
||||
.def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object {
|
||||
if (result.loss_scale_input_name.has_value()) {
|
||||
return py::str{result.loss_scale_input_name.value()};
|
||||
}
|
||||
return py::none();
|
||||
});
|
||||
|
||||
// Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user
|
||||
struct PyTrainingSession : public PyInferenceSession {
|
||||
PyTrainingSession(std::shared_ptr<Environment> env, const PySessionOptions& so)
|
||||
: PyInferenceSession(env, std::make_unique<PipelineTrainingSession>(so.value, *env)) {
|
||||
}
|
||||
~PyTrainingSession() = default;
|
||||
};
|
||||
|
||||
py::class_<PyTrainingSession, PyInferenceSession> training_session(m, "TrainingSession");
|
||||
training_session
|
||||
.def(py::init([](const PySessionOptions& so) {
|
||||
auto& training_env = GetTrainingEnv();
|
||||
return std::make_unique<PyTrainingSession>(training_env.GetORTEnv(), so);
|
||||
}))
|
||||
.def(py::init([]() {
|
||||
auto& training_env = GetTrainingEnv();
|
||||
return std::make_unique<PyTrainingSession>(training_env.GetORTEnv(), GetDefaultCPUSessionOptions());
|
||||
}))
|
||||
.def("finalize", [](py::object) {
|
||||
#if defined(USE_MPI)
|
||||
#ifdef _WIN32
|
||||
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
|
||||
// shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction
|
||||
// call shutdown_mpi() here instead.
|
||||
MPIContext::shutdown_mpi();
|
||||
#endif
|
||||
#endif
|
||||
})
|
||||
.def("load_model", [ep_registration_fn](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
|
||||
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path));
|
||||
|
||||
#if defined(USE_MPI)
|
||||
bool use_nccl = parameters.allreduce_post_accumulation;
|
||||
if (!use_nccl && parameters.world_size > 1)
|
||||
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
|
||||
#endif
|
||||
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
|
||||
|
||||
ProviderOptionsVector merged_options;
|
||||
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
|
||||
|
||||
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
|
||||
|
||||
return config_result;
|
||||
})
|
||||
.def("read_bytes", [ep_registration_fn](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
|
||||
std::istringstream buffer(serialized_model);
|
||||
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer));
|
||||
|
||||
#if defined(USE_MPI)
|
||||
bool use_nccl = parameters.allreduce_post_accumulation;
|
||||
if (!use_nccl && parameters.world_size > 1)
|
||||
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
|
||||
#endif
|
||||
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
|
||||
ProviderOptionsVector merged_options;
|
||||
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
|
||||
|
||||
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
|
||||
|
||||
return config_result;
|
||||
})
|
||||
.def("get_state", [](PyTrainingSession* sess) {
|
||||
NameMLValMap state_tensors;
|
||||
ORT_THROW_IF_ERROR(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->GetStateTensors(state_tensors));
|
||||
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
|
||||
// convert to numpy array
|
||||
std::map<std::string, py::object> rmap;
|
||||
for (auto& kv : state_tensors) {
|
||||
if (kv.second.IsTensor()) {
|
||||
py::object obj;
|
||||
const Tensor& rtensor = kv.second.Get<Tensor>();
|
||||
GetPyObjFromTensor(rtensor, obj, &data_transfer_manager);
|
||||
rmap.insert({kv.first, obj});
|
||||
} else {
|
||||
throw std::runtime_error("Non tensor type in session state tensors is not expected.");
|
||||
}
|
||||
}
|
||||
return rmap;
|
||||
})
|
||||
.def("get_model_state", [](PyTrainingSession* sess, bool include_mixed_precision_weights) {
|
||||
std::unordered_map<std::string, NameMLValMap> model_state_tensors;
|
||||
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetModelState(model_state_tensors, include_mixed_precision_weights));
|
||||
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
|
||||
return ConvertORTTensorMapToNumpy(model_state_tensors, data_transfer_manager);
|
||||
})
|
||||
.def("get_optimizer_state", [](PyTrainingSession* sess) {
|
||||
std::unordered_map<std::string, NameMLValMap> opt_state_tensors;
|
||||
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetOptimizerState(opt_state_tensors));
|
||||
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
|
||||
return ConvertORTTensorMapToNumpy(opt_state_tensors, data_transfer_manager);
|
||||
})
|
||||
.def("get_partition_info_map", [](PyTrainingSession* sess) {
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> part_info_map;
|
||||
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetPartitionInfoMap(part_info_map));
|
||||
return part_info_map;
|
||||
})
|
||||
.def("load_state", [](PyTrainingSession* sess, std::unordered_map<std::string, py::object>& state, bool strict) {
|
||||
NameMLValMap state_tensors;
|
||||
for (auto initializer : state) {
|
||||
OrtValue ml_value;
|
||||
auto px = sess->GetSessionHandle()->GetModelInputs();
|
||||
if (!px.first.IsOK() || !px.second) {
|
||||
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
|
||||
}
|
||||
CreateGenericMLValue(px.second, GetAllocator(), initializer.first, initializer.second, &ml_value);
|
||||
ThrowIfPyErrOccured();
|
||||
state_tensors.insert(std::make_pair(initializer.first, ml_value));
|
||||
}
|
||||
ORT_THROW_IF_ERROR(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->SetStateTensors(state_tensors, strict));
|
||||
})
|
||||
.def("is_output_fp32_node", [](PyTrainingSession* sess, const std::string& output_name) {
|
||||
return static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name);
|
||||
});
|
||||
|
||||
py::class_<PartialGraphExecutionState>(m, "PartialGraphExecutionState")
|
||||
.def(py::init([]() {
|
||||
return std::make_unique<PartialGraphExecutionState>();
|
||||
|
|
|
@ -40,7 +40,7 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM*
|
|||
|
||||
void addGlobalMethods(py::module& m);
|
||||
void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn);
|
||||
void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn);
|
||||
void addObjectMethodsForTraining(py::module& m);
|
||||
void addObjectMethodsForEager(py::module& m);
|
||||
#ifdef ENABLE_LAZY_TENSOR
|
||||
void addObjectMethodsForLazyTensor(py::module& m);
|
||||
|
@ -339,7 +339,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
|
|||
}
|
||||
#endif
|
||||
|
||||
addObjectMethodsForTraining(m, ORTTrainingRegisterExecutionProviders);
|
||||
addObjectMethodsForTraining(m);
|
||||
|
||||
#ifdef ENABLE_LAZY_TENSOR
|
||||
addObjectMethodsForLazyTensor(m);
|
||||
|
|
|
@ -8,26 +8,16 @@ from onnxruntime.capi._pybind_state import (
|
|||
TrainingParameters,
|
||||
is_ortmodule_available,
|
||||
)
|
||||
from onnxruntime.capi.training.training_session import TrainingSession
|
||||
|
||||
|
||||
# Options need to be imported before `ORTTrainer`.
|
||||
from .orttrainer_options import ORTTrainerOptions
|
||||
from .orttrainer import ORTTrainer, TrainStepInfo
|
||||
from . import amp, artifacts, checkpoint, model_desc_validation, optim
|
||||
from . import amp, artifacts, optim
|
||||
|
||||
__all__ = [
|
||||
"PropagateCastOpsStrategy",
|
||||
"TrainingParameters",
|
||||
"is_ortmodule_available",
|
||||
"TrainingSession",
|
||||
"ORTTrainerOptions",
|
||||
"ORTTrainer",
|
||||
"TrainStepInfo",
|
||||
"amp",
|
||||
"artifacts",
|
||||
"checkpoint",
|
||||
"model_desc_validation",
|
||||
"optim",
|
||||
]
|
||||
|
||||
|
|
|
@ -1,107 +0,0 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import pickle
|
||||
from collections.abc import Mapping
|
||||
|
||||
import h5py
|
||||
|
||||
|
||||
def _dfs_save(group, save_obj):
|
||||
"""Recursively go over each level in the save_obj dictionary and save values to a hdf5 group"""
|
||||
|
||||
for key, value in save_obj.items():
|
||||
if isinstance(value, Mapping):
|
||||
subgroup = group.create_group(key)
|
||||
_dfs_save(subgroup, value)
|
||||
else:
|
||||
group[key] = value
|
||||
|
||||
|
||||
def save(save_obj: dict, path):
|
||||
"""Persists the input dictionary to a file specified by path.
|
||||
|
||||
Saves an hdf5 representation of the save_obj dictionary to a file or a file-like object specified by path.
|
||||
Values are saved in a format supported by h5py. For example, a PyTorch tensor is saved and loaded as a
|
||||
numpy object. So, user types may be converted from their original types to numpy equivalent types.
|
||||
|
||||
Args:
|
||||
save_obj: dictionary that needs to be saved.
|
||||
save_obj should consist of types supported by hdf5 file format.
|
||||
if hdf5 does not recognize a type, an exception is raised.
|
||||
if save_obj is not a dictionary, a ValueError is raised.
|
||||
path: string representation to a file path or a python file-like object.
|
||||
if file already exists at path, an exception is raised.
|
||||
"""
|
||||
if not isinstance(save_obj, Mapping):
|
||||
raise ValueError("Object to be saved must be a dictionary")
|
||||
|
||||
with h5py.File(path, "w-") as f:
|
||||
_dfs_save(f, save_obj)
|
||||
|
||||
|
||||
def _dfs_load(group, load_obj):
|
||||
"""Recursively go over each level in the hdf5 group and load the values into the given dictionary"""
|
||||
|
||||
for key in group:
|
||||
if isinstance(group[key], h5py.Group):
|
||||
load_obj[key] = {}
|
||||
_dfs_load(group[key], load_obj[key])
|
||||
else:
|
||||
load_obj[key] = group[key][()]
|
||||
|
||||
|
||||
def load(path, key=None):
|
||||
"""Loads the data stored in the binary file specified at the given path into a dictionary and returns it.
|
||||
|
||||
Loads the data from an hdf5 file specified at the given path into a python dictionary.
|
||||
Loaded dictionary contains numpy equivalents of python data types. For example:
|
||||
PyTorch tensor -> saved as a numpy array and loaded as a numpy array.
|
||||
bool -> saved as a numpy bool and loaded as a numpy bool
|
||||
If a '/' separated key is provided, the value at that hierarchical level in the hdf5 group is returned.
|
||||
|
||||
Args:
|
||||
path: string representation to a file path or a python file-like object.
|
||||
if file does not already exist at path, an exception is raised.
|
||||
key: '/' separated representation of the hierarchy level value that needs to be returned/
|
||||
for example, if the saved binary file has structure {a: {b: x, c:y}} and the user would like
|
||||
to query the value for c, the key provided should be 'a/c'.
|
||||
the default value of None for key implies that the entire hdf5 file structure needs to be loaded into a dictionary and returned.
|
||||
|
||||
Returns:
|
||||
a dictionary loaded from the specified binary hdf5 file.
|
||||
"""
|
||||
if not h5py.is_hdf5(path):
|
||||
raise ValueError(f"{path} is not an hdf5 file or a python file-like object.")
|
||||
|
||||
load_obj = {}
|
||||
with h5py.File(path, "r") as f:
|
||||
if key:
|
||||
f = f[key] # noqa: PLW2901
|
||||
if isinstance(f, h5py.Dataset):
|
||||
return f[()]
|
||||
|
||||
_dfs_load(f, load_obj)
|
||||
|
||||
return load_obj
|
||||
|
||||
|
||||
def to_serialized_hex(user_dict):
|
||||
"""Serialize the user_dict and convert the serialized bytes to a hex string and return"""
|
||||
|
||||
return pickle.dumps(user_dict).hex()
|
||||
|
||||
|
||||
def from_serialized_hex(serialized_hex):
|
||||
"""Convert serialized_hex to bytes and deserialize it and return"""
|
||||
|
||||
# serialized_hex can be either a regular string or a byte string.
|
||||
# if it is a byte string, convert to regular string using decode()
|
||||
# if it is a regular string, do nothing to it
|
||||
try: # noqa: SIM105
|
||||
serialized_hex = serialized_hex.decode()
|
||||
except AttributeError:
|
||||
pass
|
||||
return pickle.loads(bytes.fromhex(serialized_hex))
|
|
@ -6,11 +6,9 @@
|
|||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps # noqa: F401
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnx import TensorProto # noqa: F401
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
|
@ -23,16 +21,6 @@ def get_device_index(device):
|
|||
return 0 if device.index is None else device.index
|
||||
|
||||
|
||||
def get_device_index_from_input(input):
|
||||
"""Returns device index from a input PyTorch Tensor"""
|
||||
|
||||
if isinstance(input, (list, tuple)):
|
||||
device_index = get_device_index(input[0].device)
|
||||
else:
|
||||
device_index = get_device_index(input.device)
|
||||
return device_index
|
||||
|
||||
|
||||
def get_device_str(device):
|
||||
if isinstance(device, str):
|
||||
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
|
||||
|
@ -50,24 +38,6 @@ def get_device_str(device):
|
|||
return device
|
||||
|
||||
|
||||
def get_all_gradients_finite_name_from_session(session):
|
||||
"""Find all_gradients_finite node on Session graph and return its name"""
|
||||
|
||||
nodes = [x for x in session._outputs_meta if "all_gradients_finite" in x.name]
|
||||
if len(nodes) != 1:
|
||||
raise RuntimeError("'all_gradients_finite' node not found within training session")
|
||||
return nodes[0].name
|
||||
|
||||
|
||||
def get_gradient_accumulation_name_from_session(session):
|
||||
"""Find Group_Accumulated_Gradients node on Session graph and return its name"""
|
||||
|
||||
nodes = [x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name]
|
||||
if len(nodes) != 1:
|
||||
raise RuntimeError("'Group_Accumulated_Gradients' node not found within training session")
|
||||
return nodes[0].name
|
||||
|
||||
|
||||
def dtype_torch_to_numpy(torch_dtype):
|
||||
"""Converts PyTorch types to Numpy types
|
||||
|
||||
|
@ -232,111 +202,3 @@ def import_module_from_file(file_path, module_name=None):
|
|||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def state_dict_model_key():
|
||||
"""Returns the model key name in the state dictionary"""
|
||||
|
||||
return "model"
|
||||
|
||||
|
||||
def state_dict_optimizer_key():
|
||||
"""Returns the optimizer key name in the state dictionary"""
|
||||
|
||||
return "optimizer"
|
||||
|
||||
|
||||
def state_dict_partition_info_key():
|
||||
"""Returns the partition info key name in the state dictionary"""
|
||||
|
||||
return "partition_info"
|
||||
|
||||
|
||||
def state_dict_trainer_options_key():
|
||||
"""Returns the trainer options key name in the state dictionary"""
|
||||
|
||||
return "trainer_options"
|
||||
|
||||
|
||||
def state_dict_full_precision_key():
|
||||
"""Returns the full precision key name in the state dictionary"""
|
||||
|
||||
return "full_precision"
|
||||
|
||||
|
||||
def state_dict_original_dimension_key():
|
||||
"""Returns the original dimension key name in the state dictionary"""
|
||||
|
||||
return "original_dim"
|
||||
|
||||
|
||||
def state_dict_sharded_optimizer_keys():
|
||||
"""Returns the optimizer key names that can be sharded in the state dictionary"""
|
||||
|
||||
return {"Moment_1", "Moment_2"}
|
||||
|
||||
|
||||
def state_dict_user_dict_key():
|
||||
"""Returns the user dict key name in the state dictionary"""
|
||||
|
||||
return "user_dict"
|
||||
|
||||
|
||||
def state_dict_trainer_options_mixed_precision_key():
|
||||
"""Returns the trainer options mixed precision key name in the state dictionary"""
|
||||
|
||||
return "mixed_precision"
|
||||
|
||||
|
||||
def state_dict_trainer_options_zero_stage_key():
|
||||
"""Returns the trainer options zero_stage key name in the state dictionary"""
|
||||
|
||||
return "zero_stage"
|
||||
|
||||
|
||||
def state_dict_trainer_options_world_rank_key():
|
||||
"""Returns the trainer options world_rank key name in the state dictionary"""
|
||||
|
||||
return "world_rank"
|
||||
|
||||
|
||||
def state_dict_trainer_options_world_size_key():
|
||||
"""Returns the trainer options world_size key name in the state dictionary"""
|
||||
|
||||
return "world_size"
|
||||
|
||||
|
||||
def state_dict_trainer_options_data_parallel_size_key():
|
||||
"""Returns the trainer options data_parallel_size key name in the state dictionary"""
|
||||
|
||||
return "data_parallel_size"
|
||||
|
||||
|
||||
def state_dict_trainer_options_horizontal_parallel_size_key():
|
||||
"""Returns the trainer options horizontal_parallel_size key name in the state dictionary"""
|
||||
|
||||
return "horizontal_parallel_size"
|
||||
|
||||
|
||||
def state_dict_trainer_options_optimizer_name_key():
|
||||
"""Returns the trainer options optimizer_name key name in the state dictionary"""
|
||||
|
||||
return "optimizer_name"
|
||||
|
||||
|
||||
def state_dict_train_step_info_key():
|
||||
"""Returns the train step info key name in the state dictionary"""
|
||||
|
||||
return "train_step_info"
|
||||
|
||||
|
||||
def state_dict_train_step_info_optimization_step_key():
|
||||
"""Returns the train step info optimization step key name in the state dictionary"""
|
||||
|
||||
return "optimization_step"
|
||||
|
||||
|
||||
def state_dict_train_step_info_step_key():
|
||||
"""Returns the train step info step key name in the state dictionary"""
|
||||
|
||||
return "step"
|
||||
|
|
|
@ -1,748 +0,0 @@
|
|||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
|
||||
from . import _checkpoint_storage, _utils
|
||||
|
||||
################################################################################
|
||||
# Experimental Checkpoint APIs
|
||||
################################################################################
|
||||
|
||||
|
||||
def experimental_state_dict(ort_trainer, include_optimizer_state=True):
|
||||
warnings.warn(
|
||||
"experimental_state_dict() will be deprecated soon. Please use ORTTrainer.state_dict() instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if not ort_trainer._training_session:
|
||||
warnings.warn(
|
||||
"ONNX Runtime training session is not initialized yet. "
|
||||
"Please run train_step or eval_step at least once before calling state_dict()."
|
||||
)
|
||||
return ort_trainer._state_dict
|
||||
|
||||
# extract trained weights
|
||||
session_state = ort_trainer._training_session.get_state()
|
||||
torch_state = {}
|
||||
for name in session_state:
|
||||
torch_state[name] = torch.from_numpy(session_state[name])
|
||||
|
||||
# extract untrained weights and buffer
|
||||
for n in ort_trainer._onnx_model.graph.initializer:
|
||||
if n.name not in torch_state and n.name in ort_trainer.options.utils.frozen_weights:
|
||||
torch_state[n.name] = torch.from_numpy(np.array(onnx.numpy_helper.to_array(n)))
|
||||
|
||||
# Need to remove redundant (optimizer) initializers to map back to original torch state names
|
||||
if not include_optimizer_state and ort_trainer._torch_state_dict_keys:
|
||||
return {key: torch_state[key] for key in ort_trainer._torch_state_dict_keys if key in torch_state}
|
||||
return torch_state
|
||||
|
||||
|
||||
def experimental_load_state_dict(ort_trainer, state_dict, strict=False):
|
||||
warnings.warn(
|
||||
"experimental_load_state_dict() will be deprecated soon. Please use ORTTrainer.load_state_dict() instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
# Note: It may happen ONNX model has not yet been initialized
|
||||
# In this case we cache a reference to desired state and delay the restore until after initialization
|
||||
# Unexpected behavior will result if the user changes the reference before initialization
|
||||
if not ort_trainer._training_session:
|
||||
ort_trainer._state_dict = state_dict
|
||||
ort_trainer._load_state_dict_strict = strict
|
||||
return
|
||||
|
||||
# Update onnx model from loaded state dict
|
||||
cur_initializers_names = [n.name for n in ort_trainer._onnx_model.graph.initializer]
|
||||
new_initializers = {}
|
||||
|
||||
for name in state_dict:
|
||||
if name in cur_initializers_names:
|
||||
new_initializers[name] = state_dict[name].numpy()
|
||||
elif strict:
|
||||
raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.")
|
||||
|
||||
ort_trainer._update_onnx_model_initializers(new_initializers)
|
||||
|
||||
# create new session based on updated onnx model
|
||||
ort_trainer._state_dict = None
|
||||
ort_trainer._init_session()
|
||||
|
||||
# load training state
|
||||
session_state = {name: state_dict[name].numpy() for name in state_dict}
|
||||
ort_trainer._training_session.load_state(session_state, strict)
|
||||
|
||||
|
||||
def experimental_save_checkpoint(
|
||||
ort_trainer,
|
||||
checkpoint_dir,
|
||||
checkpoint_prefix="ORT_checkpoint",
|
||||
checkpoint_state_dict=None,
|
||||
include_optimizer_state=True,
|
||||
):
|
||||
warnings.warn(
|
||||
"experimental_save_checkpoint() will be deprecated soon. Please use ORTTrainer.save_checkpoint() instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if checkpoint_state_dict is None:
|
||||
checkpoint_state_dict = {"model": experimental_state_dict(ort_trainer, include_optimizer_state)}
|
||||
else:
|
||||
checkpoint_state_dict.update({"model": experimental_state_dict(ort_trainer, include_optimizer_state)})
|
||||
|
||||
assert os.path.exists(checkpoint_dir), f"checkpoint_dir ({checkpoint_dir}) directory doesn't exist"
|
||||
|
||||
checkpoint_name = _get_checkpoint_name(
|
||||
checkpoint_prefix,
|
||||
ort_trainer.options.distributed.deepspeed_zero_optimization.stage,
|
||||
ort_trainer.options.distributed.world_rank,
|
||||
ort_trainer.options.distributed.world_size,
|
||||
)
|
||||
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
|
||||
if os.path.exists(checkpoint_file):
|
||||
msg = f"{checkpoint_file} already exists, overwriting."
|
||||
warnings.warn(msg)
|
||||
torch.save(checkpoint_state_dict, checkpoint_file)
|
||||
|
||||
|
||||
def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False):
|
||||
warnings.warn(
|
||||
"experimental_load_checkpoint() will be deprecated soon. Please use ORTTrainer.load_checkpoint() instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
|
||||
is_partitioned = False
|
||||
if len(checkpoint_files) > 1:
|
||||
msg = (
|
||||
f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}."
|
||||
" Attempting to load ZeRO checkpoint."
|
||||
)
|
||||
warnings.warn(msg)
|
||||
is_partitioned = True
|
||||
if (not ort_trainer.options.distributed.deepspeed_zero_optimization.stage) and is_partitioned:
|
||||
return _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict)
|
||||
else:
|
||||
return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict)
|
||||
|
||||
|
||||
class _AGGREGATION_MODE(Enum): # noqa: N801
|
||||
Zero = 0
|
||||
Megatron = 1
|
||||
|
||||
|
||||
def _order_paths(paths, D_groups, H_groups):
|
||||
"""Reorders the given paths in order of aggregation of ranks for D and H parallellism respectively
|
||||
and returns the ordered dict"""
|
||||
|
||||
trainer_options_path_tuples = []
|
||||
world_rank = _utils.state_dict_trainer_options_world_rank_key()
|
||||
|
||||
for path in paths:
|
||||
trainer_options_path_tuples.append(
|
||||
(_checkpoint_storage.load(path, key=_utils.state_dict_trainer_options_key()), path)
|
||||
)
|
||||
|
||||
# sort paths according to rank
|
||||
sorted_paths = [
|
||||
path
|
||||
for _, path in sorted(
|
||||
trainer_options_path_tuples, key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank]
|
||||
)
|
||||
]
|
||||
|
||||
ordered_paths = dict()
|
||||
ordered_paths["D"] = [[sorted_paths[i] for i in D_groups[group_id]] for group_id in range(len(D_groups))]
|
||||
ordered_paths["H"] = [[sorted_paths[i] for i in H_groups[group_id]] for group_id in range(len(H_groups))]
|
||||
|
||||
return ordered_paths
|
||||
|
||||
|
||||
def _add_or_update_sharded_key(
|
||||
state_key, state_value, state_sub_dict, model_state_key, state_partition_info, sharded_states_original_dims, mode
|
||||
):
|
||||
"""Add or update the record for the sharded state_key in the state_sub_dict"""
|
||||
|
||||
# record the original dimension for this state
|
||||
original_dim = _utils.state_dict_original_dimension_key()
|
||||
sharded_states_original_dims[model_state_key] = state_partition_info[original_dim]
|
||||
|
||||
axis = 0
|
||||
if mode == _AGGREGATION_MODE.Megatron and state_partition_info["megatron_row_partition"] == 0:
|
||||
axis = -1
|
||||
|
||||
if state_key in state_sub_dict:
|
||||
# state_dict already contains a record for this state
|
||||
# since this state is sharded, concatenate the state value to
|
||||
# the record in the state_dict
|
||||
state_sub_dict[state_key] = np.concatenate((state_sub_dict[state_key], state_value), axis)
|
||||
else:
|
||||
# create a new entry for this state in the state_dict
|
||||
state_sub_dict[state_key] = state_value
|
||||
|
||||
|
||||
def _add_or_validate_unsharded_key(state_key, state_value, state_sub_dict, mismatch_error_string):
|
||||
"""Add or validate the record for the unsharded state_key in the state_sub_dict"""
|
||||
|
||||
if state_key in state_sub_dict:
|
||||
# state_dict already contains a record for this unsharded state.
|
||||
# assert that all values are the same for this previously loaded state
|
||||
assert (state_sub_dict[state_key] == state_value).all(), mismatch_error_string
|
||||
else:
|
||||
# create a new entry for this state in the state_sub_dict
|
||||
state_sub_dict[state_key] = state_value
|
||||
|
||||
|
||||
def _aggregate_model_states(
|
||||
rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled, mode=_AGGREGATION_MODE.Zero
|
||||
):
|
||||
"""Aggregates all model states from the rank_state_dict into state_dict"""
|
||||
|
||||
model = _utils.state_dict_model_key()
|
||||
full_precision = _utils.state_dict_full_precision_key()
|
||||
partition_info = _utils.state_dict_partition_info_key()
|
||||
|
||||
# if there are no model states in the rank_state_dict, no model aggregation is needed
|
||||
if model not in rank_state_dict:
|
||||
return
|
||||
|
||||
if model not in state_dict:
|
||||
state_dict[model] = {}
|
||||
|
||||
if full_precision not in state_dict[model]:
|
||||
state_dict[model][full_precision] = {}
|
||||
|
||||
# iterate over all model state keys
|
||||
for model_state_key, model_state_value in rank_state_dict[model][full_precision].items():
|
||||
# ZERO: full precision model states are sharded only when they exist in the partition_info subdict and mixed
|
||||
# precision training was enabled. for full precision training, full precision model states are not sharded
|
||||
# MEGATRON : full precision model states are sharded when they exist in the partition_info subdict
|
||||
if (model_state_key in rank_state_dict[partition_info]) and (
|
||||
mode == _AGGREGATION_MODE.Megatron or mixed_precision_enabled
|
||||
):
|
||||
# this model state is sharded
|
||||
_add_or_update_sharded_key(
|
||||
model_state_key,
|
||||
model_state_value,
|
||||
state_dict[model][full_precision],
|
||||
model_state_key,
|
||||
rank_state_dict[partition_info][model_state_key],
|
||||
sharded_states_original_dims,
|
||||
mode,
|
||||
)
|
||||
else:
|
||||
# this model state is not sharded since a record for it does not exist in the partition_info subdict
|
||||
_add_or_validate_unsharded_key(
|
||||
model_state_key,
|
||||
model_state_value,
|
||||
state_dict[model][full_precision],
|
||||
f"Value mismatch for model state {model_state_key}",
|
||||
)
|
||||
|
||||
|
||||
def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode=_AGGREGATION_MODE.Zero):
|
||||
"""Aggregates all optimizer states from the rank_state_dict into state_dict"""
|
||||
|
||||
optimizer = _utils.state_dict_optimizer_key()
|
||||
partition_info = _utils.state_dict_partition_info_key()
|
||||
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()
|
||||
|
||||
# if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed
|
||||
if optimizer not in rank_state_dict:
|
||||
return
|
||||
|
||||
if optimizer not in state_dict:
|
||||
state_dict[optimizer] = {}
|
||||
|
||||
# iterate over all optimizer state keys
|
||||
for model_state_key, optimizer_dict in rank_state_dict[optimizer].items():
|
||||
for optimizer_key, optimizer_value in optimizer_dict.items():
|
||||
if model_state_key not in state_dict[optimizer]:
|
||||
state_dict[optimizer][model_state_key] = {}
|
||||
|
||||
if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]:
|
||||
# this optimizer state is sharded since a record exists in the partition_info subdict
|
||||
_add_or_update_sharded_key(
|
||||
optimizer_key,
|
||||
optimizer_value,
|
||||
state_dict[optimizer][model_state_key],
|
||||
model_state_key,
|
||||
rank_state_dict[partition_info][model_state_key],
|
||||
sharded_states_original_dims,
|
||||
mode,
|
||||
)
|
||||
else:
|
||||
# this optimizer state is not sharded since a record for it does not exist in the partition_info subdict
|
||||
# or this optimizer key is not one of the sharded optimizer keys
|
||||
_add_or_validate_unsharded_key(
|
||||
optimizer_key,
|
||||
optimizer_value,
|
||||
state_dict[optimizer][model_state_key],
|
||||
f"Value mismatch for model state {model_state_key} and optimizer state {optimizer_key}",
|
||||
)
|
||||
|
||||
|
||||
def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled):
|
||||
"""Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims"""
|
||||
|
||||
model = _utils.state_dict_model_key()
|
||||
full_precision = _utils.state_dict_full_precision_key()
|
||||
optimizer = _utils.state_dict_optimizer_key()
|
||||
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()
|
||||
|
||||
for sharded_state_key, original_dim in sharded_states_original_dims.items():
|
||||
# reshape model states to original_dim only when mixed precision is enabled
|
||||
if mixed_precision_enabled and (model in state_dict):
|
||||
state_dict[model][full_precision][sharded_state_key] = state_dict[model][full_precision][
|
||||
sharded_state_key
|
||||
].reshape(original_dim)
|
||||
|
||||
# reshape optimizer states to original_dim
|
||||
if optimizer in state_dict:
|
||||
for optimizer_key, optimizer_value in state_dict[optimizer][sharded_state_key].items():
|
||||
if optimizer_key in sharded_optimizer_keys:
|
||||
state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim)
|
||||
|
||||
|
||||
def _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation):
|
||||
"""Extracts trainer options from rank_state_dict and loads them accordingly on state_dict"""
|
||||
trainer_options = _utils.state_dict_trainer_options_key()
|
||||
state_dict[trainer_options] = {}
|
||||
|
||||
mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
|
||||
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
|
||||
world_rank = _utils.state_dict_trainer_options_world_rank_key()
|
||||
world_size = _utils.state_dict_trainer_options_world_size_key()
|
||||
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()
|
||||
D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806
|
||||
H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806
|
||||
|
||||
state_dict[trainer_options][mixed_precision] = rank_state_dict[trainer_options][mixed_precision]
|
||||
state_dict[trainer_options][zero_stage] = 0
|
||||
state_dict[trainer_options][world_rank] = rank_state_dict[trainer_options][world_rank] if partial_aggregation else 0
|
||||
state_dict[trainer_options][world_size] = 1
|
||||
state_dict[trainer_options][optimizer_name] = rank_state_dict[trainer_options][optimizer_name]
|
||||
state_dict[trainer_options][D_size] = 1
|
||||
state_dict[trainer_options][H_size] = 1
|
||||
|
||||
|
||||
def _aggregate_megatron_partition_info(rank_state_dict, state_dict):
|
||||
"""Extracts partition_info from rank_state_dict and loads on state_dict for megatron-partitioned weights"""
|
||||
partition_info = _utils.state_dict_partition_info_key()
|
||||
if partition_info not in state_dict:
|
||||
state_dict[partition_info] = {}
|
||||
|
||||
rank_partition_info = rank_state_dict[partition_info]
|
||||
for model_state_key, partition_info_dict in rank_partition_info.items():
|
||||
if model_state_key not in state_dict[partition_info]:
|
||||
# add partition info only if weight is megatron partitioned
|
||||
if partition_info_dict["megatron_row_partition"] >= 0:
|
||||
state_dict[partition_info][model_state_key] = partition_info_dict
|
||||
|
||||
|
||||
def _to_pytorch_format(state_dict):
|
||||
"""Convert ORT state dictionary schema (hierarchical structure) to PyTorch state dictionary schema (flat structure)"""
|
||||
|
||||
pytorch_state_dict = {}
|
||||
for model_state_key, model_state_value in state_dict[_utils.state_dict_model_key()][
|
||||
_utils.state_dict_full_precision_key()
|
||||
].items():
|
||||
# convert numpy array to a torch tensor
|
||||
pytorch_state_dict[model_state_key] = torch.tensor(model_state_value)
|
||||
return pytorch_state_dict
|
||||
|
||||
|
||||
def _get_parallellism_groups(data_parallel_size, horizontal_parallel_size, world_size):
|
||||
"""Returns the D and H groups for the given sizes"""
|
||||
num_data_groups = world_size // data_parallel_size
|
||||
data_groups = []
|
||||
for data_group_id in range(num_data_groups):
|
||||
data_group_ranks = []
|
||||
for r in range(data_parallel_size):
|
||||
data_group_ranks.append(data_group_id + horizontal_parallel_size * r)
|
||||
data_groups.append(data_group_ranks)
|
||||
|
||||
num_horizontal_groups = world_size // horizontal_parallel_size
|
||||
horizontal_groups = []
|
||||
for hori_group_id in range(num_horizontal_groups):
|
||||
hori_group_ranks = []
|
||||
for r in range(horizontal_parallel_size):
|
||||
hori_group_ranks.append(hori_group_id * horizontal_parallel_size + r)
|
||||
horizontal_groups.append(hori_group_ranks)
|
||||
|
||||
return data_groups, horizontal_groups
|
||||
|
||||
|
||||
def _aggregate_over_ranks(
|
||||
ordered_paths,
|
||||
ranks,
|
||||
sharded_states_original_dims=None,
|
||||
mode=_AGGREGATION_MODE.Zero,
|
||||
partial_aggregation=False,
|
||||
pytorch_format=True,
|
||||
):
|
||||
"""Aggregate checkpoint files over set of ranks and return a single state dictionary
|
||||
|
||||
Args:
|
||||
ordered_paths: list of paths in the order in which they must be aggregated
|
||||
ranks: list of ranks that are to be aggregated
|
||||
sharded_states_original_dims: dict containing the original dims for sharded states that are persisted over
|
||||
multiple calls to _aggregate_over_ranks()
|
||||
mode: mode of aggregation: Zero or Megatron
|
||||
partial_aggregation: boolean flag to indicate whether to produce a partially
|
||||
aggregated state which can be further aggregated over
|
||||
pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict
|
||||
Returns:
|
||||
state_dict that can be loaded into an ORTTrainer or into a PyTorch model
|
||||
"""
|
||||
state_dict = {}
|
||||
if sharded_states_original_dims is None:
|
||||
sharded_states_original_dims = dict()
|
||||
world_rank = _utils.state_dict_trainer_options_world_rank_key()
|
||||
mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
|
||||
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
|
||||
world_size = _utils.state_dict_trainer_options_world_size_key()
|
||||
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()
|
||||
|
||||
loaded_mixed_precision = None
|
||||
loaded_world_size = None
|
||||
loaded_zero_stage = None
|
||||
loaded_optimizer_name = None
|
||||
|
||||
for i, path in enumerate(ordered_paths):
|
||||
rank_state_dict = _checkpoint_storage.load(path)
|
||||
|
||||
assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info"
|
||||
assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options"
|
||||
assert (
|
||||
ranks[i] == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank]
|
||||
), "Unexpected rank in file at path {}. Expected {}, got {}".format(
|
||||
path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] # noqa: F821
|
||||
)
|
||||
if loaded_mixed_precision is None:
|
||||
loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
|
||||
else:
|
||||
assert (
|
||||
loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
|
||||
), f"Mixed precision state mismatch among checkpoint files. File: {path}"
|
||||
if loaded_world_size is None:
|
||||
loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size]
|
||||
else:
|
||||
assert (
|
||||
loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size]
|
||||
), f"World size state mismatch among checkpoint files. File: {path}"
|
||||
if loaded_zero_stage is None:
|
||||
loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage]
|
||||
else:
|
||||
assert (
|
||||
loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage]
|
||||
), f"Zero stage mismatch among checkpoint files. File: {path}"
|
||||
if loaded_optimizer_name is None:
|
||||
loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]
|
||||
else:
|
||||
assert (
|
||||
loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]
|
||||
), f"Optimizer name mismatch among checkpoint files. File: {path}"
|
||||
|
||||
# aggregate all model states
|
||||
_aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision, mode)
|
||||
|
||||
if not pytorch_format:
|
||||
# aggregate all optimizer states if pytorch_format is False
|
||||
_aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode)
|
||||
|
||||
# for D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups
|
||||
# to aggregate over Zero, and another pass to aggregate Megatron partitioned
|
||||
# states. Preserve the relevant partition info only for weights that are megatron partitioned for
|
||||
# a partial aggregation call
|
||||
if partial_aggregation:
|
||||
_aggregate_megatron_partition_info(rank_state_dict, state_dict)
|
||||
|
||||
# entry for trainer_options in the state_dict to perform other sanity checks
|
||||
if _utils.state_dict_trainer_options_key() not in state_dict:
|
||||
_aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation)
|
||||
|
||||
# entry for user_dict in the state_dict if not already present
|
||||
if (
|
||||
_utils.state_dict_user_dict_key() not in state_dict
|
||||
and _utils.state_dict_user_dict_key() in rank_state_dict
|
||||
):
|
||||
state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()]
|
||||
|
||||
# for a partial aggregation scenario, we might not have the entire tensor aggregated yet, thus skip reshape
|
||||
if not partial_aggregation:
|
||||
# reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims
|
||||
_reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision)
|
||||
|
||||
# return a flat structure for PyTorch model in case pytorch_format is True
|
||||
# else return the hierarchical structure for ORTTrainer
|
||||
return _to_pytorch_format(state_dict) if pytorch_format else state_dict
|
||||
|
||||
|
||||
def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format): # noqa: N802
|
||||
"""Aggregate checkpoint files and return a single state dictionary for the D+H
|
||||
(Zero+Megatron) partitioning strategy.
|
||||
For D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups
|
||||
to aggregate over Zero, and another pass over the previously aggregated states
|
||||
to aggregate Megatron partitioned states.
|
||||
"""
|
||||
sharded_states_original_dims = {}
|
||||
aggregate_data_checkpoint_files = []
|
||||
|
||||
# combine for Zero over data groups and save to temp file
|
||||
with tempfile.TemporaryDirectory() as save_dir:
|
||||
for group_id, d_group in enumerate(D_groups):
|
||||
aggregate_state_dict = _aggregate_over_ranks(
|
||||
ordered_paths["D"][group_id],
|
||||
d_group,
|
||||
sharded_states_original_dims,
|
||||
partial_aggregation=True,
|
||||
pytorch_format=False,
|
||||
)
|
||||
|
||||
filename = "ort.data_group." + str(group_id) + ".ort.pt"
|
||||
filepath = os.path.join(save_dir, filename)
|
||||
_checkpoint_storage.save(aggregate_state_dict, filepath)
|
||||
aggregate_data_checkpoint_files.append(filepath)
|
||||
|
||||
assert len(aggregate_data_checkpoint_files) > 0
|
||||
|
||||
# combine for megatron:
|
||||
aggregate_state = _aggregate_over_ranks(
|
||||
aggregate_data_checkpoint_files,
|
||||
H_groups[0],
|
||||
sharded_states_original_dims,
|
||||
mode=_AGGREGATION_MODE.Megatron,
|
||||
pytorch_format=pytorch_format,
|
||||
)
|
||||
|
||||
return aggregate_state
|
||||
|
||||
|
||||
def aggregate_checkpoints(paths, pytorch_format=True):
|
||||
"""Aggregate checkpoint files and return a single state dictionary
|
||||
|
||||
Aggregates checkpoint files specified by paths and loads them one at a time, merging
|
||||
them into a single state dictionary.
|
||||
The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function.
|
||||
The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict()
|
||||
|
||||
Args:
|
||||
paths: list of more than one file represented as strings where the checkpoint is saved
|
||||
pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict
|
||||
Returns:
|
||||
state_dict that can be loaded into an ORTTrainer or into a PyTorch model
|
||||
"""
|
||||
|
||||
loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key())
|
||||
D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806
|
||||
H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806
|
||||
world_size = _utils.state_dict_trainer_options_world_size_key()
|
||||
|
||||
D_size = loaded_trainer_options[D_size] # noqa: N806
|
||||
H_size = loaded_trainer_options[H_size] # noqa: N806
|
||||
world_size = loaded_trainer_options[world_size]
|
||||
D_groups, H_groups = _get_parallellism_groups(D_size, H_size, world_size) # noqa: N806
|
||||
|
||||
combine_zero = loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0
|
||||
combine_megatron = len(H_groups[0]) > 1
|
||||
|
||||
# order the paths in the order of groups in which they must be aggregated according to
|
||||
# data-parallel groups and H-parallel groups obtained
|
||||
# eg: {'D': [[path_0, path_2],[path_1, path_3]], 'H': [[path_0, path_1],[path_2, path_3]]}
|
||||
ordered_paths = _order_paths(paths, D_groups, H_groups)
|
||||
|
||||
aggregate_state = None
|
||||
if combine_zero and combine_megatron:
|
||||
aggregate_state = _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format)
|
||||
elif combine_zero:
|
||||
aggregate_state = _aggregate_over_ranks(
|
||||
ordered_paths["D"][0], D_groups[0], mode=_AGGREGATION_MODE.Zero, pytorch_format=pytorch_format
|
||||
)
|
||||
elif combine_megatron:
|
||||
aggregate_state = _aggregate_over_ranks(
|
||||
ordered_paths["H"][0], H_groups[0], mode=_AGGREGATION_MODE.Megatron, pytorch_format=pytorch_format
|
||||
)
|
||||
|
||||
return aggregate_state
|
||||
|
||||
|
||||
################################################################################
|
||||
# Helper functions
|
||||
################################################################################
|
||||
|
||||
|
||||
def _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict):
|
||||
checkpoint_name = _get_checkpoint_name(
|
||||
checkpoint_prefix,
|
||||
is_partitioned,
|
||||
ort_trainer.options.distributed.world_rank,
|
||||
ort_trainer.options.distributed.world_size,
|
||||
)
|
||||
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
|
||||
|
||||
if is_partitioned:
|
||||
assert_msg = (
|
||||
f"Couldn't find checkpoint file {checkpoint_file}."
|
||||
" Optimizer partitioning is enabled using ZeRO. Please make sure the checkpoint file exists "
|
||||
f"for rank {ort_trainer.options.distributed.world_rank} of {ort_trainer.options.distributed.world_size}"
|
||||
)
|
||||
else:
|
||||
assert_msg = f"Couldn't find checkpoint file {checkpoint_file}."
|
||||
assert os.path.exists(checkpoint_file), assert_msg
|
||||
|
||||
checkpoint_state = torch.load(checkpoint_file, map_location="cpu")
|
||||
experimental_load_state_dict(ort_trainer, checkpoint_state["model"], strict=strict)
|
||||
del checkpoint_state["model"]
|
||||
return checkpoint_state
|
||||
|
||||
|
||||
def _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict):
|
||||
checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
ckpt_agg = _CombineZeroCheckpoint(checkpoint_files)
|
||||
aggregate_state_dict = ckpt_agg.aggregate_checkpoints()
|
||||
|
||||
experimental_load_state_dict(ort_trainer, aggregate_state_dict, strict=strict)
|
||||
|
||||
# aggregate other keys in the state_dict.
|
||||
# Values will be overwritten for matching keys among workers
|
||||
all_checkpoint_states = dict()
|
||||
for checkpoint_file in checkpoint_files:
|
||||
checkpoint_state = torch.load(checkpoint_file, map_location="cpu")
|
||||
del checkpoint_state["model"]
|
||||
all_checkpoint_states.update(checkpoint_state)
|
||||
return all_checkpoint_states
|
||||
|
||||
|
||||
def _list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"):
|
||||
ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)]
|
||||
ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)]
|
||||
ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names]
|
||||
|
||||
assert len(ckpt_file_names) > 0, f"No checkpoint found with prefix '{checkpoint_prefix}' at '{checkpoint_dir}'"
|
||||
return ckpt_file_names
|
||||
|
||||
|
||||
def _get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None):
|
||||
SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806
|
||||
MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806
|
||||
|
||||
if is_partitioned:
|
||||
filename = MULTIPLE_CHECKPOINT_FILENAME.format(
|
||||
prefix=prefix, world_rank=world_rank, world_size=(world_size - 1)
|
||||
)
|
||||
else:
|
||||
filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix)
|
||||
return filename
|
||||
|
||||
|
||||
def _split_state_dict(state_dict):
|
||||
optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"]
|
||||
split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}}
|
||||
for k, v in state_dict.items():
|
||||
mode = "fp32_param"
|
||||
for optim_key in optimizer_keys:
|
||||
if k.startswith(optim_key):
|
||||
mode = "optimizer"
|
||||
break
|
||||
if k.endswith("_fp16"):
|
||||
mode = "fp16_param"
|
||||
split_sd[mode][k] = v
|
||||
return split_sd
|
||||
|
||||
|
||||
class _CombineZeroCheckpoint:
|
||||
def __init__(self, checkpoint_files, clean_state_dict=None):
|
||||
assert len(checkpoint_files) > 0, "No checkpoint files passed"
|
||||
self.checkpoint_files = checkpoint_files
|
||||
self.clean_state_dict = clean_state_dict
|
||||
self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1
|
||||
assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files"
|
||||
self.weight_shape_map = {}
|
||||
self.sharded_params = set()
|
||||
|
||||
def _split_name(self, name: str):
|
||||
name_split = name.split("_view_")
|
||||
view_num = None
|
||||
if len(name_split) > 1:
|
||||
view_num = int(name_split[1])
|
||||
optimizer_key = ""
|
||||
mp_suffix = ""
|
||||
if name_split[0].startswith("Moment_1"):
|
||||
optimizer_key = "Moment_1_"
|
||||
elif name_split[0].startswith("Moment_2"):
|
||||
optimizer_key = "Moment_2_"
|
||||
elif name_split[0].startswith("Update_Count"):
|
||||
optimizer_key = "Update_Count_"
|
||||
elif name_split[0].endswith("_fp16"):
|
||||
mp_suffix = "_fp16"
|
||||
param_name = name_split[0]
|
||||
if optimizer_key:
|
||||
param_name = param_name.split(optimizer_key)[1]
|
||||
param_name = param_name.split("_fp16")[0]
|
||||
return param_name, optimizer_key, view_num, mp_suffix
|
||||
|
||||
def _update_weight_statistics(self, name, value):
|
||||
if name not in self.weight_shape_map:
|
||||
self.weight_shape_map[name] = value.size() # original shape of tensor
|
||||
|
||||
def _reshape_tensor(self, key):
|
||||
value = self.aggregate_state_dict[key]
|
||||
weight_name, _, _, _ = self._split_name(key)
|
||||
set_size = self.weight_shape_map[weight_name]
|
||||
self.aggregate_state_dict[key] = value.reshape(set_size)
|
||||
|
||||
def _aggregate(self, param_dict):
|
||||
for k, v in param_dict.items():
|
||||
weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k)
|
||||
if view_num is not None:
|
||||
# parameter is sharded
|
||||
param_name = optimizer_key + weight_name + mp_suffix
|
||||
|
||||
if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]:
|
||||
self.sharded_params.add(param_name)
|
||||
# Found a previous shard of the param, concatenate shards ordered by ranks
|
||||
self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v))
|
||||
else:
|
||||
self.aggregate_state_dict[param_name] = v
|
||||
else:
|
||||
if k in self.aggregate_state_dict:
|
||||
assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value"
|
||||
else:
|
||||
self.aggregate_state_dict[k] = v
|
||||
self._update_weight_statistics(weight_name, v)
|
||||
|
||||
def aggregate_checkpoints(self):
|
||||
warnings.warn(
|
||||
"_CombineZeroCheckpoint.aggregate_checkpoints() will be deprecated soon. "
|
||||
"Please use aggregate_checkpoints() instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0]
|
||||
self.aggregate_state_dict = dict()
|
||||
|
||||
for i in range(self.world_size):
|
||||
checkpoint_name = _get_checkpoint_name(checkpoint_prefix, True, i, self.world_size)
|
||||
rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu"))
|
||||
if "model" in rank_state_dict:
|
||||
rank_state_dict = rank_state_dict["model"]
|
||||
|
||||
if self.clean_state_dict:
|
||||
rank_state_dict = self.clean_state_dict(rank_state_dict)
|
||||
|
||||
rank_state_dict = _split_state_dict(rank_state_dict)
|
||||
self._aggregate(rank_state_dict["fp16_param"])
|
||||
self._aggregate(rank_state_dict["fp32_param"])
|
||||
self._aggregate(rank_state_dict["optimizer"])
|
||||
|
||||
for k in self.sharded_params:
|
||||
self._reshape_tensor(k)
|
||||
return self.aggregate_state_dict
|
|
@ -1,408 +0,0 @@
|
|||
from collections import namedtuple
|
||||
|
||||
import cerberus
|
||||
import torch
|
||||
|
||||
from ._utils import static_vars
|
||||
|
||||
LEARNING_RATE_IO_DESCRIPTION_NAME = "__learning_rate"
|
||||
ALL_FINITE_IO_DESCRIPTION_NAME = "__all_finite"
|
||||
LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME = "__loss_scale_input_name"
|
||||
GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME = "__gradient_accumulation_name"
|
||||
|
||||
|
||||
class _ORTTrainerModelDesc:
|
||||
def __init__(self, model_desc):
|
||||
# Keep a copy of original input for debug
|
||||
self._original = dict(model_desc)
|
||||
|
||||
# Global counter used to validate occurrences of 'is_loss=True' whithin 'model_desc.outputs'
|
||||
# A stateless validator is used for each tuple, but validation accross the whole list of tuple is needed
|
||||
# because just one 'is_loss=True' is allowed withing 'model_desc.outputs' list of tuples
|
||||
_model_desc_outputs_validation.loss_counter = 0
|
||||
|
||||
# Used for logging purposes
|
||||
self._main_class_name = self.__class__.__name__
|
||||
|
||||
# Validates user input
|
||||
self._validated = dict(self._original)
|
||||
validator = cerberus.Validator(MODEL_DESC_SCHEMA)
|
||||
self._validated = validator.validated(self._validated)
|
||||
if self._validated is None:
|
||||
raise ValueError(f"Invalid model_desc: {validator.errors}")
|
||||
|
||||
# Normalize inputs to a list of namedtuple(name, shape)
|
||||
self._InputDescription = namedtuple("InputDescription", ["name", "shape"])
|
||||
self._InputDescriptionTyped = namedtuple("InputDescriptionTyped", ["name", "shape", "dtype"])
|
||||
for idx, input in enumerate(self._validated["inputs"]):
|
||||
self._validated["inputs"][idx] = self._InputDescription(*input)
|
||||
|
||||
# Normalize outputs to a list of namedtuple(name, shape, is_loss)
|
||||
self._OutputDescription = namedtuple("OutputDescription", ["name", "shape", "is_loss"])
|
||||
self._OutputDescriptionTyped = namedtuple(
|
||||
"OutputDescriptionTyped", ["name", "shape", "is_loss", "dtype", "dtype_amp"]
|
||||
)
|
||||
for idx, output in enumerate(self._validated["outputs"]):
|
||||
if len(output) == 2:
|
||||
self._validated["outputs"][idx] = self._OutputDescription(*output, False)
|
||||
else:
|
||||
self._validated["outputs"][idx] = self._OutputDescription(*output)
|
||||
|
||||
# Hard-code learning rate, all_finite descriptors
|
||||
self.learning_rate = self._InputDescriptionTyped(LEARNING_RATE_IO_DESCRIPTION_NAME, [1], torch.float32)
|
||||
|
||||
# Convert dict in object
|
||||
for k, v in self._validated.items():
|
||||
setattr(self, k, self._wrap(v))
|
||||
|
||||
def __repr__(self):
|
||||
"""Pretty representation for a model description class"""
|
||||
|
||||
pretty_msg = "Model description:\n"
|
||||
|
||||
# Inputs
|
||||
inputs = []
|
||||
for i_desc in self.inputs:
|
||||
if isinstance(i_desc, self._InputDescription):
|
||||
inputs.append(f"(name={i_desc.name}, shape={i_desc.shape})")
|
||||
elif isinstance(i_desc, self._InputDescriptionTyped):
|
||||
inputs.append(f"(name={i_desc.name}, shape={i_desc.shape}, dtype={i_desc.dtype})")
|
||||
else:
|
||||
raise ValueError(f"Unexpected type {type(i_desc)} for input description")
|
||||
|
||||
pretty_msg += "\nInputs:"
|
||||
for idx, item in enumerate(inputs):
|
||||
pretty_msg += f"\n\t{idx}: {item}"
|
||||
|
||||
# Outputs
|
||||
outputs = []
|
||||
for o_desc in self.outputs:
|
||||
if isinstance(o_desc, self._OutputDescription):
|
||||
outputs.append(f"(name={o_desc.name}, shape={o_desc.shape})")
|
||||
elif isinstance(o_desc, self._OutputDescriptionTyped):
|
||||
outputs.append(
|
||||
f"(name={o_desc.name}, shape={o_desc.shape}, dtype={o_desc.dtype}, dtype_amp={o_desc.dtype_amp})"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected type {type(o_desc)} for output description")
|
||||
pretty_msg += "\nOutputs:"
|
||||
for idx, item in enumerate(outputs):
|
||||
pretty_msg += f"\n\t{idx}: {item}"
|
||||
|
||||
# Learning rate
|
||||
if self.learning_rate:
|
||||
pretty_msg += "\nLearning rate: "
|
||||
pretty_msg += (
|
||||
f"(name={self.learning_rate.name}, shape={self.learning_rate.shape}, dtype={self.learning_rate.dtype})"
|
||||
)
|
||||
|
||||
# Mixed precision
|
||||
if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) or getattr(
|
||||
self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None
|
||||
):
|
||||
pretty_msg += "\nMixed Precision:"
|
||||
if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None):
|
||||
pretty_msg += "\n\tis gradients finite: "
|
||||
pretty_msg += (
|
||||
f"(name={self.all_finite.name}, shape={self.all_finite.shape}, dtype={self.all_finite.dtype})"
|
||||
)
|
||||
if getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None):
|
||||
pretty_msg += "\n\tloss scale input name: "
|
||||
pretty_msg += f"(name={self.loss_scale_input.name}, shape={self.loss_scale_input.shape}, dtype={self.loss_scale_input.dtype})"
|
||||
|
||||
# Gradient Accumulation steps
|
||||
if self.gradient_accumulation:
|
||||
pretty_msg += "\nGradient Accumulation: "
|
||||
pretty_msg += f"(name={self.gradient_accumulation.name}, shape={self.gradient_accumulation.shape}, dtype={self.gradient_accumulation.dtype})"
|
||||
|
||||
return pretty_msg
|
||||
|
||||
def add_type_to_input_description(self, index, dtype):
|
||||
"""Updates an existing input description at position 'index' with 'dtype' type information
|
||||
|
||||
Args:
|
||||
index (int): position within 'inputs' description
|
||||
dtype (torch.dtype): input data type
|
||||
"""
|
||||
|
||||
assert isinstance(index, int) and index >= 0, "input 'index' must be a positive int"
|
||||
assert isinstance(dtype, torch.dtype), "input 'dtype' must be a torch.dtype type"
|
||||
existing_values = (*self.inputs[index],)
|
||||
if isinstance(self.inputs[index], self._InputDescriptionTyped):
|
||||
existing_values = (*existing_values[:-1],)
|
||||
self.inputs[index] = self._InputDescriptionTyped(*existing_values, dtype)
|
||||
|
||||
def add_type_to_output_description(self, index, dtype, dtype_amp=None):
|
||||
"""Updates an existing output description at position 'index' with 'dtype' type information
|
||||
|
||||
Args:
|
||||
index (int): position within 'inputs' description
|
||||
dtype (torch.dtype): input data type
|
||||
dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision
|
||||
"""
|
||||
|
||||
assert isinstance(index, int) and index >= 0, "output 'index' must be a positive int"
|
||||
assert isinstance(dtype, torch.dtype), "output 'dtype' must be a torch.dtype type"
|
||||
assert dtype_amp is None or isinstance(
|
||||
dtype_amp, torch.dtype
|
||||
), "output 'dtype_amp' must be either None or torch.dtype type"
|
||||
existing_values = (*self.outputs[index],)
|
||||
if isinstance(self.outputs[index], self._OutputDescriptionTyped):
|
||||
existing_values = (*existing_values[:-2],)
|
||||
self.outputs[index] = self._OutputDescriptionTyped(*existing_values, dtype, dtype_amp)
|
||||
|
||||
@property
|
||||
def gradient_accumulation(self):
|
||||
return getattr(self, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, None)
|
||||
|
||||
@gradient_accumulation.setter
|
||||
def gradient_accumulation(self, name):
|
||||
self._add_output_description(
|
||||
self, name, [1], False, torch.bool, None, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, ignore_duplicate=True
|
||||
)
|
||||
|
||||
@property
|
||||
def all_finite(self):
|
||||
return getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None)
|
||||
|
||||
@all_finite.setter
|
||||
def all_finite(self, name):
|
||||
self._add_output_description(
|
||||
self, name, [1], False, torch.bool, None, ALL_FINITE_IO_DESCRIPTION_NAME, ignore_duplicate=True
|
||||
)
|
||||
|
||||
@property
|
||||
def loss_scale_input(self):
|
||||
return getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None)
|
||||
|
||||
@loss_scale_input.setter
|
||||
def loss_scale_input(self, name):
|
||||
self._add_input_description(
|
||||
self, name, [], torch.float32, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, ignore_duplicate=True
|
||||
)
|
||||
|
||||
def _add_input_description(self, node, name, shape, dtype=None, attr_name=None, ignore_duplicate=False):
|
||||
"""Add a new input description into the node object
|
||||
|
||||
If 'dtype' is specified, a typed input description namedtuple(name, shape, dtype) is created.
|
||||
Otherwise an untyped input description namedtuple(name, shape) is created instead.
|
||||
|
||||
Args:
|
||||
node (list or object): node to append input description to. When 'node' is 'self.inputs',
|
||||
a new input description is appended to the list.
|
||||
Otherwise, a new input description is created as an attribute into 'node' with name 'attr_name'
|
||||
name (str): name of input description
|
||||
shape (list): shape of input description
|
||||
dtype (torch.dtype): input data type
|
||||
attr_name (str, default is None): friendly name to allow direct access to the output description
|
||||
ignore_duplicate (bool, default is False): silently skips addition of duplicate inputs
|
||||
"""
|
||||
|
||||
assert isinstance(name, str) and len(name) > 0, "'name' is an invalid input name"
|
||||
not_found = True
|
||||
if not ignore_duplicate:
|
||||
if id(node) == id(self.inputs):
|
||||
not_found = all([name not in i_desc.name for i_desc in node])
|
||||
assert not_found, f"'name' {name} already exists in the inputs description"
|
||||
else:
|
||||
not_found = attr_name not in dir(self)
|
||||
assert not_found, f"'attr_name' {attr_name} already exists in the 'node'"
|
||||
elif not not_found:
|
||||
return
|
||||
assert isinstance(shape, list) and all(
|
||||
[(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape]
|
||||
), "'shape' must be a list of int or str with length at least 1"
|
||||
assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type"
|
||||
if dtype:
|
||||
new_input_desc = self._InputDescriptionTyped(name, shape, dtype)
|
||||
else:
|
||||
new_input_desc = self._InputDescription(name, shape)
|
||||
|
||||
if id(node) == id(self.inputs):
|
||||
self.inputs.append(new_input_desc)
|
||||
else:
|
||||
assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'"
|
||||
setattr(node, attr_name, new_input_desc)
|
||||
|
||||
def _add_output_description(
|
||||
self, node, name, shape, is_loss, dtype=None, dtype_amp=None, attr_name=None, ignore_duplicate=False
|
||||
):
|
||||
"""Add a new output description into the node object as a tuple
|
||||
|
||||
When (name, shape, is_loss, dtype) is specified, a typed output description is created
|
||||
Otherwise an untyped output description (name, shape, is_loss) is created instead
|
||||
|
||||
Args:
|
||||
node (list or object): node to append output description to. When 'node' is 'self.outputs',
|
||||
a new output description is appended to the list.
|
||||
Otherwise, a new output description is created as an attribute into 'node' with name 'attr_name'
|
||||
name (str): name of output description
|
||||
shape (list): shape of output description
|
||||
is_loss (bool): specifies whether this output is a loss
|
||||
dtype (torch.dtype): input data type
|
||||
dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision.
|
||||
attr_name (str, default is None): friendly name to allow direct access to the output description
|
||||
ignore_duplicate (bool, default is False): silently skips addition of duplicate outputs
|
||||
"""
|
||||
|
||||
assert isinstance(name, str) and len(name) > 0, "'name' is an invalid output name"
|
||||
assert isinstance(shape, list) and all(
|
||||
[(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape]
|
||||
), "'shape' must be a list of int or str with length at least 1"
|
||||
assert isinstance(is_loss, bool), "'is_loss' must be a bool"
|
||||
|
||||
not_found = True
|
||||
if not ignore_duplicate:
|
||||
if id(node) == id(self.outputs):
|
||||
not_found = all([name not in o_desc.name for o_desc in node])
|
||||
assert not_found, f"'name' {name} already exists in the outputs description"
|
||||
assert (
|
||||
all([not o_desc.is_loss for o_desc in node]) if is_loss else True
|
||||
), "Only one 'is_loss' is supported at outputs description"
|
||||
else:
|
||||
not_found = attr_name not in dir(self)
|
||||
assert not_found, f"'attr_name' {attr_name} already exists in the 'node'"
|
||||
elif not not_found:
|
||||
return
|
||||
|
||||
assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type"
|
||||
if dtype:
|
||||
new_output_desc = self._OutputDescriptionTyped(name, shape, is_loss, dtype, None)
|
||||
else:
|
||||
new_output_desc = self._OutputDescription(name, shape, is_loss)
|
||||
|
||||
if id(node) == id(self.outputs):
|
||||
self.outputs.append(new_output_desc)
|
||||
else:
|
||||
assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'"
|
||||
setattr(node, attr_name, new_output_desc)
|
||||
|
||||
def _wrap(self, v):
|
||||
"""Add 'v' as self's attribute to allow direct access as self.v"""
|
||||
if isinstance(v, (list)):
|
||||
return type(v)([self._wrap(v) for v in v])
|
||||
elif isinstance(
|
||||
v,
|
||||
(
|
||||
self._InputDescription,
|
||||
self._InputDescriptionTyped,
|
||||
self._OutputDescription,
|
||||
self._OutputDescriptionTyped,
|
||||
),
|
||||
):
|
||||
return v
|
||||
elif isinstance(v, (tuple)):
|
||||
return type(v)([self._wrap(v) for v in v])
|
||||
elif isinstance(v, (dict, int, float, bool, str)):
|
||||
return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type for model_desc ({v})."
|
||||
"Only int, float, bool, str, list, tuple and dict are supported"
|
||||
)
|
||||
|
||||
|
||||
class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc):
|
||||
r"""Internal class used by ONNX Runtime training backend for input validation
|
||||
|
||||
NOTE: Users MUST NOT use this class in any way!
|
||||
"""
|
||||
|
||||
def __init__(self, main_class_name, model_desc):
|
||||
# Used for logging purposes
|
||||
self._main_class_name = main_class_name
|
||||
|
||||
# Convert dict in object
|
||||
for k, v in dict(model_desc).items():
|
||||
setattr(self, k, self._wrap(v))
|
||||
|
||||
|
||||
def _model_desc_inputs_validation(field, value, error):
|
||||
r"""Cerberus custom check method for 'model_desc.inputs'
|
||||
|
||||
'model_desc.inputs' is a list of tuples.
|
||||
The list has variable length, but each tuple has size 2
|
||||
|
||||
The first element of the tuple is a string which represents the input name
|
||||
The second element is a list of shapes. Each shape must be either an int or string.
|
||||
Empty list represents a scalar output
|
||||
|
||||
Validation is done within each tuple to enforce the schema described above.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model_desc['inputs'] = [('input1', ['batch', 1024]),
|
||||
('input2', [])
|
||||
('input3', [512])]
|
||||
"""
|
||||
|
||||
if not isinstance(value, tuple) or len(value) != 2:
|
||||
error(field, "must be a tuple with size 2")
|
||||
if not isinstance(value[0], str):
|
||||
error(field, "the first element of the tuple (aka name) must be a string")
|
||||
if not isinstance(value[1], list):
|
||||
error(field, "the second element of the tuple (aka shape) must be a list")
|
||||
else:
|
||||
for shape in value[1]:
|
||||
if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool):
|
||||
error(field, "each shape must be either a string or integer")
|
||||
|
||||
|
||||
@static_vars(loss_counter=0)
|
||||
def _model_desc_outputs_validation(field, value, error):
|
||||
r"""Cerberus custom check method for 'model_desc.outputs'
|
||||
|
||||
'model_desc.outputs' is a list of tuples with variable length.
|
||||
The first element of the tuple is a string which represents the output name
|
||||
The second element is a list of shapes. Each shape must be either an int or string.
|
||||
Empty list represents a scalar output
|
||||
The third element is optional and is a flag that signals whether the output is a loss value
|
||||
|
||||
Validation is done within each tuple to enforce the schema described above, but also
|
||||
throughout the list of tuples to ensure a single 'is_loss=True' occurrence.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model_desc['outputs'] = [('output1', ['batch', 1024], is_loss=True),
|
||||
('output2', [], is_loss=False)
|
||||
('output3', [512])]
|
||||
"""
|
||||
|
||||
if not isinstance(value, tuple) or len(value) < 2 or len(value) > 3:
|
||||
error(field, "must be a tuple with size 2 or 3")
|
||||
if len(value) == 3 and not isinstance(value[2], bool):
|
||||
error(field, "the third element of the tuple (aka is_loss) must be a boolean")
|
||||
elif len(value) == 3:
|
||||
if value[2]:
|
||||
_model_desc_outputs_validation.loss_counter += 1
|
||||
if _model_desc_outputs_validation.loss_counter > 1:
|
||||
error(field, "only one is_loss can bet set to True")
|
||||
if not isinstance(value[0], str):
|
||||
error(field, "the first element of the tuple (aka name) must be a string")
|
||||
if not isinstance(value[1], list):
|
||||
error(field, "the second element of the tuple (aka shape) must be a list")
|
||||
else:
|
||||
for shape in value[1]:
|
||||
if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool):
|
||||
error(field, "each shape must be either a string or integer")
|
||||
|
||||
|
||||
# Validation schema for model description dictionary
|
||||
MODEL_DESC_SCHEMA = {
|
||||
"inputs": {
|
||||
"type": "list",
|
||||
"required": True,
|
||||
"minlength": 1,
|
||||
"schema": {"check_with": _model_desc_inputs_validation},
|
||||
},
|
||||
"outputs": {
|
||||
"type": "list",
|
||||
"required": True,
|
||||
"minlength": 1,
|
||||
"schema": {"check_with": _model_desc_outputs_validation},
|
||||
},
|
||||
}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,692 +0,0 @@
|
|||
import cerberus
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.capi._pybind_state import PropagateCastOpsStrategy
|
||||
|
||||
from .amp import loss_scaler
|
||||
from .optim import lr_scheduler
|
||||
|
||||
|
||||
class ORTTrainerOptions:
|
||||
r"""Settings used by ONNX Runtime training backend
|
||||
|
||||
The parameters are hierarchically organized to facilitate configuration through semantic groups
|
||||
that encompasses features, such as distributed training, etc.
|
||||
|
||||
Input validation is performed on the input dict during instantiation to ensure
|
||||
that supported parameters and values are passed in. Invalid input results
|
||||
in :py:obj:`ValueError` exception with details on it.
|
||||
|
||||
Args:
|
||||
options (dict): contains all training options
|
||||
_validate (bool, default is True): for internal use only
|
||||
|
||||
Supported schema for kwargs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
schema = {
|
||||
'batch' : {
|
||||
'type' : 'dict',
|
||||
'required': False,
|
||||
'default' : {},
|
||||
'schema' : {
|
||||
'gradient_accumulation_steps' : {
|
||||
'type' : 'integer',
|
||||
'min' : 1,
|
||||
'default' : 1
|
||||
}
|
||||
},
|
||||
},
|
||||
'device' : {
|
||||
'type' : 'dict',
|
||||
'required': False,
|
||||
'default' : {},
|
||||
'schema' : {
|
||||
'id' : {
|
||||
'type' : 'string',
|
||||
'default' : 'cuda'
|
||||
},
|
||||
'mem_limit' : {
|
||||
'type' : 'integer',
|
||||
'min' : 0,
|
||||
'default' : 0
|
||||
}
|
||||
}
|
||||
},
|
||||
'distributed': {
|
||||
'type': 'dict',
|
||||
'default': {},
|
||||
'required': False,
|
||||
'schema': {
|
||||
'world_rank': {
|
||||
'type': 'integer',
|
||||
'min': 0,
|
||||
'default': 0
|
||||
},
|
||||
'world_size': {
|
||||
'type': 'integer',
|
||||
'min': 1,
|
||||
'default': 1
|
||||
},
|
||||
'local_rank': {
|
||||
'type': 'integer',
|
||||
'min': 0,
|
||||
'default': 0
|
||||
},
|
||||
'data_parallel_size': {
|
||||
'type': 'integer',
|
||||
'min': 1,
|
||||
'default': 1
|
||||
},
|
||||
'horizontal_parallel_size': {
|
||||
'type': 'integer',
|
||||
'min': 1,
|
||||
'default': 1
|
||||
},
|
||||
'pipeline_parallel' : {
|
||||
'type': 'dict',
|
||||
'default': {},
|
||||
'required': False,
|
||||
'schema': {
|
||||
'pipeline_parallel_size': {
|
||||
'type': 'integer',
|
||||
'min': 1,
|
||||
'default': 1
|
||||
},
|
||||
'num_pipeline_micro_batches': {
|
||||
'type': 'integer',
|
||||
'min': 1,
|
||||
'default': 1
|
||||
},
|
||||
'pipeline_cut_info_string': {
|
||||
'type': 'string',
|
||||
'default': ''
|
||||
},
|
||||
'sliced_schema': {
|
||||
'type': 'dict',
|
||||
'default': {},
|
||||
'keysrules': {'type': 'string'},
|
||||
'valuesrules': {
|
||||
'type': 'list',
|
||||
'schema': {'type': 'integer'}
|
||||
}
|
||||
},
|
||||
'sliced_axes': {
|
||||
'type': 'dict',
|
||||
'default': {},
|
||||
'keysrules': {'type': 'string'},
|
||||
'valuesrules': {'type': 'integer'}
|
||||
},
|
||||
'sliced_tensor_names': {
|
||||
'type': 'list',
|
||||
'schema': {'type': 'string'},
|
||||
'default': []
|
||||
}
|
||||
}
|
||||
},
|
||||
'allreduce_post_accumulation': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'deepspeed_zero_optimization': {
|
||||
'type': 'dict',
|
||||
'default': {},
|
||||
'required': False,
|
||||
'schema': {
|
||||
'stage': {
|
||||
'type': 'integer',
|
||||
'min': 0,
|
||||
'max': 1,
|
||||
'default': 0
|
||||
},
|
||||
}
|
||||
},
|
||||
'enable_adasum': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
}
|
||||
}
|
||||
},
|
||||
'lr_scheduler' : {
|
||||
'type' : 'optim.lr_scheduler',
|
||||
'nullable' : True,
|
||||
'default' : None
|
||||
},
|
||||
'mixed_precision' : {
|
||||
'type' : 'dict',
|
||||
'required': False,
|
||||
'default' : {},
|
||||
'schema' : {
|
||||
'enabled' : {
|
||||
'type' : 'boolean',
|
||||
'default' : False
|
||||
},
|
||||
'loss_scaler' : {
|
||||
'type' : 'amp.loss_scaler',
|
||||
'nullable' : True,
|
||||
'default' : None
|
||||
}
|
||||
}
|
||||
},
|
||||
'graph_transformer': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'default': {},
|
||||
'schema': {
|
||||
'attn_dropout_recompute': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'gelu_recompute': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'transformer_layer_recompute': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'number_recompute_layers': {
|
||||
'type': 'integer',
|
||||
'min': 0,
|
||||
'default': 0
|
||||
},
|
||||
'propagate_cast_ops_config': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'default': {},
|
||||
'schema': {
|
||||
'propagate_cast_ops_strategy': {
|
||||
'type': 'onnxruntime.training.PropagateCastOpsStrategy',
|
||||
'default': PropagateCastOpsStrategy.FLOOD_FILL
|
||||
},
|
||||
'propagate_cast_ops_level': {
|
||||
'type': 'integer',
|
||||
'default': 1
|
||||
},
|
||||
'propagate_cast_ops_allow': {
|
||||
'type': 'list',
|
||||
'schema': {'type': 'string'},
|
||||
'default': []
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
'utils' : {
|
||||
'type' : 'dict',
|
||||
'required': False,
|
||||
'default' : {},
|
||||
'schema' : {
|
||||
'frozen_weights' : {
|
||||
'type' : 'list',
|
||||
'default' : []
|
||||
},
|
||||
'grad_norm_clip' : {
|
||||
'type' : 'boolean',
|
||||
'default' : True
|
||||
},
|
||||
'memory_efficient_gradient' : {
|
||||
'type' : 'boolean',
|
||||
'default' : False
|
||||
},
|
||||
'run_symbolic_shape_infer' : {
|
||||
'type' : 'boolean',
|
||||
'default' : False
|
||||
}
|
||||
}
|
||||
},
|
||||
'debug' : {
|
||||
'type' : 'dict',
|
||||
'required': False,
|
||||
'default' : {},
|
||||
'schema' : {
|
||||
'deterministic_compute' : {
|
||||
'type' : 'boolean',
|
||||
'default' : False
|
||||
},
|
||||
'check_model_export' : {
|
||||
'type' : 'boolean',
|
||||
'default' : False
|
||||
},
|
||||
'graph_save_paths' : {
|
||||
'type' : 'dict',
|
||||
'default': {},
|
||||
'required': False,
|
||||
'schema': {
|
||||
'model_after_graph_transforms_path': {
|
||||
'type': 'string',
|
||||
'default': ''
|
||||
},
|
||||
'model_with_gradient_graph_path':{
|
||||
'type': 'string',
|
||||
'default': ''
|
||||
},
|
||||
'model_with_training_graph_path': {
|
||||
'type': 'string',
|
||||
'default': ''
|
||||
},
|
||||
'model_with_training_graph_after_optimization_path': {
|
||||
'type': 'string',
|
||||
'default': ''
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
'_internal_use' : {
|
||||
'type' : 'dict',
|
||||
'required': False,
|
||||
'default' : {},
|
||||
'schema' : {
|
||||
'enable_internal_postprocess' : {
|
||||
'type' : 'boolean',
|
||||
'default' : True
|
||||
},
|
||||
'extra_postprocess' : {
|
||||
'type' : 'callable',
|
||||
'nullable' : True,
|
||||
'default' : None
|
||||
},
|
||||
'onnx_opset_version': {
|
||||
'type': 'integer',
|
||||
'min' : 12,
|
||||
'max' :14,
|
||||
'default': 14
|
||||
},
|
||||
'enable_onnx_contrib_ops' : {
|
||||
'type' : 'boolean',
|
||||
'default' : True
|
||||
}
|
||||
}
|
||||
},
|
||||
'provider_options':{
|
||||
'type': 'dict',
|
||||
'default': {},
|
||||
'required': False,
|
||||
'schema': {}
|
||||
},
|
||||
'session_options': {
|
||||
'type': 'SessionOptions',
|
||||
'nullable': True,
|
||||
'default': None
|
||||
},
|
||||
}
|
||||
|
||||
Keyword arguments:
|
||||
batch (dict):
|
||||
batch related settings
|
||||
batch.gradient_accumulation_steps (int, default is 1):
|
||||
number of steps to accumulate before do collective gradient reduction
|
||||
device (dict):
|
||||
compute device related settings
|
||||
device.id (string, default is 'cuda'):
|
||||
device to run training
|
||||
device.mem_limit (int):
|
||||
maximum memory size (in bytes) used by device.id
|
||||
distributed (dict):
|
||||
distributed training options.
|
||||
distributed.world_rank (int, default is 0):
|
||||
rank ID used for data/horizontal parallelism
|
||||
distributed.world_size (int, default is 1):
|
||||
number of ranks participating in parallelism
|
||||
distributed.data_parallel_size (int, default is 1):
|
||||
number of ranks participating in data parallelism
|
||||
distributed.horizontal_parallel_size (int, default is 1):
|
||||
number of ranks participating in horizontal parallelism
|
||||
distributed.pipeline_parallel (dict):
|
||||
Options which are only useful to pipeline parallel.
|
||||
distributed.pipeline_parallel.pipeline_parallel_size (int, default is 1):
|
||||
number of ranks participating in pipeline parallelism
|
||||
distributed.pipeline_parallel.num_pipeline_micro_batches (int, default is 1):
|
||||
number of micro-batches. We divide input batch into micro-batches and run the graph.
|
||||
distributed.pipeline_parallel.pipeline_cut_info_string (string, default is ''):
|
||||
string of cutting ids for pipeline partition.
|
||||
distributed.allreduce_post_accumulation (bool, default is False):
|
||||
True enables overlap of AllReduce with computation, while False,
|
||||
postpone AllReduce until all gradients are ready
|
||||
distributed.deepspeed_zero_optimization:
|
||||
DeepSpeed ZeRO options.
|
||||
distributed.deepspeed_zero_optimization.stage (int, default is 0):
|
||||
select which stage of DeepSpeed ZeRO to use. Stage 0 means disabled.
|
||||
distributed.enable_adasum (bool, default is False):
|
||||
enable `Adasum <https://arxiv.org/abs/2006.02924>`_
|
||||
algorithm for AllReduce
|
||||
lr_scheduler (optim._LRScheduler, default is None):
|
||||
specifies learning rate scheduler
|
||||
mixed_precision (dict):
|
||||
mixed precision training options
|
||||
mixed_precision.enabled (bool, default is False):
|
||||
enable mixed precision (fp16)
|
||||
mixed_precision.loss_scaler (amp.LossScaler, default is None):
|
||||
specifies a loss scaler to be used for fp16. If not specified,
|
||||
:py:class:`.DynamicLossScaler` is used with default values.
|
||||
Users can also instantiate :py:class:`.DynamicLossScaler` and
|
||||
override its parameters. Lastly, a completely new implementation
|
||||
can be specified by extending :py:class:`.LossScaler` class from scratch
|
||||
graph_transformer (dict):
|
||||
graph transformer related configurations
|
||||
graph_transformer.attn_dropout_recompute(bool, default False)
|
||||
graph_transformer.gelu_recompute(bool, default False)
|
||||
graph_transformer.transformer_layer_recompute(bool, default False)
|
||||
graph_transformer.number_recompute_layers(bool, default False)
|
||||
graph_transformer.propagate_cast_ops_config (dict):
|
||||
graph_transformer.propagate_cast_ops_config.strategy(PropagateCastOpsStrategy, default FLOOD_FILL)
|
||||
Specify the choice of the cast propagation optimization strategy, either, NONE, INSERT_AND_REDUCE or FLOOD_FILL.
|
||||
NONE strategy does not perform any cast propagation transformation on the graph, although other optimizations
|
||||
locally change cast operations, for example, in order to fuse Transpose and MatMul nodes, the TransposeMatMulFunsion optimization could
|
||||
interchange Transpose and Cast if the Cast node exists between Transpose and MatMul.
|
||||
INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes.
|
||||
FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike
|
||||
INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region.
|
||||
graph_transformer.propagate_cast_ops_config.level(integer, default 1)
|
||||
Optimize by moving Cast operations if propagate_cast_ops_level is non-negative.
|
||||
Use predetermined list of opcodes considered safe to move before/after cast operation
|
||||
if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise.
|
||||
graph_transformer.propagate_cast_ops_config.allow(list of str, [])
|
||||
List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
|
||||
attn_dropout_recompute (bool, default is False):
|
||||
enable recomputing attention dropout to save memory
|
||||
gelu_recompute (bool, default is False):
|
||||
enable recomputing Gelu activation output to save memory
|
||||
transformer_layer_recompute (bool, default is False):
|
||||
enable recomputing transformer layerwise to save memory
|
||||
number_recompute_layers (int, default is 0)
|
||||
number of layers to apply transformer_layer_recompute, by default system will
|
||||
apply recompute to all the layers, except for the last one
|
||||
utils (dict):
|
||||
miscellaneous options
|
||||
utils.frozen_weights (list of str, []):
|
||||
list of model parameter names to skip training (weights don't change)
|
||||
utils.grad_norm_clip (bool, default is True):
|
||||
enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer'
|
||||
utils.memory_efficient_gradient (bool, default is False):
|
||||
enables use of memory aware gradient builder.
|
||||
utils.run_symbolic_shape_infer (bool, default is False):
|
||||
runs symbolic shape inference on the model
|
||||
debug (dict):
|
||||
debug options
|
||||
debug.deterministic_compute (bool, default is False)
|
||||
forces compute to be deterministic accross runs
|
||||
debug.check_model_export (bool, default is False)
|
||||
compares PyTorch model outputs with ONNX model outputs in inference before the first
|
||||
train step to ensure successful model export
|
||||
debug.graph_save_paths (dict):
|
||||
paths used for dumping ONNX graphs for debugging purposes
|
||||
debug.graph_save_paths.model_after_graph_transforms_path (str, default is "")
|
||||
path to export the ONNX graph after training-related graph transforms have been applied.
|
||||
No output when it is empty.
|
||||
debug.graph_save_paths.model_with_gradient_graph_path (str, default is "")
|
||||
path to export the ONNX graph with the gradient graph added. No output when it is empty.
|
||||
debug.graph_save_paths.model_with_training_graph_path (str, default is "")
|
||||
path to export the training ONNX graph with forward, gradient and optimizer nodes.
|
||||
No output when it is empty.
|
||||
debug.graph_save_paths.model_with_training_graph_after_optimization_path (str, default is "")
|
||||
outputs the optimized training graph to the path if nonempty.
|
||||
_internal_use (dict):
|
||||
internal options, possibly undocumented, that might be removed without notice
|
||||
_internal_use.enable_internal_postprocess (bool, default is True):
|
||||
enable internal internal post processing of the ONNX model
|
||||
_internal_use.extra_postprocess (callable, default is None)
|
||||
a functor to postprocess the ONNX model and return a new ONNX model.
|
||||
It does not override :py:attr:`._internal_use.enable_internal_postprocess`, but complement it
|
||||
_internal_use.onnx_opset_version (int, default is 14):
|
||||
ONNX opset version used during model exporting.
|
||||
_internal_use.enable_onnx_contrib_ops (bool, default is True)
|
||||
enable PyTorch to export nodes as contrib ops in ONNX.
|
||||
This flag may be removed anytime in the future.
|
||||
session_options (onnxruntime.SessionOptions):
|
||||
The SessionOptions instance that TrainingSession will use.
|
||||
provider_options (dict):
|
||||
The provider_options for customized execution providers. it is dict map from EP name to
|
||||
a key-value pairs, like {'EP1' : {'key1' : 'val1'}, ....}
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
opts = ORTTrainerOptions({
|
||||
'batch' : {
|
||||
'gradient_accumulation_steps' : 128
|
||||
},
|
||||
'device' : {
|
||||
'id' : 'cuda:0',
|
||||
'mem_limit' : 2*1024*1024*1024,
|
||||
},
|
||||
'lr_scheduler' : optim.lr_scheduler.LinearWarmupLRScheduler(),
|
||||
'mixed_precision' : {
|
||||
'enabled': True,
|
||||
'loss_scaler': amp.LossScaler(loss_scale=float(1 << 16))
|
||||
}
|
||||
})
|
||||
fp16_enabled = opts.mixed_precision.enabled
|
||||
"""
|
||||
|
||||
def __init__(self, options={}): # noqa: B006
|
||||
# Keep a copy of original input for debug
|
||||
self._original_opts = dict(options)
|
||||
|
||||
# Used for logging purposes
|
||||
self._main_class_name = self.__class__.__name__
|
||||
|
||||
# Validates user input
|
||||
self._validated_opts = dict(self._original_opts)
|
||||
validator = ORTTrainerOptionsValidator(_ORTTRAINER_OPTIONS_SCHEMA)
|
||||
self._validated_opts = validator.validated(self._validated_opts)
|
||||
if self._validated_opts is None:
|
||||
raise ValueError(f"Invalid options: {validator.errors}")
|
||||
|
||||
# Convert dict in object
|
||||
for k, v in self._validated_opts.items():
|
||||
setattr(self, k, self._wrap(v))
|
||||
|
||||
def __repr__(self):
|
||||
return "{%s}" % str(
|
||||
", ".join(
|
||||
f"'{k}': {v!r}"
|
||||
for (k, v) in self.__dict__.items()
|
||||
if k not in ["_original_opts", "_validated_opts", "_main_class_name"]
|
||||
)
|
||||
)
|
||||
|
||||
def _wrap(self, v):
|
||||
if isinstance(v, (tuple, list, set, frozenset)):
|
||||
return type(v)([self._wrap(i) for i in v])
|
||||
else:
|
||||
return _ORTTrainerOptionsInternal(self._main_class_name, v) if isinstance(v, dict) else v
|
||||
|
||||
|
||||
class _ORTTrainerOptionsInternal(ORTTrainerOptions):
|
||||
r"""Internal class used by ONNX Runtime training backend for input validation
|
||||
|
||||
NOTE: Users MUST NOT use this class in any way!
|
||||
"""
|
||||
|
||||
def __init__(self, main_class_name, options):
|
||||
# Used for logging purposes
|
||||
self._main_class_name = main_class_name
|
||||
# We don't call super().__init__(options) here but still called it "_validated_opts"
|
||||
# instead of "_original_opts" because it has been validated in the top-level
|
||||
# ORTTrainerOptions's constructor.
|
||||
self._validated_opts = dict(options)
|
||||
# Convert dict in object
|
||||
for k, v in dict(options).items():
|
||||
setattr(self, k, self._wrap(v))
|
||||
|
||||
|
||||
class ORTTrainerOptionsValidator(cerberus.Validator):
|
||||
_LR_SCHEDULER = cerberus.TypeDefinition("lr_scheduler", (lr_scheduler._LRScheduler,), ())
|
||||
_LOSS_SCALER = cerberus.TypeDefinition("loss_scaler", (loss_scaler.LossScaler,), ())
|
||||
|
||||
_SESSION_OPTIONS = cerberus.TypeDefinition("session_options", (ort.SessionOptions,), ())
|
||||
|
||||
_PROPAGATE_CAST_OPS_STRATEGY = cerberus.TypeDefinition(
|
||||
"propagate_cast_ops_strategy", (PropagateCastOpsStrategy,), ()
|
||||
)
|
||||
|
||||
types_mapping = cerberus.Validator.types_mapping.copy()
|
||||
types_mapping["lr_scheduler"] = _LR_SCHEDULER
|
||||
types_mapping["loss_scaler"] = _LOSS_SCALER
|
||||
types_mapping["session_options"] = _SESSION_OPTIONS
|
||||
types_mapping["propagate_cast_ops_strategy"] = _PROPAGATE_CAST_OPS_STRATEGY
|
||||
|
||||
|
||||
def _check_is_callable(field, value, error):
|
||||
result = False
|
||||
try:
|
||||
# Python 3
|
||||
result = value is None or callable(value)
|
||||
except Exception:
|
||||
# Python 3 but < 3.2
|
||||
if hasattr(value, "__call__"): # noqa: B004
|
||||
result = True
|
||||
if not result:
|
||||
error(field, "Must be callable or None")
|
||||
|
||||
|
||||
_ORTTRAINER_OPTIONS_SCHEMA = {
|
||||
"batch": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {"gradient_accumulation_steps": {"type": "integer", "min": 1, "default": 1}},
|
||||
},
|
||||
"device": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"id": {"type": "string", "default": "cuda"},
|
||||
"mem_limit": {"type": "integer", "min": 0, "default": 0},
|
||||
},
|
||||
},
|
||||
"distributed": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"world_rank": {"type": "integer", "min": 0, "default": 0},
|
||||
"world_size": {"type": "integer", "min": 1, "default": 1},
|
||||
"local_rank": {"type": "integer", "min": 0, "default": 0},
|
||||
"data_parallel_size": {"type": "integer", "min": 1, "default": 1},
|
||||
"horizontal_parallel_size": {"type": "integer", "min": 1, "default": 1},
|
||||
"pipeline_parallel": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"pipeline_parallel_size": {"type": "integer", "min": 1, "default": 1},
|
||||
"num_pipeline_micro_batches": {"type": "integer", "min": 1, "default": 1},
|
||||
"pipeline_cut_info_string": {"type": "string", "default": ""},
|
||||
"sliced_schema": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"keysrules": {"type": "string"},
|
||||
"valuesrules": {"type": "list", "schema": {"type": "integer"}},
|
||||
},
|
||||
"sliced_axes": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"keysrules": {"type": "string"},
|
||||
"valuesrules": {"type": "integer"},
|
||||
},
|
||||
"sliced_tensor_names": {"type": "list", "schema": {"type": "string"}, "default": []},
|
||||
},
|
||||
},
|
||||
"allreduce_post_accumulation": {"type": "boolean", "default": False},
|
||||
"deepspeed_zero_optimization": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"stage": {"type": "integer", "min": 0, "max": 1, "default": 0},
|
||||
},
|
||||
},
|
||||
"enable_adasum": {"type": "boolean", "default": False},
|
||||
},
|
||||
},
|
||||
"lr_scheduler": {"type": "lr_scheduler", "nullable": True, "default": None},
|
||||
"mixed_precision": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"enabled": {"type": "boolean", "default": False},
|
||||
"loss_scaler": {"type": "loss_scaler", "nullable": True, "default": None},
|
||||
},
|
||||
},
|
||||
"graph_transformer": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"attn_dropout_recompute": {"type": "boolean", "default": False},
|
||||
"gelu_recompute": {"type": "boolean", "default": False},
|
||||
"transformer_layer_recompute": {"type": "boolean", "default": False},
|
||||
"number_recompute_layers": {"type": "integer", "min": 0, "default": 0},
|
||||
"propagate_cast_ops_config": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"strategy": {
|
||||
"type": "propagate_cast_ops_strategy",
|
||||
"nullable": True,
|
||||
"default": PropagateCastOpsStrategy.FLOOD_FILL,
|
||||
},
|
||||
"level": {"type": "integer", "min": -1, "default": 1},
|
||||
"allow": {"type": "list", "schema": {"type": "string"}, "default": []},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"utils": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"frozen_weights": {"type": "list", "default": []},
|
||||
"grad_norm_clip": {"type": "boolean", "default": True},
|
||||
"memory_efficient_gradient": {"type": "boolean", "default": False},
|
||||
"run_symbolic_shape_infer": {"type": "boolean", "default": False},
|
||||
},
|
||||
},
|
||||
"debug": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"deterministic_compute": {"type": "boolean", "default": False},
|
||||
"check_model_export": {"type": "boolean", "default": False},
|
||||
"graph_save_paths": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"model_after_graph_transforms_path": {"type": "string", "default": ""},
|
||||
"model_with_gradient_graph_path": {"type": "string", "default": ""},
|
||||
"model_with_training_graph_path": {"type": "string", "default": ""},
|
||||
"model_with_training_graph_after_optimization_path": {"type": "string", "default": ""},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"_internal_use": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"schema": {
|
||||
"enable_internal_postprocess": {"type": "boolean", "default": True},
|
||||
"extra_postprocess": {"check_with": _check_is_callable, "nullable": True, "default": None},
|
||||
"onnx_opset_version": {"type": "integer", "min": 12, "max": 14, "default": 14},
|
||||
"enable_onnx_contrib_ops": {"type": "boolean", "default": True},
|
||||
},
|
||||
},
|
||||
"provider_options": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
"required": False,
|
||||
"allow_unknown": True,
|
||||
"schema": {},
|
||||
},
|
||||
"session_options": {"type": "session_options", "nullable": True, "default": None},
|
||||
}
|
|
@ -1,478 +0,0 @@
|
|||
import os.path # noqa: F401
|
||||
import struct
|
||||
import sys # noqa: F401
|
||||
|
||||
import numpy as np # noqa: F401
|
||||
import onnx
|
||||
from onnx import * # noqa: F403
|
||||
from onnx import helper, numpy_helper # noqa: F401
|
||||
|
||||
|
||||
def run_postprocess(model):
|
||||
# this post pass is not required for pytorch >= 1.5
|
||||
# where add_node_name in torch.onnx.export is default to True
|
||||
model = add_name(model)
|
||||
|
||||
# this post pass is not required for pytorch > 1.6
|
||||
model = fuse_softmaxNLL_to_softmaxCE(model)
|
||||
|
||||
model = fix_expand_shape(model)
|
||||
model = fix_expand_shape_pt_1_5(model)
|
||||
return model
|
||||
|
||||
|
||||
def find_input_node(model, arg):
|
||||
result = []
|
||||
for node in model.graph.node:
|
||||
for output in node.output:
|
||||
if output == arg:
|
||||
result.append(node)
|
||||
return result[0] if len(result) == 1 else None
|
||||
|
||||
|
||||
def find_output_node(model, arg):
|
||||
result = []
|
||||
for node in model.graph.node:
|
||||
for input in node.input:
|
||||
if input == arg:
|
||||
result.append(node)
|
||||
return result[0] if len(result) == 1 else result
|
||||
|
||||
|
||||
def add_name(model):
|
||||
i = 0
|
||||
for node in model.graph.node:
|
||||
node.name = "%s_%d" % (node.op_type, i)
|
||||
i += 1
|
||||
return model
|
||||
|
||||
|
||||
# Expand Shape PostProcess
|
||||
|
||||
|
||||
def fix_expand_shape(model):
|
||||
expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"]
|
||||
model_inputs_names = [i.name for i in model.graph.input]
|
||||
|
||||
for expand_node in expand_nodes:
|
||||
shape = find_input_node(model, expand_node.input[1])
|
||||
if shape.op_type == "Shape":
|
||||
# an expand subgraph
|
||||
# Input Input2
|
||||
# | |
|
||||
# | Shape
|
||||
# | |
|
||||
# |__ __|
|
||||
# | |
|
||||
# Expand
|
||||
# |
|
||||
# output
|
||||
#
|
||||
# Only if Input2 is one of the model inputs, assign Input2's shape to output of expand.
|
||||
shape_input_name = shape.input[0]
|
||||
if shape_input_name in model_inputs_names:
|
||||
index = model_inputs_names.index(shape_input_name)
|
||||
expand_out = model.graph.value_info.add()
|
||||
expand_out.name = expand_node.output[0]
|
||||
expand_out.type.CopyFrom(model.graph.input[index].type)
|
||||
return model
|
||||
|
||||
|
||||
def fix_expand_shape_pt_1_5(model):
|
||||
# expand subgraph
|
||||
# Constant
|
||||
# +
|
||||
# ConstantOfShape
|
||||
# | + |
|
||||
# | + |
|
||||
# (Reshape subgraph) Mul |
|
||||
# |___ _________| |
|
||||
# + | | |
|
||||
# + Equal |
|
||||
# +++++|++++++++++++++|++
|
||||
# |____________ | +
|
||||
# | | +
|
||||
# (subgraph) Where
|
||||
# | |
|
||||
# |_____ ___________|
|
||||
# | |
|
||||
# Expand
|
||||
# |
|
||||
# output
|
||||
#
|
||||
# where the Reshape subgraph is
|
||||
#
|
||||
# Input
|
||||
# | |
|
||||
# | |___________________
|
||||
# | |
|
||||
# Shape Constant Shape Constant
|
||||
# | ______| | ______|
|
||||
# | | | |
|
||||
# Gather Gather
|
||||
# | |
|
||||
# Unsqueeze Unsqueeze
|
||||
# | |
|
||||
# | ..Number of dims.. |
|
||||
# | _________________|
|
||||
# |...|
|
||||
# Concat Constant
|
||||
# | |
|
||||
# |______ __________________|
|
||||
# | |
|
||||
# Reshape
|
||||
# |
|
||||
# output
|
||||
#
|
||||
# This pass will copy Input's shape to the output of Expand.
|
||||
expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"]
|
||||
model_inputs_names = [i.name for i in model.graph.input]
|
||||
|
||||
for expand_node in expand_nodes:
|
||||
n_where = find_input_node(model, expand_node.input[1])
|
||||
if n_where.op_type != "Where":
|
||||
continue
|
||||
|
||||
n_equal = find_input_node(model, n_where.input[0])
|
||||
n_cos = find_input_node(model, n_where.input[1])
|
||||
n_reshape = find_input_node(model, n_where.input[2])
|
||||
|
||||
if n_equal.op_type != "Equal" or n_cos.op_type != "ConstantOfShape" or n_reshape.op_type != "Reshape":
|
||||
continue
|
||||
|
||||
n_reshape_e = find_input_node(model, n_equal.input[0])
|
||||
n_mul = find_input_node(model, n_equal.input[1])
|
||||
if n_reshape_e != n_reshape or n_mul.op_type != "Mul":
|
||||
continue
|
||||
|
||||
n_cos_m = find_input_node(model, n_mul.input[0])
|
||||
n_constant = find_input_node(model, n_mul.input[1])
|
||||
if n_cos_m != n_cos or n_constant.op_type != "Constant":
|
||||
continue
|
||||
|
||||
n_concat = find_input_node(model, n_reshape.input[0])
|
||||
n_constant_r = find_input_node(model, n_reshape.input[1])
|
||||
if n_concat.op_type != "Concat" or n_constant_r.op_type != "Constant":
|
||||
continue
|
||||
|
||||
n_input_candidates = []
|
||||
for concat_in in n_concat.input:
|
||||
n_unsqueeze = find_input_node(model, concat_in)
|
||||
if n_unsqueeze.op_type != "Unsqueeze":
|
||||
break
|
||||
n_gather = find_input_node(model, n_unsqueeze.input[0])
|
||||
if n_gather.op_type != "Gather":
|
||||
break
|
||||
n_shape = find_input_node(model, n_gather.input[0])
|
||||
n_constant_g = find_input_node(model, n_gather.input[1])
|
||||
if n_shape.op_type != "Shape" or n_constant_g.op_type != "Constant":
|
||||
break
|
||||
n_input = n_shape.input[0]
|
||||
if n_input not in model_inputs_names:
|
||||
break
|
||||
n_input_candidates.append(n_input)
|
||||
|
||||
if not n_input_candidates or not all(elem == n_input_candidates[0] for elem in n_input_candidates):
|
||||
continue
|
||||
|
||||
index = model_inputs_names.index(n_input_candidates[0])
|
||||
expand_out = model.graph.value_info.add()
|
||||
expand_out.name = expand_node.output[0]
|
||||
expand_out.type.CopyFrom(model.graph.input[index].type)
|
||||
return model
|
||||
|
||||
|
||||
# LayerNorm PostProcess
|
||||
|
||||
|
||||
def find_nodes(graph, op_type):
|
||||
nodes = []
|
||||
for node in graph.node:
|
||||
if node.op_type == op_type:
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
|
||||
def is_type(node, op_type):
|
||||
if node is None or isinstance(node, list):
|
||||
return False
|
||||
return node.op_type == op_type
|
||||
|
||||
|
||||
def add_const(model, name, output, t_value=None, f_value=None):
|
||||
const_node = model.graph.node.add()
|
||||
const_node.op_type = "Constant"
|
||||
const_node.name = name
|
||||
const_node.output.extend([output])
|
||||
attr = const_node.attribute.add()
|
||||
attr.name = "value"
|
||||
if t_value is not None:
|
||||
attr.type = 4
|
||||
attr.t.CopyFrom(t_value)
|
||||
else:
|
||||
attr.type = 1
|
||||
attr.f = f_value
|
||||
return const_node
|
||||
|
||||
|
||||
def layer_norm_transform(model):
|
||||
# DEPRECATED: This pass is no longer needed as the transform is handled at the backend.
|
||||
# Converting below subgraph
|
||||
#
|
||||
# input
|
||||
# |
|
||||
# ReduceMean
|
||||
# |
|
||||
# Sub Constant
|
||||
# _||_____ |
|
||||
# | | |
|
||||
# | | |
|
||||
# | (optional) Cast (optional) Cast
|
||||
# | | |
|
||||
# | | ____________________|
|
||||
# | | |
|
||||
# | Pow
|
||||
# | |
|
||||
# | ReduceMean
|
||||
# | |
|
||||
# | Add
|
||||
# | |
|
||||
# |__ __Sqrt
|
||||
# | |
|
||||
# Div (weight)
|
||||
# | |
|
||||
# | _____|
|
||||
# | |
|
||||
# Mul (bias)
|
||||
# | |
|
||||
# | _____|
|
||||
# | |
|
||||
# Add
|
||||
# |
|
||||
# output
|
||||
#
|
||||
# to the below subgraph
|
||||
#
|
||||
# input (weight) (bias)
|
||||
# | | |
|
||||
# | _______| |
|
||||
# | | ________________|
|
||||
# | | |
|
||||
# LayerNormalization
|
||||
# |
|
||||
# output
|
||||
graph = model.graph
|
||||
|
||||
nodes_ReduceMean = find_nodes(graph, "ReduceMean") # noqa: N806
|
||||
|
||||
id = 0
|
||||
layer_norm_nodes = []
|
||||
remove_nodes = []
|
||||
for reduce_mean in nodes_ReduceMean:
|
||||
# check that reduce_mean output is Sub
|
||||
sub = find_output_node(model, reduce_mean.output[0])
|
||||
if not is_type(sub, "Sub"):
|
||||
continue
|
||||
|
||||
# check that sub output[0] is Div and output[1] is Pow
|
||||
pow, div = find_output_node(model, sub.output[0])
|
||||
if is_type(pow, "Cast"):
|
||||
# During an update in PyTorch, Cast nodes are inserted between Sub and Pow.
|
||||
remove_nodes += [pow]
|
||||
pow = find_output_node(model, pow.output[0])
|
||||
if not is_type(pow, "Pow"):
|
||||
continue
|
||||
cast_pow = find_input_node(model, pow.input[1])
|
||||
if not is_type(cast_pow, "Cast"):
|
||||
continue
|
||||
remove_nodes += [cast_pow]
|
||||
if not is_type(div, "Div") or not is_type(pow, "Pow"):
|
||||
continue
|
||||
|
||||
# check that pow ouput is ReduceMean
|
||||
reduce_mean2 = find_output_node(model, pow.output[0])
|
||||
if not is_type(reduce_mean2, "ReduceMean"):
|
||||
continue
|
||||
|
||||
# check that reduce_mean2 output is Add
|
||||
add = find_output_node(model, reduce_mean2.output[0])
|
||||
if not is_type(add, "Add"):
|
||||
continue
|
||||
|
||||
# check that add output is Sqrt
|
||||
sqrt = find_output_node(model, add.output[0])
|
||||
if not is_type(sqrt, "Sqrt"):
|
||||
continue
|
||||
|
||||
# check that sqrt output is div
|
||||
if div != find_output_node(model, sqrt.output[0]):
|
||||
continue
|
||||
|
||||
# check if div output is Mul
|
||||
optional_mul = find_output_node(model, div.output[0])
|
||||
if not is_type(optional_mul, "Mul"):
|
||||
optional_mul = None
|
||||
continue # default bias and weight not supported
|
||||
|
||||
# check if mul output is Add
|
||||
if optional_mul is not None:
|
||||
optional_add = find_output_node(model, optional_mul.output[0])
|
||||
else:
|
||||
optional_add = find_output_node(model, div.output[0])
|
||||
if not is_type(optional_add, "Add"):
|
||||
optional_add = None
|
||||
continue # default bias and weight not supported
|
||||
|
||||
# add nodes to remove_nodes
|
||||
remove_nodes.extend([reduce_mean, sub, div, pow, reduce_mean2, add, sqrt])
|
||||
|
||||
# create LayerNorm node
|
||||
layer_norm_input = []
|
||||
layer_norm_output = []
|
||||
|
||||
layer_norm_input.append(reduce_mean.input[0])
|
||||
|
||||
if optional_mul is not None:
|
||||
remove_nodes.append(optional_mul)
|
||||
weight = optional_mul.input[1]
|
||||
layer_norm_input.append(weight)
|
||||
|
||||
if optional_add is not None:
|
||||
remove_nodes.append(optional_add)
|
||||
bias = optional_add.input[1]
|
||||
layer_norm_input.append(bias)
|
||||
|
||||
if optional_add is not None:
|
||||
layer_norm_output.append(optional_add.output[0])
|
||||
elif optional_mul is not None:
|
||||
layer_norm_output.append(optional_mul.output[0])
|
||||
else:
|
||||
layer_norm_output.append(div.output[0])
|
||||
|
||||
layer_norm_output.append("saved_mean_" + str(id))
|
||||
layer_norm_output.append("saved_inv_std_var_" + str(id))
|
||||
|
||||
epsilon_node = find_input_node(model, add.input[1])
|
||||
epsilon = epsilon_node.attribute[0].t.raw_data
|
||||
epsilon = struct.unpack("f", epsilon)[0]
|
||||
|
||||
layer_norm = helper.make_node(
|
||||
"LayerNormalization",
|
||||
layer_norm_input,
|
||||
layer_norm_output,
|
||||
"LayerNormalization_" + str(id),
|
||||
None,
|
||||
axis=reduce_mean.attribute[0].ints[0],
|
||||
epsilon=epsilon,
|
||||
)
|
||||
layer_norm_nodes.append(layer_norm)
|
||||
id += 1
|
||||
|
||||
# remove orphan constant nodes
|
||||
for constant in graph.node:
|
||||
if constant.op_type == "Constant" and constant not in remove_nodes:
|
||||
is_orphan = True
|
||||
for out_name in constant.output:
|
||||
out = find_output_node(model, out_name)
|
||||
if out not in remove_nodes:
|
||||
is_orphan = False
|
||||
if is_orphan:
|
||||
remove_nodes.append(constant)
|
||||
|
||||
all_nodes = []
|
||||
for node in graph.node:
|
||||
if node not in remove_nodes:
|
||||
all_nodes.append(node)
|
||||
|
||||
for node in layer_norm_nodes:
|
||||
all_nodes.append(node) # noqa: PERF402
|
||||
|
||||
graph.ClearField("node")
|
||||
graph.node.extend(all_nodes)
|
||||
return model
|
||||
|
||||
|
||||
# Fuse SoftmaxCrossEntropy
|
||||
|
||||
|
||||
def fuse_softmaxNLL_to_softmaxCE(onnx_model): # noqa: N802
|
||||
# Converting below subgraph
|
||||
#
|
||||
# (subgraph)
|
||||
# |
|
||||
# LogSoftmax (target) (optional weight)
|
||||
# | | |
|
||||
# nll_loss/NegativeLogLikelihoodLoss
|
||||
# |
|
||||
# output
|
||||
#
|
||||
# to the following
|
||||
#
|
||||
# (subgraph) (target) (optional weight)
|
||||
# | | _____|
|
||||
# | | |
|
||||
# SparseSoftmaxCrossEntropy
|
||||
# |
|
||||
# output
|
||||
nll_count = 0
|
||||
while True:
|
||||
nll_count = nll_count + 1
|
||||
nll_loss_node = None
|
||||
nll_loss_node_index = 0
|
||||
for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007
|
||||
if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss":
|
||||
nll_loss_node = node
|
||||
break
|
||||
|
||||
if nll_loss_node is None:
|
||||
break
|
||||
|
||||
softmax_node = None
|
||||
softmax_node_index = 0
|
||||
label_input_name = None
|
||||
weight_input_name = None
|
||||
for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007
|
||||
if node.op_type == "LogSoftmax":
|
||||
# has to be connected to nll_loss
|
||||
if len(nll_loss_node.input) > 2:
|
||||
weight_input_name = nll_loss_node.input[2]
|
||||
if node.output[0] == nll_loss_node.input[0]:
|
||||
softmax_node = node
|
||||
label_input_name = nll_loss_node.input[1]
|
||||
break
|
||||
elif node.output[0] == nll_loss_node.input[1]:
|
||||
softmax_node = node
|
||||
label_input_name = nll_loss_node.input[0]
|
||||
break
|
||||
else:
|
||||
if softmax_node is not None:
|
||||
break
|
||||
|
||||
if softmax_node is None:
|
||||
break
|
||||
|
||||
# delete nll_loss and LogSoftmax nodes in order
|
||||
if nll_loss_node_index < softmax_node_index:
|
||||
del onnx_model.graph.node[softmax_node_index]
|
||||
del onnx_model.graph.node[nll_loss_node_index]
|
||||
else:
|
||||
del onnx_model.graph.node[nll_loss_node_index]
|
||||
del onnx_model.graph.node[softmax_node_index]
|
||||
|
||||
probability_output_name = softmax_node.output[0]
|
||||
node = onnx_model.graph.node.add()
|
||||
inputs = (
|
||||
[softmax_node.input[0], label_input_name, weight_input_name]
|
||||
if weight_input_name
|
||||
else [softmax_node.input[0], label_input_name]
|
||||
)
|
||||
node.CopyFrom(
|
||||
onnx.helper.make_node(
|
||||
"SparseSoftmaxCrossEntropy",
|
||||
inputs,
|
||||
[nll_loss_node.output[0], probability_output_name],
|
||||
"nll_loss_node_" + str(nll_count),
|
||||
)
|
||||
)
|
||||
|
||||
return onnx_model
|
|
@ -1,144 +0,0 @@
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class OutputGrabber:
|
||||
"""
|
||||
Class used to grab standard output or another stream.
|
||||
"""
|
||||
|
||||
escape_char = "\b"
|
||||
|
||||
def __init__(self, stream=None, threaded=False):
|
||||
self.origstream = stream
|
||||
self.threaded = threaded
|
||||
if self.origstream is None:
|
||||
self.origstream = sys.stdout
|
||||
self.origstreamfd = self.origstream.fileno()
|
||||
self.capturedtext = ""
|
||||
# Create a pipe so the stream can be captured:
|
||||
self.pipe_out, self.pipe_in = os.pipe()
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start capturing the stream data.
|
||||
"""
|
||||
self.capturedtext = ""
|
||||
# Save a copy of the stream:
|
||||
self.streamfd = os.dup(self.origstreamfd)
|
||||
# Replace the original stream with our write pipe:
|
||||
os.dup2(self.pipe_in, self.origstreamfd)
|
||||
if self.threaded:
|
||||
# Start thread that will read the stream:
|
||||
self.workerThread = threading.Thread(target=self.readOutput)
|
||||
self.workerThread.start()
|
||||
# Make sure that the thread is running and os.read() has executed:
|
||||
time.sleep(0.01)
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop capturing the stream data and save the text in `capturedtext`.
|
||||
"""
|
||||
# Print the escape character to make the readOutput method stop:
|
||||
self.origstream.write(self.escape_char)
|
||||
# Flush the stream to make sure all our data goes in before
|
||||
# the escape character:
|
||||
self.origstream.flush()
|
||||
if self.threaded:
|
||||
# wait until the thread finishes so we are sure that
|
||||
# we have until the last character:
|
||||
self.workerThread.join()
|
||||
else:
|
||||
self.readOutput()
|
||||
# Close the pipe:
|
||||
os.close(self.pipe_in)
|
||||
os.close(self.pipe_out)
|
||||
# Restore the original stream:
|
||||
os.dup2(self.streamfd, self.origstreamfd)
|
||||
# Close the duplicate stream:
|
||||
os.close(self.streamfd)
|
||||
|
||||
def readOutput(self):
|
||||
"""
|
||||
Read the stream data (one byte at a time)
|
||||
and save the text in `capturedtext`.
|
||||
"""
|
||||
while True:
|
||||
char = os.read(self.pipe_out, 1).decode(self.origstream.encoding)
|
||||
if not char or self.escape_char in char:
|
||||
break
|
||||
self.capturedtext += char
|
||||
|
||||
|
||||
import os # noqa: E402
|
||||
import unittest # noqa: E402
|
||||
|
||||
import numpy as np # noqa: E402, F401
|
||||
import torch # noqa: E402
|
||||
import torch.nn as nn # noqa: E402
|
||||
import torch.nn.functional as F # noqa: E402
|
||||
|
||||
from onnxruntime.capi import _pybind_state as torch_ort_eager # noqa: E402, F401
|
||||
from onnxruntime.training import optim, orttrainer, orttrainer_options # noqa: E402, F401
|
||||
|
||||
|
||||
def my_loss(x, target):
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), target)
|
||||
|
||||
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, x, target):
|
||||
out = self.fc1(x)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return my_loss(out, target)
|
||||
|
||||
|
||||
class OrtEPTests(unittest.TestCase):
|
||||
def test_external_graph_transformer_triggering(self):
|
||||
input_size = 784
|
||||
hidden_size = 500
|
||||
num_classes = 10
|
||||
batch_size = 128
|
||||
model = NeuralNet(input_size, hidden_size, num_classes)
|
||||
|
||||
model_desc = {
|
||||
"inputs": [
|
||||
("x", [batch_size, input_size]),
|
||||
(
|
||||
"target",
|
||||
[
|
||||
batch_size,
|
||||
],
|
||||
),
|
||||
],
|
||||
"outputs": [("loss", [], True)],
|
||||
}
|
||||
optim_config = optim.SGDConfig()
|
||||
opts = orttrainer.ORTTrainerOptions({"device": {"id": "cpu"}})
|
||||
model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
|
||||
# because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer
|
||||
data = torch.rand(batch_size, input_size)
|
||||
target = torch.randint(0, 10, (batch_size,))
|
||||
|
||||
with OutputGrabber() as out:
|
||||
model.train_step(data, target)
|
||||
assert "******************Trigger Customized Graph Transformer: MyGraphTransformer!" in out.capturedtext
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,35 +0,0 @@
|
|||
#include "core/optimizer/rewrite_rule.h"
|
||||
#include "orttraining/core/optimizer/graph_transformer_registry.h"
|
||||
#include "onnx/defs/schema.h"
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
class MyRewriteRule : public RewriteRule {
|
||||
public:
|
||||
MyRewriteRule() noexcept
|
||||
: RewriteRule("MyRewriteRule") {
|
||||
}
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/, const logging::Logger& /*logger*/) const override {
|
||||
std::cout << "******************Trigger Customized Graph Transformer: MyGraphTransformer!" << std::endl;
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
void RegisterTrainingExternalTransformers() {
|
||||
ONNX_REGISTER_EXTERNAL_REWRITE_RULE(MyRewriteRule, Level1, true);
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
|
@ -1,26 +1,7 @@
|
|||
import copy
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training import _utils, optim
|
||||
|
||||
|
||||
def _single_run(execution_file, scenario, checkopint_dir=None):
|
||||
cmd = [sys.executable, execution_file]
|
||||
if scenario:
|
||||
cmd += ["--scenario", scenario]
|
||||
if checkopint_dir:
|
||||
cmd += ["--checkpoint_dir", checkopint_dir]
|
||||
assert subprocess.call(cmd) == 0
|
||||
|
||||
|
||||
def is_windows():
|
||||
return sys.platform.startswith("win")
|
||||
|
@ -46,197 +27,3 @@ def run_subprocess(args, cwd=None, capture=False, dll_path=None, shell=False, en
|
|||
if log:
|
||||
log.debug("Subprocess completed. Return code=" + str(completed_process.returncode))
|
||||
return completed_process
|
||||
|
||||
|
||||
def legacy_constant_lr_scheduler(global_step, initial_lr, total_steps, warmup):
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
else:
|
||||
new_lr = initial_lr
|
||||
return new_lr
|
||||
|
||||
|
||||
def legacy_cosine_lr_scheduler(global_step, initial_lr, total_steps, warmup, cycles):
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
else:
|
||||
progress = float(global_step - num_warmup_steps) / float(max(1, total_steps - num_warmup_steps))
|
||||
new_lr = initial_lr * max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(cycles) * 2.0 * progress)))
|
||||
return new_lr
|
||||
|
||||
|
||||
def legacy_linear_lr_scheduler(global_step, initial_lr, total_steps, warmup):
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
else:
|
||||
new_lr = initial_lr * max(0.0, float(total_steps - global_step) / float(max(1, total_steps - num_warmup_steps)))
|
||||
return new_lr
|
||||
|
||||
|
||||
def legacy_poly_lr_scheduler(global_step, initial_lr, total_steps, warmup, power, lr_end):
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
elif global_step > total_steps:
|
||||
new_lr = lr_end
|
||||
else:
|
||||
lr_range = initial_lr - lr_end
|
||||
decay_steps = total_steps - num_warmup_steps
|
||||
pct_remaining = 1 - (global_step - num_warmup_steps) / decay_steps
|
||||
decay = lr_range * pct_remaining**power + lr_end
|
||||
new_lr = decay
|
||||
return new_lr
|
||||
|
||||
|
||||
def generate_dummy_optim_state(model, optimizer):
|
||||
np.random.seed(0)
|
||||
if not (isinstance(optimizer, (optim.AdamConfig, optim.LambConfig))):
|
||||
return dict()
|
||||
|
||||
moment_keys = ["Moment_1", "Moment_2"]
|
||||
uc_key = "Update_Count"
|
||||
step_key = "Step"
|
||||
shared_state_key = "shared_optimizer_state"
|
||||
|
||||
optim_state = dict()
|
||||
weight_shape_map = dict()
|
||||
if isinstance(model, torch.nn.Module):
|
||||
weight_shape_map = {name: param.size() for name, param in model.named_parameters()}
|
||||
elif isinstance(model, onnx.ModelProto):
|
||||
weight_shape_map = {n.name: n.dims for n in model.graph.initializer}
|
||||
else:
|
||||
raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'")
|
||||
|
||||
for weight_name, weight_shape in weight_shape_map.items():
|
||||
per_weight_state = dict()
|
||||
for moment in moment_keys:
|
||||
per_weight_state[moment] = np.random.uniform(-2, 2, weight_shape).astype(np.float32)
|
||||
if isinstance(optimizer, optim.AdamConfig):
|
||||
per_weight_state[uc_key] = np.full([1], 5, dtype=np.int64)
|
||||
optim_state[weight_name] = copy.deepcopy(per_weight_state)
|
||||
if isinstance(optimizer, optim.LambConfig):
|
||||
step_val = np.full([1], 5, dtype=np.int64)
|
||||
optim_state[shared_state_key] = {step_key: step_val}
|
||||
return {"optimizer": optim_state, "trainer_options": {"optimizer_name": optimizer.name}}
|
||||
|
||||
|
||||
def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False, data_dir=None):
|
||||
# Loads external Pytorch TransformerModel into utils
|
||||
root = "samples"
|
||||
if not os.path.exists(root):
|
||||
root = os.path.normpath(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "..", "samples")
|
||||
)
|
||||
if not os.path.exists(root):
|
||||
raise FileNotFoundError("Unable to find folder 'samples', tried %r." % root)
|
||||
pytorch_transformer_path = os.path.join(root, "python", "training", "orttrainer", "pytorch_transformer")
|
||||
pt_model_path = os.path.join(pytorch_transformer_path, "pt_model.py")
|
||||
pt_model = _utils.import_module_from_file(pt_model_path)
|
||||
ort_utils_path = os.path.join(pytorch_transformer_path, "ort_utils.py")
|
||||
ort_utils = _utils.import_module_from_file(ort_utils_path)
|
||||
utils_path = os.path.join(pytorch_transformer_path, "utils.py")
|
||||
utils = _utils.import_module_from_file(utils_path)
|
||||
|
||||
# Modeling
|
||||
model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
|
||||
my_loss = ort_utils.my_loss
|
||||
if legacy_api:
|
||||
if dynamic_axes:
|
||||
model_desc = ort_utils.legacy_transformer_model_description_dynamic_axes()
|
||||
else:
|
||||
model_desc = ort_utils.legacy_transformer_model_description()
|
||||
else:
|
||||
if dynamic_axes:
|
||||
model_desc = ort_utils.transformer_model_description_dynamic_axes()
|
||||
else:
|
||||
model_desc = ort_utils.transformer_model_description()
|
||||
|
||||
# Preparing data
|
||||
train_data, val_data, test_data = utils.prepare_data(device, 20, 20, data_dir)
|
||||
return model, model_desc, my_loss, utils.get_batch, train_data, val_data, test_data
|
||||
|
||||
|
||||
def generate_random_input_from_bart_model_desc(desc, seed=1, device="cuda:0"):
|
||||
"""Generates a sample input for the BART model using the model desc"""
|
||||
|
||||
torch.manual_seed(seed)
|
||||
onnxruntime.set_seed(seed)
|
||||
dtype = torch.int64
|
||||
vocab_size = 30528
|
||||
sample_input = []
|
||||
for _index, input in enumerate(desc["inputs"]):
|
||||
size = []
|
||||
for s in input[1]:
|
||||
if isinstance(s, (int)):
|
||||
size.append(s)
|
||||
else:
|
||||
size.append(1)
|
||||
sample_input.append(torch.randint(0, vocab_size, tuple(size), dtype=dtype).to(device))
|
||||
return sample_input
|
||||
|
||||
|
||||
def _load_bart_model():
|
||||
bart_onnx_model_path = os.path.join("testdata", "bart_tiny.onnx")
|
||||
model = onnx.load(bart_onnx_model_path)
|
||||
batch = 2
|
||||
seq_len = 1024
|
||||
model_desc = {
|
||||
"inputs": [
|
||||
(
|
||||
"src_tokens",
|
||||
[batch, seq_len],
|
||||
),
|
||||
(
|
||||
"prev_output_tokens",
|
||||
[batch, seq_len],
|
||||
),
|
||||
(
|
||||
"target",
|
||||
[batch * seq_len],
|
||||
),
|
||||
],
|
||||
"outputs": [("loss", [], True)],
|
||||
}
|
||||
|
||||
return model, model_desc
|
||||
|
||||
|
||||
def assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint, reshape_states=False):
|
||||
"""Assert that the two ORTTrainer (hierarchical) state dictionaries are very close for all states"""
|
||||
|
||||
assert ("model" in state_dict_pre_checkpoint) == ("model" in state_dict_post_checkpoint)
|
||||
assert ("optimizer" in state_dict_pre_checkpoint) == ("optimizer" in state_dict_post_checkpoint)
|
||||
|
||||
if "model" in state_dict_pre_checkpoint:
|
||||
for model_state_key in state_dict_pre_checkpoint["model"]["full_precision"]:
|
||||
if reshape_states:
|
||||
assert_allclose(
|
||||
state_dict_pre_checkpoint["model"]["full_precision"][model_state_key],
|
||||
state_dict_post_checkpoint["model"]["full_precision"][model_state_key].reshape(
|
||||
state_dict_pre_checkpoint["model"]["full_precision"][model_state_key].shape
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert_allclose(
|
||||
state_dict_pre_checkpoint["model"]["full_precision"][model_state_key],
|
||||
state_dict_post_checkpoint["model"]["full_precision"][model_state_key],
|
||||
)
|
||||
|
||||
if "optimizer" in state_dict_pre_checkpoint:
|
||||
for model_state_key in state_dict_pre_checkpoint["optimizer"]:
|
||||
for optimizer_state_key in state_dict_pre_checkpoint["optimizer"][model_state_key]:
|
||||
if reshape_states:
|
||||
assert_allclose(
|
||||
state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key],
|
||||
state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key].reshape(
|
||||
state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key].shape
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert_allclose(
|
||||
state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key],
|
||||
state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key],
|
||||
)
|
||||
|
|
|
@ -1,30 +1,11 @@
|
|||
import copy
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer
|
||||
from onnxruntime.training import orttrainer
|
||||
|
||||
try:
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
|
||||
from onnxruntime.training.ortmodule._graph_execution_manager_factory import ( # noqa: F401
|
||||
GraphExecutionManagerFactory,
|
||||
)
|
||||
except ImportError:
|
||||
# Some pipelines do not contain ORTModule
|
||||
pass
|
||||
except Exception as e:
|
||||
from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
|
||||
|
||||
if isinstance(e, ORTModuleInitException):
|
||||
# ORTModule is present but not ready to run
|
||||
# That is OK because this file is also used by ORTTrainer tests
|
||||
pass
|
||||
raise
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory # noqa: F401
|
||||
|
||||
|
||||
def is_all_or_nothing_fallback_enabled(model, policy=None):
|
||||
|
@ -66,103 +47,6 @@ def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0):
|
|||
assert_allclose(output_a, output_b, rtol=rtol, atol=atol, err_msg="Model output value mismatch")
|
||||
|
||||
|
||||
def assert_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0):
|
||||
r"""Asserts whether weight difference between models a and b differences are within specified tolerance
|
||||
|
||||
Compares the weights of two different ONNX models (model_a and model_b)
|
||||
and raises AssertError when they diverge by more than atol or rtol
|
||||
|
||||
Args:
|
||||
model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure
|
||||
verbose (bool, default is False): if True, prints absolute difference for each weight
|
||||
rtol (float, default is 1e-7): Max relative difference
|
||||
atol (float, default is 1e-4): Max absolute difference
|
||||
"""
|
||||
assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer)
|
||||
state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state()
|
||||
assert len(state_dict_a.items()) == len(state_dict_b.items())
|
||||
_assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol)
|
||||
|
||||
|
||||
def assert_legacy_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0):
|
||||
r"""Asserts whether weight difference between models a and b differences are within specified tolerance
|
||||
|
||||
Compares the weights of a legacy model model_a and experimental model_b model
|
||||
and raises AssertError when they diverge by more than atol or rtol.
|
||||
|
||||
Args:
|
||||
model_a (ORTTrainer): Instance of legacy ORTTrainer
|
||||
model_b (ORTTrainer): Instance of experimental ORTTrainer
|
||||
verbose (bool, default is False): if True, prints absolute difference for each weight.
|
||||
rtol (float, default is 1e-7): Max relative difference
|
||||
atol (float, default is 1e-4): Max absolute difference
|
||||
"""
|
||||
assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, Legacy_ORTTrainer)
|
||||
state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b.session.get_state()
|
||||
assert len(state_dict_a.items()) == len(state_dict_b.items())
|
||||
_assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol)
|
||||
|
||||
|
||||
def _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol):
|
||||
r"""Asserts whether dicts a and b value differences are within specified tolerance
|
||||
|
||||
Compares the weights of two model's state_dict dicts and raises AssertError
|
||||
when they diverge by more than atol or rtol
|
||||
|
||||
Args:
|
||||
model_a (ORTTrainer): Instance of legacy ORTTrainer
|
||||
model_b (ORTTrainer): Instance of experimental ORTTrainer
|
||||
verbose (bool, default is False): if True, prints absolute difference for each weight.
|
||||
rtol (float, default is 1e-7): Max relative difference
|
||||
atol (float, default is 1e-4): Max absolute difference
|
||||
"""
|
||||
|
||||
for (a_name, a_val), (_b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()):
|
||||
np_a_vals = np.array(a_val).flatten()
|
||||
np_b_vals = np.array(b_val).flatten()
|
||||
assert np_a_vals.shape == np_b_vals.shape
|
||||
if verbose:
|
||||
print(f"Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}")
|
||||
assert_allclose(a_val, b_val, rtol=rtol, atol=atol, err_msg=f"Weight mismatch for {a_name}")
|
||||
|
||||
|
||||
def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0):
|
||||
r"""Asserts whether optimizer state differences are within specified tolerance
|
||||
|
||||
Compares the expected and actual optimizer states of dicts and raises AssertError
|
||||
when they diverge by more than atol or rtol.
|
||||
The optimizer dict is of the form:
|
||||
model_weight_name:
|
||||
{
|
||||
"Moment_1": moment1_tensor,
|
||||
"Moment_2": moment2_tensor,
|
||||
"Update_Count": update_tensor # if optimizer is adam, absent otherwise
|
||||
},
|
||||
...
|
||||
"shared_optimizer_state": # if optimizer is shared, absent otherwise.
|
||||
So far, only lamb optimizer uses this.
|
||||
{
|
||||
"step": step_tensor # int array of size 1
|
||||
}
|
||||
|
||||
Args:
|
||||
expected_state (dict(dict())): Expected optimizer state
|
||||
actual_state (dict(dict())): Actual optimizer state
|
||||
rtol (float, default is 1e-7): Max relative difference
|
||||
atol (float, default is 0): Max absolute difference
|
||||
"""
|
||||
assert expected_state.keys() == actual_state.keys()
|
||||
for param_name, a_state in actual_state.items():
|
||||
for k, v in a_state.items():
|
||||
assert_allclose(
|
||||
v,
|
||||
expected_state[param_name][k],
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
err_msg=f"Optimizer state mismatch for param {param_name}, key {k}",
|
||||
)
|
||||
|
||||
|
||||
def is_dynamic_axes(model):
|
||||
# Check inputs
|
||||
for inp in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.input:
|
||||
|
|
|
@ -1,325 +0,0 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from orttraining_test_bert_postprocess import postprocess_model
|
||||
from orttraining_test_data_loader import create_ort_test_dataloader
|
||||
from orttraining_test_transformers import BertForPreTraining, BertModelTest
|
||||
from orttraining_test_utils import map_optimizer_attributes
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import ( # noqa: F401
|
||||
IODescription,
|
||||
LossScaler,
|
||||
ModelDescription,
|
||||
ORTTrainer,
|
||||
generate_sample,
|
||||
)
|
||||
|
||||
torch.manual_seed(1)
|
||||
onnxruntime.set_seed(1)
|
||||
|
||||
|
||||
class Test_PostPasses(unittest.TestCase): # noqa: N801
|
||||
def get_onnx_model(
|
||||
self, model, model_desc, inputs, device, _enable_internal_postprocess=True, _extra_postprocess=None
|
||||
):
|
||||
lr_desc = IODescription(
|
||||
"Learning_Rate",
|
||||
[
|
||||
1,
|
||||
],
|
||||
torch.float32,
|
||||
)
|
||||
model = ORTTrainer(
|
||||
model,
|
||||
None,
|
||||
model_desc,
|
||||
"LambOptimizer",
|
||||
map_optimizer_attributes,
|
||||
lr_desc,
|
||||
device,
|
||||
world_rank=0,
|
||||
world_size=1,
|
||||
_opset_version=14,
|
||||
_enable_internal_postprocess=_enable_internal_postprocess,
|
||||
_extra_postprocess=_extra_postprocess,
|
||||
)
|
||||
|
||||
model.train_step(*inputs)
|
||||
return model.onnx_model_
|
||||
|
||||
def count_all_nodes(self, model):
|
||||
return len(model.graph.node)
|
||||
|
||||
def count_nodes(self, model, node_type):
|
||||
count = 0
|
||||
for node in model.graph.node:
|
||||
if node.op_type == node_type:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def find_nodes(self, model, node_type):
|
||||
nodes = []
|
||||
for node in model.graph.node:
|
||||
if node.op_type == node_type:
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def get_name(self, name):
|
||||
if os.path.exists(name):
|
||||
return name
|
||||
rel = os.path.join("testdata", name)
|
||||
if os.path.exists(rel):
|
||||
return rel
|
||||
this = os.path.dirname(__file__)
|
||||
data = os.path.join(this, "..", "..", "..", "..", "onnxruntime", "test", "testdata")
|
||||
res = os.path.join(data, name)
|
||||
if os.path.exists(res):
|
||||
return res
|
||||
raise FileNotFoundError(f"Unable to find '{name}' or '{rel}' or '{res}'")
|
||||
|
||||
def test_layer_norm(self):
|
||||
class LayerNormNet(nn.Module):
|
||||
def __init__(self, target):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.LayerNorm(10)
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
self.target = target
|
||||
|
||||
def forward(self, x):
|
||||
output1 = self.ln_1(x)
|
||||
loss = self.loss(output1, self.target)
|
||||
return loss, output1
|
||||
|
||||
device = torch.device("cpu")
|
||||
target = torch.ones(20, 10, 10, dtype=torch.int64).to(device)
|
||||
model = LayerNormNet(target)
|
||||
input = torch.randn(20, 5, 10, 10, dtype=torch.float32).to(device)
|
||||
|
||||
input_desc = IODescription("input", [], "float32")
|
||||
output0_desc = IODescription("output0", [], "float32")
|
||||
output1_desc = IODescription("output1", [20, 5, 10, 10], "float32")
|
||||
model_desc = ModelDescription([input_desc], [output0_desc, output1_desc])
|
||||
|
||||
learning_rate = torch.tensor([1.0000000e00]).to(device)
|
||||
input_args = [input, learning_rate]
|
||||
|
||||
onnx_model = self.get_onnx_model(model, model_desc, input_args, device)
|
||||
|
||||
count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization")
|
||||
count_nodes = self.count_all_nodes(onnx_model)
|
||||
|
||||
assert count_layer_norm == 0
|
||||
assert count_nodes == 3
|
||||
|
||||
def test_expand(self):
|
||||
class ExpandNet(nn.Module):
|
||||
def __init__(self, target):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
self.target = target
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x, x1):
|
||||
output = x.expand_as(x1)
|
||||
output = self.linear(output)
|
||||
output = output + output
|
||||
loss = self.loss(output, self.target)
|
||||
return loss, output
|
||||
|
||||
device = torch.device("cpu")
|
||||
target = torch.ones(5, 5, 2, dtype=torch.int64).to(device)
|
||||
model = ExpandNet(target).to(device)
|
||||
|
||||
x = torch.randn(5, 3, 1, 2, dtype=torch.float32).to(device)
|
||||
x1 = torch.randn(5, 3, 5, 2, dtype=torch.float32).to(device)
|
||||
|
||||
input0_desc = IODescription("x", [5, 3, 1, 2], "float32")
|
||||
input1_desc = IODescription("x1", [5, 3, 5, 2], "float32")
|
||||
output0_desc = IODescription("output0", [], "float32")
|
||||
output1_desc = IODescription("output1", [5, 3, 5, 2], "float32")
|
||||
model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc])
|
||||
|
||||
learning_rate = torch.tensor([1.0000000e00]).to(device)
|
||||
input_args = [x, x1, learning_rate]
|
||||
|
||||
onnx_model = self.get_onnx_model(model, model_desc, input_args, device)
|
||||
|
||||
# check that expand output has shape
|
||||
expand_nodes = self.find_nodes(onnx_model, "Expand")
|
||||
assert len(expand_nodes) == 1
|
||||
|
||||
model_info = onnx_model.graph.value_info
|
||||
assert model_info[0].name == expand_nodes[0].output[0]
|
||||
assert model_info[0].type == onnx_model.graph.input[1].type
|
||||
|
||||
def test_bert(self):
|
||||
device = torch.device("cpu")
|
||||
|
||||
model_tester = BertModelTest.BertModelTester(self)
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = model_tester.prepare_config_and_inputs()
|
||||
|
||||
model = BertForPreTraining(config=config)
|
||||
model.eval()
|
||||
|
||||
loss, prediction_scores, seq_relationship_score = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
masked_lm_labels=token_labels,
|
||||
next_sentence_label=sequence_labels,
|
||||
)
|
||||
|
||||
model_desc = ModelDescription(
|
||||
[
|
||||
model_tester.input_ids_desc,
|
||||
model_tester.attention_mask_desc,
|
||||
model_tester.token_type_ids_desc,
|
||||
model_tester.masked_lm_labels_desc,
|
||||
model_tester.next_sentence_label_desc,
|
||||
],
|
||||
[model_tester.loss_desc, model_tester.prediction_scores_desc, model_tester.seq_relationship_scores_desc],
|
||||
)
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
MyArgs = namedtuple(
|
||||
"MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len"
|
||||
)
|
||||
args = MyArgs(
|
||||
local_rank=0,
|
||||
world_size=1,
|
||||
max_steps=100,
|
||||
learning_rate=0.00001,
|
||||
warmup_proportion=0.01,
|
||||
batch_size=13,
|
||||
seq_len=7,
|
||||
)
|
||||
|
||||
dataset_len = 100
|
||||
dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device)
|
||||
learning_rate = torch.tensor(1.0e0, dtype=torch.float32).to(device)
|
||||
for b in dataloader:
|
||||
batch = b
|
||||
break
|
||||
learning_rate = torch.tensor([1.00e00]).to(device)
|
||||
inputs = [*batch, learning_rate]
|
||||
|
||||
onnx_model = self.get_onnx_model(model, model_desc, inputs, device, _extra_postprocess=postprocess_model)
|
||||
|
||||
self._bert_helper(onnx_model)
|
||||
|
||||
def _bert_helper(self, onnx_model):
|
||||
# count layer_norm
|
||||
count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization")
|
||||
assert count_layer_norm == 0
|
||||
|
||||
# get expand node and check output shape
|
||||
expand_nodes = self.find_nodes(onnx_model, "Expand")
|
||||
assert len(expand_nodes) == 1
|
||||
|
||||
model_info = onnx_model.graph.value_info
|
||||
assert model_info[0].name == expand_nodes[0].output[0]
|
||||
assert model_info[0].type == onnx_model.graph.input[0].type
|
||||
|
||||
def test_extra_postpass(self):
|
||||
def postpass_replace_first_add_with_sub(model):
|
||||
# this post pass replaces the first Add node with Sub in the model.
|
||||
# Previous graph
|
||||
# (subgraph 1) (subgraph 2)
|
||||
# | |
|
||||
# | |
|
||||
# |________ ________|
|
||||
# | |
|
||||
# Add
|
||||
# |
|
||||
# (subgraph 3)
|
||||
#
|
||||
# Post graph
|
||||
# (subgraph 1) (subgraph 2)
|
||||
# | |
|
||||
# | |
|
||||
# |________ ________|
|
||||
# | |
|
||||
# Sub
|
||||
# |
|
||||
# (subgraph 3)
|
||||
add_nodes = [n for n in model.graph.node if n.op_type == "Add"]
|
||||
add_nodes[0].op_type = "Sub"
|
||||
|
||||
class MultiAdd(nn.Module):
|
||||
def __init__(self, target):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
self.target = target
|
||||
self.linear = torch.nn.Linear(2, 2, bias=False)
|
||||
|
||||
def forward(self, x, x1):
|
||||
output = x + x1
|
||||
output = output + x
|
||||
output = output + x1
|
||||
output = self.linear(output)
|
||||
loss = self.loss(output, self.target)
|
||||
return loss, output
|
||||
|
||||
device = torch.device("cpu")
|
||||
target = torch.ones(5, 2, dtype=torch.int64).to(device)
|
||||
model = MultiAdd(target).to(device)
|
||||
|
||||
x = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
|
||||
x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
|
||||
|
||||
input0_desc = IODescription("x", [5, 5, 2], "float32")
|
||||
input1_desc = IODescription("x1", [5, 5, 2], "float32")
|
||||
output0_desc = IODescription("output0", [], "float32")
|
||||
output1_desc = IODescription("output1", [5, 5, 2], "float32")
|
||||
model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc])
|
||||
|
||||
learning_rate = torch.tensor([1.0000000e00]).to(device)
|
||||
input_args = [x, x1, learning_rate]
|
||||
|
||||
onnx_model = self.get_onnx_model(
|
||||
model, model_desc, input_args, device, _extra_postprocess=postpass_replace_first_add_with_sub
|
||||
)
|
||||
|
||||
# check that extra postpass is called, and called only once.
|
||||
add_nodes = self.find_nodes(onnx_model, "Add")
|
||||
sub_nodes = self.find_nodes(onnx_model, "Sub")
|
||||
assert len(add_nodes) == 2
|
||||
assert len(sub_nodes) == 1
|
||||
|
||||
unprocessed_onnx_model = self.get_onnx_model(
|
||||
model, model_desc, input_args, device, _extra_postprocess=None, _enable_internal_postprocess=False
|
||||
)
|
||||
# check that the model is unchanged.
|
||||
add_nodes = self.find_nodes(unprocessed_onnx_model, "Add")
|
||||
sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub")
|
||||
assert len(add_nodes) == 3
|
||||
assert len(sub_nodes) == 0
|
||||
|
||||
processed_onnx_model = self.get_onnx_model(
|
||||
unprocessed_onnx_model,
|
||||
model_desc,
|
||||
input_args,
|
||||
device,
|
||||
_extra_postprocess=postpass_replace_first_add_with_sub,
|
||||
)
|
||||
# check that extra postpass is called, and called only once.
|
||||
add_nodes = self.find_nodes(processed_onnx_model, "Add")
|
||||
sub_nodes = self.find_nodes(processed_onnx_model, "Sub")
|
||||
assert len(add_nodes) == 2
|
||||
assert len(sub_nodes) == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(module=__name__, buffer=True)
|
|
@ -43,7 +43,7 @@ def run_ortmodule_ops_tests(cwd, log, transformers_cache):
|
|||
|
||||
env = get_env_with_transformers_cache(transformers_cache)
|
||||
|
||||
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnx_ops_ortmodule.py"]
|
||||
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_onnx_ops.py"]
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode()
|
||||
|
||||
|
@ -146,7 +146,7 @@ def run_data_sampler_tests(cwd, log):
|
|||
def run_hooks_tests(cwd, log):
|
||||
log.debug("Running: Data hooks tests")
|
||||
|
||||
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_hooks.py"]
|
||||
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_hooks.py"]
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
|
||||
|
|
|
@ -1,801 +0,0 @@
|
|||
# ==================
|
||||
import dataclasses
|
||||
import datetime
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import unittest
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers import BertConfig, BertForPreTraining, HfArgumentParser
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
# need to override torch.onnx.symbolic_opset12.nll_loss to handle ignore_index == -100 cases.
|
||||
# the fix for ignore_index == -100 cases is already in pytorch master.
|
||||
# however to use current torch master is causing computation changes in many tests.
|
||||
# eventually we will use pytorch with fixed nll_loss once computation
|
||||
# issues are understood and solved.
|
||||
import onnxruntime.capi.pt_patch
|
||||
from onnxruntime.training import amp, optim, orttrainer
|
||||
from onnxruntime.training.checkpoint import aggregate_checkpoints
|
||||
from onnxruntime.training.optim import LinearWarmupLRScheduler, PolyWarmupLRScheduler # noqa: F401
|
||||
|
||||
# we cannot make full convergence run in nightly pipeling because of its timeout limit,
|
||||
# max_steps is still needed to calculate learning rate. force_to_stop_max_steps is used to
|
||||
# terminate the training before the pipeline run hit its timeout.
|
||||
force_to_stop_max_steps = 2500
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process(args):
|
||||
if hasattr(args, "world_rank"):
|
||||
return args.world_rank in [-1, 0]
|
||||
else:
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def bert_model_description(config):
|
||||
vocab_size = config.vocab_size
|
||||
new_model_desc = {
|
||||
"inputs": [
|
||||
(
|
||||
"input_ids",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"attention_mask",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"token_type_ids",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"masked_lm_labels",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"next_sentence_label",
|
||||
[
|
||||
"batch",
|
||||
],
|
||||
),
|
||||
],
|
||||
"outputs": [
|
||||
("loss", [], True),
|
||||
(
|
||||
"prediction_scores",
|
||||
["batch", "max_seq_len_in_batch", vocab_size],
|
||||
),
|
||||
(
|
||||
"seq_relationship_scores",
|
||||
["batch", 2],
|
||||
),
|
||||
],
|
||||
}
|
||||
return new_model_desc
|
||||
|
||||
|
||||
def create_pretraining_dataset(input_file, max_pred_length, args):
|
||||
train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length)
|
||||
train_sampler = RandomSampler(train_data)
|
||||
train_dataloader = DataLoader(
|
||||
train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=0, pin_memory=True
|
||||
)
|
||||
return train_dataloader, input_file
|
||||
|
||||
|
||||
class pretraining_dataset(Dataset): # noqa: N801
|
||||
def __init__(self, input_file, max_pred_length):
|
||||
logger.info("pretraining_dataset: %s, max_pred_length: %d", input_file, max_pred_length)
|
||||
self.input_file = input_file
|
||||
self.max_pred_length = max_pred_length
|
||||
f = h5py.File(input_file, "r")
|
||||
keys = [
|
||||
"input_ids",
|
||||
"input_mask",
|
||||
"segment_ids",
|
||||
"masked_lm_positions",
|
||||
"masked_lm_ids",
|
||||
"next_sentence_labels",
|
||||
]
|
||||
self.inputs = [np.asarray(f[key][:]) for key in keys]
|
||||
f.close()
|
||||
|
||||
def __len__(self):
|
||||
"Denotes the total number of samples"
|
||||
return len(self.inputs[0])
|
||||
|
||||
def __getitem__(self, index):
|
||||
[input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [
|
||||
torch.from_numpy(input[index].astype(np.int64))
|
||||
if indice < 5
|
||||
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
|
||||
for indice, input in enumerate(self.inputs)
|
||||
]
|
||||
|
||||
# HF model use default ignore_index value (-100) for CrossEntropyLoss
|
||||
masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -100
|
||||
index = self.max_pred_length
|
||||
# store number of masked tokens in index
|
||||
padded_mask_indices = (masked_lm_positions == 0).nonzero()
|
||||
if len(padded_mask_indices) != 0:
|
||||
index = padded_mask_indices[0].item()
|
||||
masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index]
|
||||
return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels]
|
||||
|
||||
|
||||
import argparse # noqa: E402
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# batch size test config parameters
|
||||
parser.add_argument(
|
||||
"--enable_mixed_precision",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit float precision instead of 32-bit",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sequence_length",
|
||||
default=512,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, and sequences shorter \n"
|
||||
"than this will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence"
|
||||
)
|
||||
parser.add_argument("--max_batch_size", default=32, type=int, help="Total batch size for training.")
|
||||
|
||||
parser.add_argument("--gelu_recompute", default=False, action="store_true")
|
||||
|
||||
parser.add_argument("--attn_dropout_recompute", default=False, action="store_true")
|
||||
|
||||
parser.add_argument("--transformer_layer_recompute", default=False, action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretrainArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
input_dir: str = field(
|
||||
default=None, metadata={"help": "The input data dir. Should contain .hdf5 files for the task"}
|
||||
)
|
||||
|
||||
bert_model: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Bert pre-trained model selected in the list: bert-base-uncased, \
|
||||
bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
|
||||
},
|
||||
)
|
||||
|
||||
output_dir: str = field(
|
||||
default=None, metadata={"help": "The output directory where the model checkpoints will be written."}
|
||||
)
|
||||
|
||||
cache_dir: str = field(
|
||||
default="/tmp/bert_pretrain/",
|
||||
metadata={"help": "The output directory where the model checkpoints will be written."},
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer \
|
||||
than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
|
||||
max_predictions_per_seq: Optional[int] = field(
|
||||
default=80, metadata={"help": "The maximum total of masked tokens in input sequence."}
|
||||
)
|
||||
|
||||
train_batch_size: Optional[int] = field(default=32, metadata={"help": "Batch size for training."})
|
||||
|
||||
learning_rate: Optional[float] = field(default=5e-5, metadata={"help": "The initial learning rate for Lamb."})
|
||||
|
||||
num_train_epochs: Optional[float] = field(
|
||||
default=3.0, metadata={"help": "Total number of training epochs to perform."}
|
||||
)
|
||||
|
||||
max_steps: Optional[float] = field(default=1000, metadata={"help": "Total number of training steps to perform."})
|
||||
|
||||
warmup_proportion: Optional[float] = field(
|
||||
default=0.01,
|
||||
metadata={
|
||||
"help": "Proportion of training to perform linear learning rate warmup for. \
|
||||
E.g., 0.1 = 10%% of training."
|
||||
},
|
||||
)
|
||||
|
||||
local_rank: Optional[int] = field(default=-1, metadata={"help": "local_rank for distributed training on gpus."})
|
||||
|
||||
world_rank: Optional[int] = field(default=-1)
|
||||
|
||||
world_size: Optional[int] = field(default=1)
|
||||
|
||||
seed: Optional[int] = field(default=42, metadata={"help": "random seed for initialization."})
|
||||
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=1, metadata={"help": "Number of updates steps to accumualte before performing a backward/update pass."}
|
||||
)
|
||||
|
||||
fp16: bool = field(default=False, metadata={"help": "Whether to use 16-bit float precision instead of 32-bit."})
|
||||
|
||||
gelu_recompute: bool = field(
|
||||
default=False, metadata={"help": "Whether to enable recomputing Gelu activation output to save memory."}
|
||||
)
|
||||
attn_dropout_recompute: bool = field(
|
||||
default=False, metadata={"help": "Whether to enable recomputing attention dropout to save memory."}
|
||||
)
|
||||
transformer_layer_recompute: bool = field(
|
||||
default=False, metadata={"help": "Whether to enable recomputing transformer layerwise to save memory."}
|
||||
)
|
||||
|
||||
loss_scale: Optional[float] = field(
|
||||
default=0.0, metadata={"help": "Loss scaling, positive power of 2 values can improve fp16 convergence."}
|
||||
)
|
||||
|
||||
deepspeed_zero_stage: Optional[int] = field(default=0, metadata={"help": "Deepspeed Zero Stage. 0 => disabled"})
|
||||
|
||||
log_freq: Optional[float] = field(default=1.0, metadata={"help": "frequency of logging loss."})
|
||||
|
||||
checkpoint_activations: bool = field(default=False, metadata={"help": "Whether to use gradient checkpointing."})
|
||||
|
||||
resume_from_checkpoint: bool = field(
|
||||
default=False, metadata={"help": "Whether to resume training from checkpoint."}
|
||||
)
|
||||
|
||||
resume_step: Optional[int] = field(default=-1, metadata={"help": "Step to resume training from."})
|
||||
|
||||
num_steps_per_checkpoint: Optional[int] = field(
|
||||
default=100, metadata={"help": "Number of update steps until a model checkpoint is saved to disk."}
|
||||
)
|
||||
|
||||
save_checkpoint: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Enable for saving a model checkpoint to disk."}
|
||||
)
|
||||
|
||||
init_state_dict: Optional[dict] = field(default=None, metadata={"help": "State to load before training."})
|
||||
|
||||
phase2: bool = field(default=False, metadata={"help": "Whether to train with seq len 512."})
|
||||
|
||||
allreduce_post_accumulation: bool = field(
|
||||
default=False, metadata={"help": "Whether to do allreduces during gradient accumulation steps."}
|
||||
)
|
||||
|
||||
allreduce_post_accumulation_fp16: bool = field(
|
||||
default=False, metadata={"help": "Whether to do fp16 allreduce post accumulation."}
|
||||
)
|
||||
|
||||
accumulate_into_fp16: bool = field(default=False, metadata={"help": "Whether to use fp16 gradient accumulators."})
|
||||
|
||||
phase1_end_step: Optional[int] = field(
|
||||
default=7038, metadata={"help": "Whether to use fp16 gradient accumulators."}
|
||||
)
|
||||
|
||||
tensorboard_dir: Optional[str] = field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
schedule: Optional[str] = field(
|
||||
default="warmup_poly",
|
||||
)
|
||||
|
||||
# this argument is test specific. to run a full bert model will take too long to run. instead, we reduce
|
||||
# number of hidden layers so that it can show convergence to an extend to help detect any regression.
|
||||
force_num_hidden_layers: Optional[int] = field(
|
||||
default=None, metadata={"help": "Whether to use fp16 gradient accumulators."}
|
||||
)
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Serializes this instance to a JSON string.
|
||||
"""
|
||||
return json.dumps(dataclasses.asdict(self), indent=2)
|
||||
|
||||
def to_sanitized_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Sanitized serialization to use with TensorBoard`s hparams
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
valid_types = [bool, int, float, str, torch.Tensor]
|
||||
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
|
||||
|
||||
|
||||
def setup_training(args):
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
if args.local_rank == -1:
|
||||
args.local_rank = 0
|
||||
args.world_rank = 0
|
||||
|
||||
print("args.local_rank: ", args.local_rank)
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
args.n_gpu = 1
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError(
|
||||
f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1"
|
||||
)
|
||||
if args.train_batch_size % args.gradient_accumulation_steps != 0:
|
||||
raise ValueError(
|
||||
"Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format(
|
||||
args.gradient_accumulation_steps, args.train_batch_size
|
||||
)
|
||||
)
|
||||
|
||||
# args.train_batch_size is per global step (optimization step) batch size
|
||||
# now make it a per gpu batch size
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
args.train_batch_size = args.train_batch_size // args.world_size
|
||||
|
||||
logger.info("setup_training: args.train_batch_size = %d", args.train_batch_size)
|
||||
return device, args
|
||||
|
||||
|
||||
def setup_torch_distributed(world_rank, world_size):
|
||||
os.environ["RANK"] = str(world_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12345"
|
||||
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=world_rank)
|
||||
return
|
||||
|
||||
|
||||
def prepare_model(args, device):
|
||||
config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir)
|
||||
|
||||
# config.num_hidden_layers = 12
|
||||
if args.force_num_hidden_layers:
|
||||
logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers)
|
||||
config.num_hidden_layers = args.force_num_hidden_layers
|
||||
|
||||
model = BertForPreTraining(config)
|
||||
if args.init_state_dict is not None:
|
||||
model.load_state_dict(args.init_state_dict)
|
||||
model_desc = bert_model_description(config)
|
||||
|
||||
lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion)
|
||||
|
||||
loss_scaler = amp.DynamicLossScaler() if args.fp16 else None
|
||||
|
||||
options = orttrainer.ORTTrainerOptions(
|
||||
{
|
||||
"batch": {"gradient_accumulation_steps": args.gradient_accumulation_steps},
|
||||
"device": {"id": str(device)},
|
||||
"mixed_precision": {"enabled": args.fp16, "loss_scaler": loss_scaler},
|
||||
"graph_transformer": {
|
||||
"attn_dropout_recompute": args.attn_dropout_recompute,
|
||||
"gelu_recompute": args.gelu_recompute,
|
||||
"transformer_layer_recompute": args.transformer_layer_recompute,
|
||||
},
|
||||
"debug": {
|
||||
"deterministic_compute": True,
|
||||
},
|
||||
"utils": {"grad_norm_clip": True},
|
||||
"distributed": {
|
||||
"world_rank": max(0, args.local_rank),
|
||||
"world_size": args.world_size,
|
||||
"local_rank": max(0, args.local_rank),
|
||||
"allreduce_post_accumulation": args.allreduce_post_accumulation,
|
||||
"deepspeed_zero_optimization": {"stage": args.deepspeed_zero_stage},
|
||||
"enable_adasum": False,
|
||||
},
|
||||
"lr_scheduler": lr_scheduler,
|
||||
}
|
||||
)
|
||||
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
|
||||
params = [
|
||||
{
|
||||
"params": [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)],
|
||||
"alpha": 0.9,
|
||||
"beta": 0.999,
|
||||
"lambda": 0.0,
|
||||
"epsilon": 1e-6,
|
||||
},
|
||||
{
|
||||
"params": [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)],
|
||||
"alpha": 0.9,
|
||||
"beta": 0.999,
|
||||
"lambda": 0.0,
|
||||
"epsilon": 1e-6,
|
||||
},
|
||||
]
|
||||
|
||||
optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
|
||||
model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_data_file(f_id, world_rank, world_size, files):
|
||||
num_files = len(files)
|
||||
if world_size > num_files:
|
||||
remainder = world_size % num_files
|
||||
return files[(f_id * world_size + world_rank + remainder * f_id) % num_files]
|
||||
elif world_size > 1:
|
||||
return files[(f_id * world_size + world_rank) % num_files]
|
||||
else:
|
||||
return files[f_id % num_files]
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser(PretrainArguments)
|
||||
args = parser.parse_args_into_dataclasses()[0]
|
||||
do_pretrain(args)
|
||||
|
||||
|
||||
def do_pretrain(args):
|
||||
if is_main_process(args) and args.tensorboard_dir:
|
||||
tb_writer = SummaryWriter(log_dir=args.tensorboard_dir)
|
||||
tb_writer.add_text("args", args.to_json_string())
|
||||
tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
ort.set_seed(args.seed)
|
||||
|
||||
device, args = setup_training(args)
|
||||
|
||||
model = prepare_model(args, device)
|
||||
|
||||
logger.info("Running training: Batch size = %d, initial LR = %f", args.train_batch_size, args.learning_rate)
|
||||
|
||||
average_loss = 0.0
|
||||
epoch = 0
|
||||
training_steps = 0
|
||||
|
||||
pool = ProcessPoolExecutor(1)
|
||||
while True:
|
||||
files = [
|
||||
os.path.join(args.input_dir, f)
|
||||
for f in os.listdir(args.input_dir)
|
||||
if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in f
|
||||
]
|
||||
files.sort()
|
||||
random.shuffle(files)
|
||||
|
||||
f_id = 0
|
||||
train_dataloader, data_file = create_pretraining_dataset(
|
||||
get_data_file(f_id, args.world_rank, args.world_size, files), args.max_predictions_per_seq, args
|
||||
)
|
||||
|
||||
for f_id in range(1, len(files)):
|
||||
logger.info("data file %s" % (data_file))
|
||||
|
||||
dataset_future = pool.submit(
|
||||
create_pretraining_dataset,
|
||||
get_data_file(f_id, args.world_rank, args.world_size, files),
|
||||
args.max_predictions_per_seq,
|
||||
args,
|
||||
)
|
||||
|
||||
train_iter = tqdm(train_dataloader, desc="Iteration") if is_main_process(args) else train_dataloader
|
||||
for _step, batch in enumerate(train_iter):
|
||||
training_steps += 1
|
||||
batch = [t.to(device) for t in batch] # noqa: PLW2901
|
||||
input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
|
||||
|
||||
loss, _, _ = model.train_step(
|
||||
input_ids, input_mask, segment_ids, masked_lm_labels, next_sentence_labels
|
||||
)
|
||||
average_loss += loss.item()
|
||||
|
||||
global_step = model._train_step_info.optimization_step
|
||||
if training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
|
||||
if is_main_process(args):
|
||||
divisor = args.log_freq * args.gradient_accumulation_steps
|
||||
if tb_writer:
|
||||
lr = model.options.lr_scheduler.get_last_lr()[0]
|
||||
tb_writer.add_scalar("train/summary/scalar/Learning_Rate", lr, global_step)
|
||||
if args.fp16:
|
||||
tb_writer.add_scalar("train/summary/scalar/loss_scale_25", loss, global_step)
|
||||
# TODO: ORTTrainer to expose all_finite
|
||||
# tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step)
|
||||
tb_writer.add_scalar("train/summary/total_loss", average_loss / divisor, global_step)
|
||||
|
||||
print(f"Step:{global_step} Average Loss = {average_loss / divisor}")
|
||||
|
||||
if global_step >= args.max_steps or global_step >= force_to_stop_max_steps:
|
||||
if tb_writer:
|
||||
tb_writer.close()
|
||||
|
||||
if global_step >= args.max_steps:
|
||||
if args.save_checkpoint:
|
||||
model.save_checkpoint(os.path.join(args.output_dir, f"checkpoint-{args.world_rank}.ortcp"))
|
||||
final_loss = average_loss / (args.log_freq * args.gradient_accumulation_steps)
|
||||
return final_loss
|
||||
|
||||
average_loss = 0
|
||||
|
||||
del train_dataloader
|
||||
|
||||
train_dataloader, data_file = dataset_future.result(timeout=None)
|
||||
|
||||
epoch += 1
|
||||
|
||||
|
||||
def generate_tensorboard_logdir(root_dir):
|
||||
current_date_time = datetime.datetime.today()
|
||||
|
||||
dt_string = current_date_time.strftime("BERT_pretrain_%y_%m_%d_%I_%M_%S")
|
||||
return os.path.join(root_dir, dt_string)
|
||||
|
||||
|
||||
class ORTBertPretrainTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.output_dir = "/bert_data/hf_data/test_out/bert_pretrain_results"
|
||||
self.bert_model = "bert-base-uncased"
|
||||
self.local_rank = -1
|
||||
self.world_rank = -1
|
||||
self.world_size = 1
|
||||
self.max_steps = 300000
|
||||
self.learning_rate = 5e-4
|
||||
self.max_seq_length = 512
|
||||
self.max_predictions_per_seq = 20
|
||||
self.input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train"
|
||||
self.train_batch_size = 4096
|
||||
self.gradient_accumulation_steps = 64
|
||||
self.fp16 = True
|
||||
self.allreduce_post_accumulation = True
|
||||
self.tensorboard_dir = "/bert_data/hf_data/test_out"
|
||||
|
||||
def test_pretrain_throughput(self, process_args=None):
|
||||
if process_args.sequence_length == 128:
|
||||
input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train"
|
||||
else:
|
||||
input_dir = "/bert_data/hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train"
|
||||
|
||||
print("process_args.enable_mixed_precision: ", process_args.enable_mixed_precision)
|
||||
print("process_args.sequence_length: ", process_args.sequence_length)
|
||||
print("process_args.max_batch_size: ", process_args.max_batch_size)
|
||||
print("process_args.max_predictions_per_seq: ", process_args.max_predictions_per_seq)
|
||||
print("process_args.gelu_recompute: ", process_args.gelu_recompute)
|
||||
print("process_args.attn_dropout_recompute: ", process_args.attn_dropout_recompute)
|
||||
print("process_args.transformer_layer_recompute: ", process_args.transformer_layer_recompute)
|
||||
|
||||
args = PretrainArguments(
|
||||
input_dir=input_dir,
|
||||
output_dir="/bert_data/hf_data/test_out/bert_pretrain_results",
|
||||
bert_model="bert-large-uncased",
|
||||
local_rank=self.local_rank,
|
||||
world_rank=self.world_rank,
|
||||
world_size=self.world_size,
|
||||
max_steps=10,
|
||||
learning_rate=5e-4,
|
||||
max_seq_length=process_args.sequence_length,
|
||||
max_predictions_per_seq=process_args.max_predictions_per_seq,
|
||||
train_batch_size=process_args.max_batch_size,
|
||||
gradient_accumulation_steps=1,
|
||||
fp16=process_args.enable_mixed_precision,
|
||||
gelu_recompute=process_args.gelu_recompute,
|
||||
attn_dropout_recompute=process_args.attn_dropout_recompute,
|
||||
transformer_layer_recompute=process_args.transformer_layer_recompute,
|
||||
allreduce_post_accumulation=True,
|
||||
# TODO: remove
|
||||
force_num_hidden_layers=2,
|
||||
)
|
||||
do_pretrain(args)
|
||||
|
||||
def test_pretrain_convergence(self):
|
||||
args = PretrainArguments(
|
||||
output_dir=self.output_dir,
|
||||
bert_model=self.bert_model,
|
||||
local_rank=self.local_rank,
|
||||
world_rank=self.world_rank,
|
||||
world_size=self.world_size,
|
||||
max_steps=self.max_steps,
|
||||
learning_rate=self.learning_rate,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_predictions_per_seq=self.max_predictions_per_seq,
|
||||
train_batch_size=self.train_batch_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
input_dir=self.input_dir,
|
||||
fp16=self.fp16,
|
||||
allreduce_post_accumulation=self.allreduce_post_accumulation,
|
||||
force_num_hidden_layers=self.force_num_hidden_layers,
|
||||
tensorboard_dir=generate_tensorboard_logdir("/bert_data/hf_data/test_out/"),
|
||||
)
|
||||
final_loss = do_pretrain(args)
|
||||
return final_loss
|
||||
|
||||
def test_pretrain_zero(self):
|
||||
assert self.world_size > 0, "ZeRO test requires a distributed run."
|
||||
setup_torch_distributed(self.world_rank, self.world_size)
|
||||
per_gpu_batch_size = 32
|
||||
optimization_batch_size = per_gpu_batch_size * self.world_size # set to disable grad accumulation
|
||||
|
||||
self.train_batch_size = optimization_batch_size
|
||||
self.gradient_accumulation_steps = 1
|
||||
self.deepspeed_zero_stage = 1
|
||||
self.force_num_hidden_layers = 2
|
||||
self.max_seq_length = 32
|
||||
self.output_dir = "./bert_pretrain_ckpt"
|
||||
if self.world_rank == 0:
|
||||
if os.path.isdir(self.output_dir):
|
||||
shutil.rmtree(self.output_dir)
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
assert os.path.exists(self.output_dir)
|
||||
|
||||
# run a few optimization steps
|
||||
self.max_steps = 200
|
||||
args = PretrainArguments(
|
||||
output_dir=self.output_dir,
|
||||
bert_model=self.bert_model,
|
||||
local_rank=self.local_rank,
|
||||
world_rank=self.world_rank,
|
||||
world_size=self.world_size,
|
||||
max_steps=self.max_steps,
|
||||
learning_rate=self.learning_rate,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_predictions_per_seq=self.max_predictions_per_seq,
|
||||
train_batch_size=self.train_batch_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
input_dir=self.input_dir,
|
||||
fp16=self.fp16,
|
||||
allreduce_post_accumulation=self.allreduce_post_accumulation,
|
||||
force_num_hidden_layers=self.force_num_hidden_layers,
|
||||
deepspeed_zero_stage=self.deepspeed_zero_stage,
|
||||
save_checkpoint=True,
|
||||
)
|
||||
do_pretrain(args)
|
||||
|
||||
# ensure all workers reach this point before loading the checkpointed state
|
||||
torch.distributed.barrier()
|
||||
|
||||
# on rank 0, load the trained state
|
||||
if args.world_rank == 0:
|
||||
checkpoint_files = glob.glob(os.path.join(self.output_dir, "checkpoint*.ortcp"))
|
||||
args.init_state_dict = aggregate_checkpoints(checkpoint_files, pytorch_format=True)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
# run a single step to get the loss, on rank 0 should be lesser than starting loss
|
||||
args.save_checkpoint = False
|
||||
args.max_steps = 1
|
||||
args.deepspeed_zero_stage = 0
|
||||
final_loss = do_pretrain(args)
|
||||
return final_loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
logger.warning("sys.argv: %s", sys.argv)
|
||||
# usage:
|
||||
# data parallel training
|
||||
# mpirun -n 4 python orttraining_run_bert_pretrain.py
|
||||
#
|
||||
# single gpu:
|
||||
# python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput
|
||||
# [batch size test arguments]
|
||||
# python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_convergence
|
||||
#
|
||||
# pytorch.distributed.launch will not work because ort backend requires MPI to broadcast ncclUniqueId
|
||||
# calling unpublished get_mpi_context_xxx to get rank/size numbers.
|
||||
try:
|
||||
# In case ORT is not built with MPI/NCCL, there are no get_mpi_context_xxx internal apis.
|
||||
from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401
|
||||
from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401
|
||||
from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size
|
||||
|
||||
has_get_mpi_context_internal_api = True
|
||||
except ImportError:
|
||||
has_get_mpi_context_internal_api = False
|
||||
pass
|
||||
if has_get_mpi_context_internal_api and get_mpi_context_world_size() > 1:
|
||||
world_size = get_mpi_context_world_size()
|
||||
print("get_mpi_context_world_size(): ", world_size)
|
||||
local_rank = get_mpi_context_local_rank()
|
||||
|
||||
if local_rank == 0:
|
||||
print("================================================================> os.getpid() = ", os.getpid())
|
||||
|
||||
test = ORTBertPretrainTest()
|
||||
test.setUp()
|
||||
test.local_rank = local_rank
|
||||
test.world_rank = local_rank
|
||||
test.world_size = world_size
|
||||
|
||||
if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_zero":
|
||||
logger.info("running ORTBertPretrainTest.test_pretrain_zero()...")
|
||||
final_loss = test.test_pretrain_zero()
|
||||
logger.info("ORTBertPretrainTest.test_pretrain_zero() rank = %i final loss = %f", local_rank, final_loss)
|
||||
if local_rank == 0:
|
||||
test.assertLess(final_loss, 10.2)
|
||||
else:
|
||||
test.assertGreater(final_loss, 11.0)
|
||||
logger.info("ORTBertPretrainTest.test_pretrain_zero() passed")
|
||||
elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence":
|
||||
logger.info("running ORTBertPretrainTest.test_pretrain_convergence()...")
|
||||
test.max_steps = 200
|
||||
test.force_num_hidden_layers = 8
|
||||
final_loss = test.test_pretrain_convergence()
|
||||
logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss)
|
||||
test.assertLess(final_loss, 8.5)
|
||||
logger.info("ORTBertPretrainTest.test_pretrain_convergence() passed")
|
||||
else:
|
||||
# https://microsoft.sharepoint.com/teams/ONNX2/_layouts/15/Doc.aspx?sourcedoc={170774be-e1c6-4f8b-a3ae-984f211fe410}&action=edit&wd=target%28ONNX%20Training.one%7C8176133b-c7cb-4ef2-aa9d-3fdad5344c40%2FGitHub%20Master%20Merge%20Schedule%7Cb67f0db1-e3a0-4add-80a6-621d67fd8107%2F%29
|
||||
# to make equivalent args for cpp convergence test
|
||||
test.max_seq_length = 128
|
||||
test.max_predictions_per_seq = 20
|
||||
test.gradient_accumulation_steps = 16
|
||||
|
||||
# cpp_batch_size (=64) * grad_acc * world_size
|
||||
test.train_batch_size = 64 * test.gradient_accumulation_steps * test.world_size
|
||||
test.max_steps = 300000
|
||||
|
||||
test.force_num_hidden_layers = None
|
||||
|
||||
# already using Adam (e.g. AdamConfig)
|
||||
test.learning_rate = 5e-4
|
||||
test.warmup_proportion = 0.1
|
||||
|
||||
final_loss = test.test_pretrain_convergence()
|
||||
logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss)
|
||||
else:
|
||||
# unittest does not accept user defined arguments
|
||||
# we need to run this script with user defined arguments
|
||||
if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_throughput":
|
||||
run_test_pretrain_throughput, run_test_pretrain_convergence = True, False
|
||||
sys.argv.remove("ORTBertPretrainTest.test_pretrain_throughput")
|
||||
elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence":
|
||||
run_test_pretrain_throughput, run_test_pretrain_convergence = False, True
|
||||
sys.argv.remove("ORTBertPretrainTest.test_pretrain_convergence")
|
||||
else:
|
||||
run_test_pretrain_throughput, run_test_pretrain_convergence = True, True
|
||||
process_args = parse_arguments()
|
||||
test = ORTBertPretrainTest()
|
||||
test.setUp()
|
||||
|
||||
if run_test_pretrain_throughput:
|
||||
logger.info("running single GPU ORTBertPretrainTest.test_pretrain_throughput()...")
|
||||
test.test_pretrain_throughput(process_args)
|
||||
logger.info("single GPU ORTBertPretrainTest.test_pretrain_throughput() passed")
|
||||
|
||||
# unittest.main()
|
|
@ -1,67 +0,0 @@
|
|||
import collections
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
Config = collections.namedtuple(
|
||||
"Config",
|
||||
[
|
||||
"enable_mixed_precision",
|
||||
"sequence_length",
|
||||
"max_batch_size",
|
||||
"max_predictions_per_seq",
|
||||
"gelu_recompute",
|
||||
"attn_dropout_recompute",
|
||||
"transformer_layer_recompute",
|
||||
],
|
||||
)
|
||||
|
||||
configs = [
|
||||
Config(True, 128, 46, 20, False, False, False),
|
||||
Config(True, 512, 8, 80, False, False, False),
|
||||
Config(False, 128, 26, 20, False, False, False),
|
||||
Config(False, 512, 4, 80, False, False, False),
|
||||
Config(True, 128, 50, 20, True, False, False),
|
||||
Config(True, 128, 50, 20, False, True, False),
|
||||
Config(True, 128, 76, 20, False, False, True),
|
||||
Config(True, 512, 8, 80, True, False, False),
|
||||
Config(True, 512, 9, 80, False, True, False),
|
||||
Config(True, 512, 15, 80, False, False, True),
|
||||
]
|
||||
|
||||
|
||||
def run_with_config(config):
|
||||
print(
|
||||
"##### testing name - {}-{} #####".format(
|
||||
"fp16" if config.enable_mixed_precision else "fp32", config.sequence_length
|
||||
)
|
||||
)
|
||||
print("gelu_recompute: ", config.gelu_recompute)
|
||||
print("attn_dropout_recompute: ", config.attn_dropout_recompute)
|
||||
print("transformer_layer_recompute: ", config.transformer_layer_recompute)
|
||||
|
||||
cmds = [
|
||||
sys.executable,
|
||||
"orttraining_run_bert_pretrain.py",
|
||||
"ORTBertPretrainTest.test_pretrain_throughput",
|
||||
"--sequence_length",
|
||||
str(config.sequence_length),
|
||||
"--max_batch_size",
|
||||
str(config.max_batch_size),
|
||||
"--max_predictions_per_seq",
|
||||
str(config.max_predictions_per_seq),
|
||||
]
|
||||
if config.enable_mixed_precision:
|
||||
cmds.append("--enable_mixed_precision")
|
||||
if config.gelu_recompute:
|
||||
cmds.append("--gelu_recompute")
|
||||
if config.attn_dropout_recompute:
|
||||
cmds.append("--attn_dropout_recompute")
|
||||
if config.transformer_layer_recompute:
|
||||
cmds.append("--transformer_layer_recompute")
|
||||
|
||||
# access to azure storage shared disk is much slower so we need a longer timeout.
|
||||
subprocess.run(cmds, timeout=1200).check_returncode() # noqa: PLW1510
|
||||
|
||||
|
||||
for config in configs:
|
||||
run_with_config(config)
|
|
@ -1,323 +0,0 @@
|
|||
# adapted from run_glue.py of huggingface transformers
|
||||
|
||||
import dataclasses # noqa: F401
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
EvalPrediction,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
TrainingArguments,
|
||||
glue_compute_metrics,
|
||||
glue_output_modes,
|
||||
glue_tasks_num_labels,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401
|
||||
|
||||
try:
|
||||
from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401
|
||||
from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401
|
||||
from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size
|
||||
|
||||
has_get_mpi_context_internal_api = True
|
||||
except ImportError:
|
||||
has_get_mpi_context_internal_api = False
|
||||
pass
|
||||
|
||||
|
||||
import torch # noqa: F401
|
||||
from orttraining_transformer_trainer import ORTTransformerTrainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def verify_old_and_new_api_are_equal(results_per_api):
|
||||
new_api_results = results_per_api[True]
|
||||
old_api_results = results_per_api[False]
|
||||
for key in new_api_results:
|
||||
assert_allclose(new_api_results[key], old_api_results[key])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"})
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
|
||||
|
||||
class ORTGlueTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# configurations not to be changed accoss tests
|
||||
self.max_seq_length = 128
|
||||
self.train_batch_size = 8
|
||||
self.learning_rate = 2e-5
|
||||
self.num_train_epochs = 3.0
|
||||
self.local_rank = -1
|
||||
self.world_size = 1
|
||||
self.overwrite_output_dir = True
|
||||
self.gradient_accumulation_steps = 1
|
||||
self.data_dir = "/bert_data/hf_data/glue_data/"
|
||||
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "glue_test_output/")
|
||||
self.cache_dir = "/tmp/glue/"
|
||||
self.logging_steps = 10
|
||||
|
||||
def test_roberta_with_mrpc(self):
|
||||
expected_acc = 0.85
|
||||
expected_f1 = 0.88
|
||||
expected_loss = 0.35
|
||||
results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=False)
|
||||
|
||||
assert results["acc"] >= expected_acc
|
||||
assert results["f1"] >= expected_f1
|
||||
assert results["loss"] <= expected_loss
|
||||
|
||||
def test_roberta_fp16_with_mrpc(self):
|
||||
expected_acc = 0.87
|
||||
expected_f1 = 0.90
|
||||
expected_loss = 0.33
|
||||
|
||||
results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=True)
|
||||
|
||||
assert results["acc"] >= expected_acc
|
||||
assert results["f1"] >= expected_f1
|
||||
assert results["loss"] <= expected_loss
|
||||
|
||||
def test_bert_with_mrpc(self):
|
||||
if self.local_rank == -1:
|
||||
expected_acc = 0.83
|
||||
expected_f1 = 0.88
|
||||
expected_loss = 0.44
|
||||
elif self.local_rank == 0:
|
||||
expected_acc = 0.81
|
||||
expected_f1 = 0.86
|
||||
expected_loss = 0.44
|
||||
|
||||
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False)
|
||||
|
||||
if self.local_rank in [-1, 0]:
|
||||
assert results["acc"] >= expected_acc
|
||||
assert results["f1"] >= expected_f1
|
||||
assert results["loss"] <= expected_loss
|
||||
|
||||
def test_bert_fp16_with_mrpc(self):
|
||||
expected_acc = 0.84
|
||||
expected_f1 = 0.88
|
||||
expected_loss = 0.46
|
||||
|
||||
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=True)
|
||||
|
||||
assert results["acc"] >= expected_acc
|
||||
assert results["f1"] >= expected_f1
|
||||
assert results["loss"] <= expected_loss
|
||||
|
||||
def model_to_desc(self, model_name, model):
|
||||
if model_name.startswith("bert") or model_name.startswith("xlnet"):
|
||||
model_desc = {
|
||||
"inputs": [
|
||||
(
|
||||
"input_ids",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"attention_mask",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"token_type_ids",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"labels",
|
||||
[
|
||||
"batch",
|
||||
],
|
||||
),
|
||||
],
|
||||
"outputs": [("loss", [], True), ("logits", ["batch", 2])],
|
||||
}
|
||||
elif model_name.startswith("roberta"):
|
||||
model_desc = {
|
||||
"inputs": [
|
||||
(
|
||||
"input_ids",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"attention_mask",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"labels",
|
||||
[
|
||||
"batch",
|
||||
],
|
||||
),
|
||||
],
|
||||
"outputs": [("loss", [], True), ("logits", ["batch", 2])],
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"unsupported base model name {model_name}.")
|
||||
|
||||
return model_desc
|
||||
|
||||
def run_glue(self, model_name, task_name, fp16):
|
||||
model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir)
|
||||
data_args = GlueDataTrainingArguments(
|
||||
task_name=task_name, data_dir=os.path.join(self.data_dir, task_name), max_seq_length=self.max_seq_length
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=os.path.join(self.output_dir, task_name),
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
per_gpu_train_batch_size=self.train_batch_size,
|
||||
learning_rate=self.learning_rate,
|
||||
num_train_epochs=self.num_train_epochs,
|
||||
local_rank=self.local_rank,
|
||||
overwrite_output_dir=self.overwrite_output_dir,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
fp16=fp16,
|
||||
logging_steps=self.logging_steps,
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
set_seed(training_args.seed)
|
||||
onnxruntime.set_seed(training_args.seed)
|
||||
|
||||
try:
|
||||
num_labels = glue_tasks_num_labels[data_args.task_name]
|
||||
output_mode = glue_output_modes[data_args.task_name]
|
||||
except KeyError:
|
||||
raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=data_args.task_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
||||
|
||||
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
if output_mode == "classification":
|
||||
preds = np.argmax(p.predictions, axis=1)
|
||||
elif output_mode == "regression":
|
||||
preds = np.squeeze(p.predictions)
|
||||
return glue_compute_metrics(data_args.task_name, preds, p.label_ids)
|
||||
|
||||
model_desc = self.model_to_desc(model_name, model)
|
||||
# Initialize the ORTTrainer within ORTTransformerTrainer
|
||||
trainer = ORTTransformerTrainer(
|
||||
model=model,
|
||||
model_desc=model_desc,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval and training_args.local_rank in [-1, 0]:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
result = trainer.evaluate()
|
||||
|
||||
logger.info(f"***** Eval results {data_args.task_name} *****")
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if has_get_mpi_context_internal_api:
|
||||
local_rank = get_mpi_context_local_rank()
|
||||
world_size = get_mpi_context_world_size()
|
||||
else:
|
||||
local_rank = -1
|
||||
world_size = 1
|
||||
|
||||
if world_size > 1:
|
||||
# mpi launch
|
||||
logger.warning("mpirun launch, local_rank / world_size: %s : % s", local_rank, world_size)
|
||||
|
||||
# TrainingArguments._setup_devices will call torch.distributed.init_process_group(backend="nccl")
|
||||
# pytorch expects following environment settings (which would be set if launched with torch.distributed.launch).
|
||||
|
||||
os.environ["RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "29500"
|
||||
|
||||
from onnxruntime.capi._pybind_state import set_cuda_device_id
|
||||
|
||||
set_cuda_device_id(local_rank)
|
||||
|
||||
test = ORTGlueTest()
|
||||
test.setUp()
|
||||
test.local_rank = local_rank
|
||||
test.world_size = world_size
|
||||
test.test_bert_with_mrpc()
|
||||
else:
|
||||
unittest.main()
|
|
@ -1,281 +0,0 @@
|
|||
# adapted from run_multiple_choice.py of huggingface transformers
|
||||
# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/run_multiple_choice.py
|
||||
|
||||
import dataclasses # noqa: F401
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch # noqa: F401
|
||||
from numpy.testing import assert_allclose # noqa: F401
|
||||
from orttraining_run_glue import verify_old_and_new_api_are_equal # noqa: F401
|
||||
from orttraining_transformer_trainer import ORTTransformerTrainer
|
||||
from transformers import HfArgumentParser # noqa: F401
|
||||
from transformers import Trainer # noqa: F401
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoTokenizer,
|
||||
EvalPrediction,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from utils_multiple_choice import MultipleChoiceDataset, Split, SwagProcessor
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"})
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
task_name: str = field(metadata={"help": "The name of the task to train on."})
|
||||
data_dir: str = field(metadata={"help": "Should contain the data files for the task."})
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"})
|
||||
|
||||
|
||||
class ORTMultipleChoiceTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# configurations not to be changed accoss tests
|
||||
self.max_seq_length = 80
|
||||
self.train_batch_size = 16
|
||||
self.eval_batch_size = 2
|
||||
self.learning_rate = 2e-5
|
||||
self.num_train_epochs = 1.0
|
||||
self.local_rank = -1
|
||||
self.overwrite_output_dir = True
|
||||
self.gradient_accumulation_steps = 8
|
||||
self.data_dir = "/bert_data/hf_data/swag/swagaf/data"
|
||||
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "multiple_choice_test_output/")
|
||||
self.cache_dir = "/tmp/multiple_choice/"
|
||||
self.logging_steps = 10
|
||||
self.rtol = 2e-01
|
||||
|
||||
def test_bert_with_swag(self):
|
||||
expected_acc = 0.75
|
||||
expected_loss = 0.64
|
||||
|
||||
results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=False)
|
||||
assert results["acc"] >= expected_acc
|
||||
assert results["loss"] <= expected_loss
|
||||
|
||||
def test_bert_fp16_with_swag(self):
|
||||
# larger batch can be handled with mixed precision
|
||||
self.train_batch_size = 32
|
||||
|
||||
expected_acc = 0.73
|
||||
expected_loss = 0.68
|
||||
|
||||
results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=True)
|
||||
assert results["acc"] >= expected_acc
|
||||
assert results["loss"] <= expected_loss
|
||||
|
||||
def run_multiple_choice(self, model_name, task_name, fp16):
|
||||
model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir)
|
||||
data_args = DataTrainingArguments(
|
||||
task_name=task_name, data_dir=self.data_dir, max_seq_length=self.max_seq_length
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=os.path.join(self.output_dir, task_name),
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
per_gpu_train_batch_size=self.train_batch_size,
|
||||
per_gpu_eval_batch_size=self.eval_batch_size,
|
||||
learning_rate=self.learning_rate,
|
||||
num_train_epochs=self.num_train_epochs,
|
||||
local_rank=self.local_rank,
|
||||
overwrite_output_dir=self.overwrite_output_dir,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
fp16=fp16,
|
||||
logging_steps=self.logging_steps,
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
set_seed(training_args.seed)
|
||||
onnxruntime.set_seed(training_args.seed)
|
||||
|
||||
try:
|
||||
processor = SwagProcessor()
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
except KeyError:
|
||||
raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=data_args.task_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
model = AutoModelForMultipleChoice.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
MultipleChoiceDataset(
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
task=data_args.task_name,
|
||||
processor=processor,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
mode=Split.train,
|
||||
)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
MultipleChoiceDataset(
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
task=data_args.task_name,
|
||||
processor=processor,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
overwrite_cache=data_args.overwrite_cache,
|
||||
mode=Split.dev,
|
||||
)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
preds = np.argmax(p.predictions, axis=1)
|
||||
return {"acc": simple_accuracy(preds, p.label_ids)}
|
||||
|
||||
if model_name.startswith("bert"):
|
||||
model_desc = {
|
||||
"inputs": [
|
||||
(
|
||||
"input_ids",
|
||||
["batch", num_labels, "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"attention_mask",
|
||||
["batch", num_labels, "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"token_type_ids",
|
||||
["batch", num_labels, "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"labels",
|
||||
["batch", num_labels],
|
||||
),
|
||||
],
|
||||
"outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])],
|
||||
}
|
||||
else:
|
||||
model_desc = {
|
||||
"inputs": [
|
||||
(
|
||||
"input_ids",
|
||||
["batch", num_labels, "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"attention_mask",
|
||||
["batch", num_labels, "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"labels",
|
||||
["batch", num_labels],
|
||||
),
|
||||
],
|
||||
"outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])],
|
||||
}
|
||||
|
||||
# Initialize the ORTTrainer within ORTTransformerTrainer
|
||||
trainer = ORTTransformerTrainer(
|
||||
model=model,
|
||||
model_desc=model_desc,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval and training_args.local_rank in [-1, 0]:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
result = trainer.evaluate()
|
||||
|
||||
logger.info(f"***** Eval results {data_args.task_name} *****")
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,6 +0,0 @@
|
|||
from orttraining_test_layer_norm_transform import layer_norm_transform # noqa: F401
|
||||
from orttraining_test_model_transform import add_expand_shape, add_name, fix_transpose # noqa: F401
|
||||
|
||||
|
||||
def postprocess_model(model):
|
||||
add_name(model)
|
|
@ -1,257 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# orttraining_test_checkpoint_storage.py
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from onnxruntime.training import _checkpoint_storage
|
||||
|
||||
# Helper functions
|
||||
|
||||
|
||||
def _equals(a, b):
|
||||
"""Checks recursively if two dictionaries are equal"""
|
||||
if isinstance(a, dict):
|
||||
return all(not (key not in b or not _equals(a[key], b[key])) for key in a)
|
||||
else:
|
||||
if isinstance(a, bytes):
|
||||
a = a.decode()
|
||||
if isinstance(b, bytes):
|
||||
b = b.decode()
|
||||
are_equal = a == b
|
||||
return are_equal if isinstance(are_equal, bool) else are_equal.all()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _numpy_types(obj_value):
|
||||
"""Return a bool indicating whether or not the input obj_value is a numpy type object
|
||||
|
||||
Recursively checks if the obj_value (could be a dictionary) is a numpy type object.
|
||||
Exceptions are str and bytes.
|
||||
|
||||
Returns true if object is numpy type, str, or bytes
|
||||
False if any other type
|
||||
"""
|
||||
if not isinstance(obj_value, dict):
|
||||
return isinstance(obj_value, (str, bytes)) or type(obj_value).__module__ == np.__name__
|
||||
|
||||
return all(_numpy_types(value) for _, value in obj_value.items())
|
||||
|
||||
|
||||
def _get_dict(separated_key):
|
||||
"""Create dummy dictionary with different datatypes
|
||||
|
||||
Returns the tuple of the entire dummy dictionary created, key argument as a dictionary for _checkpoint_storage.load
|
||||
function and the value for that key in the original dictionary
|
||||
|
||||
For example the complete dictionary is represented by:
|
||||
{
|
||||
'int1':1,
|
||||
'int2': 2,
|
||||
'int_list': [1,2,3,5,6],
|
||||
'dict1': {
|
||||
'np_array': np.arange(100),
|
||||
'dict2': {'int3': 3, 'int4': 4},
|
||||
'str1': "onnxruntime"
|
||||
},
|
||||
'bool1': bool(True),
|
||||
'int5': 5,
|
||||
'float1': 2.345,
|
||||
'np_array_float': np.array([1.234, 2.345, 3.456]),
|
||||
'np_array_float_3_dim': np.array([[[1,2],[3,4]], [[5,6],[7,8]]])
|
||||
}
|
||||
|
||||
if the input key is ['dict1', 'str1'], then the key argument returned is 'dict1/str1'
|
||||
and the value corresponding to that is "onnxruntime"
|
||||
|
||||
so, for the above example, the returned tuple is:
|
||||
(original_dict, {'key': 'dict1/str1', "onnxruntime")
|
||||
"""
|
||||
test_dict = {
|
||||
"int1": 1,
|
||||
"int2": 2,
|
||||
"int_list": [1, 2, 3, 5, 6],
|
||||
"dict1": {"np_array": np.arange(100), "dict2": {"int3": 3, "int4": 4}, "str1": "onnxruntime"},
|
||||
"bool1": True,
|
||||
"int5": 5,
|
||||
"float1": 2.345,
|
||||
"np_array_float": np.array([1.234, 2.345, 3.456]),
|
||||
"np_array_float_3_dim": np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]),
|
||||
}
|
||||
key = ""
|
||||
expected_val = test_dict
|
||||
for single_key in separated_key:
|
||||
key += single_key + "/"
|
||||
expected_val = expected_val[single_key]
|
||||
return test_dict, {"key": key} if len(separated_key) > 0 else dict(), expected_val
|
||||
|
||||
|
||||
class _CustomClass:
|
||||
"""Custom object that encpsulates dummy values for loss, epoch and train_step"""
|
||||
|
||||
def __init__(self):
|
||||
self._loss = 1.23
|
||||
self._epoch = 12000
|
||||
self._train_step = 25
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, _CustomClass):
|
||||
return self._loss == other._loss and self._epoch == other._epoch and self._train_step == other._train_step
|
||||
|
||||
|
||||
# Test fixtures
|
||||
|
||||
|
||||
@pytest.yield_fixture(scope="function")
|
||||
def checkpoint_storage_test_setup():
|
||||
checkpoint_dir = os.path.abspath("checkpoint_dir/")
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
pytest.checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ortcp")
|
||||
yield "checkpoint_storage_test_setup"
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
|
||||
|
||||
@pytest.yield_fixture(scope="function")
|
||||
def checkpoint_storage_test_parameterized_setup(request, checkpoint_storage_test_setup):
|
||||
yield request.param
|
||||
|
||||
|
||||
# Tests
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"checkpoint_storage_test_parameterized_setup",
|
||||
[
|
||||
_get_dict([]),
|
||||
_get_dict(["int1"]),
|
||||
_get_dict(["dict1"]),
|
||||
_get_dict(["dict1", "dict2"]),
|
||||
_get_dict(["dict1", "dict2", "int4"]),
|
||||
_get_dict(["dict1", "str1"]),
|
||||
_get_dict(["bool1"]),
|
||||
_get_dict(["float1"]),
|
||||
_get_dict(["np_array_float"]),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_checkpoint_storage_saved_dict_matches_loaded(checkpoint_storage_test_parameterized_setup):
|
||||
to_save = checkpoint_storage_test_parameterized_setup[0]
|
||||
key_arg = checkpoint_storage_test_parameterized_setup[1]
|
||||
expected = checkpoint_storage_test_parameterized_setup[2]
|
||||
_checkpoint_storage.save(to_save, pytest.checkpoint_path)
|
||||
loaded = _checkpoint_storage.load(pytest.checkpoint_path, **key_arg)
|
||||
assert _equals(loaded, expected)
|
||||
assert _numpy_types(loaded)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"checkpoint_storage_test_parameterized_setup",
|
||||
[{"int_set": {1, 2, 3, 4, 5}}, {"str_set": {"one", "two"}}, [1, 2, 3], 2.352],
|
||||
indirect=True,
|
||||
)
|
||||
def test_checkpoint_storage_saving_non_supported_types_fails(checkpoint_storage_test_parameterized_setup):
|
||||
to_save = checkpoint_storage_test_parameterized_setup
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
_checkpoint_storage.save(to_save, pytest.checkpoint_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"checkpoint_storage_test_parameterized_setup",
|
||||
[
|
||||
({"int64_tensor": torch.tensor(np.arange(100))}, "int64_tensor", torch.int64, np.int64),
|
||||
({"int32_tensor": torch.tensor(np.arange(100), dtype=torch.int32)}, "int32_tensor", torch.int32, np.int32),
|
||||
({"int16_tensor": torch.tensor(np.arange(100), dtype=torch.int16)}, "int16_tensor", torch.int16, np.int16),
|
||||
({"int8_tensor": torch.tensor(np.arange(100), dtype=torch.int8)}, "int8_tensor", torch.int8, np.int8),
|
||||
({"float64_tensor": torch.tensor(np.array([1.0, 2.0]))}, "float64_tensor", torch.float64, np.float64),
|
||||
(
|
||||
{"float32_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32)},
|
||||
"float32_tensor",
|
||||
torch.float32,
|
||||
np.float32,
|
||||
),
|
||||
(
|
||||
{"float16_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float16)},
|
||||
"float16_tensor",
|
||||
torch.float16,
|
||||
np.float16,
|
||||
),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_checkpoint_storage_saving_tensor_datatype(checkpoint_storage_test_parameterized_setup):
|
||||
tensor_dict = checkpoint_storage_test_parameterized_setup[0]
|
||||
tensor_name = checkpoint_storage_test_parameterized_setup[1]
|
||||
tensor_dtype = checkpoint_storage_test_parameterized_setup[2]
|
||||
np_dtype = checkpoint_storage_test_parameterized_setup[3]
|
||||
|
||||
_checkpoint_storage.save(tensor_dict, pytest.checkpoint_path)
|
||||
|
||||
loaded = _checkpoint_storage.load(pytest.checkpoint_path)
|
||||
assert isinstance(loaded[tensor_name], np.ndarray)
|
||||
assert tensor_dict[tensor_name].dtype == tensor_dtype
|
||||
assert loaded[tensor_name].dtype == np_dtype
|
||||
assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"checkpoint_storage_test_parameterized_setup",
|
||||
[
|
||||
({"two_dim": torch.ones([2, 4], dtype=torch.float64)}, "two_dim"),
|
||||
({"three_dim": torch.ones([2, 4, 6], dtype=torch.float64)}, "three_dim"),
|
||||
({"four_dim": torch.ones([2, 4, 6, 8], dtype=torch.float64)}, "four_dim"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_checkpoint_storage_saving_multiple_dimension_tensors(checkpoint_storage_test_parameterized_setup):
|
||||
tensor_dict = checkpoint_storage_test_parameterized_setup[0]
|
||||
tensor_name = checkpoint_storage_test_parameterized_setup[1]
|
||||
|
||||
_checkpoint_storage.save(tensor_dict, pytest.checkpoint_path)
|
||||
|
||||
loaded = _checkpoint_storage.load(pytest.checkpoint_path)
|
||||
assert isinstance(loaded[tensor_name], np.ndarray)
|
||||
assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"checkpoint_storage_test_parameterized_setup", [{}, {"a": {}}, {"a": {"b": {}}}], indirect=True
|
||||
)
|
||||
def test_checkpoint_storage_saving_and_loading_empty_dictionaries_succeeds(checkpoint_storage_test_parameterized_setup):
|
||||
saved = checkpoint_storage_test_parameterized_setup
|
||||
_checkpoint_storage.save(saved, pytest.checkpoint_path)
|
||||
|
||||
loaded = _checkpoint_storage.load(pytest.checkpoint_path)
|
||||
assert _equals(saved, loaded)
|
||||
|
||||
|
||||
def test_checkpoint_storage_load_file_that_does_not_exist_fails(checkpoint_storage_test_setup):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
_checkpoint_storage.load(pytest.checkpoint_path)
|
||||
|
||||
|
||||
def test_checkpoint_storage_for_custom_user_dict_succeeds(checkpoint_storage_test_setup):
|
||||
custom_class = _CustomClass()
|
||||
user_dict = {"tensor1": torch.tensor(np.arange(100), dtype=torch.float32), "custom_class": custom_class}
|
||||
|
||||
pickled_bytes = pickle.dumps(user_dict).hex()
|
||||
to_save = {"a": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32), "user_dict": pickled_bytes}
|
||||
_checkpoint_storage.save(to_save, pytest.checkpoint_path)
|
||||
|
||||
loaded_dict = _checkpoint_storage.load(pytest.checkpoint_path)
|
||||
assert (loaded_dict["a"] == to_save["a"].numpy()).all()
|
||||
try: # noqa: SIM105
|
||||
loaded_dict["user_dict"] = loaded_dict["user_dict"].decode()
|
||||
except AttributeError:
|
||||
pass
|
||||
loaded_obj = pickle.loads(bytes.fromhex(loaded_dict["user_dict"]))
|
||||
|
||||
assert torch.all(loaded_obj["tensor1"].eq(user_dict["tensor1"]))
|
||||
assert loaded_obj["custom_class"] == custom_class
|
|
@ -4,8 +4,6 @@ from enum import Enum
|
|||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from onnxruntime.capi.ort_trainer import generate_sample
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
|
@ -41,6 +39,16 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
|||
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
|
||||
|
||||
|
||||
def generate_sample(desc, device=None):
|
||||
"""Generate a sample based on the description"""
|
||||
# symbolic dimensions are described with strings. set symbolic dimensions to be 1
|
||||
size = [s if isinstance(s, (int)) else 1 for s in desc.shape_]
|
||||
if desc.num_classes_:
|
||||
return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device)
|
||||
else:
|
||||
return torch.randn(size, dtype=desc.dtype_).to(device)
|
||||
|
||||
|
||||
class OrtTestDataset(Dataset):
|
||||
def __init__(self, input_desc, seq_len, dataset_len, device):
|
||||
import copy
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from _test_commons import _load_pytorch_transformer_model
|
||||
|
||||
from onnxruntime import set_seed
|
||||
from onnxruntime.training import optim, orttrainer
|
||||
|
||||
###############################################################################
|
||||
# Testing starts here #########################################################
|
||||
###############################################################################
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seed, device",
|
||||
[
|
||||
(24, "cuda"),
|
||||
],
|
||||
)
|
||||
def testORTTransformerModelExport(seed, device):
|
||||
# Common setup
|
||||
optim_config = optim.LambConfig()
|
||||
opts = orttrainer.ORTTrainerOptions(
|
||||
{
|
||||
"debug": {
|
||||
"check_model_export": True,
|
||||
},
|
||||
"device": {
|
||||
"id": device,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Setup for the first ORTTRainer run
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
|
||||
first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
|
||||
data, targets = batcher_fn(train_data, 0)
|
||||
_ = first_trainer.train_step(data, targets)
|
||||
assert first_trainer._onnx_model is not None
|
|
@ -27,7 +27,7 @@ def run_training_apis_python_api_tests(cwd, log):
|
|||
|
||||
log.debug("Running: ort training api tests")
|
||||
|
||||
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_python_bindings.py"]
|
||||
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_py_bindings.py"]
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
|
||||
|
@ -37,7 +37,7 @@ def run_onnxblock_tests(cwd, log):
|
|||
|
||||
log.debug("Running: onnxblock tests")
|
||||
|
||||
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnxblock.py"]
|
||||
command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_onnxblock.py"]
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import numpy as np
|
|||
import onnx
|
||||
import pytest
|
||||
import torch
|
||||
from orttraining_test_onnxblock import _get_models
|
||||
from orttraining_test_ort_apis_onnxblock import _get_models
|
||||
|
||||
import onnxruntime.training.onnxblock as onnxblock
|
||||
from onnxruntime import OrtValue, SessionOptions
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,722 +0,0 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import pytest
|
||||
import torch
|
||||
from _test_commons import _load_pytorch_transformer_model
|
||||
|
||||
from onnxruntime.training import _checkpoint_storage, amp, checkpoint, optim, orttrainer # noqa: F401
|
||||
|
||||
# Helper functions
|
||||
|
||||
|
||||
def _create_trainer(zero_enabled=False):
|
||||
"""Cerates a simple ORTTrainer for ORTTrainer functional tests"""
|
||||
|
||||
device = "cuda"
|
||||
optim_config = optim.LambConfig(lr=0.1)
|
||||
opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}}
|
||||
if zero_enabled:
|
||||
opts["distributed"] = {
|
||||
"world_rank": 0,
|
||||
"world_size": 1,
|
||||
"horizontal_parallel_size": 1,
|
||||
"data_parallel_size": 1,
|
||||
"allreduce_post_accumulation": True,
|
||||
"deepspeed_zero_optimization": {"stage": 1},
|
||||
}
|
||||
model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
|
||||
trainer = orttrainer.ORTTrainer(
|
||||
model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts)
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
class _training_session_mock: # noqa: N801
|
||||
"""Mock object for the ORTTrainer _training_session member"""
|
||||
|
||||
def __init__(self, model_states, optimizer_states, partition_info):
|
||||
self.model_states = model_states
|
||||
self.optimizer_states = optimizer_states
|
||||
self.partition_info = partition_info
|
||||
|
||||
def get_model_state(self, include_mixed_precision_weights=False):
|
||||
return self.model_states
|
||||
|
||||
def get_optimizer_state(self):
|
||||
return self.optimizer_states
|
||||
|
||||
def get_partition_info_map(self):
|
||||
return self.partition_info
|
||||
|
||||
|
||||
def _get_load_state_dict_strict_error_arguments():
|
||||
"""Return a list of tuples that can be used as parameters for test_load_state_dict_errors_when_model_key_missing
|
||||
|
||||
Construct a list of tuples (training_session_state_dict, input_state_dict, error_arguments)
|
||||
The load_state_dict function will compare the two state dicts (training_session_state_dict, input_state_dict) and
|
||||
throw a runtime error with the missing/unexpected keys. The error arguments capture these missing/unexpected keys.
|
||||
"""
|
||||
|
||||
training_session_state_dict = {
|
||||
"model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}},
|
||||
"optimizer": {
|
||||
"a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
|
||||
"shared_optimizer_state": {"step": np.arange(5)},
|
||||
},
|
||||
}
|
||||
|
||||
# input state dictionaries
|
||||
precision_key_missing = {"model": {}, "optimizer": {}}
|
||||
precision_key_unexpected = {"model": {"full_precision": {}, "mixed_precision": {}}, "optimizer": {}}
|
||||
model_state_key_missing = {"model": {"full_precision": {}}, "optimizer": {}}
|
||||
model_state_key_unexpected = {"model": {"full_precision": {"a": 2, "b": 3, "c": 4}}, "optimizer": {}}
|
||||
optimizer_model_state_key_missing = {"model": {"full_precision": {"a": 2, "b": 3}}, "optimizer": {}}
|
||||
optimizer_model_state_key_unexpected = {
|
||||
"model": {"full_precision": {"a": 2, "b": 3}},
|
||||
"optimizer": {"a": {}, "shared_optimizer_state": {}, "b": {}},
|
||||
}
|
||||
optimizer_state_key_missing = {
|
||||
"model": {"full_precision": {"a": 2, "b": 3}},
|
||||
"optimizer": {"a": {}, "shared_optimizer_state": {"step": np.arange(5)}},
|
||||
}
|
||||
optimizer_state_key_unexpected = {
|
||||
"model": {"full_precision": {"a": 2, "b": 3}},
|
||||
"optimizer": {
|
||||
"a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
|
||||
"shared_optimizer_state": {"step": np.arange(5), "another_step": np.arange(1)},
|
||||
},
|
||||
}
|
||||
|
||||
input_arguments = [
|
||||
(training_session_state_dict, precision_key_missing, ["full_precision"]),
|
||||
(training_session_state_dict, precision_key_unexpected, ["mixed_precision"]),
|
||||
(training_session_state_dict, model_state_key_missing, ["a", "b"]),
|
||||
(training_session_state_dict, model_state_key_unexpected, ["c"]),
|
||||
(training_session_state_dict, optimizer_model_state_key_missing, ["a", "shared_optimizer_state"]),
|
||||
(training_session_state_dict, optimizer_model_state_key_unexpected, ["b"]),
|
||||
(training_session_state_dict, optimizer_state_key_missing, ["Moment_1", "Moment_2"]),
|
||||
(training_session_state_dict, optimizer_state_key_unexpected, ["another_step"]),
|
||||
]
|
||||
|
||||
return input_arguments
|
||||
|
||||
|
||||
# Tests
|
||||
|
||||
|
||||
def test_empty_state_dict_when_training_session_uninitialized():
|
||||
trainer = _create_trainer()
|
||||
with pytest.warns(UserWarning) as user_warning:
|
||||
state_dict = trainer.state_dict()
|
||||
|
||||
assert len(state_dict.keys()) == 0
|
||||
assert (
|
||||
user_warning[0].message.args[0] == "ONNX Runtime training session is not initialized yet. "
|
||||
"Please run train_step or eval_step at least once before calling ORTTrainer.state_dict()."
|
||||
)
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_training_session_provides_empty_model_states(onnx_model_mock):
|
||||
trainer = _create_trainer()
|
||||
training_session_mock = _training_session_mock({}, {}, {})
|
||||
trainer._training_session = training_session_mock
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
|
||||
state_dict = trainer.state_dict()
|
||||
assert len(state_dict["model"].keys()) == 0
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_training_session_provides_model_states(onnx_model_mock):
|
||||
trainer = _create_trainer()
|
||||
model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}
|
||||
training_session_mock = _training_session_mock(model_states, {}, {})
|
||||
trainer._training_session = training_session_mock
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
|
||||
state_dict = trainer.state_dict()
|
||||
assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all()
|
||||
assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all()
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_training_session_provides_model_states_pytorch_format(onnx_model_mock):
|
||||
trainer = _create_trainer()
|
||||
model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}
|
||||
training_session_mock = _training_session_mock(model_states, {}, {})
|
||||
trainer._training_session = training_session_mock
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
|
||||
state_dict = trainer.state_dict(pytorch_format=True)
|
||||
assert torch.all(torch.eq(state_dict["a"], torch.tensor(np.arange(5))))
|
||||
assert torch.all(torch.eq(state_dict["b"], torch.tensor(np.arange(7))))
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_onnx_graph_provides_frozen_model_states(onnx_model_mock):
|
||||
trainer = _create_trainer()
|
||||
model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}
|
||||
training_session_mock = _training_session_mock(model_states, {}, {})
|
||||
trainer._training_session = training_session_mock
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
trainer.options.utils.frozen_weights = ["a_frozen_weight", "a_float16_weight"]
|
||||
trainer._onnx_model.graph.initializer = [
|
||||
onnx.numpy_helper.from_array(np.array([1, 2, 3], dtype=np.float32), "a_frozen_weight"),
|
||||
onnx.numpy_helper.from_array(np.array([4, 5, 6], dtype=np.float32), "a_non_fronzen_weight"),
|
||||
onnx.numpy_helper.from_array(np.array([7, 8, 9], dtype=np.float16), "a_float16_weight"),
|
||||
]
|
||||
|
||||
state_dict = trainer.state_dict()
|
||||
assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all()
|
||||
assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all()
|
||||
assert (state_dict["model"]["full_precision"]["a_frozen_weight"] == np.array([1, 2, 3], dtype=np.float32)).all()
|
||||
assert "a_non_fronzen_weight" not in state_dict["model"]["full_precision"]
|
||||
assert (state_dict["model"]["full_precision"]["a_float16_weight"] == np.array([7, 8, 9], dtype=np.float32)).all()
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_training_session_provides_empty_optimizer_states(onnx_model_mock):
|
||||
trainer = _create_trainer()
|
||||
training_session_mock = _training_session_mock({}, {}, {})
|
||||
trainer._training_session = training_session_mock
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
|
||||
state_dict = trainer.state_dict()
|
||||
assert len(state_dict["optimizer"].keys()) == 0
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_training_session_provides_optimizer_states(onnx_model_mock):
|
||||
trainer = _create_trainer()
|
||||
optimizer_states = {
|
||||
"model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
|
||||
"shared_optimizer_state": {"step": np.arange(1)},
|
||||
}
|
||||
training_session_mock = _training_session_mock({}, optimizer_states, {})
|
||||
trainer._training_session = training_session_mock
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
|
||||
state_dict = trainer.state_dict()
|
||||
assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all()
|
||||
assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all()
|
||||
assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all()
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_training_session_provides_optimizer_states_pytorch_format(onnx_model_mock):
|
||||
trainer = _create_trainer()
|
||||
model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}
|
||||
optimizer_states = {
|
||||
"model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
|
||||
"shared_optimizer_state": {"step": np.arange(1)},
|
||||
}
|
||||
training_session_mock = _training_session_mock(model_states, optimizer_states, {})
|
||||
trainer._training_session = training_session_mock
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
|
||||
state_dict = trainer.state_dict(pytorch_format=True)
|
||||
assert "optimizer" not in state_dict
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_training_session_provides_empty_partition_info_map(onnx_model_mock):
|
||||
trainer = _create_trainer(zero_enabled=True)
|
||||
training_session_mock = _training_session_mock({}, {}, {})
|
||||
trainer._training_session = training_session_mock
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
|
||||
state_dict = trainer.state_dict()
|
||||
assert len(state_dict["partition_info"].keys()) == 0
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_training_session_provides_partition_info_map(onnx_model_mock):
|
||||
trainer = _create_trainer(zero_enabled=True)
|
||||
partition_info = {"a": {"original_dim": [1, 2, 3]}}
|
||||
training_session_mock = _training_session_mock({}, {}, partition_info)
|
||||
trainer._training_session = training_session_mock
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
|
||||
state_dict = trainer.state_dict()
|
||||
assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3]
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_training_session_provides_all_states(onnx_model_mock):
|
||||
trainer = _create_trainer(zero_enabled=True)
|
||||
model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}
|
||||
optimizer_states = {
|
||||
"model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
|
||||
"shared_optimizer_state": {"step": np.arange(1)},
|
||||
}
|
||||
partition_info = {"a": {"original_dim": [1, 2, 3]}}
|
||||
training_session_mock = _training_session_mock(model_states, optimizer_states, partition_info)
|
||||
trainer._training_session = training_session_mock
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
|
||||
state_dict = trainer.state_dict()
|
||||
assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all()
|
||||
assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all()
|
||||
assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all()
|
||||
assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all()
|
||||
assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all()
|
||||
assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_load_state_dict_holds_when_training_session_not_initialized():
|
||||
trainer = _create_trainer()
|
||||
state_dict = {
|
||||
"model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}},
|
||||
"optimizer": {
|
||||
"a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
|
||||
"shared_optimizer_state": {"step": np.arange(5)},
|
||||
},
|
||||
}
|
||||
assert not trainer._load_state_dict
|
||||
state_dict = trainer.load_state_dict(state_dict)
|
||||
assert trainer._load_state_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state_dict, input_state_dict, error_key",
|
||||
[
|
||||
(
|
||||
{"model": {}, "optimizer": {}},
|
||||
{"model": {}, "optimizer": {}, "trainer_options": {"optimizer_name": "LambOptimizer"}},
|
||||
"train_step_info",
|
||||
),
|
||||
(
|
||||
{"optimizer": {}, "train_step_info": {"optimization_step": 0, "step": 0}},
|
||||
{
|
||||
"optimizer": {},
|
||||
"trainer_options": {"optimizer_name": "LambOptimizer"},
|
||||
"train_step_info": {"optimization_step": 0, "step": 0},
|
||||
},
|
||||
"model",
|
||||
),
|
||||
(
|
||||
{"model": {}, "train_step_info": {"optimization_step": 0, "step": 0}},
|
||||
{
|
||||
"model": {},
|
||||
"trainer_options": {"optimizer_name": "LambOptimizer"},
|
||||
"train_step_info": {"optimization_step": 0, "step": 0},
|
||||
},
|
||||
"optimizer",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, input_state_dict, error_key):
|
||||
trainer = _create_trainer()
|
||||
trainer._training_session = _training_session_mock({}, {}, {})
|
||||
trainer.state_dict = Mock(return_value=state_dict)
|
||||
trainer._update_onnx_model_initializers = Mock()
|
||||
trainer._init_session = Mock()
|
||||
with patch("onnx.ModelProto") as onnx_model_mock:
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
trainer._onnx_model.graph.initializer = []
|
||||
with pytest.warns(UserWarning) as user_warning:
|
||||
trainer.load_state_dict(input_state_dict)
|
||||
|
||||
assert user_warning[0].message.args[0] == f"Missing key: {error_key} in state_dict"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("state_dict, input_state_dict, error_keys", _get_load_state_dict_strict_error_arguments())
|
||||
def test_load_state_dict_errors_when_state_dict_mismatch(state_dict, input_state_dict, error_keys):
|
||||
trainer = _create_trainer()
|
||||
trainer._training_session = _training_session_mock({}, {}, {})
|
||||
trainer.state_dict = Mock(return_value=state_dict)
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
trainer.load_state_dict(input_state_dict)
|
||||
|
||||
assert any(key in str(runtime_error.value) for key in error_keys)
|
||||
|
||||
|
||||
@patch("onnx.ModelProto")
|
||||
def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_mock):
|
||||
trainer = _create_trainer()
|
||||
training_session_state_dict = {
|
||||
"model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}},
|
||||
"optimizer": {
|
||||
"a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)},
|
||||
"shared_optimizer_state": {"step": np.arange(1)},
|
||||
},
|
||||
}
|
||||
|
||||
input_state_dict = {
|
||||
"model": {"full_precision": {"a": np.array([1, 2]), "b": np.array([3, 4])}},
|
||||
"optimizer": {
|
||||
"a": {"Moment_1": np.array([5, 6]), "Moment_2": np.array([7, 8])},
|
||||
"shared_optimizer_state": {"step": np.array([9])},
|
||||
},
|
||||
"trainer_options": {"optimizer_name": "LambOptimizer"},
|
||||
}
|
||||
trainer._training_session = _training_session_mock({}, {}, {})
|
||||
trainer.state_dict = Mock(return_value=training_session_state_dict)
|
||||
trainer._onnx_model = onnx_model_mock()
|
||||
trainer._onnx_model.graph.initializer = [
|
||||
onnx.numpy_helper.from_array(np.arange(20, dtype=np.float32), "a"),
|
||||
onnx.numpy_helper.from_array(np.arange(25, dtype=np.float32), "b"),
|
||||
]
|
||||
trainer._update_onnx_model_initializers = Mock()
|
||||
trainer._init_session = Mock()
|
||||
|
||||
trainer.load_state_dict(input_state_dict)
|
||||
|
||||
loaded_initializers, _ = trainer._update_onnx_model_initializers.call_args
|
||||
state_dict_to_load, _ = trainer._init_session.call_args
|
||||
|
||||
assert "a" in loaded_initializers[0]
|
||||
assert (loaded_initializers[0]["a"] == np.array([1, 2])).all()
|
||||
assert "b" in loaded_initializers[0]
|
||||
assert (loaded_initializers[0]["b"] == np.array([3, 4])).all()
|
||||
|
||||
assert (state_dict_to_load[0]["a"]["Moment_1"] == np.array([5, 6])).all()
|
||||
assert (state_dict_to_load[0]["a"]["Moment_2"] == np.array([7, 8])).all()
|
||||
assert (state_dict_to_load[0]["shared_optimizer_state"]["step"] == np.array([9])).all()
|
||||
|
||||
|
||||
@patch("onnxruntime.training._checkpoint_storage.save")
|
||||
def test_save_checkpoint_calls_checkpoint_storage_save(save_mock):
|
||||
trainer = _create_trainer()
|
||||
state_dict = {"model": {}, "optimizer": {}}
|
||||
trainer.state_dict = Mock(return_value=state_dict)
|
||||
|
||||
trainer.save_checkpoint("abc")
|
||||
|
||||
save_args, _ = save_mock.call_args
|
||||
assert "model" in save_args[0]
|
||||
assert not bool(save_args[0]["model"])
|
||||
assert "optimizer" in save_args[0]
|
||||
assert not bool(save_args[0]["optimizer"])
|
||||
assert save_args[1] == "abc"
|
||||
|
||||
|
||||
@patch("onnxruntime.training._checkpoint_storage.save")
|
||||
def test_save_checkpoint_exclude_optimizer_states(save_mock):
|
||||
trainer = _create_trainer()
|
||||
state_dict = {"model": {}, "optimizer": {}}
|
||||
trainer.state_dict = Mock(return_value=state_dict)
|
||||
|
||||
trainer.save_checkpoint("abc", include_optimizer_states=False)
|
||||
|
||||
save_args, _ = save_mock.call_args
|
||||
assert "model" in save_args[0]
|
||||
assert not bool(save_args[0]["model"])
|
||||
assert "optimizer" not in save_args[0]
|
||||
assert save_args[1] == "abc"
|
||||
|
||||
|
||||
@patch("onnxruntime.training._checkpoint_storage.save")
|
||||
def test_save_checkpoint_user_dict(save_mock):
|
||||
trainer = _create_trainer()
|
||||
state_dict = {"model": {}, "optimizer": {}}
|
||||
trainer.state_dict = Mock(return_value=state_dict)
|
||||
|
||||
trainer.save_checkpoint("abc", user_dict={"abc": np.arange(4)})
|
||||
|
||||
save_args, _ = save_mock.call_args
|
||||
assert "user_dict" in save_args[0]
|
||||
assert save_args[0]["user_dict"] == _checkpoint_storage.to_serialized_hex({"abc": np.arange(4)})
|
||||
|
||||
|
||||
@patch("onnxruntime.training._checkpoint_storage.load")
|
||||
@patch("onnxruntime.training.checkpoint.aggregate_checkpoints")
|
||||
def test_load_checkpoint(aggregate_checkpoints_mock, load_mock):
|
||||
trainer = _create_trainer()
|
||||
trainer_options = {
|
||||
"mixed_precision": np.bool_(False),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(1),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(1),
|
||||
"zero_stage": np.int64(0),
|
||||
}
|
||||
state_dict = {
|
||||
"model": {},
|
||||
"optimizer": {},
|
||||
"trainer_options": {
|
||||
"mixed_precision": np.bool_(False),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(1),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(1),
|
||||
"zero_stage": np.int64(0),
|
||||
},
|
||||
}
|
||||
trainer.load_state_dict = Mock()
|
||||
|
||||
load_mock.side_effect = [trainer_options, state_dict]
|
||||
trainer.load_checkpoint("abc")
|
||||
|
||||
args_list = load_mock.call_args_list
|
||||
load_args, load_kwargs = args_list[0]
|
||||
assert load_args[0] == "abc"
|
||||
assert load_kwargs["key"] == "trainer_options"
|
||||
load_args, load_kwargs = args_list[1]
|
||||
assert load_args[0] == "abc"
|
||||
assert "key" not in load_kwargs
|
||||
assert not aggregate_checkpoints_mock.called
|
||||
|
||||
|
||||
@patch("onnxruntime.training._checkpoint_storage.load")
|
||||
@patch("onnxruntime.training.checkpoint.aggregate_checkpoints")
|
||||
@pytest.mark.parametrize(
|
||||
"trainer_options",
|
||||
[
|
||||
{
|
||||
"mixed_precision": np.bool_(False),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(4),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(4),
|
||||
"zero_stage": np.int64(1),
|
||||
},
|
||||
{
|
||||
"mixed_precision": np.bool_(True),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(1),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(1),
|
||||
"zero_stage": np.int64(1),
|
||||
},
|
||||
{
|
||||
"mixed_precision": np.bool_(True),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(1),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(1),
|
||||
"zero_stage": np.int64(1),
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_load_checkpoint_aggregation_required_zero_enabled(aggregate_checkpoints_mock, load_mock, trainer_options):
|
||||
trainer = _create_trainer()
|
||||
trainer.load_state_dict = Mock()
|
||||
|
||||
load_mock.side_effect = [trainer_options]
|
||||
trainer.load_checkpoint("abc")
|
||||
|
||||
args_list = load_mock.call_args_list
|
||||
load_args, load_kwargs = args_list[0]
|
||||
assert load_args[0] == "abc"
|
||||
assert load_kwargs["key"] == "trainer_options"
|
||||
assert aggregate_checkpoints_mock.called
|
||||
call_args, _ = aggregate_checkpoints_mock.call_args
|
||||
assert call_args[0] == tuple(["abc"])
|
||||
|
||||
|
||||
@patch("onnxruntime.training._checkpoint_storage.load")
|
||||
@patch("onnxruntime.training.checkpoint.aggregate_checkpoints")
|
||||
def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock):
|
||||
trainer = _create_trainer()
|
||||
trainer_options = {
|
||||
"mixed_precision": np.bool_(False),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(1),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(1),
|
||||
"zero_stage": np.int64(0),
|
||||
}
|
||||
state_dict = {
|
||||
"model": {},
|
||||
"optimizer": {},
|
||||
"trainer_options": {
|
||||
"mixed_precision": np.bool_(False),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(1),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(1),
|
||||
"zero_stage": np.int64(0),
|
||||
},
|
||||
"user_dict": _checkpoint_storage.to_serialized_hex({"array": torch.tensor(np.arange(5))}),
|
||||
}
|
||||
trainer.load_state_dict = Mock()
|
||||
|
||||
load_mock.side_effect = [trainer_options, state_dict]
|
||||
user_dict = trainer.load_checkpoint("abc")
|
||||
|
||||
assert torch.all(torch.eq(user_dict["array"], torch.tensor(np.arange(5))))
|
||||
|
||||
|
||||
@patch("onnxruntime.training._checkpoint_storage.load")
|
||||
def test_checkpoint_aggregation(load_mock):
|
||||
trainer_options1 = {
|
||||
"mixed_precision": np.bool_(False),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(2),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(2),
|
||||
"zero_stage": np.int64(1),
|
||||
"optimizer_name": b"Adam",
|
||||
}
|
||||
trainer_options2 = {
|
||||
"mixed_precision": np.bool_(False),
|
||||
"world_rank": np.int64(1),
|
||||
"world_size": np.int64(2),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(2),
|
||||
"zero_stage": np.int64(1),
|
||||
"optimizer_name": b"Adam",
|
||||
}
|
||||
|
||||
state_dict1 = {
|
||||
"model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}},
|
||||
"optimizer": {
|
||||
"optimizer_sharded": {
|
||||
"Moment_1": np.array([9, 8, 7]),
|
||||
"Moment_2": np.array([99, 88, 77]),
|
||||
"Step": np.array([5]),
|
||||
},
|
||||
"non_sharded": {
|
||||
"Moment_1": np.array([666, 555, 444]),
|
||||
"Moment_2": np.array([6666, 5555, 4444]),
|
||||
"Step": np.array([55]),
|
||||
},
|
||||
},
|
||||
"trainer_options": {
|
||||
"mixed_precision": np.bool_(False),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(1),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(1),
|
||||
"zero_stage": np.int64(0),
|
||||
"optimizer_name": b"Adam",
|
||||
},
|
||||
"partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}},
|
||||
}
|
||||
|
||||
state_dict2 = {
|
||||
"model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}},
|
||||
"optimizer": {
|
||||
"optimizer_sharded": {
|
||||
"Moment_1": np.array([6, 5, 4]),
|
||||
"Moment_2": np.array([66, 55, 44]),
|
||||
"Step": np.array([5]),
|
||||
},
|
||||
"non_sharded": {
|
||||
"Moment_1": np.array([666, 555, 444]),
|
||||
"Moment_2": np.array([6666, 5555, 4444]),
|
||||
"Step": np.array([55]),
|
||||
},
|
||||
},
|
||||
"trainer_options": {
|
||||
"mixed_precision": np.bool_(False),
|
||||
"world_rank": np.int64(1),
|
||||
"world_size": np.int64(1),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(1),
|
||||
"zero_stage": np.int64(0),
|
||||
"optimizer_name": b"Adam",
|
||||
},
|
||||
"partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}},
|
||||
}
|
||||
|
||||
load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2]
|
||||
state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False)
|
||||
|
||||
assert (state_dict["model"]["full_precision"]["optimizer_sharded"] == np.array([1, 2, 3])).all()
|
||||
assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all()
|
||||
assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all()
|
||||
assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all()
|
||||
assert (state_dict["optimizer"]["optimizer_sharded"]["Step"] == np.array([5])).all()
|
||||
assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all()
|
||||
assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all()
|
||||
assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all()
|
||||
|
||||
assert state_dict["trainer_options"]["mixed_precision"] is False
|
||||
assert state_dict["trainer_options"]["world_rank"] == 0
|
||||
assert state_dict["trainer_options"]["world_size"] == 1
|
||||
assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1
|
||||
assert state_dict["trainer_options"]["data_parallel_size"] == 1
|
||||
assert state_dict["trainer_options"]["zero_stage"] == 0
|
||||
assert state_dict["trainer_options"]["optimizer_name"] == b"Adam"
|
||||
|
||||
|
||||
@patch("onnxruntime.training._checkpoint_storage.load")
|
||||
def test_checkpoint_aggregation_mixed_precision(load_mock):
|
||||
trainer_options1 = {
|
||||
"mixed_precision": np.bool_(True),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(2),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(2),
|
||||
"zero_stage": np.int64(1),
|
||||
"optimizer_name": b"Adam",
|
||||
}
|
||||
trainer_options2 = {
|
||||
"mixed_precision": np.bool_(True),
|
||||
"world_rank": np.int64(1),
|
||||
"world_size": np.int64(2),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(2),
|
||||
"zero_stage": np.int64(1),
|
||||
"optimizer_name": b"Adam",
|
||||
}
|
||||
|
||||
state_dict1 = {
|
||||
"model": {"full_precision": {"sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}},
|
||||
"optimizer": {
|
||||
"sharded": {"Moment_1": np.array([9, 8, 7]), "Moment_2": np.array([99, 88, 77]), "Step": np.array([5])},
|
||||
"non_sharded": {
|
||||
"Moment_1": np.array([666, 555, 444]),
|
||||
"Moment_2": np.array([6666, 5555, 4444]),
|
||||
"Step": np.array([55]),
|
||||
},
|
||||
},
|
||||
"trainer_options": {
|
||||
"mixed_precision": np.bool_(True),
|
||||
"world_rank": np.int64(0),
|
||||
"world_size": np.int64(1),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(1),
|
||||
"zero_stage": np.int64(0),
|
||||
"optimizer_name": b"Adam",
|
||||
},
|
||||
"partition_info": {"sharded": {"original_dim": np.array([2, 3])}},
|
||||
}
|
||||
|
||||
state_dict2 = {
|
||||
"model": {"full_precision": {"sharded": np.array([4, 5, 6]), "non_sharded": np.array([11, 22, 33])}},
|
||||
"optimizer": {
|
||||
"sharded": {"Moment_1": np.array([6, 5, 4]), "Moment_2": np.array([66, 55, 44]), "Step": np.array([5])},
|
||||
"non_sharded": {
|
||||
"Moment_1": np.array([666, 555, 444]),
|
||||
"Moment_2": np.array([6666, 5555, 4444]),
|
||||
"Step": np.array([55]),
|
||||
},
|
||||
},
|
||||
"trainer_options": {
|
||||
"mixed_precision": np.bool_(True),
|
||||
"world_rank": np.int64(1),
|
||||
"world_size": np.int64(1),
|
||||
"horizontal_parallel_size": np.int64(1),
|
||||
"data_parallel_size": np.int64(1),
|
||||
"zero_stage": np.int64(0),
|
||||
"optimizer_name": b"Adam",
|
||||
},
|
||||
"partition_info": {"sharded": {"original_dim": np.array([2, 3])}},
|
||||
}
|
||||
|
||||
load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2]
|
||||
state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False)
|
||||
|
||||
assert (state_dict["model"]["full_precision"]["sharded"] == np.array([[1, 2, 3], [4, 5, 6]])).all()
|
||||
assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all()
|
||||
assert (state_dict["optimizer"]["sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all()
|
||||
assert (state_dict["optimizer"]["sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all()
|
||||
assert (state_dict["optimizer"]["sharded"]["Step"] == np.array([5])).all()
|
||||
assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all()
|
||||
assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all()
|
||||
assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all()
|
||||
|
||||
assert state_dict["trainer_options"]["mixed_precision"] is True
|
||||
assert state_dict["trainer_options"]["world_rank"] == 0
|
||||
assert state_dict["trainer_options"]["world_size"] == 1
|
||||
assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1
|
||||
assert state_dict["trainer_options"]["data_parallel_size"] == 1
|
||||
assert state_dict["trainer_options"]["zero_stage"] == 0
|
||||
assert state_dict["trainer_options"]["optimizer_name"] == b"Adam"
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,480 +0,0 @@
|
|||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.testing import assert_allclose
|
||||
from orttraining_test_data_loader import BatchArgsOption, ids_tensor
|
||||
from orttraining_test_utils import get_lr, run_test
|
||||
from transformers import BertConfig, BertForPreTraining
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401
|
||||
|
||||
|
||||
class BertModelTest(unittest.TestCase):
|
||||
class BertModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
device="cpu",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.device = device
|
||||
|
||||
# 1. superset of bert input/output descs
|
||||
# see BertPreTrainedModel doc
|
||||
self.input_ids_desc = IODescription(
|
||||
"input_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size
|
||||
)
|
||||
self.attention_mask_desc = IODescription(
|
||||
"attention_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2
|
||||
)
|
||||
self.token_type_ids_desc = IODescription(
|
||||
"token_type_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2
|
||||
)
|
||||
self.position_ids_desc = IODescription(
|
||||
"position_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.max_position_embeddings
|
||||
)
|
||||
self.head_mask_desc = IODescription(
|
||||
"head_mask", [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2
|
||||
)
|
||||
self.inputs_embeds_desc = IODescription(
|
||||
"inputs_embeds", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32
|
||||
)
|
||||
|
||||
self.encoder_hidden_states_desc = IODescription(
|
||||
"encoder_hidden_states", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32
|
||||
)
|
||||
self.encoder_attention_mask_desc = IODescription(
|
||||
"encoder_attention_mask", ["batch", "max_seq_len_in_batch"], torch.float32
|
||||
)
|
||||
|
||||
# see BertForPreTraining doc
|
||||
self.masked_lm_labels_desc = IODescription(
|
||||
"masked_lm_labels", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size
|
||||
)
|
||||
self.next_sentence_label_desc = IODescription(
|
||||
"next_sentence_label",
|
||||
[
|
||||
"batch",
|
||||
],
|
||||
torch.int64,
|
||||
num_classes=2,
|
||||
)
|
||||
|
||||
# outputs
|
||||
self.loss_desc = IODescription(
|
||||
"loss",
|
||||
[
|
||||
1,
|
||||
],
|
||||
torch.float32,
|
||||
)
|
||||
self.prediction_scores_desc = IODescription(
|
||||
"prediction_scores", ["batch", "max_seq_len_in_batch", self.vocab_size], torch.float32
|
||||
)
|
||||
|
||||
self.seq_relationship_scores_desc = IODescription(
|
||||
"seq_relationship_scores", ["batch", 2], torch.float32
|
||||
) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32)
|
||||
self.hidden_states_desc = IODescription(
|
||||
"hidden_states",
|
||||
[self.num_hidden_layers, "batch", "max_seq_len_in_batch", self.hidden_size],
|
||||
torch.float32,
|
||||
)
|
||||
self.attentions_desc = IODescription(
|
||||
"attentions",
|
||||
[
|
||||
self.num_hidden_layers,
|
||||
"batch",
|
||||
self.num_attention_heads,
|
||||
"max_seq_len_in_batch",
|
||||
"max_seq_len_in_batch",
|
||||
],
|
||||
torch.float32,
|
||||
)
|
||||
self.last_hidden_state_desc = IODescription(
|
||||
"last_hidden_state", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32
|
||||
)
|
||||
self.pooler_output_desc = IODescription("pooler_output", ["batch", self.hidden_size], torch.float32)
|
||||
|
||||
def BertForPreTraining_descs(self):
|
||||
return ModelDescription(
|
||||
[
|
||||
self.input_ids_desc,
|
||||
self.attention_mask_desc,
|
||||
self.token_type_ids_desc,
|
||||
self.masked_lm_labels_desc,
|
||||
self.next_sentence_label_desc,
|
||||
],
|
||||
# returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided
|
||||
# hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states
|
||||
[
|
||||
self.loss_desc,
|
||||
self.prediction_scores_desc,
|
||||
self.seq_relationship_scores_desc,
|
||||
# hidden_states_desc, attentions_desc
|
||||
],
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device)
|
||||
|
||||
config = BertConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_bert_for_pretraining(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
option_fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
option_gradient_accumulation_steps,
|
||||
option_split_batch,
|
||||
option_use_internal_get_lr_this_step=[True], # noqa: B006
|
||||
option_use_internal_loss_scaler=[True], # noqa: B006
|
||||
):
|
||||
seed = 42
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
onnxruntime.set_seed(seed)
|
||||
|
||||
model = BertForPreTraining(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores, seq_relationship_score = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
masked_lm_labels=token_labels,
|
||||
next_sentence_label=sequence_labels,
|
||||
)
|
||||
model_desc = ModelDescription(
|
||||
[
|
||||
self.input_ids_desc,
|
||||
self.attention_mask_desc,
|
||||
self.token_type_ids_desc,
|
||||
self.masked_lm_labels_desc,
|
||||
self.next_sentence_label_desc,
|
||||
],
|
||||
[self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc],
|
||||
)
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
MyArgs = namedtuple(
|
||||
"MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len"
|
||||
)
|
||||
|
||||
dataset_len = 100
|
||||
epochs = 8
|
||||
max_steps = epochs * dataset_len
|
||||
args = MyArgs(
|
||||
local_rank=0,
|
||||
world_size=1,
|
||||
max_steps=max_steps,
|
||||
learning_rate=0.00001,
|
||||
warmup_proportion=0.01,
|
||||
batch_size=13,
|
||||
seq_len=7,
|
||||
)
|
||||
|
||||
def get_lr_this_step(global_step):
|
||||
return get_lr(args, global_step)
|
||||
|
||||
loss_scaler = LossScaler("loss_scale_input_name", True, up_scale_window=2000)
|
||||
|
||||
for fp16 in option_fp16:
|
||||
for allreduce_post_accumulation in option_allreduce_post_accumulation:
|
||||
for gradient_accumulation_steps in option_gradient_accumulation_steps:
|
||||
for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step:
|
||||
for use_internal_loss_scaler in option_use_internal_loss_scaler:
|
||||
for split_batch in option_split_batch:
|
||||
print("gradient_accumulation_steps:", gradient_accumulation_steps)
|
||||
print("split_batch:", split_batch)
|
||||
|
||||
seed = 42
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
onnxruntime.set_seed(seed)
|
||||
|
||||
(
|
||||
old_api_loss_ort,
|
||||
old_api_prediction_scores_ort,
|
||||
old_api_seq_relationship_score_ort,
|
||||
) = run_test(
|
||||
model,
|
||||
model_desc,
|
||||
self.device,
|
||||
args,
|
||||
gradient_accumulation_steps,
|
||||
fp16,
|
||||
allreduce_post_accumulation,
|
||||
get_lr_this_step,
|
||||
use_internal_get_lr_this_step,
|
||||
loss_scaler,
|
||||
use_internal_loss_scaler,
|
||||
split_batch,
|
||||
dataset_len,
|
||||
epochs,
|
||||
use_new_api=False,
|
||||
)
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
onnxruntime.set_seed(seed)
|
||||
if use_internal_get_lr_this_step and use_internal_loss_scaler:
|
||||
(
|
||||
new_api_loss_ort,
|
||||
new_api_prediction_scores_ort,
|
||||
new_api_seq_relationship_score_ort,
|
||||
) = run_test(
|
||||
model,
|
||||
model_desc,
|
||||
self.device,
|
||||
args,
|
||||
gradient_accumulation_steps,
|
||||
fp16,
|
||||
allreduce_post_accumulation,
|
||||
get_lr_this_step,
|
||||
use_internal_get_lr_this_step,
|
||||
loss_scaler,
|
||||
use_internal_loss_scaler,
|
||||
split_batch,
|
||||
dataset_len,
|
||||
epochs,
|
||||
use_new_api=True,
|
||||
)
|
||||
|
||||
assert_allclose(old_api_loss_ort, new_api_loss_ort)
|
||||
assert_allclose(old_api_prediction_scores_ort, new_api_prediction_scores_ort)
|
||||
assert_allclose(
|
||||
old_api_seq_relationship_score_ort, new_api_seq_relationship_score_ort
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertModelTest.BertModelTester(self)
|
||||
|
||||
def test_for_pretraining_mixed_precision(self):
|
||||
# It would be better to test both with/without mixed precision and allreduce_post_accumulation.
|
||||
# However, stress test of all the 4 cases is not stable at least on the test machine.
|
||||
# There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases.
|
||||
option_fp16 = [True]
|
||||
option_allreduce_post_accumulation = [True]
|
||||
option_gradient_accumulation_steps = [1]
|
||||
option_split_batch = [BatchArgsOption.ListAndDict]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(
|
||||
*config_and_inputs,
|
||||
option_fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
option_gradient_accumulation_steps,
|
||||
option_split_batch,
|
||||
)
|
||||
|
||||
def test_for_pretraining_mixed_precision_with_gradient_accumulation(self):
|
||||
# It would be better to test both with/without mixed precision and allreduce_post_accumulation.
|
||||
# However, stress test of all the 4 cases is not stable at least on the test machine.
|
||||
# There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases.
|
||||
option_fp16 = [True]
|
||||
option_allreduce_post_accumulation = [True]
|
||||
option_gradient_accumulation_steps = [8]
|
||||
option_split_batch = [BatchArgsOption.ListAndDict]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(
|
||||
*config_and_inputs,
|
||||
option_fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
option_gradient_accumulation_steps,
|
||||
option_split_batch,
|
||||
)
|
||||
|
||||
def test_for_pretraining_full_precision_all(self):
|
||||
# This test is not stable because it create and run ORTSession multiple times.
|
||||
# It occasionally gets seg fault at ~MemoryPattern()
|
||||
# when releasing patterns_. In order not to block PR merging CI test,
|
||||
# this test is broke into following individual tests.
|
||||
option_fp16 = [False]
|
||||
option_allreduce_post_accumulation = [True]
|
||||
option_gradient_accumulation_steps = [1, 8]
|
||||
option_split_batch = [BatchArgsOption.List, BatchArgsOption.Dict, BatchArgsOption.ListAndDict]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(
|
||||
*config_and_inputs,
|
||||
option_fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
option_gradient_accumulation_steps,
|
||||
option_split_batch,
|
||||
)
|
||||
|
||||
def test_for_pretraining_full_precision_list_input(self):
|
||||
option_fp16 = [False]
|
||||
option_allreduce_post_accumulation = [True]
|
||||
option_gradient_accumulation_steps = [1]
|
||||
option_split_batch = [BatchArgsOption.List]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(
|
||||
*config_and_inputs,
|
||||
option_fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
option_gradient_accumulation_steps,
|
||||
option_split_batch,
|
||||
)
|
||||
|
||||
def test_for_pretraining_full_precision_dict_input(self):
|
||||
option_fp16 = [False]
|
||||
option_allreduce_post_accumulation = [True]
|
||||
option_gradient_accumulation_steps = [1]
|
||||
option_split_batch = [BatchArgsOption.Dict]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(
|
||||
*config_and_inputs,
|
||||
option_fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
option_gradient_accumulation_steps,
|
||||
option_split_batch,
|
||||
)
|
||||
|
||||
def test_for_pretraining_full_precision_list_and_dict_input(self):
|
||||
option_fp16 = [False]
|
||||
option_allreduce_post_accumulation = [True]
|
||||
option_gradient_accumulation_steps = [1]
|
||||
option_split_batch = [BatchArgsOption.ListAndDict]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(
|
||||
*config_and_inputs,
|
||||
option_fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
option_gradient_accumulation_steps,
|
||||
option_split_batch,
|
||||
)
|
||||
|
||||
def test_for_pretraining_full_precision_grad_accumulation_list_input(self):
|
||||
option_fp16 = [False]
|
||||
option_allreduce_post_accumulation = [True]
|
||||
option_gradient_accumulation_steps = [8]
|
||||
option_split_batch = [BatchArgsOption.List]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(
|
||||
*config_and_inputs,
|
||||
option_fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
option_gradient_accumulation_steps,
|
||||
option_split_batch,
|
||||
)
|
||||
|
||||
def test_for_pretraining_full_precision_grad_accumulation_dict_input(self):
|
||||
option_fp16 = [False]
|
||||
option_allreduce_post_accumulation = [True]
|
||||
option_gradient_accumulation_steps = [8]
|
||||
option_split_batch = [BatchArgsOption.Dict]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(
|
||||
*config_and_inputs,
|
||||
option_fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
option_gradient_accumulation_steps,
|
||||
option_split_batch,
|
||||
)
|
||||
|
||||
def test_for_pretraining_full_precision_grad_accumulation_list_and_dict_input(self):
|
||||
option_fp16 = [False]
|
||||
option_allreduce_post_accumulation = [True]
|
||||
option_gradient_accumulation_steps = [8]
|
||||
option_split_batch = [BatchArgsOption.ListAndDict]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(
|
||||
*config_and_inputs,
|
||||
option_fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
option_gradient_accumulation_steps,
|
||||
option_split_batch,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,246 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from orttraining_test_data_loader import BatchArgsOption, create_ort_test_dataloader, split_batch
|
||||
|
||||
from onnxruntime.capi.ort_trainer import IODescription, ORTTrainer
|
||||
from onnxruntime.training import amp, optim, orttrainer
|
||||
from onnxruntime.training.optim import _LRScheduler
|
||||
|
||||
|
||||
def warmup_cosine(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x / warmup
|
||||
return 0.5 * (1.0 + torch.cos(math.pi * x))
|
||||
|
||||
|
||||
def warmup_constant(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x / warmup
|
||||
return 1.0
|
||||
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x / warmup
|
||||
return max((x - 1.0) / (warmup - 1.0), 0.0)
|
||||
|
||||
|
||||
def warmup_poly(x, warmup=0.002, degree=0.5):
|
||||
if x < warmup:
|
||||
return x / warmup
|
||||
return (1.0 - x) ** degree
|
||||
|
||||
|
||||
SCHEDULES = {
|
||||
"warmup_cosine": warmup_cosine,
|
||||
"warmup_constant": warmup_constant,
|
||||
"warmup_linear": warmup_linear,
|
||||
"warmup_poly": warmup_poly,
|
||||
}
|
||||
|
||||
|
||||
def get_lr(args, training_steps, schedule="warmup_poly"):
|
||||
if args.max_steps == -1:
|
||||
return args.learning_rate
|
||||
|
||||
schedule_fct = SCHEDULES[schedule]
|
||||
return args.learning_rate * schedule_fct(training_steps / args.max_steps, args.warmup_proportion)
|
||||
|
||||
|
||||
def map_optimizer_attributes(name):
|
||||
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
|
||||
no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys)
|
||||
if no_decay:
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
|
||||
else:
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
|
||||
|
||||
|
||||
class WrapLRScheduler(_LRScheduler):
|
||||
def __init__(self, get_lr_this_step):
|
||||
super().__init__()
|
||||
self.get_lr_this_step = get_lr_this_step
|
||||
|
||||
def get_lr(self, train_step_info):
|
||||
return [self.get_lr_this_step(train_step_info.optimization_step)]
|
||||
|
||||
|
||||
def run_test(
|
||||
model,
|
||||
model_desc,
|
||||
device,
|
||||
args,
|
||||
gradient_accumulation_steps,
|
||||
fp16,
|
||||
allreduce_post_accumulation,
|
||||
get_lr_this_step,
|
||||
use_internal_get_lr_this_step,
|
||||
loss_scaler,
|
||||
use_internal_loss_scaler,
|
||||
batch_args_option,
|
||||
dataset_len,
|
||||
epochs,
|
||||
use_new_api,
|
||||
):
|
||||
dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device)
|
||||
|
||||
if use_new_api:
|
||||
assert use_internal_loss_scaler, "new api should always use internal loss scaler"
|
||||
|
||||
new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step)
|
||||
|
||||
new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None
|
||||
options = orttrainer.ORTTrainerOptions(
|
||||
{
|
||||
"batch": {"gradient_accumulation_steps": gradient_accumulation_steps},
|
||||
"device": {"id": device},
|
||||
"mixed_precision": {"enabled": fp16, "loss_scaler": new_api_loss_scaler},
|
||||
"debug": {
|
||||
"deterministic_compute": True,
|
||||
},
|
||||
"utils": {"grad_norm_clip": True},
|
||||
"distributed": {"allreduce_post_accumulation": True},
|
||||
"lr_scheduler": new_api_lr_scheduler,
|
||||
}
|
||||
)
|
||||
|
||||
param_optimizer = list(model.named_parameters())
|
||||
params = [
|
||||
{
|
||||
"params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n],
|
||||
"alpha": 0.9,
|
||||
"beta": 0.999,
|
||||
"lambda": 0.0,
|
||||
"epsilon": 1e-6,
|
||||
},
|
||||
{
|
||||
"params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)],
|
||||
"alpha": 0.9,
|
||||
"beta": 0.999,
|
||||
"lambda": 0.0,
|
||||
"epsilon": 1e-6,
|
||||
},
|
||||
]
|
||||
|
||||
vocab_size = 99
|
||||
new_model_desc = {
|
||||
"inputs": [
|
||||
(
|
||||
"input_ids",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"attention_mask",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"token_type_ids",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"masked_lm_labels",
|
||||
["batch", "max_seq_len_in_batch"],
|
||||
),
|
||||
(
|
||||
"next_sentence_label",
|
||||
[
|
||||
"batch",
|
||||
],
|
||||
),
|
||||
],
|
||||
"outputs": [
|
||||
(
|
||||
"loss",
|
||||
[
|
||||
1,
|
||||
],
|
||||
True,
|
||||
),
|
||||
("prediction_scores", ["batch", "max_seq_len_in_batch", vocab_size]),
|
||||
("seq_relationship_scores", ["batch", 2]),
|
||||
],
|
||||
}
|
||||
|
||||
optim_config = optim.LambConfig(params=params, lr=2e-5)
|
||||
model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options)
|
||||
print("running with new frontend API")
|
||||
else:
|
||||
model = ORTTrainer(
|
||||
model,
|
||||
None,
|
||||
model_desc,
|
||||
"LambOptimizer",
|
||||
map_optimizer_attributes=map_optimizer_attributes,
|
||||
learning_rate_description=IODescription(
|
||||
"Learning_Rate",
|
||||
[
|
||||
1,
|
||||
],
|
||||
torch.float32,
|
||||
),
|
||||
device=device,
|
||||
_enable_internal_postprocess=True,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
# BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
|
||||
world_rank=args.local_rank,
|
||||
world_size=args.world_size,
|
||||
use_mixed_precision=fp16,
|
||||
allreduce_post_accumulation=allreduce_post_accumulation,
|
||||
get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None,
|
||||
loss_scaler=loss_scaler if use_internal_loss_scaler else None,
|
||||
_opset_version=14,
|
||||
_use_deterministic_compute=True,
|
||||
)
|
||||
print("running with old frontend API")
|
||||
|
||||
# training loop
|
||||
eval_batch = None
|
||||
if not use_new_api:
|
||||
model.train()
|
||||
for _epoch in range(epochs):
|
||||
for step, batch in enumerate(dataloader):
|
||||
if eval_batch is None:
|
||||
eval_batch = batch
|
||||
|
||||
if not use_internal_get_lr_this_step:
|
||||
lr = get_lr_this_step(step)
|
||||
learning_rate = torch.tensor([lr])
|
||||
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
loss_scale = torch.tensor([loss_scaler.loss_scale_])
|
||||
|
||||
if batch_args_option == BatchArgsOption.List:
|
||||
if not use_internal_get_lr_this_step:
|
||||
batch = [*batch, learning_rate] # noqa: PLW2901
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
batch = [*batch, loss_scale] # noqa: PLW2901
|
||||
outputs = model.train_step(*batch)
|
||||
elif batch_args_option == BatchArgsOption.Dict:
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
|
||||
if not use_internal_get_lr_this_step:
|
||||
kwargs["Learning_Rate"] = learning_rate
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
kwargs[model.loss_scale_input_name] = loss_scale
|
||||
outputs = model.train_step(*args, **kwargs)
|
||||
else:
|
||||
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
|
||||
if not use_internal_get_lr_this_step:
|
||||
kwargs["Learning_Rate"] = learning_rate
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
kwargs[model.loss_scale_input_name] = loss_scale
|
||||
outputs = model.train_step(*args, **kwargs)
|
||||
|
||||
# eval
|
||||
if batch_args_option == BatchArgsOption.List:
|
||||
outputs = model.eval_step(*batch)
|
||||
elif batch_args_option == BatchArgsOption.Dict:
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
|
||||
outputs = model.eval_step(*args, **kwargs)
|
||||
else:
|
||||
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
|
||||
outputs = model.eval_step(*args, **kwargs)
|
||||
|
||||
return (output.cpu().numpy() for output in outputs)
|
|
@ -1,357 +0,0 @@
|
|||
# adapted from Trainer.py of huggingface transformers
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
from tqdm import tqdm, trange
|
||||
from transformers.data.data_collator import DefaultDataCollator
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training import amp, optim, orttrainer
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
_has_tensorboard = True
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter # noqa: F401
|
||||
|
||||
_has_tensorboard = True
|
||||
except ImportError:
|
||||
_has_tensorboard = False
|
||||
|
||||
|
||||
def is_tensorboard_available():
|
||||
return _has_tensorboard
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
onnxruntime.set_seed(seed)
|
||||
|
||||
|
||||
class EvalPrediction(NamedTuple):
|
||||
predictions: np.ndarray
|
||||
label_ids: np.ndarray
|
||||
|
||||
|
||||
class PredictionOutput(NamedTuple):
|
||||
predictions: np.ndarray
|
||||
label_ids: Optional[np.ndarray]
|
||||
metrics: Optional[Dict[str, float]]
|
||||
|
||||
|
||||
class TrainOutput(NamedTuple):
|
||||
global_step: int
|
||||
training_loss: float
|
||||
|
||||
|
||||
def get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps, base_lr):
|
||||
def lr_lambda_linear(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
|
||||
|
||||
def lambda_lr_get_lr(current_global_step):
|
||||
# LambdaLR increment self.last_epoch at evert sept()
|
||||
return base_lr * lr_lambda_linear(current_global_step)
|
||||
|
||||
return lambda_lr_get_lr
|
||||
|
||||
|
||||
class ORTTransformerTrainer:
|
||||
""" """
|
||||
|
||||
model: PreTrainedModel
|
||||
args: TrainingArguments
|
||||
train_dataset: Dataset
|
||||
eval_dataset: Dataset
|
||||
compute_metrics: Callable[[EvalPrediction], Dict]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
model_desc: dict,
|
||||
args: TrainingArguments,
|
||||
train_dataset: Dataset,
|
||||
eval_dataset: Dataset,
|
||||
compute_metrics: Callable[[EvalPrediction], Dict],
|
||||
world_size: Optional[int] = 1,
|
||||
):
|
||||
""" """
|
||||
|
||||
self.model = model
|
||||
self.model_desc = model_desc
|
||||
self.args = args
|
||||
self.world_size = world_size
|
||||
self.data_collator = DefaultDataCollator()
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.compute_metrics = compute_metrics
|
||||
set_seed(self.args.seed)
|
||||
# Create output directory if needed
|
||||
if self.args.local_rank in [-1, 0]:
|
||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
train_sampler = (
|
||||
SequentialSampler(self.train_dataset)
|
||||
if self.args.local_rank == -1
|
||||
else DistributedSampler(self.train_dataset)
|
||||
)
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.args.train_batch_size,
|
||||
sampler=train_sampler,
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
)
|
||||
|
||||
def get_eval_dataloader(self) -> DataLoader:
|
||||
return DataLoader(
|
||||
self.eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
)
|
||||
|
||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||
# We use the same batch_size as for eval.
|
||||
return DataLoader(
|
||||
test_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
)
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
Main training entry point.
|
||||
"""
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
|
||||
if self.args.max_steps > 0:
|
||||
t_total = self.args.max_steps
|
||||
num_train_epochs = (
|
||||
self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
|
||||
)
|
||||
else:
|
||||
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
|
||||
lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps / float(t_total))
|
||||
|
||||
loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None
|
||||
device = self.args.device.type
|
||||
|
||||
device = f"{device}:{self.args.device.index}" if self.args.device.index else f"{device}:0"
|
||||
options = orttrainer.ORTTrainerOptions(
|
||||
{
|
||||
"batch": {"gradient_accumulation_steps": self.args.gradient_accumulation_steps},
|
||||
"device": {"id": device},
|
||||
"mixed_precision": {"enabled": self.args.fp16, "loss_scaler": loss_scaler},
|
||||
"debug": {
|
||||
"deterministic_compute": True,
|
||||
},
|
||||
"utils": {"grad_norm_clip": False},
|
||||
"distributed": {
|
||||
# we are running single node multi gpu test. thus world_rank = local_rank
|
||||
# and world_size = self.args.n_gpu
|
||||
"world_rank": max(0, self.args.local_rank),
|
||||
"world_size": int(self.world_size),
|
||||
"local_rank": max(0, self.args.local_rank),
|
||||
"allreduce_post_accumulation": True,
|
||||
},
|
||||
"lr_scheduler": lr_scheduler,
|
||||
}
|
||||
)
|
||||
|
||||
param_optimizer = list(self.model.named_parameters())
|
||||
params = [
|
||||
{
|
||||
"params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n],
|
||||
"weight_decay_mode": 1,
|
||||
},
|
||||
{
|
||||
"params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)],
|
||||
"weight_decay_mode": 1,
|
||||
},
|
||||
]
|
||||
|
||||
optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
|
||||
self.model = orttrainer.ORTTrainer(self.model, self.model_desc, optim_config, options=options)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataloader.dataset))
|
||||
logger.info(" Num Epochs = %d", num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
self.args.train_batch_size
|
||||
* self.args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
|
||||
tr_loss = 0.0
|
||||
logging_loss = 0.0
|
||||
train_iterator = trange(
|
||||
epochs_trained,
|
||||
int(num_train_epochs),
|
||||
desc="Epoch",
|
||||
disable=self.args.local_rank not in [-1, 0],
|
||||
)
|
||||
|
||||
for _epoch in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0])
|
||||
for step, inputs in enumerate(epoch_iterator):
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
tr_loss += self._training_step(self.model, inputs)
|
||||
|
||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||
len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)
|
||||
):
|
||||
global_step += 1
|
||||
|
||||
if self.args.local_rank in [-1, 0]:
|
||||
if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (
|
||||
global_step == 1 and self.args.logging_first_step
|
||||
):
|
||||
logs = {}
|
||||
if self.args.evaluate_during_training:
|
||||
results = self.evaluate()
|
||||
for key, value in results.items():
|
||||
eval_key = f"eval_{key}"
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps
|
||||
|
||||
logs["loss"] = loss_scalar
|
||||
logging_loss = tr_loss
|
||||
|
||||
epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))
|
||||
|
||||
if self.args.max_steps > 0 and global_step > self.args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if self.args.max_steps > 0 and global_step > self.args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
logger.info("\n\nTraining completed. \n\n")
|
||||
return TrainOutput(global_step, tr_loss / global_step)
|
||||
|
||||
def _training_step(self, model, inputs: Dict[str, torch.Tensor]) -> float:
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(self.args.device)
|
||||
|
||||
outputs = model.train_step(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
return loss.item()
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.model.save_as_onnx(os.path.join(output_dir, "transformer.onnx"))
|
||||
|
||||
def evaluate(self) -> Dict[str, float]:
|
||||
"""
|
||||
Run evaluation and return metrics.
|
||||
|
||||
Returns:
|
||||
A dict containing:
|
||||
- the eval loss
|
||||
- the potential metrics computed from the predictions
|
||||
"""
|
||||
eval_dataloader = self.get_eval_dataloader()
|
||||
|
||||
output = self._prediction_loop(eval_dataloader, description="Evaluation")
|
||||
return output.metrics
|
||||
|
||||
def predict(self, test_dataset: Dataset) -> PredictionOutput:
|
||||
"""
|
||||
Run prediction and return predictions and potential metrics.
|
||||
|
||||
Depending on the dataset and your use case, your test dataset may contain labels.
|
||||
In that case, this method will also return metrics, like in evaluate().
|
||||
"""
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
return self._prediction_loop(test_dataloader, description="Prediction")
|
||||
|
||||
def _prediction_loop(self, dataloader: DataLoader, description: str) -> PredictionOutput:
|
||||
"""
|
||||
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
|
||||
|
||||
Works both with or without labels.
|
||||
"""
|
||||
|
||||
logger.info("***** Running %s *****", description)
|
||||
logger.info(" Num examples = %d", len(dataloader.dataset))
|
||||
logger.info(" Batch size = %d", dataloader.batch_size)
|
||||
eval_losses: List[float] = []
|
||||
preds: np.ndarray = None
|
||||
label_ids: np.ndarray = None
|
||||
|
||||
for inputs in tqdm(dataloader, desc=description):
|
||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "masked_lm_labels"])
|
||||
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(self.args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model.eval_step(**inputs)
|
||||
|
||||
if has_labels:
|
||||
step_eval_loss, logits = outputs[:2]
|
||||
eval_losses += [step_eval_loss.mean().item()]
|
||||
else:
|
||||
logits = outputs[0]
|
||||
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
if inputs.get("labels") is not None:
|
||||
if label_ids is None:
|
||||
label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
||||
else:
|
||||
metrics = {}
|
||||
if len(eval_losses) > 0:
|
||||
metrics["loss"] = np.mean(eval_losses)
|
||||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
|
@ -1,269 +0,0 @@
|
|||
# adapted from run_multiple_choice.py of huggingface transformers
|
||||
# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/utils_multiple_choice.py
|
||||
|
||||
import csv
|
||||
import glob # noqa: F401
|
||||
import json # noqa: F401
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from filelock import FileLock
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available # noqa: F401
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputExample:
|
||||
"""
|
||||
A single training/test example for multiple choice
|
||||
|
||||
Args:
|
||||
example_id: Unique id for the example.
|
||||
question: string. The untokenized text of the second sequence (question).
|
||||
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
|
||||
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
|
||||
example_id: str
|
||||
question: str
|
||||
contexts: List[str]
|
||||
endings: List[str]
|
||||
label: Optional[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputFeatures:
|
||||
"""
|
||||
A single set of features of data.
|
||||
Property names are the same names as the corresponding inputs to a model.
|
||||
"""
|
||||
|
||||
example_id: str
|
||||
input_ids: List[List[int]]
|
||||
attention_mask: Optional[List[List[int]]]
|
||||
token_type_ids: Optional[List[List[int]]]
|
||||
label: Optional[int]
|
||||
|
||||
|
||||
class Split(Enum):
|
||||
train = "train"
|
||||
dev = "dev"
|
||||
test = "test"
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
"""Base class for data converters for multiple choice data sets."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the test set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self):
|
||||
"""Gets the list of labels for this data set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MultipleChoiceDataset(Dataset):
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
task: str,
|
||||
processor: DataProcessor,
|
||||
max_seq_length: Optional[int] = None,
|
||||
overwrite_cache=False,
|
||||
mode: Split = Split.train,
|
||||
):
|
||||
cached_features_file = os.path.join(
|
||||
data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
mode.value,
|
||||
tokenizer.__class__.__name__,
|
||||
str(max_seq_length),
|
||||
task,
|
||||
),
|
||||
)
|
||||
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||
label_list = processor.get_labels()
|
||||
if mode == Split.dev:
|
||||
examples = processor.get_dev_examples(data_dir)
|
||||
elif mode == Split.test:
|
||||
examples = processor.get_test_examples(data_dir)
|
||||
else:
|
||||
examples = processor.get_train_examples(data_dir)
|
||||
logger.info("Training examples: %s", len(examples))
|
||||
# TODO clean up all this to leverage built-in features of tokenizers
|
||||
self.features = convert_examples_to_features(
|
||||
examples,
|
||||
label_list,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
pad_on_left=bool(tokenizer.padding_side == "left"),
|
||||
pad_token=tokenizer.pad_token_id,
|
||||
pad_token_segment_id=tokenizer.pad_token_type_id,
|
||||
)
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(self.features, cached_features_file)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
|
||||
class SwagProcessor(DataProcessor):
|
||||
"""Processor for the SWAG data set."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info(f"LOOKING AT {data_dir} train")
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info(f"LOOKING AT {data_dir} dev")
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info(f"LOOKING AT {data_dir} dev")
|
||||
raise ValueError(
|
||||
"For swag testing, the input file does not contain a label column. It can not be tested in current code"
|
||||
"setting!"
|
||||
)
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_csv(self, input_file):
|
||||
with open(input_file, encoding="utf-8") as f:
|
||||
return list(csv.reader(f))
|
||||
|
||||
def _create_examples(self, lines: List[List[str]], type: str):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
if type == "train" and lines[0][-1] != "label":
|
||||
raise ValueError("For training, the input file must contain a label column.")
|
||||
|
||||
examples = [
|
||||
InputExample(
|
||||
example_id=line[2],
|
||||
question=line[5], # in the swag dataset, the
|
||||
# common beginning of each
|
||||
# choice is stored in "sent2".
|
||||
contexts=[line[4], line[4], line[4], line[4]],
|
||||
endings=[line[7], line[8], line[9], line[10]],
|
||||
label=line[11],
|
||||
)
|
||||
for line in lines[1:] # we skip the line with the column names
|
||||
]
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(
|
||||
examples: List[InputExample],
|
||||
label_list: List[str],
|
||||
max_length: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
pad_token_segment_id=0,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
mask_padding_with_zero=True,
|
||||
) -> List[InputFeatures]:
|
||||
"""
|
||||
Loads a data file into a list of `InputFeatures`
|
||||
"""
|
||||
|
||||
label_map = {label: i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
choices_inputs = []
|
||||
for _ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
|
||||
text_a = context
|
||||
if example.question.find("_") != -1:
|
||||
# this is for cloze question
|
||||
text_b = example.question.replace("_", ending)
|
||||
else:
|
||||
text_b = example.question + " " + ending
|
||||
|
||||
inputs = tokenizer.encode_plus(
|
||||
text_a,
|
||||
text_b,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
pad_to_max_length=True,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
|
||||
logger.info(
|
||||
"Attention! you are cropping tokens (swag task is ok). "
|
||||
"If you are training ARC and RACE and you are poping question + options,"
|
||||
"you need to try to use a bigger max seq length!"
|
||||
)
|
||||
|
||||
choices_inputs.append(inputs)
|
||||
|
||||
label = label_map[example.label]
|
||||
|
||||
input_ids = [x["input_ids"] for x in choices_inputs]
|
||||
attention_mask = (
|
||||
[x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None
|
||||
)
|
||||
token_type_ids = (
|
||||
[x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None
|
||||
)
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
example_id=example.example_id,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
label=label,
|
||||
)
|
||||
)
|
||||
|
||||
for f in features[:2]:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("feature: %s" % f)
|
||||
|
||||
return features
|
|
@ -1,200 +0,0 @@
|
|||
## This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
## with modification to do training using onnxruntime as backend on cuda device.
|
||||
## A private PyTorch build from https://aiinfra.visualstudio.com/Lotus/_git/pytorch (ORTTraining branch) is needed to run the demo.
|
||||
|
||||
## Model testing is not complete.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np # noqa: F401
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim # noqa: F401
|
||||
from mpi4py import MPI
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer
|
||||
|
||||
try: # noqa: SIM105
|
||||
from onnxruntime.capi._pybind_state import set_cuda_device_id
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
|
||||
def my_loss(x, target):
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), target)
|
||||
|
||||
|
||||
def train_with_trainer(args, trainer, device, train_loader, epoch):
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device) # noqa: PLW2901
|
||||
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
|
||||
|
||||
learning_rate = torch.tensor([args.lr])
|
||||
loss = trainer.train_step(data, target, learning_rate)
|
||||
|
||||
# Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss.
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print(
|
||||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch,
|
||||
batch_idx * len(data),
|
||||
len(train_loader.dataset),
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
loss[0],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# TODO: comple this once ORT training can do evaluation.
|
||||
def test_with_trainer(args, trainer, device, test_loader):
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device) # noqa: PLW2901
|
||||
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
|
||||
output = F.log_softmax(trainer.eval_step(data, fetches=["probability"]), dim=1)
|
||||
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
|
||||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print(
|
||||
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
|
||||
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def mnist_model_description():
|
||||
input_desc = IODescription("input1", ["batch", 784], torch.float32)
|
||||
label_desc = IODescription(
|
||||
"label",
|
||||
[
|
||||
"batch",
|
||||
],
|
||||
torch.int64,
|
||||
num_classes=10,
|
||||
)
|
||||
loss_desc = IODescription("loss", [], torch.float32)
|
||||
probability_desc = IODescription("probability", ["batch", 10], torch.float32)
|
||||
return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc])
|
||||
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)")
|
||||
parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)")
|
||||
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
|
||||
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
kwargs = {"num_workers": 0, "pin_memory": True}
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
"../data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
|
||||
),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
**kwargs,
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
"../data",
|
||||
train=False,
|
||||
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
|
||||
),
|
||||
batch_size=args.test_batch_size,
|
||||
shuffle=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
args.local_rank = (
|
||||
int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) if ("OMPI_COMM_WORLD_LOCAL_RANK" in os.environ) else 0
|
||||
)
|
||||
args.world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) if ("OMPI_COMM_WORLD_RANK" in os.environ) else 0
|
||||
args.world_size = comm.Get_size()
|
||||
if use_cuda:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
args.n_gpu = 1
|
||||
set_cuda_device_id(args.local_rank)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
input_size = 784
|
||||
hidden_size = 500
|
||||
num_classes = 10
|
||||
model = NeuralNet(input_size, hidden_size, num_classes)
|
||||
|
||||
model_desc = mnist_model_description()
|
||||
# use log_interval as gradient accumulate steps
|
||||
trainer = ORTTrainer(
|
||||
model,
|
||||
my_loss,
|
||||
model_desc,
|
||||
"SGDOptimizer",
|
||||
None,
|
||||
IODescription(
|
||||
"Learning_Rate",
|
||||
[
|
||||
1,
|
||||
],
|
||||
torch.float32,
|
||||
),
|
||||
device,
|
||||
1,
|
||||
args.world_rank,
|
||||
args.world_size,
|
||||
use_mixed_precision=False,
|
||||
allreduce_post_accumulation=True,
|
||||
)
|
||||
print("\nBuild ort model done.")
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train_with_trainer(args, trainer, device, train_loader, epoch)
|
||||
test_with_trainer(args, trainer, device, test_loader)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Двоичные данные
samples/python/training/orttrainer/mnist/mnist_original.onnx
Двоичные данные
samples/python/training/orttrainer/mnist/mnist_original.onnx
Двоичный файл не отображается.
|
@ -1,174 +0,0 @@
|
|||
# This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
# with modification to do training using onnxruntime as backend on cuda device.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training import ORTTrainer, ORTTrainerOptions, optim
|
||||
|
||||
|
||||
# Pytorch model
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1):
|
||||
out = self.fc1(input1)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
|
||||
# ONNX Runtime training
|
||||
def mnist_model_description():
|
||||
return {
|
||||
"inputs": [("input1", ["batch", 784]), ("label", ["batch"])],
|
||||
"outputs": [("loss", [], True), ("probability", ["batch", 10])],
|
||||
}
|
||||
|
||||
|
||||
def my_loss(x, target):
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), target)
|
||||
|
||||
|
||||
# Helpers
|
||||
def train(log_interval, trainer, device, train_loader, epoch, train_steps):
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if batch_idx == train_steps:
|
||||
break
|
||||
|
||||
# Fetch data
|
||||
data, target = data.to(device), target.to(device) # noqa: PLW2901
|
||||
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
|
||||
|
||||
# Train step
|
||||
loss, prob = trainer.train_step(data, target)
|
||||
|
||||
# Stats
|
||||
if batch_idx % log_interval == 0:
|
||||
print(
|
||||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test(trainer, device, test_loader):
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device) # noqa: PLW2901
|
||||
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
|
||||
|
||||
# Using fetches around without eval_step to not pass 'target' as input
|
||||
trainer._train_step_info.fetches = ["probability"]
|
||||
output = F.log_softmax(trainer.eval_step(data), dim=1)
|
||||
trainer._train_step_info.fetches = []
|
||||
|
||||
# Stats
|
||||
test_loss += F.nll_loss(output, target, reduction="sum").item()
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print(
|
||||
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
|
||||
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description="ONNX Runtime MNIST Example")
|
||||
parser.add_argument(
|
||||
"--train-steps",
|
||||
type=int,
|
||||
default=-1,
|
||||
metavar="N",
|
||||
help="number of steps to train. Set -1 to run through whole dataset (default: -1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)")
|
||||
parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)")
|
||||
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
|
||||
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status",
|
||||
)
|
||||
parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model state")
|
||||
|
||||
# Basic setup
|
||||
args = parser.parse_args()
|
||||
if not args.no_cuda and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
torch.manual_seed(args.seed)
|
||||
onnxruntime.set_seed(args.seed)
|
||||
|
||||
# Data loader
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
"./data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
|
||||
),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
if args.test_batch_size > 0:
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
"./data",
|
||||
train=False,
|
||||
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
|
||||
),
|
||||
batch_size=args.test_batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# Modeling
|
||||
model = NeuralNet(784, 500, 10)
|
||||
model_desc = mnist_model_description()
|
||||
optim_config = optim.SGDConfig(lr=args.lr)
|
||||
opts = {"device": {"id": device}}
|
||||
opts = ORTTrainerOptions(opts)
|
||||
|
||||
trainer = ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
|
||||
|
||||
# Train loop
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args.log_interval, trainer, device, train_loader, epoch, args.train_steps)
|
||||
if args.test_batch_size > 0:
|
||||
test(trainer, device, test_loader)
|
||||
|
||||
# Save model
|
||||
if args.save_path:
|
||||
torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,157 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
|
||||
# Pytorch model
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1):
|
||||
out = self.fc1(input1)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
|
||||
def my_loss(x, target, is_train=True):
|
||||
if is_train:
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), target)
|
||||
else:
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), target, reduction="sum")
|
||||
|
||||
|
||||
# Helpers
|
||||
def train(args, model, device, train_loader, optimizer, epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if batch_idx == args.train_steps:
|
||||
break
|
||||
data, target = data.to(device), target.to(device) # noqa: PLW2901
|
||||
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = my_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print(
|
||||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch,
|
||||
batch_idx * len(data),
|
||||
len(train_loader.dataset),
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
loss.item(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test(model, device, test_loader):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device) # noqa: PLW2901
|
||||
data = data.reshape(data.shape[0], -1) # noqa: PLW2901
|
||||
output = model(data)
|
||||
# Stats
|
||||
test_loss += my_loss(output, target, False).item()
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print(
|
||||
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
|
||||
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
|
||||
parser.add_argument(
|
||||
"--train-steps",
|
||||
type=int,
|
||||
default=-1,
|
||||
metavar="N",
|
||||
help="number of steps to train. Set -1 to run through whole dataset (default: -1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)")
|
||||
parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)")
|
||||
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
|
||||
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status",
|
||||
)
|
||||
parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model")
|
||||
|
||||
# Basic setup
|
||||
args = parser.parse_args()
|
||||
if not args.no_cuda and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# Data loader
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
"./data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
|
||||
),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
if args.test_batch_size > 0:
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
"./data",
|
||||
train=False,
|
||||
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
|
||||
),
|
||||
batch_size=args.test_batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# Modeling
|
||||
model = NeuralNet(784, 500, 10).to(device)
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.lr)
|
||||
|
||||
# Train loop
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args, model, device, train_loader, optimizer, epoch)
|
||||
if args.test_batch_size > 0:
|
||||
test(model, device, test_loader)
|
||||
|
||||
# Save model
|
||||
if args.save_path:
|
||||
torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,33 +0,0 @@
|
|||
# TransformerModel example
|
||||
|
||||
This example was adapted from Pytorch's [Sequence-to-Sequence Modeling with nn.Transformer and TorchText](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) tutorial
|
||||
|
||||
## Requirements
|
||||
|
||||
* PyTorch 1.6+
|
||||
* TorchText 0.6+
|
||||
* ONNX Runtime 1.5+
|
||||
|
||||
## Running PyTorch version
|
||||
|
||||
```bash
|
||||
python pt_train.py
|
||||
```
|
||||
|
||||
## Running ONNX Runtime version
|
||||
|
||||
```bash
|
||||
python ort_train.py
|
||||
```
|
||||
|
||||
## Optional arguments
|
||||
|
||||
| Argument | Description | Default |
|
||||
| :---------------- | :-----------------------------------------------------: | --------: |
|
||||
| --batch-size | input batch size for training | 20 |
|
||||
| --test-batch-size | input batch size for testing | 20 |
|
||||
| --epochs | number of epochs to train | 2 |
|
||||
| --lr | learning rate | 0.001 |
|
||||
| --no-cuda | disables CUDA training | False |
|
||||
| --seed | random seed | 1 |
|
||||
| --log-interval | how many batches to wait before logging training status | 200 |
|
|
@ -1,89 +0,0 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
from ort_utils import my_loss, transformer_model_description_dynamic_axes
|
||||
from pt_model import TransformerModel
|
||||
from utils import get_batch, prepare_data
|
||||
|
||||
import onnxruntime
|
||||
|
||||
|
||||
def train(trainer, data_source, device, epoch, args, bptt=35):
|
||||
total_loss = 0.0
|
||||
for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)):
|
||||
data, targets = get_batch(data_source, i)
|
||||
|
||||
loss, pred = trainer.train_step(data, targets)
|
||||
total_loss += loss.item()
|
||||
if batch % args.log_interval == 0 and batch > 0:
|
||||
cur_loss = total_loss / args.log_interval
|
||||
print(
|
||||
"epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format(
|
||||
epoch, batch, len(data_source) // bptt, cur_loss
|
||||
)
|
||||
)
|
||||
total_loss = 0
|
||||
|
||||
|
||||
def evaluate(trainer, data_source, bptt=35):
|
||||
total_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for i in range(0, data_source.size(0) - 1, bptt):
|
||||
data, targets = get_batch(data_source, i)
|
||||
loss, pred = trainer.eval_step(data, targets)
|
||||
total_loss += len(data) * loss.item()
|
||||
return total_loss / (len(data_source) - 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description="PyTorch TransformerModel example")
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)")
|
||||
parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)")
|
||||
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
|
||||
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=200,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status (default: 200)",
|
||||
)
|
||||
|
||||
# Basic setup
|
||||
args = parser.parse_args()
|
||||
if not args.no_cuda and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
torch.manual_seed(args.seed)
|
||||
onnxruntime.set_seed(args.seed)
|
||||
|
||||
# Model
|
||||
optim_config = onnxruntime.training.optim.SGDConfig(lr=args.lr)
|
||||
model_desc = transformer_model_description_dynamic_axes()
|
||||
model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
|
||||
|
||||
# Preparing data
|
||||
train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size)
|
||||
trainer = onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss)
|
||||
|
||||
# Train
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(trainer, train_data, device, epoch, args)
|
||||
val_loss = evaluate(trainer, val_data)
|
||||
print("-" * 89)
|
||||
print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ")
|
||||
print("-" * 89)
|
||||
|
||||
# Evaluate
|
||||
test_loss = evaluate(trainer, test_data)
|
||||
print("=" * 89)
|
||||
print(f"| End of training | test loss {test_loss:5.2f}")
|
||||
print("=" * 89)
|
|
@ -1,47 +0,0 @@
|
|||
import torch
|
||||
|
||||
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription
|
||||
from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription
|
||||
|
||||
|
||||
def my_loss(x, target):
|
||||
x = x.view(-1, 28785)
|
||||
return torch.nn.CrossEntropyLoss()(x, target)
|
||||
|
||||
|
||||
def transformer_model_description(bptt=35, batch_size=20, ntokens=28785):
|
||||
model_desc = {
|
||||
"inputs": [("input1", [bptt, batch_size]), ("label", [bptt * batch_size])],
|
||||
"outputs": [("loss", [], True), ("predictions", [bptt, batch_size, ntokens])],
|
||||
}
|
||||
return model_desc
|
||||
|
||||
|
||||
def transformer_model_description_dynamic_axes(ntokens=28785):
|
||||
model_desc = {
|
||||
"inputs": [("input1", ["bptt", "batch_size"]), ("label", ["bptt_x_batch_size"])],
|
||||
"outputs": [("loss", [], True), ("predictions", ["bptt", "batch_size", ntokens])],
|
||||
}
|
||||
return model_desc
|
||||
|
||||
|
||||
def legacy_transformer_model_description(bptt=35, batch_size=20, ntokens=28785):
|
||||
input_desc = Legacy_IODescription("input1", [bptt, batch_size])
|
||||
label_desc = Legacy_IODescription("label", [bptt * batch_size])
|
||||
loss_desc = Legacy_IODescription("loss", [])
|
||||
predictions_desc = Legacy_IODescription("predictions", [bptt, batch_size, ntokens])
|
||||
return (
|
||||
Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]),
|
||||
Legacy_IODescription("__learning_rate", [1]),
|
||||
)
|
||||
|
||||
|
||||
def legacy_transformer_model_description_dynamic_axes(ntokens=28785):
|
||||
input_desc = Legacy_IODescription("input1", ["bptt", "batch_size"])
|
||||
label_desc = Legacy_IODescription("label", ["bptt_x_batch_size"])
|
||||
loss_desc = Legacy_IODescription("loss", [])
|
||||
predictions_desc = Legacy_IODescription("predictions", ["bptt", "batch_size", ntokens])
|
||||
return (
|
||||
Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]),
|
||||
Legacy_IODescription("__learning_rate", [1]),
|
||||
)
|
|
@ -1,62 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TransformerModel(nn.Module):
|
||||
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
|
||||
super().__init__()
|
||||
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
self.model_type = "Transformer"
|
||||
self.input1_mask = None
|
||||
self.pos_encoder = PositionalEncoding(ninp, dropout)
|
||||
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
|
||||
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
|
||||
self.encoder = nn.Embedding(ntoken, ninp)
|
||||
self.ninp = ninp
|
||||
self.decoder = nn.Linear(ninp, ntoken)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _generate_square_subsequent_mask(self, sz):
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0)
|
||||
return mask
|
||||
|
||||
def init_weights(self):
|
||||
initrange = 0.1
|
||||
self.encoder.weight.data.uniform_(-initrange, initrange)
|
||||
self.decoder.bias.data.zero_()
|
||||
self.decoder.weight.data.uniform_(-initrange, initrange)
|
||||
|
||||
def forward(self, input1):
|
||||
if self.input1_mask is None or self.input1_mask.size(0) != input1.size(0):
|
||||
device = input1.device
|
||||
mask = self._generate_square_subsequent_mask(input1.size(0)).to(device)
|
||||
self.input1_mask = mask
|
||||
|
||||
input1 = self.encoder(input1) * math.sqrt(self.ninp)
|
||||
input1 = self.pos_encoder(input1)
|
||||
output = self.transformer_encoder(input1, self.input1_mask)
|
||||
output = self.decoder(output)
|
||||
return output
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[: x.size(0), :]
|
||||
return self.dropout(x)
|
|
@ -1,94 +0,0 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pt_model import TransformerModel
|
||||
from utils import get_batch, prepare_data
|
||||
|
||||
|
||||
def train(model, data_source, device, epoch, args, bptt=35):
|
||||
total_loss = 0.0
|
||||
model.train()
|
||||
for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)):
|
||||
data, targets = get_batch(data_source, i)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output.view(-1, 28785), targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
if batch % args.log_interval == 0 and batch > 0:
|
||||
cur_loss = total_loss / args.log_interval
|
||||
print(
|
||||
"epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format(
|
||||
epoch, batch, len(data_source) // bptt, cur_loss
|
||||
)
|
||||
)
|
||||
total_loss = 0
|
||||
|
||||
|
||||
def evaluate(model, data_source, criterion, bptt=35):
|
||||
total_loss = 0.0
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(0, data_source.size(0) - 1, bptt):
|
||||
data, targets = get_batch(data_source, i)
|
||||
output = model(data)
|
||||
output_flat = output.view(-1, 28785)
|
||||
total_loss += len(data) * criterion(output_flat, targets).item()
|
||||
return total_loss / (len(data_source) - 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description="PyTorch TransformerModel example")
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)")
|
||||
parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)")
|
||||
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
|
||||
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=200,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status (default: 200)",
|
||||
)
|
||||
|
||||
# Basic setup
|
||||
args = parser.parse_args()
|
||||
if not args.no_cuda and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# Model
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
lr = 0.001
|
||||
model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||
|
||||
# Preparing data
|
||||
train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size)
|
||||
|
||||
# Train
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(model, train_data, device, epoch, args)
|
||||
val_loss = evaluate(model, val_data, criterion)
|
||||
print("-" * 89)
|
||||
print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ")
|
||||
print("-" * 89)
|
||||
|
||||
# Evaluate
|
||||
test_loss = evaluate(model, test_data, criterion)
|
||||
print("=" * 89)
|
||||
print(f"| End of training | test loss {test_loss:5.2f}")
|
||||
print("=" * 89)
|
|
@ -1,59 +0,0 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
from torchtext.data.utils import get_tokenizer
|
||||
from torchtext.utils import download_from_url, extract_archive
|
||||
from torchtext.vocab import build_vocab_from_iterator
|
||||
|
||||
|
||||
def batchify(data, bsz, device):
|
||||
# Divide the dataset into bsz parts.
|
||||
nbatch = data.size(0) // bsz
|
||||
# Trim off any extra elements that wouldn't cleanly fit (remainders).
|
||||
data = data.narrow(0, 0, nbatch * bsz)
|
||||
# Evenly divide the data across the bsz batches.
|
||||
data = data.view(bsz, -1).t().contiguous()
|
||||
return data.to(device)
|
||||
|
||||
|
||||
def get_batch(source, i, bptt=35):
|
||||
seq_len = min(bptt, len(source) - 1 - i)
|
||||
data = source[i : i + seq_len]
|
||||
target = source[i + 1 : i + 1 + seq_len].view(-1)
|
||||
return data, target
|
||||
|
||||
|
||||
def prepare_data(device="cpu", train_batch_size=20, eval_batch_size=20, data_dir=None):
|
||||
url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
|
||||
|
||||
download_path = ".data_wikitext_2_v1"
|
||||
extract_path = None
|
||||
if data_dir:
|
||||
download_path = os.path.join(data_dir, "download")
|
||||
os.makedirs(download_path, exist_ok=True)
|
||||
download_path = os.path.join(download_path, "wikitext-2-v1.zip")
|
||||
|
||||
extract_path = os.path.join(data_dir, "extracted")
|
||||
os.makedirs(extract_path, exist_ok=True)
|
||||
|
||||
test_filepath, valid_filepath, train_filepath = extract_archive(
|
||||
download_from_url(url, root=download_path), to_path=extract_path
|
||||
)
|
||||
tokenizer = get_tokenizer("basic_english")
|
||||
vocab = build_vocab_from_iterator(map(tokenizer, iter(open(train_filepath, encoding="utf8")))) # noqa: SIM115
|
||||
|
||||
def data_process(raw_text_iter):
|
||||
data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter]
|
||||
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
|
||||
|
||||
train_data = data_process(iter(open(train_filepath, encoding="utf8"))) # noqa: SIM115
|
||||
val_data = data_process(iter(open(valid_filepath, encoding="utf8"))) # noqa: SIM115
|
||||
test_data = data_process(iter(open(test_filepath, encoding="utf8"))) # noqa: SIM115
|
||||
|
||||
device = torch.device(device)
|
||||
|
||||
train_data = batchify(train_data, train_batch_size, device)
|
||||
val_data = batchify(val_data, eval_batch_size, device)
|
||||
test_data = batchify(test_data, eval_batch_size, device)
|
||||
|
||||
return train_data, val_data, test_data
|
1
setup.py
1
setup.py
|
@ -398,7 +398,6 @@ packages = [
|
|||
"onnxruntime",
|
||||
"onnxruntime.backend",
|
||||
"onnxruntime.capi",
|
||||
"onnxruntime.capi.training",
|
||||
"onnxruntime.datasets",
|
||||
"onnxruntime.tools",
|
||||
"onnxruntime.tools.mobile_helpers",
|
||||
|
|
Загрузка…
Ссылка в новой задаче