зеркало из https://github.com/microsoft/DeepSpeed.git
[engine] train should be able to get `mode` arg (#571)
This commit is contained in:
Родитель
845921b3b6
Коммит
2d1f7c0172
|
@ -800,12 +800,12 @@ class DeepSpeedEngine(Module):
|
|||
data_parallel_world_size=data_parallel_world_size,
|
||||
data_parallel_rank=data_parallel_rank)
|
||||
|
||||
def train(self):
|
||||
def train(self, mode=True):
|
||||
r"""
|
||||
"""
|
||||
|
||||
self.warn_unscaled_loss = True
|
||||
self.module.train()
|
||||
self.module.train(mode)
|
||||
|
||||
def eval(self):
|
||||
r"""
|
||||
|
|
Загрузка…
Ссылка в новой задаче