зеркало из https://github.com/microsoft/DeepSpeed.git
ec6cbb3c08
This PR allows `deepspeed.comm.inference_all_reduce()` enters torch.compile graph even it is implemented as C++ kernel in DeepSpeed. Previous implementation register `inference_all_reduce()` C++ kernel as pybind function so it can be called inside PyThon code. However pybind function cannot be recognized by PyTorch so graph breaks when `inference_all_reduce` is called. We address issue by register `inference_all_reduce` as a PyTorch custom op `torch.ops.deepspeed.inference_all_reduce`, so it can be built into PyTorch graph The output trace code from torchinductor ``` class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[5, 4]", primals_2: "f32[5]", primals_3: "f32[4, 4]"): # File: /home/gma/DeepSpeed/deepspeed/comm/torch.py:161 in inference_all_reduce, code: return torch.ops.deepspeed.inference_all_reduce_(tensor) inference_all_reduce: "f32[4, 4]" = torch.ops.deepspeed.inference_all_reduce.default(primals_3) # File: /home/gma/allreduce_graph/test_allreduce.py:33 in forward, code: return self.linear(input) permute: "f32[4, 5]" = torch.ops.aten.permute.default(primals_1, [1, 0]); primals_1 = None addmm: "f32[4, 5]" = torch.ops.aten.addmm.default(primals_2, inference_all_reduce, permute); primals_2 = permute = None # No stacktrace found for following nodes copy_: "f32[4, 4]" = torch.ops.aten.copy_.default(primals_3, inference_all_reduce); primals_3 = None return [addmm, inference_all_reduce] ``` Note in this PR the inference_all_reduce op for CPU does not handle multinode and FP16 data type. For FP16 data type support, we will align with PyTorch CPU FP16 plan. For multinode, we are still looking at the possibility to upstream oneCCL integration into PyTorch, so we are able to get use of oneCCL for multinode tensor parallel inference with PyTorch. This PR is independent to https://github.com/microsoft/DeepSpeed/pull/5571. They can work seperately or together without issue. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> |
||
---|---|---|
.. | ||
adam | ||
comm | ||
lion |