зеркало из https://github.com/microsoft/DeepSpeed.git
Enable dynamic shapes for pipeline parallel engine inputs (#5481)
This PR enables dynamic shapes for inputs to pipeline parallel (PP) engine. Currently PP engine checks tensor shapes and allocates communication buffer at the first forward/backward passes. This causes a tensor shape mismatch error when input tensor shapes changed. This PR adds an option to check tensor shapes at every iteration and allocate buffer based on the shapes. As shown below, you can enable this feature by passing `dynamic_shape=True` to `PipelineModule`. Note that this might have a performance impact and the option is set to False as default. ```python model = PipelineModule( ... dynamic_shape=True ) ``` This will increase the overhead of buffer allocation and communication for tensor metadata. To mitigate the overhead, this PR also includes these improvements: - Consolidate multiple communication calls to send/recv tensor shapes9f96ad4049
- Reuse (extend) communication buffer instead of creating a new oneb3c07504be
--------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Родитель
4d4ff0eddd
Коммит
1ab1928d79
|
@ -5,6 +5,8 @@
|
|||
|
||||
from types import MethodType
|
||||
from collections import OrderedDict
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
import torch
|
||||
from deepspeed import comm as dist
|
||||
|
@ -40,6 +42,9 @@ PIPE_SEND_GRAD_TIMER = 'pipe_send_grad'
|
|||
PIPE_RECV_INPUT_TIMER = 'pipe_recv_input'
|
||||
PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad'
|
||||
|
||||
# The buffer size to store the meta data for each tensor.
|
||||
TENSOR_META_SIZE = 256
|
||||
|
||||
|
||||
def is_even(number):
|
||||
return number % 2 == 0
|
||||
|
@ -179,6 +184,7 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
}
|
||||
self.pipe_recv_buf = None
|
||||
self.grad_layer = None
|
||||
self._grad_layer_buf = []
|
||||
|
||||
self.meta_buffer = None
|
||||
|
||||
|
@ -250,6 +256,8 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
self.timers(STEP_MICRO_TIMER).start()
|
||||
self.timers(STEP_MICRO_TIMER).stop()
|
||||
|
||||
self.dynamic_shape = self.module.dynamic_shape
|
||||
|
||||
def set_has_attention_mask(self, value):
|
||||
assert isinstance(value, bool)
|
||||
self.has_attention_mask = value
|
||||
|
@ -318,6 +326,7 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
self.first_output_send = True
|
||||
self.pipe_recv_buf = None
|
||||
self.grad_layer = None
|
||||
self._grad_layer_buf = []
|
||||
self.meta_buffer = None
|
||||
|
||||
self.pipe_partition_input_meta_cache = None
|
||||
|
@ -926,51 +935,38 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
* ndims
|
||||
* shape
|
||||
"""
|
||||
send_bytes = 0
|
||||
meta_buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device)
|
||||
if isinstance(buffer, torch.Tensor):
|
||||
type_tensor = torch.LongTensor(data=[0]).to(self.device)
|
||||
p2p.send(type_tensor, recv_stage)
|
||||
send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
|
||||
send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
|
||||
p2p.send(send_ndims, recv_stage)
|
||||
p2p.send(send_shape, recv_stage)
|
||||
send_bytes += _tensor_bytes(buffer)
|
||||
elif isinstance(buffer, list):
|
||||
assert (False)
|
||||
type_tensor = torch.LongTensor(data=[1]).to(self.device)
|
||||
p2p.send(type_tensor, recv_stage)
|
||||
count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
|
||||
p2p.send(count_tensor, recv_stage)
|
||||
meta_buf_list = [
|
||||
0, # type of data (0: tensor, 1: list (unused), 2: tuple)
|
||||
self.DTYPE_TO_ID[buffer.dtype], # dtype
|
||||
len(buffer.size()) # ndims
|
||||
]
|
||||
meta_buf_list.extend(buffer.size())
|
||||
assert len(
|
||||
meta_buf_list
|
||||
) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}"
|
||||
meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32))
|
||||
p2p.send(meta_buffer, recv_stage)
|
||||
|
||||
elif isinstance(buffer, tuple):
|
||||
meta_buf_list = [
|
||||
2, # type of data (0: tensor, 1: list (unused), 2: tuple)
|
||||
len(buffer) # num_tensors
|
||||
]
|
||||
|
||||
for tensor in buffer:
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
|
||||
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
|
||||
p2p.send(send_ndims, recv_stage)
|
||||
p2p.send(send_shape, recv_stage)
|
||||
send_bytes += _tensor_bytes(tensor)
|
||||
elif isinstance(buffer, tuple):
|
||||
type_tensor = torch.LongTensor(data=[2]).to(self.device)
|
||||
p2p.send(type_tensor, recv_stage)
|
||||
count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
|
||||
p2p.send(count_tensor, recv_stage)
|
||||
for idx, tensor in enumerate(buffer):
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
|
||||
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
|
||||
send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device)
|
||||
p2p.send(send_dtype, recv_stage)
|
||||
p2p.send(send_ndims, recv_stage)
|
||||
p2p.send(send_shape, recv_stage)
|
||||
# Useful for performance debugging.
|
||||
'''
|
||||
new_bytes = _tensor_bytes(tensor)
|
||||
send_bytes += _tensor_bytes(tensor)
|
||||
# Useful for performance debugging.
|
||||
if self.grid.data_parallel_id == 0:
|
||||
print(
|
||||
f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
|
||||
)
|
||||
'''
|
||||
meta_buf_list.append(self.DTYPE_TO_ID[tensor.dtype])
|
||||
meta_buf_list.append(len(tensor.size()))
|
||||
meta_buf_list.extend(tensor.size())
|
||||
|
||||
assert len(
|
||||
meta_buf_list
|
||||
) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}"
|
||||
meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32))
|
||||
p2p.send(meta_buffer, recv_stage)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Could not send meta type {type(buffer)}')
|
||||
|
||||
|
@ -983,49 +979,35 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
def _recv_tensor_meta(self, send_stage):
|
||||
"""Receive metadata about upcoming p2p transfers and return allocated buffers.
|
||||
|
||||
Metadata is communicated in this order:
|
||||
* type (0: tensor, 1: list)
|
||||
* num_tensors if type=list
|
||||
foreach tensor in buffer:
|
||||
* ndims
|
||||
* shape
|
||||
|
||||
Returns:
|
||||
Allocated buffer for receiving from send_stage.
|
||||
"""
|
||||
buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device)
|
||||
p2p.recv(buffer, send_stage)
|
||||
|
||||
type_tensor = torch.LongTensor(data=[0]).to(self.device)
|
||||
p2p.recv(type_tensor, send_stage)
|
||||
recv_type = type_tensor.item()
|
||||
recv_type = buffer[0].item()
|
||||
|
||||
# A single tensor will be sent.
|
||||
if recv_type == 0:
|
||||
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
|
||||
p2p.recv(recv_ndims, send_stage)
|
||||
recv_ndims = recv_ndims.item()
|
||||
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
|
||||
p2p.recv(recv_shape, send_stage)
|
||||
recv_shape = recv_shape.tolist()
|
||||
return self._allocate_buffer(recv_shape, num_buffers=1)[0]
|
||||
recv_dtype = self.ID_TO_DTYPE[buffer[1].item()]
|
||||
recv_ndims = buffer[2].item()
|
||||
recv_shape = buffer[3:3 + recv_ndims].tolist()
|
||||
return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype)
|
||||
|
||||
# List or tuple of tensors
|
||||
# List or tuple of tensors (recv_type == 1 (list) is currently unused)
|
||||
elif recv_type == 1 or recv_type == 2:
|
||||
count_tensor = torch.LongTensor(data=[0]).to(self.device)
|
||||
p2p.recv(count_tensor, send_stage)
|
||||
num_tensors = count_tensor.item()
|
||||
recv_shapes_and_dtypes = []
|
||||
for idx in range(num_tensors):
|
||||
recv_dtype = torch.LongTensor(data=[0]).to(self.device)
|
||||
p2p.recv(recv_dtype, send_stage)
|
||||
recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
|
||||
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
|
||||
p2p.recv(recv_ndims, send_stage)
|
||||
recv_ndims = recv_ndims.item()
|
||||
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
|
||||
p2p.recv(recv_shape, send_stage)
|
||||
recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))
|
||||
num_tensors = buffer[1].item()
|
||||
|
||||
buffers = []
|
||||
offset = 2
|
||||
for idx in range(num_tensors):
|
||||
recv_dtype = self.ID_TO_DTYPE[buffer[offset].item()]
|
||||
recv_ndims = buffer[offset + 1].item()
|
||||
recv_shape = buffer[offset + 2:offset + 2 + recv_ndims].tolist()
|
||||
offset += 2 + recv_ndims
|
||||
|
||||
buffers.append(self._allocate_or_extend_buffers(idx, recv_shape, recv_dtype))
|
||||
|
||||
buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
|
||||
# Convert to tuples if requested.
|
||||
if recv_type == 2:
|
||||
buffers = tuple(buffers)
|
||||
|
@ -1048,7 +1030,7 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
outputs[-1] = outputs[-1].half()
|
||||
outputs = tuple(outputs)
|
||||
|
||||
if self.first_output_send:
|
||||
if self.dynamic_shape or self.first_output_send:
|
||||
self.first_output_send = False
|
||||
self._send_tensor_meta(outputs, self.next_stage)
|
||||
|
||||
|
@ -1133,7 +1115,7 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
recvd = None
|
||||
|
||||
# Allocate the buffer if necessary
|
||||
if self.pipe_recv_buf is None:
|
||||
if self.dynamic_shape or self.pipe_recv_buf is None:
|
||||
self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage)
|
||||
|
||||
if isinstance(self.pipe_recv_buf, torch.Tensor):
|
||||
|
@ -1188,10 +1170,9 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
self.pipe_buffers['outputs'][buffer_id] = outputs
|
||||
|
||||
# Allocate gradient if necessary
|
||||
if self.grad_layer is None:
|
||||
if self.dynamic_shape or self.grad_layer is None:
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
s = list(outputs.size())
|
||||
self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0]
|
||||
self.grad_layer = self._allocate_or_extend_buffers(0, list(outputs.size()), outputs.dtype)
|
||||
else:
|
||||
# XXX This is a HACK
|
||||
# When we exchange activations/gradients, the two pipe stages
|
||||
|
@ -1213,7 +1194,11 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
for t in outputs[2:] if t.is_floating_point()]
|
||||
else:
|
||||
sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()]
|
||||
self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0]
|
||||
|
||||
self.grad_layer = [
|
||||
self._allocate_or_extend_buffers(i, size, dtype)
|
||||
for i, (size, dtype) in enumerate(sizes_and_dtypes)
|
||||
]
|
||||
|
||||
if isinstance(self.grad_layer, torch.Tensor):
|
||||
p2p.recv(self.grad_layer, self.next_stage)
|
||||
|
@ -1294,16 +1279,17 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
buffers.append(self._allocate_zeros(shape, **kwargs))
|
||||
return buffers
|
||||
|
||||
def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1):
|
||||
buffers = []
|
||||
if num_buffers == -1:
|
||||
num_buffers = self.num_pipe_buffers
|
||||
for count in range(num_buffers):
|
||||
buffer = []
|
||||
for shape, dtype in shapes_and_dtypes:
|
||||
buffer.append(self._allocate_zeros(shape, dtype=dtype, requires_grad=requires_grad))
|
||||
buffers.append(buffer)
|
||||
return buffers
|
||||
def _allocate_or_extend_buffers(self, idx, shape, dtype):
|
||||
numel = reduce(mul, shape) if len(shape) > 0 else 1
|
||||
if len(self._grad_layer_buf) <= idx or self._grad_layer_buf[idx].numel() < numel:
|
||||
new_buf = self._allocate_buffer(shape, dtype=dtype, num_buffers=1)[0]
|
||||
if len(self._grad_layer_buf) <= idx:
|
||||
self._grad_layer_buf.append(new_buf)
|
||||
else:
|
||||
self._grad_layer_buf[idx] = new_buf
|
||||
return self._grad_layer_buf[idx]
|
||||
else:
|
||||
return self._grad_layer_buf[idx].flatten()[:numel].view(shape)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Disabled for pipeline parallel training. See ``train_batch()``. """
|
||||
|
|
|
@ -117,6 +117,7 @@ class PipelineModule(nn.Module):
|
|||
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
|
||||
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
|
||||
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
|
||||
dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -130,7 +131,8 @@ class PipelineModule(nn.Module):
|
|||
partition_method='parameters',
|
||||
activation_checkpoint_interval=0,
|
||||
activation_checkpoint_func=checkpointing.checkpoint,
|
||||
checkpointable_layers=None):
|
||||
checkpointable_layers=None,
|
||||
dynamic_shape=False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
@ -213,6 +215,8 @@ class PipelineModule(nn.Module):
|
|||
self.tied_comms = self._index_tied_modules()
|
||||
self._synchronize_tied_weights()
|
||||
|
||||
self.dynamic_shape = dynamic_shape
|
||||
|
||||
def _precompute_checkpointable_values(self):
|
||||
if self.activation_checkpoint_interval > 0 and self.is_checkpointable_results_interval != self.activation_checkpoint_interval:
|
||||
num_layers = len(self.forward_funcs)
|
||||
|
|
|
@ -14,6 +14,7 @@ import deepspeed.runtime.utils as ds_utils
|
|||
from deepspeed.utils.torch import required_torch_version
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec
|
||||
from .util import no_child_process_in_deepspeed_io
|
||||
|
||||
|
||||
class AlexNet(nn.Module):
|
||||
|
@ -125,22 +126,11 @@ def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True,
|
|||
trainset = cifar_trainset(fp16=fp16)
|
||||
config['local_rank'] = dist.get_rank()
|
||||
|
||||
# deepspeed_io defaults to creating a dataloader that uses a
|
||||
# multiprocessing pool. Our tests use pools and we cannot nest pools in
|
||||
# python. Therefore we're injecting this kwarg to ensure that no pools
|
||||
# are used in the dataloader.
|
||||
old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io
|
||||
|
||||
def new_method(*args, **kwargs):
|
||||
kwargs["num_local_io_workers"] = 0
|
||||
return old_method(*args, **kwargs)
|
||||
|
||||
deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method
|
||||
|
||||
engine, _, _, _ = deepspeed.initialize(config=config,
|
||||
model=model,
|
||||
model_parameters=[p for p in model.parameters()],
|
||||
training_data=trainset)
|
||||
with no_child_process_in_deepspeed_io():
|
||||
engine, _, _, _ = deepspeed.initialize(config=config,
|
||||
model=model,
|
||||
model_parameters=[p for p in model.parameters()],
|
||||
training_data=trainset)
|
||||
|
||||
losses = []
|
||||
for step in range(num_steps):
|
||||
|
|
|
@ -7,12 +7,15 @@ import copy
|
|||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
|
||||
import deepspeed
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.runtime.pipe.topology import PipeDataParallelTopology
|
||||
from deepspeed.runtime.pipe.module import PipelineModule
|
||||
from unit.alexnet_model import AlexNetPipe, train_cifar
|
||||
from unit.common import DistributedTest
|
||||
from unit.util import skip_on_arch
|
||||
from unit.util import skip_on_arch, no_child_process_in_deepspeed_io
|
||||
|
||||
PipeTopo = PipeDataParallelTopology
|
||||
|
||||
|
@ -155,3 +158,95 @@ class TestPipeCifar10(DistributedTest):
|
|||
# the following check could passed on higher version docker: nvcr.io/nvidia/pytorch:23.07-py3(torch2.1.0 cuda12.1)
|
||||
# Check if models have same weights after training
|
||||
# self._check_model_params_equal(base_model, test_model)
|
||||
|
||||
|
||||
class DynamicShapeTestLayer(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(hidden_size, hidden_size)
|
||||
self.shapes = set()
|
||||
|
||||
def forward(self, x):
|
||||
self.shapes.add(x.shape)
|
||||
y = self.fc(x)
|
||||
return y
|
||||
|
||||
|
||||
class DynamicShapeTestModel(nn.Module):
|
||||
|
||||
def __init__(self, n_layers, hidden_size):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([DynamicShapeTestLayer(hidden_size) for _ in range(n_layers)])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('topo_config', [
|
||||
{
|
||||
"num_pp": 1,
|
||||
"num_dp": 4
|
||||
},
|
||||
{
|
||||
"num_pp": 2,
|
||||
"num_dp": 2
|
||||
},
|
||||
{
|
||||
"num_pp": 4,
|
||||
"num_dp": 1
|
||||
},
|
||||
])
|
||||
class TestPipeDynamicShape(DistributedTest):
|
||||
world_size = 4
|
||||
|
||||
def test_pipe_base(self, topo_config):
|
||||
"""This test checks if the pipeline engine can handle dynamic shapes correctly.
|
||||
We pass inputs of different shapes to the pipeline engine.
|
||||
"""
|
||||
|
||||
n_iter = 10
|
||||
n_layers = 4
|
||||
n_samples = 1024
|
||||
batch_size = 4
|
||||
channel_dims = [8, 16, 32, 64]
|
||||
hidden_size = 16
|
||||
|
||||
topo = PipeTopo(**topo_config)
|
||||
|
||||
model = DynamicShapeTestModel(n_layers, hidden_size)
|
||||
model = PipelineModule(layers=model.layers, topology=topo, loss_fn=nn.MSELoss(), dynamic_shape=True)
|
||||
|
||||
# Each batch has different channel dim but we use the same channel dim in the same batch
|
||||
xs = [
|
||||
torch.randn(channel_dims[(i // batch_size) % len(channel_dims)], hidden_size, dtype=torch.float32)
|
||||
for i in range(n_samples)
|
||||
]
|
||||
ys = [torch.randn_like(x) for x in xs]
|
||||
|
||||
class CustomDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, xs, ys):
|
||||
self.xs = xs
|
||||
self.ys = ys
|
||||
|
||||
def __len__(self):
|
||||
return len(self.xs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.xs[idx], self.ys[idx]
|
||||
|
||||
dataset = CustomDataset(xs, ys)
|
||||
|
||||
config_dict["train_batch_size"] = batch_size
|
||||
|
||||
with no_child_process_in_deepspeed_io():
|
||||
engine, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model,
|
||||
model_parameters=[p for p in model.parameters()],
|
||||
training_data=dataset)
|
||||
|
||||
for _ in range(n_iter):
|
||||
_ = engine.train_batch()
|
||||
|
||||
# Check if all layers have seen different shapes
|
||||
for layer in model.modules():
|
||||
if isinstance(layer, DynamicShapeTestLayer):
|
||||
assert len(layer.shapes) > 1
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator, is_current_accelerator_supported
|
||||
from deepspeed.git_version_info import torch_info
|
||||
|
||||
|
@ -67,3 +69,22 @@ def required_amp_check():
|
|||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
class no_child_process_in_deepspeed_io:
|
||||
|
||||
def __enter__(self):
|
||||
# deepspeed_io defaults to creating a dataloader that uses a
|
||||
# multiprocessing pool. Our tests use pools and we cannot nest pools in
|
||||
# python. Therefore we're injecting this kwarg to ensure that no pools
|
||||
# are used in the dataloader.
|
||||
self.old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io
|
||||
|
||||
def new_method(*args, **kwargs):
|
||||
kwargs["num_local_io_workers"] = 0
|
||||
return self.old_method(*args, **kwargs)
|
||||
|
||||
deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method
|
||||
|
||||
def __exit__(self, *_):
|
||||
deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method
|
||||
|
|
Загрузка…
Ссылка в новой задаче