update xpu fusedadam opbuilder for pytorch 2.3 (#5702)

update the way to get queue for FusedAdam OpBuilder.

---------

Signed-off-by: baodii <di.bao@intel.com>
Co-authored-by: Logan Adams <loadams@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
baodi 2024-07-02 03:34:11 +08:00 коммит произвёл GitHub
Родитель df58a784c8
Коммит e39229676c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 3 добавлений и 4 удалений

Просмотреть файл

@ -10,6 +10,7 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <c10/xpu/XPUStream.h>
#include <ipex.h>
#include <sycl/sycl.hpp>
#include "compat.h"
@ -22,10 +23,8 @@ namespace at {
namespace cuda {
sycl::queue* getCurrentCUDAStream()
{
auto device_type = c10::DeviceType::XPU;
c10::impl::VirtualGuardImpl impl(device_type);
c10::Stream c10_stream = impl.getStream(c10::Device(device_type));
auto& queue = xpu::get_queue_from_stream(c10_stream);
c10::xpu::XPUStream stream = c10::xpu::getCurrentXPUStream();
auto& queue = stream.queue();
return &queue;
}