зеркало из https://github.com/microsoft/DeepSpeed.git
pipe engine _aggregate_total_loss: more efficient loss concatenation (#4327)
* _aggregate_total_loss: more efficient loss concatenation optimize _aggregate_total_loss function in order to remove dependancy of copying from device to host and back to device. This reduce the runtime on the host. * Fixing the if/else block on which the optimization should take place --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Родитель
0f2338f7b8
Коммит
a02de228d0
|
@ -549,7 +549,7 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
agg_loss /= self.dp_world_size
|
||||
|
||||
assert self.global_rank in self.grid.pp_group
|
||||
losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
|
||||
losses = torch.stack([self.dp_group_loss, agg_loss])
|
||||
if self.is_pipe_parallel:
|
||||
dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
|
||||
else:
|
||||
|
|
Загрузка…
Ссылка в новой задаче