From 2d1f7c01721a25a2a98152092fedc1bf10f91f85 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 2 Dec 2020 16:54:00 -0800 Subject: [PATCH] [engine] train should be able to get `mode` arg (#571) --- deepspeed/runtime/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 42e356b36..ee515a072 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -800,12 +800,12 @@ class DeepSpeedEngine(Module): data_parallel_world_size=data_parallel_world_size, data_parallel_rank=data_parallel_rank) - def train(self): + def train(self, mode=True): r""" """ self.warn_unscaled_loss = True - self.module.train() + self.module.train(mode) def eval(self): r"""