Updated DL device name for MAIA
This commit is contained in:
Родитель
130fed8ae3
Коммит
58c51e40e9
|
@ -86,7 +86,7 @@ OrtDevice GetOrtDevice(const DLDevice& device) {
|
|||
case DLDeviceType::kDLCUDA:
|
||||
case DLDeviceType::kDLROCM:
|
||||
return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(device.device_id));
|
||||
case DLDeviceType::kDLExtDev:
|
||||
case DLDeviceType::kDLMAIA:
|
||||
return OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(device.device_id));
|
||||
default:
|
||||
ORT_THROW("Unsupported device type");
|
||||
|
@ -202,7 +202,7 @@ DLDevice GetDlpackDevice(const OrtValue& ort_value, const int64_t& device_id) {
|
|||
break;
|
||||
case OrtDevice::FPGA:
|
||||
case OrtDevice::NPU:
|
||||
device.device_type = DLDeviceType::kDLExtDev;
|
||||
device.device_type = DLDeviceType::kDLMAIA;
|
||||
default:
|
||||
ORT_THROW("Cannot pack tensors on this device.");
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче