change predict to be the same in staging
This commit is contained in:
Родитель
caf347a32b
Коммит
6f3233d3d3
|
@ -261,16 +261,11 @@ class Transformer:
|
|||
def predict(self, eval_dataloader, get_inputs, n_gpu=1, verbose=True, move_batch_to_device=None):
|
||||
device, num_gpus = get_device(num_gpus=n_gpu, local_rank=-1)
|
||||
|
||||
if isinstance(self.model, torch.nn.DataParallel):
|
||||
self.model = self.model.module
|
||||
|
||||
if num_gpus > 1:
|
||||
if not isinstance(self.model, torch.nn.DataParallel):
|
||||
self.model = torch.nn.DataParallel(self.model, device_ids=range(0,num_gpus))
|
||||
else:
|
||||
# make sure the prediction can switch between different numbers of multiple gpus
|
||||
self.model = self.model.module
|
||||
self.model = torch.nn.DataParallel(self.model, device_ids=range(0,num_gpus))
|
||||
else:
|
||||
if isinstance(self.model, torch.nn.DataParallel):
|
||||
self.model = self.model.module
|
||||
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(num_gpus)))
|
||||
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
|
Загрузка…
Ссылка в новой задаче