Updated DL device name for MAIA

This commit is contained in:
kyule7 2024-02-22 15:52:48 -08:00
Родитель 130fed8ae3
Коммит 58c51e40e9
1 изменённых файлов: 2 добавлений и 2 удалений

Просмотреть файл

@ -86,7 +86,7 @@ OrtDevice GetOrtDevice(const DLDevice& device) {
case DLDeviceType::kDLCUDA:
case DLDeviceType::kDLROCM:
return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(device.device_id));
case DLDeviceType::kDLExtDev:
case DLDeviceType::kDLMAIA:
return OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(device.device_id));
default:
ORT_THROW("Unsupported device type");
@ -202,7 +202,7 @@ DLDevice GetDlpackDevice(const OrtValue& ort_value, const int64_t& device_id) {
break;
case OrtDevice::FPGA:
case OrtDevice::NPU:
device.device_type = DLDeviceType::kDLExtDev;
device.device_type = DLDeviceType::kDLMAIA;
default:
ORT_THROW("Cannot pack tensors on this device.");
}