зеркало из https://github.com/microsoft/DeepSpeed.git
[engine] train should be able to get `mode` arg (#571)
This commit is contained in:
Родитель
845921b3b6
Коммит
2d1f7c0172
|
@ -800,12 +800,12 @@ class DeepSpeedEngine(Module):
|
||||||
data_parallel_world_size=data_parallel_world_size,
|
data_parallel_world_size=data_parallel_world_size,
|
||||||
data_parallel_rank=data_parallel_rank)
|
data_parallel_rank=data_parallel_rank)
|
||||||
|
|
||||||
def train(self):
|
def train(self, mode=True):
|
||||||
r"""
|
r"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.warn_unscaled_loss = True
|
self.warn_unscaled_loss = True
|
||||||
self.module.train()
|
self.module.train(mode)
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
r"""
|
r"""
|
||||||
|
|
Загрузка…
Ссылка в новой задаче