[engine] train should be able to get `mode` arg (#571)

This commit is contained in:
Stas Bekman 2020-12-02 16:54:00 -08:00 коммит произвёл GitHub
Родитель 845921b3b6
Коммит 2d1f7c0172
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 2 добавлений и 2 удалений

Просмотреть файл

@ -800,12 +800,12 @@ class DeepSpeedEngine(Module):
data_parallel_world_size=data_parallel_world_size, data_parallel_world_size=data_parallel_world_size,
data_parallel_rank=data_parallel_rank) data_parallel_rank=data_parallel_rank)
def train(self): def train(self, mode=True):
r""" r"""
""" """
self.warn_unscaled_loss = True self.warn_unscaled_loss = True
self.module.train() self.module.train(mode)
def eval(self): def eval(self):
r""" r"""