[CPU] Allow deepspeed.comm.inference_all_reduce in torch.compile graph (#5604)

This PR allows `deepspeed.comm.inference_all_reduce()` enters
torch.compile graph even it is implemented as C++ kernel in DeepSpeed.

Previous implementation register `inference_all_reduce()` C++ kernel as
pybind function so it can be called inside PyThon code. However pybind
function cannot be recognized by PyTorch so graph breaks when
`inference_all_reduce` is called.

We address issue by register `inference_all_reduce` as a PyTorch custom
op `torch.ops.deepspeed.inference_all_reduce`, so it can be built into
PyTorch graph

The output trace code from torchinductor
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[5, 4]", primals_2: "f32[5]", primals_3: "f32[4, 4]"):
        # File: /home/gma/DeepSpeed/deepspeed/comm/torch.py:161 in inference_all_reduce, code: return torch.ops.deepspeed.inference_all_reduce_(tensor)
        inference_all_reduce: "f32[4, 4]" = torch.ops.deepspeed.inference_all_reduce.default(primals_3)

        # File: /home/gma/allreduce_graph/test_allreduce.py:33 in forward, code: return self.linear(input)
        permute: "f32[4, 5]" = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        addmm: "f32[4, 5]" = torch.ops.aten.addmm.default(primals_2, inference_all_reduce, permute);  primals_2 = permute = None

        # No stacktrace found for following nodes
        copy_: "f32[4, 4]" = torch.ops.aten.copy_.default(primals_3, inference_all_reduce);  primals_3 = None
        return [addmm, inference_all_reduce]
```

Note in this PR the inference_all_reduce op for CPU does not handle
multinode and FP16 data type. For FP16 data type support, we will align
with PyTorch CPU FP16 plan. For multinode, we are still looking at the
possibility to upstream oneCCL integration into PyTorch, so we are able
to get use of oneCCL for multinode tensor parallel inference with
PyTorch.

This PR is independent to
https://github.com/microsoft/DeepSpeed/pull/5571. They can work
seperately or together without issue.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
This commit is contained in:
Ma, Guokai 2024-07-16 06:24:11 +08:00 коммит произвёл GitHub
Родитель a07a3c5d22
Коммит ec6cbb3c08
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 85 добавлений и 19 удалений

Просмотреть файл

@ -46,15 +46,13 @@ void initialize(int size, int rank)
if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); }
}
int get_rank(int group = 0) { return world_rank; }
int get_world_size(int group = 0) { return world_size; }
void inference_all_reduce_(torch::Tensor& data, int op);
// Success - return 0
// Fail (cannot hornor the request and need to fall back) - return -1
int inference_all_reduce(torch::Tensor& data, py::object op)
void inference_all_reduce_(torch::Tensor& data, int op)
{
if (!all_ranks_local_p) return -1;
assert(op == 0);
#ifdef DO_PROFILE
static double total_time = 0.0;
static double total_time_sq = 0.0;
@ -67,11 +65,6 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
auto start = std::chrono::system_clock::now();
#endif
static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
assert(py::int_(op.attr("value")) == ReduceOpSum);
auto numel = data.numel();
int data_size = 0;
@ -84,7 +77,7 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
default: data_type_fallback = true;
}
if (data_type_fallback) return -1;
if (data_type_fallback) return;
all_reduce_outer_loop(data, numel, data_size);
@ -109,13 +102,85 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
}
}
#endif
return 0;
return;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("initialize", &initialize, "shm initialize"); }
TORCH_LIBRARY(deepspeed, m)
{
m.def("initialize", &initialize, "shm initialize");
m.def("get_rank", &get_rank, "get rank");
m.def("get_world_size", &get_world_size, "get world size");
m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation");
m.def("inference_all_reduce(Tensor self) -> Tensor");
m.def("inference_all_reduce_(Tensor(a!) self) -> Tensor(a!)");
}
torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_)
{
torch::Tensor result_ = torch::empty_like(self_);
return result_;
}
torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_) { return self_; }
torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_)
{
TORCH_INTERNAL_ASSERT(self_.device().type() == torch::DeviceType::CPU);
torch::Tensor self_tensor = self_.contiguous();
inference_all_reduce_(self_tensor, 0);
return self_;
}
torch::Tensor inference_all_reduce_cpu(const torch::Tensor& self_)
{
torch::Tensor result = self_.clone();
inference_all_reduce__cpu(result);
return result;
}
#include <ATen/FunctionalTensorWrapper.h>
// The boilerplate functionalization logic, that teaches functionalization
// how to map x_() calls into x() calls.
// Long term, we'd like to not require users to write this logic.
// HOWEVER, if you have a custom op that is mutable,
// You will still need to write an out-of-place version of that op!
at::Tensor& inference_all_reduce__functionalization_glue(at::Tensor& x)
{
// We expect all tensor inputs to our op to be "functional tensors"
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(x));
// First, sync and unwrap and functional tensors
at::functionalization::impl::sync(x);
auto x_ = at::functionalization::impl::from_functional_tensor(x);
// Grab the dispatcher entry corresponding to the out-of-place op, "x"
static auto op_handle = c10::Dispatcher::singleton()
// specify namespace::op_name, op_overload_name
.findSchemaOrThrow("deepspeed::inference_all_reduce", "")
// Specify the C++ schema of the out-of-place op.
.typed<at::Tensor(const at::Tensor&)>();
// Next, redispatch to the out-of-place op, x() (user called x_, we call x)
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = op_handle.call(x_);
}
// Finally, tell functionalization about this mutation.
at::functionalization::impl::replace_(x, tmp_output);
at::functionalization::impl::commit_update(x);
at::functionalization::impl::sync(x);
return x;
}
TORCH_LIBRARY_IMPL(deepspeed, CPU, m)
{
m.impl("inference_all_reduce", inference_all_reduce_cpu);
m.impl("inference_all_reduce_", inference_all_reduce__cpu);
}
TORCH_LIBRARY_IMPL(deepspeed, Meta, m)
{
m.impl("inference_all_reduce", inference_all_reduce_meta);
m.impl("inference_all_reduce_", inference_all_reduce__meta);
}
TORCH_LIBRARY_IMPL(deepspeed, Functionalize, m)
{
m.impl("inference_all_reduce_", inference_all_reduce__functionalization_glue);
}

Просмотреть файл

@ -151,11 +151,12 @@ class TorchBackend(Backend):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
@compiler.disable
def inference_all_reduce(self, tensor, op, group=None):
if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1:
if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False)
else:
return torch.ops.deepspeed.inference_all_reduce_(tensor)
@compiler.disable
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):