зеркало из https://github.com/microsoft/DeepSpeed.git
update xpu fusedadam opbuilder for pytorch 2.3 (#5702)
update the way to get queue for FusedAdam OpBuilder. --------- Signed-off-by: baodii <di.bao@intel.com> Co-authored-by: Logan Adams <loadams@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Родитель
df58a784c8
Коммит
e39229676c
|
@ -10,6 +10,7 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
#include <ipex.h>
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "compat.h"
|
||||
|
@ -22,10 +23,8 @@ namespace at {
|
|||
namespace cuda {
|
||||
sycl::queue* getCurrentCUDAStream()
|
||||
{
|
||||
auto device_type = c10::DeviceType::XPU;
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
c10::Stream c10_stream = impl.getStream(c10::Device(device_type));
|
||||
auto& queue = xpu::get_queue_from_stream(c10_stream);
|
||||
c10::xpu::XPUStream stream = c10::xpu::getCurrentXPUStream();
|
||||
auto& queue = stream.queue();
|
||||
return &queue;
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче