[RUST][FRONTEND] Add rust frontend v0.1 (#2292)
This commit is contained in:
Родитель
18b2ebac9b
Коммит
e2970b226e
|
@ -1,6 +1,6 @@
|
|||
max_width = 100
|
||||
hard_tabs = false
|
||||
tab_spaces = 2
|
||||
tab_spaces = 4
|
||||
newline_style = "Auto"
|
||||
use_small_heuristics = "Default"
|
||||
indent_style = "Block"
|
||||
|
@ -38,7 +38,7 @@ trailing_comma = "Vertical"
|
|||
match_block_trailing_comma = false
|
||||
blank_lines_upper_bound = 1
|
||||
blank_lines_lower_bound = 0
|
||||
edition = "2015"
|
||||
edition = "2018"
|
||||
merge_derives = true
|
||||
use_try_shorthand = true
|
||||
use_field_init_shorthand = false
|
||||
|
@ -50,8 +50,8 @@ unstable_features = false
|
|||
disable_all_formatting = false
|
||||
skip_children = false
|
||||
hide_parse_errors = false
|
||||
error_on_line_overflow = false
|
||||
error_on_unformatted = false
|
||||
error_on_line_overflow = true
|
||||
error_on_unformatted = true
|
||||
report_todo = "Never"
|
||||
report_fixme = "Never"
|
||||
ignore = []
|
||||
|
|
|
@ -1,28 +1,11 @@
|
|||
[package]
|
||||
name = "tvm"
|
||||
version = "0.1.0"
|
||||
license = "Apache-2.0"
|
||||
description = "TVM Rust runtime"
|
||||
repository = "https://github.com/dmlc/tvm"
|
||||
readme = "README.md"
|
||||
keywords = ["tvm", "nnvm"]
|
||||
categories = ["api-bindings", "science"]
|
||||
authors = ["TVM Contributors"]
|
||||
|
||||
[features]
|
||||
default = ["nom/std"]
|
||||
sgx = ["nom/alloc"]
|
||||
|
||||
[dependencies]
|
||||
bounded-spsc-queue = "0.4.0"
|
||||
error-chain = { version = "0.12.0", default-features = false }
|
||||
itertools = "0.7.8"
|
||||
lazy_static = "1.1.0"
|
||||
ndarray = "0.11.2"
|
||||
nom = {version = "4.0.0", default-features = false }
|
||||
serde = "1.0.59"
|
||||
serde_derive = "1.0.79"
|
||||
serde_json = "1.0.17"
|
||||
|
||||
[target.'cfg(not(target_env = "sgx"))'.dependencies]
|
||||
num_cpus = "1.8.0"
|
||||
[workspace]
|
||||
members = [
|
||||
"common",
|
||||
"runtime",
|
||||
"runtime/tests/test_tvm_basic",
|
||||
"runtime/tests/test_nnvm",
|
||||
"frontend",
|
||||
"frontend/tests/basics",
|
||||
"frontend/tests/callback",
|
||||
"frontend/examples/resnet"
|
||||
]
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
target
|
||||
**/*.rs.bk
|
||||
Cargo.lock
|
||||
/tvm-sys/src/bindgen.rs
|
|
@ -0,0 +1,13 @@
|
|||
[package]
|
||||
name = "tvm-common"
|
||||
version = "0.1.0"
|
||||
authors = ["TVM Contributors"]
|
||||
license = "Apache-2.0"
|
||||
|
||||
[features]
|
||||
runtime = []
|
||||
frontend = ["tvm-sys"]
|
||||
|
||||
[dependencies]
|
||||
error-chain = { version = "0.12.0", default-features = false }
|
||||
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
|
|
@ -0,0 +1,15 @@
|
|||
//! Error types for `TVMArgValue` and `TVMRetValue` conversions.
|
||||
|
||||
error_chain! {
|
||||
errors {
|
||||
TryFromTVMArgValueError(expected: String, actual: String) {
|
||||
description("mismatched types while converting from TVMArgValue")
|
||||
display("expected `{}` but given `{}`", expected, actual)
|
||||
}
|
||||
|
||||
TryFromTVMRetValueError(expected: String, actual: String) {
|
||||
description("mismatched types while downcasting TVMRetValue")
|
||||
display("invalid downcast: expected `{}` but given `{}`", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
//! This crate contains the refactored basic components required
|
||||
//! for `runtime` and `frontend` TVM crates.
|
||||
|
||||
#![crate_name = "tvm_common"]
|
||||
#![recursion_limit = "1024"]
|
||||
#![allow(non_camel_case_types, unused_imports)]
|
||||
#![feature(box_syntax, try_from)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate error_chain;
|
||||
|
||||
/// Unified ffi module for both runtime and frontend crates.
|
||||
pub mod ffi {
|
||||
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
|
||||
|
||||
#[cfg(feature = "frontend")]
|
||||
pub extern crate tvm_sys as ts;
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub mod runtime {
|
||||
use std::os::raw::{c_char, c_int, c_void};
|
||||
|
||||
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
|
||||
|
||||
pub type BackendPackedCFunc = extern "C" fn(
|
||||
args: *const TVMValue,
|
||||
type_codes: *const c_int,
|
||||
num_args: c_int,
|
||||
) -> c_int;
|
||||
}
|
||||
}
|
||||
|
||||
pub mod errors;
|
||||
pub mod ty;
|
||||
pub mod value;
|
||||
|
||||
pub use errors::*;
|
||||
pub use ty::TVMTypeCode;
|
||||
pub use value::{TVMArgValue, TVMRetValue, TVMValue};
|
|
@ -0,0 +1,144 @@
|
|||
//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```
|
||||
//! let dtype = TVMType::from("float");
|
||||
//! println!("dtype is: {}", dtype);
|
||||
//! ```
|
||||
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
fmt::{self, Display, Formatter},
|
||||
};
|
||||
|
||||
/// TVM type codes.
|
||||
#[repr(u32)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum TVMTypeCode {
|
||||
kDLInt = 0,
|
||||
kDLUInt = 1,
|
||||
kDLFloat = 2,
|
||||
kHandle = 3,
|
||||
kNull = 4,
|
||||
kTVMType = 5,
|
||||
kTVMContext = 6,
|
||||
kArrayHandle = 7,
|
||||
kNodeHandle = 8,
|
||||
kModuleHandle = 9,
|
||||
kFuncHandle = 10,
|
||||
kStr = 11,
|
||||
kBytes = 12,
|
||||
kNDArrayContainer = 13,
|
||||
}
|
||||
|
||||
impl Default for TVMTypeCode {
|
||||
fn default() -> Self {
|
||||
TVMTypeCode::kDLInt
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TVMTypeCode> for i64 {
|
||||
fn from(arg: TVMTypeCode) -> i64 {
|
||||
match arg {
|
||||
TVMTypeCode::kDLInt => 0,
|
||||
TVMTypeCode::kDLUInt => 1,
|
||||
TVMTypeCode::kDLFloat => 2,
|
||||
TVMTypeCode::kHandle => 3,
|
||||
TVMTypeCode::kNull => 4,
|
||||
TVMTypeCode::kTVMType => 5,
|
||||
TVMTypeCode::kTVMContext => 6,
|
||||
TVMTypeCode::kArrayHandle => 7,
|
||||
TVMTypeCode::kNodeHandle => 8,
|
||||
TVMTypeCode::kModuleHandle => 9,
|
||||
TVMTypeCode::kFuncHandle => 10,
|
||||
TVMTypeCode::kStr => 11,
|
||||
TVMTypeCode::kBytes => 12,
|
||||
TVMTypeCode::kNDArrayContainer => 13,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<TVMTypeCode> for i64 {
|
||||
fn into(self) -> TVMTypeCode {
|
||||
match self {
|
||||
0 => TVMTypeCode::kDLInt,
|
||||
1 => TVMTypeCode::kDLUInt,
|
||||
2 => TVMTypeCode::kDLFloat,
|
||||
3 => TVMTypeCode::kHandle,
|
||||
4 => TVMTypeCode::kNull,
|
||||
5 => TVMTypeCode::kTVMType,
|
||||
6 => TVMTypeCode::kTVMContext,
|
||||
7 => TVMTypeCode::kArrayHandle,
|
||||
8 => TVMTypeCode::kNodeHandle,
|
||||
9 => TVMTypeCode::kModuleHandle,
|
||||
10 => TVMTypeCode::kFuncHandle,
|
||||
11 => TVMTypeCode::kStr,
|
||||
12 => TVMTypeCode::kBytes,
|
||||
13 => TVMTypeCode::kNDArrayContainer,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for TVMTypeCode {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
TVMTypeCode::kDLInt => "int",
|
||||
TVMTypeCode::kDLUInt => "uint",
|
||||
TVMTypeCode::kDLFloat => "float",
|
||||
TVMTypeCode::kHandle => "handle",
|
||||
TVMTypeCode::kNull => "null",
|
||||
TVMTypeCode::kTVMType => "TVM type",
|
||||
TVMTypeCode::kTVMContext => "TVM context",
|
||||
TVMTypeCode::kArrayHandle => "Array handle",
|
||||
TVMTypeCode::kNodeHandle => "Node handle",
|
||||
TVMTypeCode::kModuleHandle => "Module handle",
|
||||
TVMTypeCode::kFuncHandle => "Function handle",
|
||||
TVMTypeCode::kStr => "string",
|
||||
TVMTypeCode::kBytes => "bytes",
|
||||
TVMTypeCode::kNDArrayContainer => "ndarray container",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_prim_type {
|
||||
($type:ty, $variant:ident) => {
|
||||
impl<'a> From<&'a $type> for TVMTypeCode {
|
||||
fn from(_arg: &$type) -> Self {
|
||||
TVMTypeCode::$variant
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a mut $type> for TVMTypeCode {
|
||||
fn from(_arg: &mut $type) -> Self {
|
||||
TVMTypeCode::$variant
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_prim_type!(usize, kDLInt);
|
||||
impl_prim_type!(i64, kDLInt);
|
||||
impl_prim_type!(i32, kDLInt);
|
||||
impl_prim_type!(i16, kDLInt);
|
||||
impl_prim_type!(i8, kDLInt);
|
||||
|
||||
impl_prim_type!(u64, kDLUInt);
|
||||
impl_prim_type!(u32, kDLUInt);
|
||||
impl_prim_type!(u16, kDLUInt);
|
||||
impl_prim_type!(u8, kDLUInt);
|
||||
|
||||
impl_prim_type!(f64, kDLFloat);
|
||||
impl_prim_type!(f32, kDLFloat);
|
||||
|
||||
impl_prim_type!(str, kStr);
|
||||
impl_prim_type!(CStr, kStr);
|
||||
impl_prim_type!(String, kStr);
|
||||
impl_prim_type!(CString, kStr);
|
||||
|
||||
impl_prim_type!([u8], kBytes);
|
|
@ -0,0 +1,559 @@
|
|||
//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue`
|
||||
//! required for using TVM functions.
|
||||
|
||||
use std::{
|
||||
any::Any,
|
||||
convert::TryFrom,
|
||||
ffi::{CStr, CString},
|
||||
fmt::{self, Debug, Formatter},
|
||||
marker::PhantomData,
|
||||
mem,
|
||||
ops::Deref,
|
||||
os::raw::{c_char, c_void},
|
||||
};
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
use ffi::runtime::TVMValue as _TVMValue;
|
||||
|
||||
#[cfg(feature = "frontend")]
|
||||
use ffi::ts::TVMValue as _TVMValue;
|
||||
|
||||
use errors::*;
|
||||
|
||||
use ty::TVMTypeCode;
|
||||
|
||||
/// Wrapped TVMValue type.
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct TVMValue {
|
||||
pub inner: _TVMValue,
|
||||
}
|
||||
|
||||
impl TVMValue {
|
||||
/// Creates TVMValue from the raw part.
|
||||
pub fn new(inner: _TVMValue) -> Self {
|
||||
TVMValue { inner }
|
||||
}
|
||||
|
||||
pub(crate) fn into_raw(self) -> _TVMValue {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for TVMValue {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
unsafe {
|
||||
write!(
|
||||
f,
|
||||
"TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\
|
||||
[v_str: {:?}]",
|
||||
self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for TVMValue {
|
||||
type Target = _TVMValue;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_prim_val {
|
||||
($type:ty, $field:ident, $cast:ty) => {
|
||||
impl From<$type> for TVMValue {
|
||||
fn from(arg: $type) -> Self {
|
||||
let inner = _TVMValue {
|
||||
$field: arg as $cast,
|
||||
};
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a $type> for TVMValue {
|
||||
fn from(arg: &$type) -> Self {
|
||||
let inner = _TVMValue {
|
||||
$field: *arg as $cast,
|
||||
};
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a mut $type> for TVMValue {
|
||||
fn from(arg: &mut $type) -> Self {
|
||||
let inner = _TVMValue {
|
||||
$field: *arg as $cast,
|
||||
};
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<TVMValue> for $type {
|
||||
type Error = Error;
|
||||
fn try_from(val: TVMValue) -> Result<Self> {
|
||||
Ok(unsafe { val.inner.$field as $type })
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a TVMValue> for $type {
|
||||
type Error = Error;
|
||||
fn try_from(val: &TVMValue) -> Result<Self> {
|
||||
Ok(unsafe { val.into_raw().$field as $type })
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a mut TVMValue> for $type {
|
||||
type Error = Error;
|
||||
fn try_from(val: &mut TVMValue) -> Result<Self> {
|
||||
Ok(unsafe { val.into_raw().$field as $type })
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_prim_val!(isize, v_int64, i64);
|
||||
impl_prim_val!(i64, v_int64, i64);
|
||||
impl_prim_val!(i32, v_int64, i64);
|
||||
impl_prim_val!(i16, v_int64, i64);
|
||||
impl_prim_val!(i8, v_int64, i64);
|
||||
impl_prim_val!(usize, v_int64, i64);
|
||||
impl_prim_val!(u64, v_int64, i64);
|
||||
impl_prim_val!(u32, v_int64, i64);
|
||||
impl_prim_val!(u16, v_int64, i64);
|
||||
impl_prim_val!(u8, v_int64, i64);
|
||||
|
||||
impl_prim_val!(f64, v_float64, f64);
|
||||
impl_prim_val!(f32, v_float64, f64);
|
||||
|
||||
impl<'a> From<&'a str> for TVMValue {
|
||||
fn from(arg: &str) -> TVMValue {
|
||||
let arg = CString::new(arg).unwrap();
|
||||
let inner = _TVMValue {
|
||||
v_str: arg.as_ptr() as *const c_char,
|
||||
};
|
||||
mem::forget(arg);
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a String> for TVMValue {
|
||||
fn from(arg: &String) -> TVMValue {
|
||||
let arg = CString::new(arg.as_bytes()).unwrap();
|
||||
let inner = _TVMValue {
|
||||
v_str: arg.as_ptr() as *const c_char,
|
||||
};
|
||||
mem::forget(arg);
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a CString> for TVMValue {
|
||||
fn from(arg: &CString) -> TVMValue {
|
||||
let arg = arg.to_owned();
|
||||
let inner = _TVMValue {
|
||||
v_str: arg.as_ptr() as *const c_char,
|
||||
};
|
||||
mem::forget(arg);
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a [u8]> for TVMValue {
|
||||
fn from(arg: &[u8]) -> TVMValue {
|
||||
let arg = arg.to_owned();
|
||||
let inner = _TVMValue {
|
||||
v_handle: &arg as *const _ as *mut c_void,
|
||||
};
|
||||
mem::forget(arg);
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function.
|
||||
/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`.
|
||||
/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```
|
||||
/// let s = "hello".to_string();
|
||||
/// let arg = TVMArgValue::from(&s);
|
||||
/// let tvm: String = arg.try_into().unwrap();
|
||||
/// assert_eq!(arg, s);
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct TVMArgValue<'a> {
|
||||
/// The wrapped TVMValue
|
||||
pub value: TVMValue,
|
||||
/// The matching type code.
|
||||
pub type_code: TVMTypeCode,
|
||||
/// This is only exposed to runtime and frontend crates and is not meant to be used directly.
|
||||
pub lifetime: PhantomData<&'a ()>,
|
||||
}
|
||||
|
||||
impl<'a> TVMArgValue<'a> {
|
||||
pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self {
|
||||
TVMArgValue {
|
||||
value: value,
|
||||
type_code: type_code,
|
||||
lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 {
|
||||
type Error = Error;
|
||||
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
|
||||
if (arg.type_code == TVMTypeCode::kDLInt)
|
||||
| (arg.type_code == TVMTypeCode::kDLUInt)
|
||||
| (arg.type_code == TVMTypeCode::kNull)
|
||||
{
|
||||
Ok(unsafe { arg.value.inner.v_int64 })
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMArgValueError(
|
||||
stringify!(i64).to_string(),
|
||||
arg.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 {
|
||||
type Error = Error;
|
||||
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
|
||||
if arg.type_code == TVMTypeCode::kDLFloat {
|
||||
Ok(unsafe { arg.value.inner.v_float64 })
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMArgValueError(
|
||||
stringify!(f64).to_string(),
|
||||
arg.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String {
|
||||
type Error = Error;
|
||||
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
|
||||
if arg.type_code == TVMTypeCode::kStr {
|
||||
let ret_str = unsafe {
|
||||
match CStr::from_ptr(arg.value.inner.v_str).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => "Invalid UTF-8 message",
|
||||
}
|
||||
};
|
||||
Ok(ret_str.to_string())
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMArgValueError(
|
||||
stringify!(String).to_string(),
|
||||
arg.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Main way to create a TVMArgValue from suported Rust values.
|
||||
impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a>
|
||||
where
|
||||
TVMValue: From<&'b T>,
|
||||
TVMTypeCode: From<&'b T>,
|
||||
{
|
||||
fn from(arg: &'b T) -> Self {
|
||||
TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg))
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a conversion to a `TVMArgValue` for an object handle.
|
||||
impl<'a, T> From<*const T> for TVMArgValue<'a> {
|
||||
fn from(ptr: *const T) -> Self {
|
||||
let value = TVMValue::new(_TVMValue {
|
||||
v_handle: ptr as *mut T as *mut c_void,
|
||||
});
|
||||
|
||||
TVMArgValue::new(value, TVMTypeCode::kArrayHandle)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
|
||||
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
|
||||
fn from(ptr: *mut T) -> Self {
|
||||
let value = TVMValue::new(_TVMValue {
|
||||
v_handle: ptr as *mut c_void,
|
||||
});
|
||||
|
||||
TVMArgValue::new(value, TVMTypeCode::kHandle)
|
||||
}
|
||||
}
|
||||
|
||||
/// An owned version of TVMPODValue. It can be converted from varieties of
|
||||
/// primitive and object types.
|
||||
/// It can be downcasted using `try_from` if it contains the desired type.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// let a = 42u32;
|
||||
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
|
||||
///
|
||||
/// let s = "hello, world!";
|
||||
/// let t: TVMRetValue = s.into();
|
||||
/// assert_eq!(String::try_from(t).unwrap(), s);
|
||||
/// ```
|
||||
pub struct TVMRetValue {
|
||||
/// A primitive return value, if any.
|
||||
pub prim_value: usize,
|
||||
/// An object return value, if any.
|
||||
pub box_value: Box<Any>,
|
||||
pub type_code: TVMTypeCode,
|
||||
}
|
||||
|
||||
impl TVMRetValue {
|
||||
fn new(prim_value: usize, box_value: Box<Any>, type_code: TVMTypeCode) -> Self {
|
||||
Self {
|
||||
prim_value,
|
||||
box_value,
|
||||
type_code,
|
||||
}
|
||||
}
|
||||
|
||||
/// unsafe function to create `TVMRetValue` from `TVMValue` and
|
||||
/// its matching `TVMTypeCode`.
|
||||
pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self {
|
||||
let value = value.into_raw();
|
||||
match type_code {
|
||||
TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => {
|
||||
Self::new(value.v_int64 as usize, box (), type_code)
|
||||
}
|
||||
TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code),
|
||||
TVMTypeCode::kHandle
|
||||
| TVMTypeCode::kArrayHandle
|
||||
| TVMTypeCode::kNodeHandle
|
||||
| TVMTypeCode::kModuleHandle
|
||||
| TVMTypeCode::kFuncHandle => {
|
||||
Self::new(value.v_handle as usize, box value.v_handle, type_code)
|
||||
}
|
||||
TVMTypeCode::kStr | TVMTypeCode::kBytes => {
|
||||
Self::new(value.v_str as usize, box (value.v_str), type_code)
|
||||
}
|
||||
_ => Self::new(0usize, box (), type_code),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the underlying `TVMValue` and `TVMTypeCode`.
|
||||
pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
|
||||
let val = match self.type_code {
|
||||
TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue {
|
||||
v_int64: self.prim_value as i64,
|
||||
}),
|
||||
TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue {
|
||||
v_float64: self.prim_value as f64,
|
||||
}),
|
||||
TVMTypeCode::kHandle
|
||||
| TVMTypeCode::kArrayHandle
|
||||
| TVMTypeCode::kNodeHandle
|
||||
| TVMTypeCode::kModuleHandle
|
||||
| TVMTypeCode::kFuncHandle
|
||||
| TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue {
|
||||
v_handle: self.prim_value as *const c_void as *mut c_void,
|
||||
}),
|
||||
TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue {
|
||||
v_str: self.prim_value as *const c_char,
|
||||
}),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
(val, self.type_code)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TVMRetValue {
|
||||
fn default() -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: 0usize,
|
||||
box_value: box (),
|
||||
type_code: TVMTypeCode::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TVMRetValue {
|
||||
fn clone(&self) -> Self {
|
||||
match self.type_code {
|
||||
TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => {
|
||||
Self::new(self.prim_value.clone(), box (), self.type_code.clone())
|
||||
}
|
||||
TVMTypeCode::kHandle
|
||||
| TVMTypeCode::kArrayHandle
|
||||
| TVMTypeCode::kNodeHandle
|
||||
| TVMTypeCode::kModuleHandle
|
||||
| TVMTypeCode::kFuncHandle
|
||||
| TVMTypeCode::kNDArrayContainer => Self::new(
|
||||
self.prim_value.clone(),
|
||||
box (self.prim_value.clone() as *const c_void as *mut c_void),
|
||||
self.type_code.clone(),
|
||||
),
|
||||
TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new(
|
||||
self.prim_value.clone(),
|
||||
box (self.prim_value.clone() as *const c_char),
|
||||
self.type_code.clone(),
|
||||
),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for TVMRetValue {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"prim_value: {:?}, box_value: {:?}, type_code: {:?}",
|
||||
self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_prim_ret_value {
|
||||
($type:ty, $code:expr) => {
|
||||
impl From<$type> for TVMRetValue {
|
||||
fn from(val: $type) -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: val as usize,
|
||||
box_value: box (),
|
||||
type_code: $code,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a $type> for TVMRetValue {
|
||||
fn from(val: &$type) -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: *val as usize,
|
||||
box_value: box (),
|
||||
type_code: $code,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a mut $type> for TVMRetValue {
|
||||
fn from(val: &mut $type) -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: *val as usize,
|
||||
box_value: box (),
|
||||
type_code: $code,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<TVMRetValue> for $type {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<$type> {
|
||||
if ret.type_code == $code {
|
||||
Ok(ret.prim_value as $type)
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMRetValueError(
|
||||
stringify!($type).to_string(),
|
||||
ret.type_code.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_prim_ret_value!(i8, TVMTypeCode::kDLInt);
|
||||
impl_prim_ret_value!(i16, TVMTypeCode::kDLInt);
|
||||
impl_prim_ret_value!(i32, TVMTypeCode::kDLInt);
|
||||
impl_prim_ret_value!(i64, TVMTypeCode::kDLInt);
|
||||
impl_prim_ret_value!(isize, TVMTypeCode::kDLInt);
|
||||
|
||||
impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt);
|
||||
impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt);
|
||||
impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt);
|
||||
impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt);
|
||||
impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt);
|
||||
|
||||
impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat);
|
||||
impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat);
|
||||
|
||||
macro_rules! impl_ptr_ret_value {
|
||||
($type:ty) => {
|
||||
impl From<$type> for TVMRetValue {
|
||||
fn from(ptr: $type) -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: ptr as usize,
|
||||
box_value: box (),
|
||||
type_code: TVMTypeCode::kHandle,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<TVMRetValue> for $type {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<$type> {
|
||||
if ret.type_code == TVMTypeCode::kHandle {
|
||||
Ok(ret.prim_value as $type)
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMRetValueError(
|
||||
stringify!($type).to_string(),
|
||||
ret.type_code.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_ptr_ret_value!(*const c_void);
|
||||
impl_ptr_ret_value!(*mut c_void);
|
||||
|
||||
impl From<String> for TVMRetValue {
|
||||
fn from(val: String) -> Self {
|
||||
let pval = val.as_ptr() as *const c_char as usize;
|
||||
let bval = box (val.as_ptr() as *const c_char);
|
||||
mem::forget(val);
|
||||
TVMRetValue::new(pval, bval, TVMTypeCode::kStr)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<TVMRetValue> for String {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<String> {
|
||||
// Note: simple downcast doesn't work for function call return values
|
||||
let ret_str = unsafe {
|
||||
match CStr::from_ptr(ret.prim_value as *const c_char).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => "Invalid UTF-8 message",
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ret_str.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::convert::TryInto;
|
||||
|
||||
#[test]
|
||||
fn numeric() {
|
||||
macro_rules! arg_ret_tests {
|
||||
($v:expr; $($ty:ty),+) => {{
|
||||
$(
|
||||
let v = $v as $ty;
|
||||
let b = TVMRetValue::from(&v);
|
||||
let b: $ty = b.try_into().unwrap();
|
||||
assert_eq!(b, v);
|
||||
)+
|
||||
}};
|
||||
}
|
||||
|
||||
arg_ret_tests!(42; i8, i16, i32, i64, f32, f64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn string() {
|
||||
let s = "hello".to_string();
|
||||
let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap();
|
||||
assert_eq!(tvm_arg, s);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
[package]
|
||||
name = "tvm-sys"
|
||||
version = "0.1.0"
|
||||
authors = ["TVM Contributors"]
|
||||
license = "Apache-2.0"
|
||||
description = "Raw C API"
|
||||
|
||||
[build-dependencies]
|
||||
bindgen = "0.37.4"
|
|
@ -0,0 +1,25 @@
|
|||
extern crate bindgen;
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-env-changed=TVM_HOME");
|
||||
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
|
||||
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
|
||||
let bindings = bindgen::Builder::default()
|
||||
.header(format!(
|
||||
"{}/include/tvm/runtime/c_runtime_api.h",
|
||||
env!("TVM_HOME")
|
||||
))
|
||||
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
|
||||
.blacklist_type("max_align_t") // @see rust-bindgen#550
|
||||
.layout_tests(false)
|
||||
.derive_partialeq(true)
|
||||
.derive_eq(true)
|
||||
.generate()
|
||||
.expect("unable to generate bindings");
|
||||
|
||||
bindings
|
||||
.write_to_file(PathBuf::from("src/bindgen.rs"))
|
||||
.expect("can not write the bindings!");
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
#![allow(
|
||||
non_camel_case_types,
|
||||
non_snake_case,
|
||||
non_upper_case_globals,
|
||||
dead_code,
|
||||
improper_ctypes
|
||||
)]
|
||||
|
||||
include!("bindgen.rs");
|
|
@ -0,0 +1,7 @@
|
|||
target
|
||||
**/*.rs.bk
|
||||
Cargo.lock
|
||||
/tests/basics/add_*
|
||||
/examples/resnet/deploy_*
|
||||
/examples/resnet/*.png
|
||||
/examples/resnet/synset.*
|
|
@ -0,0 +1,25 @@
|
|||
[package]
|
||||
name = "tvm-frontend"
|
||||
version = "0.1.0"
|
||||
license = "Apache-2.0"
|
||||
description = "Rust frontend support for TVM"
|
||||
repository = "https://github.com/dmlc/tvm"
|
||||
homepage = "https://github.com/dmlc/tvm"
|
||||
readme = "README.md"
|
||||
keywords = ["rust", "tvm", "nnvm"]
|
||||
categories = ["api-bindings", "science"]
|
||||
authors = ["TVM Contributors"]
|
||||
|
||||
[lib]
|
||||
name = "tvm_frontend"
|
||||
crate-type = ["dylib"]
|
||||
|
||||
[dependencies]
|
||||
error-chain = "0.12.0"
|
||||
lazy_static = "1.1.0"
|
||||
ndarray = "0.12.1"
|
||||
num-traits = "0.2"
|
||||
tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] }
|
||||
|
||||
[features]
|
||||
blas = ["ndarray/blas"]
|
|
@ -0,0 +1,219 @@
|
|||
# TVM Runtime Frontend Support
|
||||
|
||||
This crate provides an idiomatic Rust API for [TVM](https://github.com/dmlc/tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly`
|
||||
|
||||
## What Does This Crate Offer?
|
||||
|
||||
Here is a major workflow
|
||||
|
||||
1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/)
|
||||
2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators.
|
||||
3. Deploy your models using **Rust** :heart:
|
||||
|
||||
### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k
|
||||
|
||||
Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example.
|
||||
|
||||
Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM
|
||||
|
||||
```python
|
||||
block = get_model('resnet18_v1', pretrained=True)
|
||||
|
||||
sym, params = nnvm.frontend.from_mxnet(block)
|
||||
# add the softmax layer for prediction
|
||||
net = nnvm.sym.softmax(sym)
|
||||
# compile the model
|
||||
with nnvm.compiler.build_config(opt_level=opt_level):
|
||||
graph, lib, params = nnvm.compiler.build(
|
||||
net, target, shape={"data": data_shape}, params=params)
|
||||
# same the model artifacts
|
||||
lib.save(os.path.join(target_dir, "deploy_lib.o"))
|
||||
cc.create_shared(os.path.join(target_dir, "deploy_lib.so"),
|
||||
[os.path.join(target_dir, "deploy_lib.o")])
|
||||
|
||||
with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo:
|
||||
fo.write(graph.json())
|
||||
with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo:
|
||||
fo.write(nnvm.compiler.save_param_dict(params))
|
||||
```
|
||||
|
||||
Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image
|
||||
|
||||
![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true)
|
||||
|
||||
as demostrated in the following Rust snippet
|
||||
|
||||
```rust
|
||||
let graph = fs::read_to_string("deploy_graph.json")?;
|
||||
// load the built module
|
||||
let lib = Module::load(&Path::new("deploy_lib.so"))?;
|
||||
// get the global TVM graph runtime function
|
||||
let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
|
||||
let runtime_create_fn_ret = call_packed!(
|
||||
runtime_create_fn,
|
||||
&graph,
|
||||
&lib,
|
||||
&ctx.device_type,
|
||||
&ctx.device_id
|
||||
)?;
|
||||
// get graph runtime module
|
||||
let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?;
|
||||
// get the registered `load_params` from runtime module
|
||||
let ref load_param_fn = graph_runtime_module
|
||||
.get_function("load_params", false)
|
||||
.unwrap();
|
||||
// parse parameters and convert to TVMByteArray
|
||||
let params: Vec<u8> = fs::read("deploy_param.params")?;
|
||||
let barr = TVMByteArray::from(¶ms);
|
||||
// load the parameters
|
||||
call_packed!(load_param_fn, &barr)?;
|
||||
// get the set_input function
|
||||
let ref set_input_fn = graph_runtime_module
|
||||
.get_function("set_input", false)
|
||||
.unwrap();
|
||||
|
||||
call_packed!(set_input_fn, "data", &input)?;
|
||||
// get `run` function from runtime module
|
||||
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
|
||||
// execute the run function. Note that it has no argument
|
||||
call_packed!(run_fn,)?;
|
||||
// prepare to get the output
|
||||
let output_shape = &mut [1, 1000];
|
||||
let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
|
||||
// get the `get_output` function from runtime module
|
||||
let ref get_output_fn = graph_runtime_module
|
||||
.get_function("get_output", false)
|
||||
.unwrap();
|
||||
// execute the get output function
|
||||
call_packed!(get_output_fn, &0, &output)?;
|
||||
// flatten the output as Vec<f32>
|
||||
let output = output.to_vec::<f32>()?;
|
||||
```
|
||||
|
||||
and the model correctly predicts the input image as **tiger cat**.
|
||||
|
||||
## Installations
|
||||
|
||||
Please follow TVM [installations](https://docs.tvm.ai/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`.
|
||||
|
||||
*Note:* To run the end-to-end examples and tests, `tvm`, `nnvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually.
|
||||
|
||||
## Supported TVM Functionalities
|
||||
|
||||
### Use TVM to Generate Shared Library
|
||||
|
||||
One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU.
|
||||
|
||||
```python
|
||||
import os
|
||||
import tvm
|
||||
from tvm.contrib import cc
|
||||
|
||||
def test_add(target_dir):
|
||||
if not tvm.module.enabled("cuda"):
|
||||
print(f"skip {__file__} because cuda is not enabled...")
|
||||
return
|
||||
n = tvm.var("n")
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
B = tvm.placeholder((n,), name='B')
|
||||
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
|
||||
s = tvm.create_schedule(C.op)
|
||||
bx, tx = s[C].split(C.op.axis[0], factor=64)
|
||||
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
|
||||
|
||||
fadd_cuda.save(os.path.join(target_dir, "add_gpu.o"))
|
||||
fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx"))
|
||||
cc.create_shared(os.path.join(target_dir, "add_gpu.so"),
|
||||
[os.path.join(target_dir, "add_gpu.o")])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
if len(sys.argv) != 2:
|
||||
sys.exit(-1)
|
||||
test_add(sys.argv[1])
|
||||
```
|
||||
|
||||
### Run the Generated Shared Library
|
||||
|
||||
The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust.
|
||||
|
||||
```rust
|
||||
extern crate tvm_frontend as tvm;
|
||||
|
||||
use tvm::*;
|
||||
|
||||
fn main() {
|
||||
let shape = &mut [2];
|
||||
let mut data = vec![3f32, 4.0];
|
||||
let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
|
||||
arr.copy_from_buffer(data.as_mut_slice());
|
||||
let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
|
||||
let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap();
|
||||
let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap();
|
||||
assert!(fadd.enabled("gpu"));
|
||||
fadd.import_module(fadd_dep);
|
||||
fadd.entry();
|
||||
function::Builder::from(&mut fadd)
|
||||
.arg(&arr)
|
||||
.arg(&arr)
|
||||
.set_output(&mut ret)?
|
||||
.invoke()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by
|
||||
`cargo:rustc-link-search=native=add_gpu`.
|
||||
|
||||
See the tests and examples custom `build.rs` for more details.
|
||||
|
||||
### Convert and Register a Rust Function as a TVM Packed Function
|
||||
|
||||
One can use `register_global_func!` macro to convert and register a Rust
|
||||
function of type `fn(&[TVMArgValue]) -> Result<TVMRetValue>` to a global TVM **packed function** as follows
|
||||
|
||||
```rust
|
||||
#[macro_use]
|
||||
extern crate tvm_frontend as tvm;
|
||||
use std::convert::TryInto;
|
||||
use tvm::*;
|
||||
|
||||
fn main() {
|
||||
register_global_func! {
|
||||
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
let mut ret = 0f32;
|
||||
let shape = &mut [2];
|
||||
for arg in args.iter() {
|
||||
let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
|
||||
let arg: NDArray = arg.try_into()?;
|
||||
let arr = arg.copy_to_ndarray(e).unwrap();
|
||||
let rnd: ArrayD<f32> = ArrayD::try_from(&arr).unwrap();
|
||||
ret += rnd.scalar_sum();
|
||||
}
|
||||
let ret_val = TVMRetValue::from(&ret);
|
||||
Ok(ret_val)
|
||||
}
|
||||
}
|
||||
|
||||
let shape = &mut [2];
|
||||
let mut data = vec![3f32, 4.0];
|
||||
let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
|
||||
arr.copy_from_buffer(data.as_mut_slice());
|
||||
let mut registered = function::Builder::default();
|
||||
let ret: f64 = registered
|
||||
.get_function("sum", true)
|
||||
.arg(&arr)
|
||||
.arg(&arr)
|
||||
.invoke()
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(ret, 14f64);
|
||||
}
|
||||
```
|
|
@ -0,0 +1,12 @@
|
|||
[package]
|
||||
name = "resnet"
|
||||
version = "0.0.0"
|
||||
authors = ["TVM Contributors"]
|
||||
license = "Apache-2.0"
|
||||
build = "build.rs"
|
||||
|
||||
[dependencies]
|
||||
ndarray = "0.12.1"
|
||||
tvm-frontend = { path = "../../" }
|
||||
image = "0.20.1"
|
||||
csv = "1"
|
|
@ -0,0 +1,15 @@
|
|||
## Resnet example
|
||||
|
||||
This end-to-end example shows how to:
|
||||
* build `Resnet 18` with `tvm` and `nnvm` from Python
|
||||
* use the provided Rust frontend API to test for an input image
|
||||
|
||||
To run the example, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
|
||||
and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html).
|
||||
|
||||
* **Build the example**: `cargo build`
|
||||
|
||||
To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with
|
||||
`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details.
|
||||
|
||||
* **Run the example**: `cargo run`
|
|
@ -0,0 +1,16 @@
|
|||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
let output = Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
|
||||
.output()
|
||||
.expect("Failed to execute command");
|
||||
assert!(
|
||||
std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_lib.o")).exists(),
|
||||
"Could not prepare demo: {}",
|
||||
String::from_utf8(output.stderr).unwrap().trim()
|
||||
);
|
||||
println!(
|
||||
"cargo:rustc-link-search=native={}",
|
||||
env!("CARGO_MANIFEST_DIR")
|
||||
);
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
from os import path as osp
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon.model_zoo.vision import get_model
|
||||
from mxnet.gluon.utils import download
|
||||
|
||||
import tvm
|
||||
from tvm.contrib import graph_runtime, cc
|
||||
import nnvm
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Resnet build example')
|
||||
aa = parser.add_argument
|
||||
aa('--batch-size', type=int, default=1, help='input image batch size')
|
||||
aa('--opt-level', type=int, default=3,
|
||||
help='level of optimization. 0 is unoptimized and 3 is the highest level')
|
||||
aa('--target', type=str, default='llvm', help='target context for compilation')
|
||||
aa('--image-shape', type=str, default='3,224,224', help='input image dimensions')
|
||||
aa('--image-name', type=str, default='cat.png', help='name of input image to download')
|
||||
args = parser.parse_args()
|
||||
|
||||
target_dir = osp.dirname(osp.dirname(osp.realpath(__file__)))
|
||||
batch_size = args.batch_size
|
||||
opt_level = args.opt_level
|
||||
target = tvm.target.create(args.target)
|
||||
image_shape = tuple(map(int, args.image_shape.split(",")))
|
||||
data_shape = (batch_size,) + image_shape
|
||||
|
||||
def build(target_dir):
|
||||
""" Compiles resnet18 with TVM"""
|
||||
deploy_lib = osp.join(target_dir, 'deploy_lib.o')
|
||||
if osp.exists(deploy_lib):
|
||||
return
|
||||
# download the pretrained resnet18 trained on imagenet1k dataset for
|
||||
# image classification task
|
||||
block = get_model('resnet18_v1', pretrained=True)
|
||||
|
||||
sym, params = nnvm.frontend.from_mxnet(block)
|
||||
# add the softmax layer for prediction
|
||||
net = nnvm.sym.softmax(sym)
|
||||
# compile the model
|
||||
with nnvm.compiler.build_config(opt_level=opt_level):
|
||||
graph, lib, params = nnvm.compiler.build(
|
||||
net, target, shape={"data": data_shape}, params=params)
|
||||
# save the model artifacts
|
||||
lib.save(deploy_lib)
|
||||
cc.create_shared(osp.join(target_dir, "deploy_lib.so"),
|
||||
[osp.join(target_dir, "deploy_lib.o")])
|
||||
|
||||
with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo:
|
||||
fo.write(graph.json())
|
||||
|
||||
with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo:
|
||||
fo.write(nnvm.compiler.save_param_dict(params))
|
||||
|
||||
def download_img_labels():
|
||||
""" Download an image and imagenet1k class labels for test"""
|
||||
img_name = 'cat.png'
|
||||
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
|
||||
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
|
||||
'596b27d23537e5a1b5751d2b0481ef172f58b539/',
|
||||
'imagenet1000_clsid_to_human.txt'])
|
||||
synset_name = 'synset.txt'
|
||||
download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
|
||||
download(synset_url, synset_name)
|
||||
|
||||
with open(synset_name) as fin:
|
||||
synset = eval(fin.read())
|
||||
|
||||
with open("synset.csv", "w") as fout:
|
||||
w = csv.writer(fout)
|
||||
w.writerows(synset.items())
|
||||
|
||||
def test_build(target_dir):
|
||||
""" Sanity check with random input"""
|
||||
graph = open(osp.join(target_dir, "deploy_graph.json")).read()
|
||||
lib = tvm.module.load(osp.join(target_dir, "deploy_lib.so"))
|
||||
params = bytearray(open(osp.join(target_dir,"deploy_param.params"), "rb").read())
|
||||
input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
|
||||
ctx = tvm.cpu()
|
||||
module = graph_runtime.create(graph, lib, ctx)
|
||||
module.load_params(params)
|
||||
module.run(data=input_data)
|
||||
out = module.get_output(0).asnumpy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info("building the model")
|
||||
build(target_dir)
|
||||
logger.info("build was successful")
|
||||
logger.info("test the build artifacts")
|
||||
test_build(target_dir)
|
||||
logger.info("test was successful")
|
||||
download_img_labels()
|
||||
logger.info("image and synset downloads are successful")
|
|
@ -0,0 +1,134 @@
|
|||
#![feature(try_from)]
|
||||
|
||||
extern crate csv;
|
||||
extern crate image;
|
||||
extern crate ndarray;
|
||||
extern crate tvm_frontend as tvm;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
convert::TryInto,
|
||||
fs::{self, File},
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use image::{FilterType, GenericImageView};
|
||||
use ndarray::{Array, ArrayD, Axis};
|
||||
|
||||
use tvm::*;
|
||||
|
||||
fn main() {
|
||||
let ctx = TVMContext::cpu(0);
|
||||
let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap();
|
||||
println!("original image dimensions: {:?}", img.dimensions());
|
||||
// for bigger size images, one needs to first resize to 256x256
|
||||
// with `img.resize_exact` method and then `image.crop` to 224x224
|
||||
let img = img.resize(224, 224, FilterType::Nearest).to_rgb();
|
||||
println!("resized image dimensions: {:?}", img.dimensions());
|
||||
let mut pixels: Vec<f32> = vec![];
|
||||
for pixel in img.pixels() {
|
||||
let tmp = pixel.data;
|
||||
// normalize the RGB channels using mean, std of imagenet1k
|
||||
let tmp = [
|
||||
(tmp[0] as f32 - 123.0) / 58.395, // R
|
||||
(tmp[1] as f32 - 117.0) / 57.12, // G
|
||||
(tmp[2] as f32 - 104.0) / 57.375, // B
|
||||
];
|
||||
for e in &tmp {
|
||||
pixels.push(*e);
|
||||
}
|
||||
}
|
||||
|
||||
let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap();
|
||||
let arr: ArrayD<f32> = arr.permuted_axes([2, 0, 1]).into_dyn();
|
||||
// make arr shape as [1, 3, 224, 224] acceptable to resnet
|
||||
let arr = arr.insert_axis(Axis(0));
|
||||
// create input tensor from rust's ndarray
|
||||
let input =
|
||||
NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
|
||||
println!(
|
||||
"input size is {:?}",
|
||||
input.shape().expect("cannot get the input shape")
|
||||
);
|
||||
let graph =
|
||||
fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap();
|
||||
// load the built module
|
||||
let lib = Module::load(&Path::new(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/deploy_lib.so"
|
||||
)))
|
||||
.unwrap();
|
||||
// get the global TVM graph runtime function
|
||||
let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
|
||||
let runtime_create_fn_ret = call_packed!(
|
||||
runtime_create_fn,
|
||||
&graph,
|
||||
&lib,
|
||||
&ctx.device_type,
|
||||
&ctx.device_id
|
||||
)
|
||||
.unwrap();
|
||||
// get graph runtime module
|
||||
let graph_runtime_module: Module = runtime_create_fn_ret.try_into().unwrap();
|
||||
// get the registered `load_params` from runtime module
|
||||
let ref load_param_fn = graph_runtime_module
|
||||
.get_function("load_params", false)
|
||||
.unwrap();
|
||||
// parse parameters and convert to TVMByteArray
|
||||
let params: Vec<u8> =
|
||||
fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap();
|
||||
let barr = TVMByteArray::from(¶ms);
|
||||
// load the parameters
|
||||
call_packed!(load_param_fn, &barr).unwrap();
|
||||
// get the set_input function
|
||||
let ref set_input_fn = graph_runtime_module
|
||||
.get_function("set_input", false)
|
||||
.unwrap();
|
||||
|
||||
call_packed!(set_input_fn, "data", &input).unwrap();
|
||||
// get `run` function from runtime module
|
||||
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
|
||||
// execute the run function. Note that it has no argument
|
||||
call_packed!(run_fn,).unwrap();
|
||||
// prepare to get the output
|
||||
let output_shape = &mut [1, 1000];
|
||||
let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
|
||||
// get the `get_output` function from runtime module
|
||||
let ref get_output_fn = graph_runtime_module
|
||||
.get_function("get_output", false)
|
||||
.unwrap();
|
||||
// execute the get output function
|
||||
call_packed!(get_output_fn, &0, &output).unwrap();
|
||||
// flatten the output as Vec<f32>
|
||||
let output = output.to_vec::<f32>().unwrap();
|
||||
// find the maximum entry in the output and its index
|
||||
let mut argmax = -1;
|
||||
let mut max_prob = 0.;
|
||||
for i in 0..output.len() {
|
||||
if output[i] > max_prob {
|
||||
max_prob = output[i];
|
||||
argmax = i as i32;
|
||||
}
|
||||
}
|
||||
// create a hash map of (class id, class name)
|
||||
let mut synset: HashMap<i32, String> = HashMap::new();
|
||||
let file = File::open("synset.csv").unwrap();
|
||||
let mut rdr = csv::ReaderBuilder::new()
|
||||
.has_headers(true)
|
||||
.from_reader(file);
|
||||
|
||||
for result in rdr.records() {
|
||||
let record = result.unwrap();
|
||||
let id: i32 = record[0].parse().unwrap();
|
||||
let cls = record[1].to_string();
|
||||
synset.insert(id, cls);
|
||||
}
|
||||
|
||||
println!(
|
||||
"input image belongs to the class `{}` with probability {}",
|
||||
synset
|
||||
.get(&argmax)
|
||||
.expect("cannot find the class id for argmax"),
|
||||
max_prob
|
||||
);
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
//! Provides [`TVMByteArray`] used for passing the model parameters
|
||||
//! (stored as byte-array) to a runtime module.
|
||||
//!
|
||||
//! For more detail, please see the example `resnet` in `examples` repository.
|
||||
|
||||
use std::os::raw::c_char;
|
||||
|
||||
use crate::ts;
|
||||
|
||||
/// A struct holding TVM byte-array.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```
|
||||
/// let v = b"hello".to_vec();
|
||||
/// let barr = TVMByteArray::from(&v);
|
||||
/// assert_eq!(barr.len(), v.len());
|
||||
/// assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TVMByteArray {
|
||||
pub(crate) inner: ts::TVMByteArray,
|
||||
}
|
||||
|
||||
impl TVMByteArray {
|
||||
pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray {
|
||||
TVMByteArray { inner: barr }
|
||||
}
|
||||
|
||||
/// Gets the length of the underlying byte-array
|
||||
pub fn len(&self) -> usize {
|
||||
self.inner.size
|
||||
}
|
||||
|
||||
/// Gets the underlying byte-array as `Vec<i8>`
|
||||
pub fn data(&self) -> Vec<i8> {
|
||||
unsafe {
|
||||
let sz = self.len();
|
||||
let mut ret_buf = Vec::with_capacity(sz);
|
||||
ret_buf.set_len(sz);
|
||||
self.inner.data.copy_to(ret_buf.as_mut_ptr(), sz);
|
||||
ret_buf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a Vec<u8>> for TVMByteArray {
|
||||
fn from(arg: &Vec<u8>) -> Self {
|
||||
let barr = ts::TVMByteArray {
|
||||
data: arg.as_ptr() as *const c_char,
|
||||
size: arg.len(),
|
||||
};
|
||||
TVMByteArray::new(barr)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn convert() {
|
||||
let v = vec![1u8, 2, 3];
|
||||
let barr = TVMByteArray::from(&v);
|
||||
assert_eq!(barr.len(), v.len());
|
||||
assert_eq!(barr.data(), vec![1i8, 2, 3]);
|
||||
let v = b"hello".to_vec();
|
||||
let barr = TVMByteArray::from(&v);
|
||||
assert_eq!(barr.len(), v.len());
|
||||
assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,286 @@
|
|||
//! Provides [`TVMContext`] and related device specific queries.
|
||||
//!
|
||||
//! Create a new context by device type (cpu is 1) and device id.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```
|
||||
//! let ctx = TVMContext::new(1, 0);
|
||||
//! let cpu0 = TVMContext::cpu(0);
|
||||
//! assert_eq!(ctx, cpu0);
|
||||
//! ```
|
||||
//!
|
||||
//! Or from a supported device name.
|
||||
//!
|
||||
//! ```
|
||||
//! let cpu0 = TVMContext::from("cpu");
|
||||
//! println!("{}", cpu0);
|
||||
//! ```
|
||||
|
||||
use std::{
|
||||
fmt::{self, Display, Formatter},
|
||||
os::raw::c_void,
|
||||
ptr,
|
||||
};
|
||||
|
||||
use crate::{function, ts, Result};
|
||||
|
||||
/// Device type can be from a supported device name. See the supported devices
|
||||
/// in [TVM](https://github.com/dmlc/tvm).
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```
|
||||
/// let cpu = TVMDeviceType::from("cpu");
|
||||
/// println!("device is: {}", cpu);
|
||||
///```
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct TVMDeviceType(pub usize);
|
||||
|
||||
impl Default for TVMDeviceType {
|
||||
/// default device is cpu.
|
||||
fn default() -> Self {
|
||||
TVMDeviceType(1)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TVMDeviceType> for ts::DLDeviceType {
|
||||
fn from(device_type: TVMDeviceType) -> Self {
|
||||
match device_type.0 {
|
||||
1 => ts::DLDeviceType_kDLCPU,
|
||||
2 => ts::DLDeviceType_kDLGPU,
|
||||
3 => ts::DLDeviceType_kDLCPUPinned,
|
||||
4 => ts::DLDeviceType_kDLOpenCL,
|
||||
7 => ts::DLDeviceType_kDLVulkan,
|
||||
8 => ts::DLDeviceType_kDLMetal,
|
||||
9 => ts::DLDeviceType_kDLVPI,
|
||||
10 => ts::DLDeviceType_kDLROCM,
|
||||
12 => ts::DLDeviceType_kDLExtDev,
|
||||
_ => panic!("device type not found!"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ts::DLDeviceType> for TVMDeviceType {
|
||||
fn from(device_type: ts::DLDeviceType) -> Self {
|
||||
match device_type {
|
||||
ts::DLDeviceType_kDLCPU => TVMDeviceType(1),
|
||||
ts::DLDeviceType_kDLGPU => TVMDeviceType(2),
|
||||
ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
|
||||
ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
|
||||
ts::DLDeviceType_kDLVulkan => TVMDeviceType(7),
|
||||
ts::DLDeviceType_kDLMetal => TVMDeviceType(8),
|
||||
ts::DLDeviceType_kDLVPI => TVMDeviceType(9),
|
||||
ts::DLDeviceType_kDLROCM => TVMDeviceType(10),
|
||||
ts::DLDeviceType_kDLExtDev => TVMDeviceType(12),
|
||||
_ => panic!("device type not found!"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for TVMDeviceType {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
TVMDeviceType(1) => "cpu",
|
||||
TVMDeviceType(2) => "gpu",
|
||||
TVMDeviceType(3) => "cpu_pinned",
|
||||
TVMDeviceType(4) => "opencl",
|
||||
TVMDeviceType(8) => "meta",
|
||||
TVMDeviceType(9) => "vpi",
|
||||
TVMDeviceType(10) => "rocm",
|
||||
TVMDeviceType(_) => "rpc",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for TVMDeviceType {
|
||||
fn from(type_str: &'a str) -> Self {
|
||||
match type_str {
|
||||
"cpu" => TVMDeviceType(1),
|
||||
"llvm" => TVMDeviceType(1),
|
||||
"stackvm" => TVMDeviceType(1),
|
||||
"gpu" => TVMDeviceType(2),
|
||||
"cuda" => TVMDeviceType(2),
|
||||
"nvptx" => TVMDeviceType(2),
|
||||
"cl" => TVMDeviceType(4),
|
||||
"opencl" => TVMDeviceType(4),
|
||||
"metal" => TVMDeviceType(8),
|
||||
"vpi" => TVMDeviceType(9),
|
||||
"rocm" => TVMDeviceType(10),
|
||||
_ => panic!("{:?} not supported!", type_str),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the underlying device context. Default is cpu.
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
/// ```
|
||||
/// let ctx = TVMContext::from("gpu");
|
||||
/// assert!(ctx.exist());
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
/// It is possible to query the underlying context as follows
|
||||
///
|
||||
/// ```
|
||||
/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
|
||||
/// println!("compute version: {}", ctx.compute_version());
|
||||
/// ```
|
||||
#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub struct TVMContext {
|
||||
/// Supported device types
|
||||
pub device_type: TVMDeviceType,
|
||||
/// Device id
|
||||
pub device_id: usize,
|
||||
}
|
||||
|
||||
impl TVMContext {
|
||||
/// Creates context from device type and id.
|
||||
pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self {
|
||||
TVMContext {
|
||||
device_type: device_type,
|
||||
device_id: device_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_ctxs {
|
||||
($(($ctx:ident, $dldevt:expr));+) => {
|
||||
$(
|
||||
impl TVMContext {
|
||||
pub fn $ctx(device_id: usize) -> Self {
|
||||
Self::new(TVMDeviceType($dldevt), device_id)
|
||||
}
|
||||
}
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
impl_ctxs!((cpu, 1);
|
||||
(gpu, 2);
|
||||
(nvptx, 2);
|
||||
(cuda, 2);
|
||||
(cpu_pinned, 3);
|
||||
(cl, 4);
|
||||
(opencl, 4);
|
||||
(metal, 8);
|
||||
(vpi, 9);
|
||||
(rocm, 10);
|
||||
(opengl, 11);
|
||||
(ext_dev, 12));
|
||||
|
||||
impl<'a> From<&'a str> for TVMContext {
|
||||
fn from(target: &str) -> Self {
|
||||
TVMContext::new(TVMDeviceType::from(target), 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl TVMContext {
|
||||
/// Checks whether the context exists or not.
|
||||
pub fn exist(&self) -> bool {
|
||||
let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
|
||||
.expect("API function always exists");
|
||||
let dt = self.device_type.0 as usize;
|
||||
// `unwrap` is ok here because if there is any error,
|
||||
// if would occure inside `call_packed!`
|
||||
let ret = call_packed!(func, &dt, &self.device_id, &0)
|
||||
.unwrap()
|
||||
.prim_value;
|
||||
ret != 0
|
||||
}
|
||||
|
||||
/// Synchronize the context stream.
|
||||
pub fn sync(&self) -> Result<()> {
|
||||
check_call!(ts::TVMSynchronize(
|
||||
self.device_type.0 as i32,
|
||||
self.device_id as i32,
|
||||
ptr::null_mut() as *mut c_void
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_device_attrs {
|
||||
($(($attr_name:ident, $attr_kind:expr));+) => {
|
||||
$(
|
||||
impl TVMContext {
|
||||
pub fn $attr_name(&self) -> usize {
|
||||
let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
|
||||
.expect("API function always exists");
|
||||
let dt = self.device_type.0 as usize;
|
||||
// `unwrap` is ok here because if there is any error,
|
||||
// if would occur in function call.
|
||||
let ret = function::Builder::from(func)
|
||||
.args(&[dt, self.device_id, $attr_kind])
|
||||
.invoke()
|
||||
.unwrap();
|
||||
ret.prim_value as usize
|
||||
}
|
||||
}
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
impl_device_attrs!((max_threads_per_block, 1);
|
||||
(warp_size, 2);
|
||||
(max_shared_memory_per_block, 3);
|
||||
(compute_version, 4);
|
||||
(device_name, 5);
|
||||
(max_clock_rate, 6);
|
||||
(multi_processor_count, 7);
|
||||
(max_thread_dimensions, 8));
|
||||
|
||||
impl From<ts::DLContext> for TVMContext {
|
||||
fn from(ctx: ts::DLContext) -> Self {
|
||||
TVMContext {
|
||||
device_type: TVMDeviceType::from(ctx.device_type),
|
||||
device_id: ctx.device_id as usize,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TVMContext> for ts::DLContext {
|
||||
fn from(ctx: TVMContext) -> Self {
|
||||
ts::DLContext {
|
||||
device_type: ctx.device_type.into(),
|
||||
device_id: ctx.device_id as i32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for TVMContext {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
write!(f, "{}({})", self.device_type, self.device_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn context() {
|
||||
let ctx = TVMContext::cpu(0);
|
||||
println!("ctx: {}", ctx);
|
||||
let default_ctx = TVMContext::new(TVMDeviceType(1), 0);
|
||||
assert_eq!(ctx.clone(), default_ctx);
|
||||
assert_ne!(ctx, TVMContext::gpu(0));
|
||||
|
||||
let str_ctx = TVMContext::new(TVMDeviceType::from("gpu"), 0);
|
||||
assert_eq!(str_ctx.clone(), str_ctx);
|
||||
assert_ne!(str_ctx, TVMContext::new(TVMDeviceType::from("cpu"), 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync() {
|
||||
let ctx = TVMContext::cpu(0);
|
||||
assert!(ctx.sync().is_ok())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types.
|
||||
|
||||
use std::{ffi, option};
|
||||
|
||||
use crate::{common_errors, rust_ndarray};
|
||||
|
||||
error_chain! {
|
||||
errors {
|
||||
EmptyArray {
|
||||
description("cannot convert from an empty array")
|
||||
}
|
||||
|
||||
NullHandle(name: String) {
|
||||
description("null handle")
|
||||
display("requested `{}` handle is null", name)
|
||||
}
|
||||
|
||||
FunctionNotFound {
|
||||
description("function not found")
|
||||
display("function was not set in `function::Builder`")
|
||||
}
|
||||
|
||||
TypeMismatch(expected: String, found: String) {
|
||||
description("type mismatch!")
|
||||
display("expected type `{}`, but found `{}`", expected, found)
|
||||
}
|
||||
|
||||
MissingShapeError {
|
||||
description("ndarray `shape()` returns `None`")
|
||||
display("called `Option::unwrap()` on a `None` value")
|
||||
}
|
||||
|
||||
AtMostOneReturn {
|
||||
description("TVM functions accept at most one return value")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
foreign_links {
|
||||
ShapeError(rust_ndarray::ShapeError);
|
||||
NulError(ffi::NulError);
|
||||
IntoStringError(ffi::IntoStringError);
|
||||
CommonError(common_errors::Error);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<option::NoneError> for Error {
|
||||
fn from(_err: option::NoneError) -> Self {
|
||||
ErrorKind::MissingShapeError.into()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,512 @@
|
|||
//! This module provides an idiomatic Rust API for creating and working with TVM functions.
|
||||
//!
|
||||
//! For calling an already registered TVM function use [`function::Builder`]
|
||||
//! To register a TVM packed function from Rust side either
|
||||
//! use [`function::register`] or the macro [`register_global_func`].
|
||||
//!
|
||||
//! See the tests and examples repository for more examples.
|
||||
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
ffi::{CStr, CString},
|
||||
mem,
|
||||
os::raw::{c_char, c_int, c_void},
|
||||
ptr, slice, str,
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue};
|
||||
|
||||
lazy_static! {
|
||||
static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
|
||||
let mut out_size = 0 as c_int;
|
||||
let name = ptr::null_mut() as *mut c_char;
|
||||
let mut out_array = name as *mut _;
|
||||
check_call!(ts::TVMFuncListGlobalNames(
|
||||
&mut out_size as *mut _,
|
||||
&mut out_array
|
||||
));
|
||||
let names_list = unsafe { slice::from_raw_parts(out_array, out_size as usize) };
|
||||
Mutex::new(
|
||||
names_list
|
||||
.into_iter()
|
||||
.map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
/// Wrapper around TVM function handle which includes `is_global`
|
||||
/// indicating whether the function is global or not, `is_released`
|
||||
/// to hint dropping the function handle and `is_cloned` showing
|
||||
/// not to drop a cloned function from Rust side.
|
||||
/// The value of these fields can be accessed through their respective methods.
|
||||
#[derive(Debug, Hash)]
|
||||
pub struct Function {
|
||||
pub(crate) handle: ts::TVMFunctionHandle,
|
||||
// whether the registered function is global or not.
|
||||
is_global: bool,
|
||||
// whether the function has been dropped from frontend or not.
|
||||
is_released: bool,
|
||||
// whether the function has been cloned from frontend or not.
|
||||
is_cloned: bool,
|
||||
}
|
||||
|
||||
unsafe impl Send for Function {}
|
||||
unsafe impl Sync for Function {}
|
||||
|
||||
impl Function {
|
||||
pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self {
|
||||
Function {
|
||||
handle: handle,
|
||||
is_global: is_global,
|
||||
is_released: is_released,
|
||||
is_cloned: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// For a given function, it returns a function by name.
|
||||
pub fn get<S: AsRef<str>>(name: S, is_global: bool) -> Option<&'static Function> {
|
||||
let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
|
||||
globals.get_mut(name.as_ref()).and_then(|maybe_func| {
|
||||
if maybe_func.is_none() {
|
||||
let name = CString::new(name.as_ref()).unwrap();
|
||||
let mut handle = ptr::null_mut() as ts::TVMFunctionHandle;
|
||||
check_call!(ts::TVMFuncGetGlobal(
|
||||
name.as_ptr() as *const c_char,
|
||||
&mut handle as *mut _
|
||||
));
|
||||
maybe_func.replace(Function::new(
|
||||
handle, is_global, false, /* is_released */
|
||||
));
|
||||
}
|
||||
unsafe {
|
||||
std::mem::transmute::<Option<&Function>, Option<&'static Function>>(
|
||||
maybe_func.as_ref(),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the underlying TVM function handle.
|
||||
pub fn handle(&self) -> ts::TVMFunctionHandle {
|
||||
self.handle
|
||||
}
|
||||
|
||||
/// Returns `true` if the underlying TVM function is global and `false` otherwise.
|
||||
pub fn is_global(&self) -> bool {
|
||||
self.is_global
|
||||
}
|
||||
|
||||
/// Returns `true` if the underlying TVM function has been released
|
||||
/// from the frontend and `false` otherwise.
|
||||
pub fn is_released(&self) -> bool {
|
||||
self.is_released
|
||||
}
|
||||
|
||||
/// Returns `true` if the underlying TVM function has been cloned
|
||||
/// from the frontend and `false` otherwise.
|
||||
pub fn is_cloned(&self) -> bool {
|
||||
self.is_cloned
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Function {
|
||||
fn clone(&self) -> Function {
|
||||
if !self.is_released && !self.is_cloned {
|
||||
Self {
|
||||
handle: self.handle,
|
||||
is_global: self.is_global,
|
||||
is_released: self.is_released,
|
||||
is_cloned: true,
|
||||
}
|
||||
} else {
|
||||
Function::new(self.handle, self.is_global, self.is_released)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Function {
|
||||
fn drop(&mut self) {
|
||||
if !self.is_released && !self.is_global && !self.is_cloned {
|
||||
check_call!(ts::TVMFuncFree(self.handle));
|
||||
self.is_released = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Function builder in order to create and call functions.
|
||||
///
|
||||
/// *Note:* Currently TVM functions accept *at most* one return value.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Builder<'a, 'm> {
|
||||
pub func: Option<&'m Function>,
|
||||
pub arg_buf: Option<Box<[TVMArgValue<'a>]>>,
|
||||
pub ret_buf: Option<TVMRetValue>,
|
||||
}
|
||||
|
||||
impl<'a, 'm> Builder<'a, 'm> {
|
||||
pub fn new(
|
||||
func: Option<&'m Function>,
|
||||
arg_buf: Option<Box<[TVMArgValue<'a>]>>,
|
||||
ret_buf: Option<TVMRetValue>,
|
||||
) -> Self {
|
||||
Self {
|
||||
func,
|
||||
arg_buf,
|
||||
ret_buf,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self {
|
||||
self.func = Function::get(name, is_global);
|
||||
self
|
||||
}
|
||||
|
||||
/// Pushes a [`TVMArgValue`] into the function argument buffer.
|
||||
pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self
|
||||
where
|
||||
TVMValue: From<&'b T>,
|
||||
TVMTypeCode: From<&'b T>,
|
||||
{
|
||||
let tvm_arg = TVMArgValue::from(arg);
|
||||
if self.arg_buf.is_none() {
|
||||
self.arg_buf = Some(Box::new([tvm_arg]));
|
||||
} else {
|
||||
let new_arg_buf = self.arg_buf.take().map(|bbuf| {
|
||||
let mut new_arg_buf = Vec::from(bbuf);
|
||||
new_arg_buf.push(tvm_arg);
|
||||
let new_len = new_arg_buf.len();
|
||||
new_arg_buf.truncate(new_len);
|
||||
new_arg_buf.into_boxed_slice()
|
||||
});
|
||||
self.arg_buf = new_arg_buf;
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
|
||||
pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self
|
||||
where
|
||||
I: IntoIterator<Item = &'b T>,
|
||||
TVMValue: From<&'b T>,
|
||||
TVMTypeCode: From<&'b T>,
|
||||
{
|
||||
for arg in args {
|
||||
self.arg(&arg);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets an output for a function that requirs a mutable output to be provided.
|
||||
/// See the `basics` in tests for an example.
|
||||
pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self>
|
||||
where
|
||||
TVMValue: From<&'b T>,
|
||||
TVMTypeCode: From<&'b T>,
|
||||
{
|
||||
if self.ret_buf.is_none() {
|
||||
let tvm_ret =
|
||||
unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) };
|
||||
self.ret_buf = Some(tvm_ret);
|
||||
} else {
|
||||
bail!(ErrorKind::AtMostOneReturn)
|
||||
}
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Calls the function that created from `Builder`.
|
||||
pub fn invoke(&mut self) -> Result<TVMRetValue> {
|
||||
self.clone()(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> {
|
||||
type Output = Result<TVMRetValue>;
|
||||
extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output {
|
||||
if self.func.is_none() {
|
||||
bail!("{}", ErrorKind::FunctionNotFound);
|
||||
}
|
||||
|
||||
let mut ret_val = unsafe { mem::uninitialized::<ts::TVMValue>() };
|
||||
let mut ret_type_code = 0 as c_int;
|
||||
if self.arg_buf.is_some() {
|
||||
let arg_buf = self.arg_buf?;
|
||||
let mut num_args = arg_buf.len();
|
||||
let mut values = arg_buf
|
||||
.iter()
|
||||
.map(|tav| tav.value.inner)
|
||||
.collect::<Vec<ts::TVMValue>>();
|
||||
let mut tcodes = arg_buf
|
||||
.iter()
|
||||
.map(|tav| tav.type_code as c_int)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if self.ret_buf.is_some() {
|
||||
num_args = num_args + 1;
|
||||
let ret_buf = self.ret_buf?;
|
||||
let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf);
|
||||
values.append(&mut vec![ret_val.inner]);
|
||||
tcodes.append(&mut vec![ret_type_code as c_int]);
|
||||
}
|
||||
|
||||
values.truncate(num_args);
|
||||
tcodes.truncate(num_args);
|
||||
check_call!(ts::TVMFuncCall(
|
||||
self.func?.handle,
|
||||
values.as_mut_ptr(),
|
||||
tcodes.as_mut_ptr(),
|
||||
num_args as c_int,
|
||||
&mut ret_val as *mut _,
|
||||
&mut ret_type_code as *mut _
|
||||
));
|
||||
} else {
|
||||
check_call!(ts::TVMFuncCall(
|
||||
self.func?.handle,
|
||||
ptr::null_mut(),
|
||||
ptr::null_mut(),
|
||||
0 as c_int,
|
||||
&mut ret_val as *mut _,
|
||||
&mut ret_type_code as *mut _
|
||||
));
|
||||
}
|
||||
|
||||
let ret = unsafe {
|
||||
TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into())
|
||||
};
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a [`Function`] to builder. Currently, this is the best way to work with
|
||||
/// TVM functions.
|
||||
impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
|
||||
fn from(func: &'m Function) -> Self {
|
||||
Builder::new(Some(func), None, None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a mutable reference of a [`Module`] to [`Builder`].
|
||||
impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
|
||||
fn from(module: &'m mut Module) -> Self {
|
||||
Builder::new(module.entry(), None, None)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "C" fn tvm_callback(
|
||||
args: *mut ts::TVMValue,
|
||||
type_codes: *mut c_int,
|
||||
num_args: c_int,
|
||||
ret: ts::TVMRetValueHandle,
|
||||
fhandle: *mut c_void,
|
||||
) -> c_int {
|
||||
// turning off the incorrect linter complaints
|
||||
#![allow(unused_assignments)]
|
||||
let len = num_args as usize;
|
||||
let args_list = slice::from_raw_parts_mut(args, len);
|
||||
let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
|
||||
let mut local_args: Vec<TVMArgValue> = Vec::new();
|
||||
let mut value = mem::uninitialized::<ts::TVMValue>();
|
||||
let mut tcode = mem::uninitialized::<c_int>();
|
||||
let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
|
||||
for i in 0..len {
|
||||
value = args_list[i];
|
||||
tcode = type_codes_list[i];
|
||||
if tcode == ts::TVMTypeCode_kNodeHandle as c_int
|
||||
|| tcode == ts::TVMTypeCode_kFuncHandle as c_int
|
||||
|| tcode == ts::TVMTypeCode_kModuleHandle as c_int
|
||||
{
|
||||
check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode));
|
||||
}
|
||||
local_args.push(TVMArgValue::new(
|
||||
TVMValue::new(value),
|
||||
(tcode as i64).into(),
|
||||
));
|
||||
}
|
||||
|
||||
let rv = match rust_fn(local_args.as_slice()) {
|
||||
Ok(v) => v,
|
||||
Err(msg) => {
|
||||
crate::set_last_error(&msg);
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv);
|
||||
let mut ret_val = ret_val.inner;
|
||||
let mut ret_type_code = ret_tcode as c_int;
|
||||
check_call!(ts::TVMCFuncSetReturn(
|
||||
ret,
|
||||
&mut ret_val as *mut _,
|
||||
&mut ret_type_code as *mut _,
|
||||
1 as c_int
|
||||
));
|
||||
0
|
||||
}
|
||||
|
||||
unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
|
||||
let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
|
||||
mem::drop(rust_fn);
|
||||
}
|
||||
|
||||
fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function {
|
||||
let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
|
||||
let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue>;
|
||||
check_call!(ts::TVMFuncCreateFromCFunc(
|
||||
Some(tvm_callback),
|
||||
resource_handle as *mut c_void,
|
||||
Some(tvm_callback_finalizer),
|
||||
&mut fhandle as *mut _
|
||||
));
|
||||
Function::new(fhandle, false, false)
|
||||
}
|
||||
|
||||
/// Registers a Rust function with signature
|
||||
/// `fn(&[TVMArgValue]) -> Result<TVMRetValue>`
|
||||
/// as a **global TVM packed function** from frontend to TVM backend.
|
||||
///
|
||||
/// Use [`register_global_func`] if overriding an existing global TVM function
|
||||
/// is not required.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```
|
||||
/// use std::convert::TryInto;
|
||||
///
|
||||
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
/// let mut ret = 0i64;
|
||||
/// for arg in args.iter() {
|
||||
/// let arg: i64 = arg.try_into()?;
|
||||
/// ret += arg;
|
||||
/// }
|
||||
/// let ret_val = TVMRetValue::from(&ret);
|
||||
/// Ok(ret_val)
|
||||
/// }
|
||||
///
|
||||
/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
|
||||
/// let mut registered = function::Builder::default();
|
||||
/// registered.get_function("mysum", true);
|
||||
/// assert!(registered.func.is_some());
|
||||
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
|
||||
/// assert_eq!(ret, 60);
|
||||
/// ```
|
||||
pub fn register<S: AsRef<str>>(
|
||||
f: fn(&[TVMArgValue]) -> Result<TVMRetValue>,
|
||||
name: S,
|
||||
override_: bool,
|
||||
) -> Result<()> {
|
||||
let func = convert_to_tvm_func(f);
|
||||
let name = CString::new(name.as_ref())?;
|
||||
check_call!(ts::TVMFuncRegisterGlobal(
|
||||
name.as_ref().as_ptr() as *const c_char,
|
||||
func.handle(),
|
||||
override_ as c_int
|
||||
));
|
||||
mem::forget(name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convenient macro for registering functions from frontend to backend as global
|
||||
/// TVM packed functions without overriding. If overriding an existing function is needed
|
||||
/// use the [`function::register`] function instead.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```
|
||||
/// use std::convert::TryInto;
|
||||
///
|
||||
/// register_global_func! {
|
||||
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
/// let mut ret = 0f64;
|
||||
/// for arg in args.iter() {
|
||||
/// let arg: f64 = arg.try_into()?;
|
||||
/// ret += arg;
|
||||
/// }
|
||||
/// let ret_val = TVMRetValue::from(&ret);
|
||||
/// Ok(ret_val)
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// let mut registered = function::Builder::default();
|
||||
/// registered.get_function("sum", true);
|
||||
/// assert!(registered.func.is_some());
|
||||
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
|
||||
/// assert_eq!(ret, 60f64);
|
||||
/// ```
|
||||
#[macro_export]
|
||||
macro_rules! register_global_func {
|
||||
{
|
||||
$(#[$m:meta])*
|
||||
fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
$($code:tt)*
|
||||
}
|
||||
} => {{
|
||||
$(#[$m])*
|
||||
fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
$($code)*
|
||||
}
|
||||
|
||||
$crate::function::register($fn_name, stringify!($fn_name).to_owned(), false).unwrap();
|
||||
}}
|
||||
}
|
||||
|
||||
/// Convenient macro for calling TVM packed functions by providing a
|
||||
/// function identifier and some arguments. This macro outputs a `Result` type
|
||||
/// and let user to perform proper error handling.
|
||||
///
|
||||
/// **Note**: this macro does *not* expect an outside mutable output. To
|
||||
/// set mutable output use [`set_output`] directly in the builder pattern.
|
||||
///
|
||||
/// [`set_output`]:function/struct.Builder.html#method.set_output
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// Instead of
|
||||
///
|
||||
/// ```
|
||||
/// function::Builder::from(func).arg(&a).arg(&b).invoke();
|
||||
/// ```
|
||||
///
|
||||
/// one can use
|
||||
///
|
||||
/// ```
|
||||
/// call_packed!(func, &a, &b);
|
||||
/// ```
|
||||
#[macro_export]
|
||||
macro_rules! call_packed {
|
||||
($fn_name:expr, $($arg:expr),*) => {{
|
||||
let mut builder = $crate::function::Builder::from($fn_name);
|
||||
$(
|
||||
builder.arg($arg);
|
||||
)*
|
||||
builder.invoke()
|
||||
}}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
static CANARY: &str = "module._LoadFromFile";
|
||||
|
||||
#[test]
|
||||
fn list_global_func() {
|
||||
assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_fn() {
|
||||
assert!(Function::get(CANARY, true).is_some());
|
||||
assert!(Function::get("does not exists!", false).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provide_args() {
|
||||
let mut func = Builder::default();
|
||||
func.get_function("tvm.graph_runtime.remote_create", true)
|
||||
.args(&[10, 20])
|
||||
.arg(&"test".to_owned());
|
||||
assert!(func.arg_buf.is_some());
|
||||
assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
//! [TVM](https://github.com/dmlc/tvm) is a compiler stack for deep learning systems.
|
||||
//!
|
||||
//! This crate provides an idiomatic Rust API for TVM runtime frontend.
|
||||
//!
|
||||
//! One particular use case is that given optimized deep learning model artifacts,
|
||||
//! (compiled with TVM) which include a shared library
|
||||
//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them
|
||||
//! in Rust idomatically to create a TVM Graph Runtime and
|
||||
//! run the model for some inputs and get the
|
||||
//! desired predictions *all in Rust*.
|
||||
//!
|
||||
//! Checkout the `examples` repository for more details.
|
||||
|
||||
#![crate_name = "tvm_frontend"]
|
||||
#![recursion_limit = "1024"]
|
||||
#![allow(non_camel_case_types, unused_unsafe)]
|
||||
#![feature(
|
||||
try_from,
|
||||
try_trait,
|
||||
fn_traits,
|
||||
unboxed_closures,
|
||||
box_syntax,
|
||||
option_replace
|
||||
)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate error_chain;
|
||||
extern crate tvm_common as common;
|
||||
#[macro_use]
|
||||
extern crate lazy_static;
|
||||
extern crate ndarray as rust_ndarray;
|
||||
extern crate num_traits;
|
||||
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
str,
|
||||
};
|
||||
|
||||
use crate::common::ffi::ts;
|
||||
|
||||
// Macro to check the return call to TVM runtime shared library.
|
||||
macro_rules! check_call {
|
||||
($e:expr) => {{
|
||||
if unsafe { $e } != 0 {
|
||||
panic!("{}", $crate::get_last_error());
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/// Gets the last error message.
|
||||
pub fn get_last_error() -> &'static str {
|
||||
unsafe {
|
||||
match CStr::from_ptr(ts::TVMGetLastError()).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => "Invalid UTF-8 message",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_last_error(err: &Error) {
|
||||
let c_string = CString::new(err.to_string()).unwrap();
|
||||
unsafe {
|
||||
ts::TVMAPISetLastError(c_string.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_use]
|
||||
pub mod function;
|
||||
pub mod bytearray;
|
||||
pub mod context;
|
||||
pub mod errors;
|
||||
pub mod module;
|
||||
pub mod ndarray;
|
||||
pub mod ty;
|
||||
pub mod value;
|
||||
|
||||
pub use crate::{
|
||||
bytearray::TVMByteArray,
|
||||
common::{
|
||||
errors as common_errors,
|
||||
ty::TVMTypeCode,
|
||||
value::{TVMArgValue, TVMRetValue, TVMValue},
|
||||
},
|
||||
context::{TVMContext, TVMDeviceType},
|
||||
errors::*,
|
||||
function::Function,
|
||||
module::Module,
|
||||
ndarray::NDArray,
|
||||
ty::TVMType,
|
||||
};
|
||||
|
||||
/// Outputs the current TVM version.
|
||||
pub fn version() -> &'static str {
|
||||
match str::from_utf8(ts::TVM_VERSION) {
|
||||
Ok(s) => s,
|
||||
Err(_) => "Invalid UTF-8 string",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn print_version() {
|
||||
println!("TVM version: {}", version());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_error() {
|
||||
let err = ErrorKind::EmptyArray;
|
||||
set_last_error(&err.into());
|
||||
assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
//! Provides the [`Module`] type and methods for working with runtime TVM modules.
|
||||
|
||||
use std::{
|
||||
convert::TryInto,
|
||||
ffi::CString,
|
||||
os::raw::{c_char, c_int},
|
||||
path::Path,
|
||||
ptr,
|
||||
};
|
||||
|
||||
use crate::ts;
|
||||
|
||||
use crate::{function::Function, ErrorKind, Result};
|
||||
|
||||
const ENTRY_FUNC: &'static str = "__tvm_main__";
|
||||
|
||||
/// Wrapper around TVM module handle which contains an entry function.
|
||||
/// The entry function can be applied to an imported module through [`entry_func`].
|
||||
/// Also [`is_released`] shows whether the module is dropped or not.
|
||||
///
|
||||
/// [`entry_func`]:struct.Module.html#method.entry_func
|
||||
/// [`is_released`]:struct.Module.html#method.is_released
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Module {
|
||||
pub(crate) handle: ts::TVMModuleHandle,
|
||||
is_released: bool,
|
||||
entry_func: Option<Function>,
|
||||
}
|
||||
|
||||
impl Module {
|
||||
pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self {
|
||||
Self {
|
||||
handle,
|
||||
is_released,
|
||||
entry_func: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn entry(&mut self) -> Option<&Function> {
|
||||
if self.entry_func.is_none() {
|
||||
self.entry_func = self.get_function(ENTRY_FUNC, false).ok();
|
||||
}
|
||||
self.entry_func.as_ref()
|
||||
}
|
||||
|
||||
/// Gets a function by name from a registered module.
|
||||
pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> {
|
||||
let name = CString::new(name)?;
|
||||
let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
|
||||
check_call!(ts::TVMModGetFunction(
|
||||
self.handle,
|
||||
name.as_ptr() as *const c_char,
|
||||
query_import as c_int,
|
||||
&mut fhandle as *mut _
|
||||
));
|
||||
if fhandle.is_null() {
|
||||
bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?)))
|
||||
} else {
|
||||
Ok(Function::new(fhandle, false, false))
|
||||
}
|
||||
}
|
||||
|
||||
/// Imports a dependent module such as `.ptx` for gpu.
|
||||
pub fn import_module(&self, dependent_module: Module) {
|
||||
check_call!(ts::TVMModImport(self.handle, dependent_module.handle))
|
||||
}
|
||||
|
||||
/// Loads a module shared library from path.
|
||||
pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module> {
|
||||
let ext = path.as_ref().extension()?.to_str()?;
|
||||
let func = Function::get("module._LoadFromFile", true /* is_global */)
|
||||
.expect("API function always exists");
|
||||
let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?;
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
/// Checks if a target device is enabled for a module.
|
||||
pub fn enabled(&self, target: &str) -> bool {
|
||||
let func = Function::get("module._Enabled", true /* is_global */)
|
||||
.expect("API function always exists");
|
||||
// `unwrap` is safe here because if there is any error during the
|
||||
// function call, it would occur in `call_packed!`.
|
||||
let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap();
|
||||
ret != 0
|
||||
}
|
||||
|
||||
/// Returns the underlying module handle.
|
||||
pub fn handle(&self) -> ts::TVMModuleHandle {
|
||||
self.handle
|
||||
}
|
||||
|
||||
/// Returns true if the underlying module has been dropped and false otherwise.
|
||||
pub fn is_released(&self) -> bool {
|
||||
self.is_released
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Module {
|
||||
fn drop(&mut self) {
|
||||
if !self.is_released {
|
||||
check_call!(ts::TVMModFree(self.handle));
|
||||
self.is_released = true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,363 @@
|
|||
//! This module implements the [`NDArray`] type for working with *TVM tensors* or
|
||||
//! coverting from a Rust's ndarray to TVM `NDArray`.
|
||||
//!
|
||||
//! One can create an empty NDArray given the shape, context and dtype using [`empty`].
|
||||
//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`].
|
||||
//! To copy an NDArray to different context use [`copy_to_ctx`].
|
||||
//!
|
||||
//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows:
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```
|
||||
//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
|
||||
//! .unwrap()
|
||||
//! .into_dyn(); // Rust's ndarray
|
||||
//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
|
||||
//! assert_eq!(nd.shape(), Some(&mut [2, 2]));
|
||||
//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
|
||||
//! assert!(rnd.all_close(&a, 1e-8f32));
|
||||
//! ```
|
||||
//!
|
||||
//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/
|
||||
//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
|
||||
//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
|
||||
|
||||
use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice};
|
||||
|
||||
use crate::rust_ndarray::{Array, ArrayD};
|
||||
use num_traits::Num;
|
||||
|
||||
use crate::ts;
|
||||
|
||||
use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext, TVMType};
|
||||
|
||||
/// See the [`module-level documentation`](../ndarray/index.html) for more details.
|
||||
///
|
||||
/// Wrapper around TVM array handle.
|
||||
#[derive(Debug)]
|
||||
pub struct NDArray {
|
||||
pub(crate) handle: ts::TVMArrayHandle,
|
||||
is_view: bool,
|
||||
}
|
||||
|
||||
impl NDArray {
|
||||
pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self {
|
||||
NDArray {
|
||||
handle: handle,
|
||||
is_view: is_view,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the underlying array handle.
|
||||
pub fn handle(&self) -> ts::TVMArrayHandle {
|
||||
self.handle
|
||||
}
|
||||
|
||||
pub fn is_view(&self) -> bool {
|
||||
self.is_view
|
||||
}
|
||||
|
||||
/// Returns the shape of the NDArray.
|
||||
pub fn shape(&self) -> Option<&mut [usize]> {
|
||||
let arr = unsafe { *(self.handle) };
|
||||
if arr.shape.is_null() || arr.data.is_null() {
|
||||
return None;
|
||||
};
|
||||
let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) };
|
||||
Some(slc)
|
||||
}
|
||||
|
||||
/// Returns the total number of entries of the NDArray.
|
||||
pub fn size(&self) -> Option<usize> {
|
||||
self.shape()
|
||||
.map(|v| v.into_iter().fold(1, |acc, &mut e| acc * e))
|
||||
}
|
||||
|
||||
/// Returns the context which the NDArray was defined.
|
||||
pub fn ctx(&self) -> TVMContext {
|
||||
unsafe { (*self.handle).ctx.into() }
|
||||
}
|
||||
|
||||
/// Returns the type of the entries of the NDArray.
|
||||
pub fn dtype(&self) -> TVMType {
|
||||
unsafe { (*self.handle).dtype.into() }
|
||||
}
|
||||
|
||||
/// Returns the number of dimensions of the NDArray.
|
||||
pub fn ndim(&self) -> usize {
|
||||
unsafe { (*self.handle).ndim as usize }
|
||||
}
|
||||
|
||||
/// Returns the strides of the underlying NDArray.
|
||||
pub fn strides(&self) -> Option<&[usize]> {
|
||||
unsafe {
|
||||
let sz = self.ndim() * mem::size_of::<usize>();
|
||||
let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz);
|
||||
Some(slc)
|
||||
}
|
||||
}
|
||||
|
||||
/// Shows whether the underlying ndarray is contiguous in memory or not.
|
||||
pub fn is_contiguous(&self) -> Result<bool> {
|
||||
Ok(match self.strides() {
|
||||
None => true,
|
||||
Some(strides) => {
|
||||
// MissingShapeError in case shape is not determined
|
||||
self.shape()?
|
||||
.iter()
|
||||
.zip(strides)
|
||||
.rfold(
|
||||
(true, 1),
|
||||
|(is_contig, expected_stride), (shape, stride)| {
|
||||
(
|
||||
is_contig && *stride == expected_stride,
|
||||
expected_stride * (*shape as usize),
|
||||
)
|
||||
},
|
||||
)
|
||||
.0
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn byte_offset(&self) -> isize {
|
||||
unsafe { (*self.handle).byte_offset as isize }
|
||||
}
|
||||
|
||||
/// Flattens the NDArray to a `Vec` of the same type in cpu.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```
|
||||
/// let shape = &mut [4];
|
||||
/// let mut data = vec![1i32, 2, 3, 4];
|
||||
/// let ctx = TVMContext::cpu(0);
|
||||
/// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
|
||||
/// ndarray.copy_from_buffer(&mut data);
|
||||
/// assert_eq!(ndarray.shape(), Some(shape));
|
||||
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
|
||||
/// ```
|
||||
pub fn to_vec<T>(&self) -> Result<Vec<T>> {
|
||||
if self.shape().is_none() {
|
||||
bail!("{}", ErrorKind::EmptyArray);
|
||||
}
|
||||
let earr = NDArray::empty(self.shape()?, TVMContext::cpu(0), self.dtype());
|
||||
let target = self.copy_to_ndarray(earr)?;
|
||||
let arr = unsafe { *(target.handle) };
|
||||
let sz = self.size()? as usize;
|
||||
let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
|
||||
unsafe {
|
||||
v.as_mut_ptr()
|
||||
.copy_from_nonoverlapping(arr.data as *const T, sz);
|
||||
v.set_len(sz);
|
||||
}
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
/// Converts the NDArray to [`TVMByteArray`].
|
||||
pub fn to_bytearray(&self) -> Result<TVMByteArray> {
|
||||
let v = self.to_vec::<u8>()?;
|
||||
Ok(TVMByteArray::from(&v))
|
||||
}
|
||||
|
||||
/// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```
|
||||
/// let shape = &mut [2];
|
||||
/// let mut data = vec![1f32, 2];
|
||||
/// let ctx = TVMContext::gpu(0);
|
||||
/// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
|
||||
/// ndarray.copy_from_buffer(&mut data);
|
||||
/// ```
|
||||
///
|
||||
/// *Note*: if something goes wrong during the copy, it will panic
|
||||
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
|
||||
pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
|
||||
check_call!(ts::TVMArrayCopyFromBytes(
|
||||
self.handle,
|
||||
data.as_ptr() as *mut _,
|
||||
data.len() * mem::size_of::<T>()
|
||||
));
|
||||
}
|
||||
|
||||
/// Copies the NDArray to another target NDArray.
|
||||
pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> {
|
||||
if self.dtype() != target.dtype() {
|
||||
bail!(
|
||||
"{}",
|
||||
ErrorKind::TypeMismatch(
|
||||
format!("{}", self.dtype().to_string()),
|
||||
format!("{}", target.dtype().to_string()),
|
||||
)
|
||||
);
|
||||
}
|
||||
check_call!(ts::TVMArrayCopyFromTo(
|
||||
self.handle,
|
||||
target.handle,
|
||||
ptr::null_mut() as ts::TVMStreamHandle
|
||||
));
|
||||
Ok(target)
|
||||
}
|
||||
|
||||
/// Copies the NDArray to a target context.
|
||||
pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray> {
|
||||
let tmp = NDArray::empty(self.shape()?, target.clone(), self.dtype());
|
||||
let copy = self.copy_to_ndarray(tmp)?;
|
||||
Ok(copy)
|
||||
}
|
||||
|
||||
/// Converts a Rust's ndarray to TVM NDArray.
|
||||
pub fn from_rust_ndarray<T: Num32 + Copy>(
|
||||
rnd: &ArrayD<T>,
|
||||
ctx: TVMContext,
|
||||
dtype: TVMType,
|
||||
) -> Result<Self> {
|
||||
let mut shape = rnd.shape().to_vec();
|
||||
let mut nd = NDArray::empty(&mut shape, ctx, dtype);
|
||||
let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
|
||||
nd.copy_from_buffer(buf.as_slice_mut()?);
|
||||
Ok(nd)
|
||||
}
|
||||
|
||||
/// Allocates and creates an empty NDArray given the shape, context and dtype.
|
||||
pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray {
|
||||
let mut handle = ptr::null_mut() as ts::TVMArrayHandle;
|
||||
check_call!(ts::TVMArrayAlloc(
|
||||
shape.as_ptr() as *const i64,
|
||||
shape.len() as c_int,
|
||||
dtype.inner.code as c_int,
|
||||
dtype.inner.bits as c_int,
|
||||
dtype.inner.lanes as c_int,
|
||||
ctx.device_type.0 as c_int,
|
||||
ctx.device_id as c_int,
|
||||
&mut handle as *mut _,
|
||||
));
|
||||
NDArray::new(handle, false)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_from_ndarray_rustndarray {
|
||||
($type:ty, $type_name:tt) => {
|
||||
impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
|
||||
type Error = Error;
|
||||
fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> {
|
||||
if nd.shape().is_none() {
|
||||
bail!("{}", ErrorKind::EmptyArray);
|
||||
}
|
||||
assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
|
||||
Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
|
||||
type Error = Error;
|
||||
fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> {
|
||||
if nd.shape().is_none() {
|
||||
bail!("{}", ErrorKind::EmptyArray);
|
||||
}
|
||||
assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
|
||||
Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_from_ndarray_rustndarray!(i32, "int");
|
||||
impl_from_ndarray_rustndarray!(u32, "uint");
|
||||
impl_from_ndarray_rustndarray!(f32, "float");
|
||||
|
||||
impl Drop for NDArray {
|
||||
fn drop(&mut self) {
|
||||
if !self.is_view {
|
||||
check_call!(ts::TVMArrayFree(self.handle));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod sealed {
|
||||
/// Private trait to prevent other traits from being implemeneted in downstream crates.
|
||||
pub trait Sealed {}
|
||||
}
|
||||
|
||||
/// A trait for the supported 32-bits numerical types in frontend.
|
||||
pub trait Num32: Num + sealed::Sealed {
|
||||
const BITS: u8 = 32;
|
||||
}
|
||||
|
||||
macro_rules! impl_num32 {
|
||||
($($type:ty),+) => {
|
||||
$(
|
||||
impl sealed::Sealed for $type {}
|
||||
impl Num32 for $type {}
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
impl_num32!(i32, u32, f32);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn basics() {
|
||||
let shape = &mut [1, 2, 3];
|
||||
let ctx = TVMContext::cpu(0);
|
||||
let ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
|
||||
assert_eq!(ndarray.shape().unwrap(), shape);
|
||||
assert_eq!(
|
||||
ndarray.size().unwrap(),
|
||||
shape.to_vec().into_iter().product()
|
||||
);
|
||||
assert_eq!(ndarray.ndim(), 3);
|
||||
assert!(ndarray.strides().is_none());
|
||||
assert_eq!(ndarray.byte_offset(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copy() {
|
||||
let shape = &mut [4];
|
||||
let mut data = vec![1i32, 2, 3, 4];
|
||||
let ctx = TVMContext::cpu(0);
|
||||
let mut ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
|
||||
assert!(ndarray.to_vec::<i32>().is_ok());
|
||||
ndarray.copy_from_buffer(&mut data);
|
||||
assert_eq!(ndarray.shape().unwrap(), shape);
|
||||
assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
|
||||
assert_eq!(ndarray.ndim(), 1);
|
||||
assert!(ndarray.is_contiguous().is_ok());
|
||||
assert_eq!(ndarray.byte_offset(), 0);
|
||||
let mut shape = vec![4];
|
||||
let e = NDArray::empty(&mut shape, TVMContext::cpu(0), TVMType::from("int32"));
|
||||
let nd = ndarray.copy_to_ndarray(e);
|
||||
assert!(nd.is_ok());
|
||||
assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
|
||||
fn copy_wrong_dtype() {
|
||||
let mut shape = vec![4];
|
||||
let mut data = vec![1f32, 2., 3., 4.];
|
||||
let ctx = TVMContext::cpu(0);
|
||||
let mut nd_float = NDArray::empty(&mut shape, ctx.clone(), TVMType::from("float32"));
|
||||
nd_float.copy_from_buffer(&mut data);
|
||||
let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from("int32"));
|
||||
nd_float.copy_to_ndarray(empty_int).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rust_ndarray() {
|
||||
let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
|
||||
.unwrap()
|
||||
.into_dyn();
|
||||
let nd =
|
||||
NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
|
||||
assert_eq!(nd.shape().unwrap(), &mut [2, 2]);
|
||||
let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
|
||||
assert!(rnd.all_close(&a, 1e-8f32));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
//! This module implements the required conversions from Rust types to TVM types.
|
||||
//!
|
||||
//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32)
|
||||
//! and 64-bits pointers are supported.
|
||||
|
||||
use std::{
|
||||
fmt::{self, Display, Formatter},
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
|
||||
use crate::ts;
|
||||
|
||||
use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode};
|
||||
|
||||
macro_rules! impl_prim_type {
|
||||
($type:ty, $variant:ident) => {
|
||||
impl From<$type> for TVMTypeCode {
|
||||
fn from(_arg: $type) -> Self {
|
||||
TVMTypeCode::$variant
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a $type> for TVMTypeCode {
|
||||
fn from(_arg: &$type) -> Self {
|
||||
TVMTypeCode::$variant
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a mut $type> for TVMTypeCode {
|
||||
fn from(_arg: &mut $type) -> Self {
|
||||
TVMTypeCode::$variant
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_prim_type!(TVMDeviceType, kDLInt);
|
||||
impl_prim_type!(TVMContext, kTVMContext);
|
||||
impl_prim_type!(TVMType, kTVMType);
|
||||
impl_prim_type!(Function, kFuncHandle);
|
||||
impl_prim_type!(Module, kModuleHandle);
|
||||
impl_prim_type!(NDArray, kArrayHandle);
|
||||
impl_prim_type!(TVMByteArray, kBytes);
|
||||
|
||||
/// See the [module-level documentation](../ty/index.html) for more details.
|
||||
///
|
||||
/// Wrapper around underlying TVMType
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub struct TVMType {
|
||||
// inner fields are (code: u8, bits: u8, lanes: u16)
|
||||
pub inner: ts::TVMType,
|
||||
}
|
||||
|
||||
impl TVMType {
|
||||
pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
|
||||
TVMType {
|
||||
inner: ts::TVMType {
|
||||
code: type_code,
|
||||
bits: bits,
|
||||
lanes: lanes,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
|
||||
/// such as "int32", "float32" or with lane "float32x1".
|
||||
impl<'a> From<&'a str> for TVMType {
|
||||
fn from(type_str: &'a str) -> Self {
|
||||
if type_str == "bool" {
|
||||
return TVMType::new(1, 1, 1);
|
||||
}
|
||||
|
||||
let mut type_lanes = type_str.split("x");
|
||||
let typ = type_lanes.next().expect("Missing dtype");
|
||||
let lanes = type_lanes
|
||||
.next()
|
||||
.map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l)))
|
||||
.unwrap_or(1);
|
||||
let (type_name, bits) = match typ.find(char::is_numeric) {
|
||||
Some(idx) => {
|
||||
let (name, bits_str) = typ.split_at(idx);
|
||||
(
|
||||
name,
|
||||
u8::from_str_radix(bits_str, 10)
|
||||
.expect(&format!("Bad dtype bits: {}", bits_str)),
|
||||
)
|
||||
}
|
||||
None => (typ, 32),
|
||||
};
|
||||
|
||||
let type_code = match type_name {
|
||||
"int" => 0,
|
||||
"uint" => 1,
|
||||
"float" => 2,
|
||||
"handle" => 3,
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
TVMType::new(type_code, bits, lanes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for TVMType {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
let ts::TVMType { code, bits, lanes } = self.inner;
|
||||
if bits == 1 && lanes == 1 {
|
||||
return write!(f, "bool");
|
||||
}
|
||||
let mut tcode_str = match code {
|
||||
0 => "int",
|
||||
1 => "uint",
|
||||
2 => "float",
|
||||
4 => "handle",
|
||||
_ => "Unknown",
|
||||
}
|
||||
.to_string();
|
||||
|
||||
tcode_str += &bits.to_string();
|
||||
if lanes > 1 {
|
||||
tcode_str += &format!("x{}", lanes.to_string());
|
||||
}
|
||||
f.write_str(&tcode_str)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TVMType> for ts::DLDataType {
|
||||
fn from(dtype: TVMType) -> Self {
|
||||
dtype.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ts::DLDataType> for TVMType {
|
||||
fn from(dtype: ts::DLDataType) -> Self {
|
||||
Self::new(dtype.code, dtype.bits, dtype.lanes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for TVMType {
|
||||
type Target = ts::TVMType;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for TVMType {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
|
@ -0,0 +1,241 @@
|
|||
//! This module implements [`TVMArgValue`] and [`TVMRetValue`] types
|
||||
//! and their conversions needed for the types used in frontend crate.
|
||||
//! `TVMRetValue` is the owned version of `TVMPODValue`.
|
||||
|
||||
use std::{convert::TryFrom, mem, os::raw::c_void};
|
||||
|
||||
use crate::{
|
||||
common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext,
|
||||
TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue,
|
||||
};
|
||||
|
||||
macro_rules! impl_tvm_val_from_handle {
|
||||
($($ty:ty),+) => {
|
||||
$(
|
||||
impl<'a> From<&'a $ty> for TVMValue {
|
||||
fn from(arg: &$ty) -> Self {
|
||||
let inner = ts::TVMValue {
|
||||
v_handle: arg.handle as *mut _ as *mut c_void,
|
||||
};
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
impl_tvm_val_from_handle!(Module, Function, NDArray);
|
||||
|
||||
impl<'a> From<&'a TVMType> for TVMValue {
|
||||
fn from(ty: &TVMType) -> Self {
|
||||
let inner = ts::TVMValue { v_type: ty.inner };
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a TVMContext> for TVMValue {
|
||||
fn from(ctx: &TVMContext) -> Self {
|
||||
let inner = ts::TVMValue {
|
||||
v_ctx: ctx.clone().into(),
|
||||
};
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a TVMDeviceType> for TVMValue {
|
||||
fn from(dev: &TVMDeviceType) -> Self {
|
||||
let inner = ts::TVMValue {
|
||||
v_int64: dev.0 as i64,
|
||||
};
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a TVMByteArray> for TVMValue {
|
||||
fn from(barr: &TVMByteArray) -> Self {
|
||||
let inner = ts::TVMValue {
|
||||
v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void,
|
||||
};
|
||||
Self::new(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray {
|
||||
type Error = Error;
|
||||
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
|
||||
if arg.type_code == TVMTypeCode::kArrayHandle {
|
||||
let handle = unsafe { arg.value.inner.v_handle };
|
||||
let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) };
|
||||
Ok(Self::new(arr_handle, true))
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMArgValueError(
|
||||
stringify!(NDArray).to_string(),
|
||||
arg.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module {
|
||||
type Error = Error;
|
||||
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
|
||||
if arg.type_code == TVMTypeCode::kModuleHandle {
|
||||
let handle = unsafe { arg.value.inner.v_handle };
|
||||
Ok(Self::new(handle, false))
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMArgValueError(
|
||||
stringify!(Module).to_string(),
|
||||
arg.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray {
|
||||
type Error = Error;
|
||||
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
|
||||
if arg.type_code == TVMTypeCode::kBytes {
|
||||
unsafe {
|
||||
let barr_ptr =
|
||||
mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle);
|
||||
Ok(Self::new(*barr_ptr))
|
||||
}
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMArgValueError(
|
||||
stringify!(TVMByteArray).to_string(),
|
||||
arg.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType {
|
||||
type Error = Error;
|
||||
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
|
||||
if arg.type_code == TVMTypeCode::kTVMType {
|
||||
let ty = unsafe { arg.value.inner.v_type };
|
||||
Ok(TVMType::from(ty))
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMArgValueError(
|
||||
stringify!(TVMType).to_string(),
|
||||
arg.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext {
|
||||
type Error = Error;
|
||||
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
|
||||
if arg.type_code == TVMTypeCode::kTVMContext {
|
||||
let ty = unsafe { arg.value.inner.v_ctx };
|
||||
Ok(TVMContext::from(ty))
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMArgValueError(
|
||||
stringify!(TVMContext).to_string(),
|
||||
arg.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_boxed_ret_value {
|
||||
($type:ty, $code:expr) => {
|
||||
impl From<$type> for TVMRetValue {
|
||||
fn from(val: $type) -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: 0,
|
||||
box_value: box val,
|
||||
type_code: $code,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl TryFrom<TVMRetValue> for $type {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<$type> {
|
||||
if let Ok(val) = ret.box_value.downcast::<$type>() {
|
||||
Ok(*val)
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMRetValueError(
|
||||
stringify!($type).to_string(),
|
||||
ret.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType);
|
||||
impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext);
|
||||
impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes);
|
||||
|
||||
impl TryFrom<TVMRetValue> for Module {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<Module> {
|
||||
if let Ok(handle) = ret.box_value.downcast::<ts::TVMModuleHandle>() {
|
||||
Ok(Module::new(*handle, false))
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMRetValueError(
|
||||
stringify!(TVMTypeCode::kModuleHandle).to_string(),
|
||||
ret.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<TVMRetValue> for Function {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<Function> {
|
||||
if let Ok(handle) = ret.box_value.downcast::<ts::TVMFunctionHandle>() {
|
||||
Ok(Function::new(*handle, false, false))
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMRetValueError(
|
||||
stringify!(TVMTypeCode::kFuncHandle).to_string(),
|
||||
ret.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<TVMRetValue> for NDArray {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<NDArray> {
|
||||
if let Ok(handle) = ret.box_value.downcast::<ts::TVMArrayHandle>() {
|
||||
Ok(NDArray::new(*handle, false))
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMRetValueError(
|
||||
stringify!(TVMTypeCode::kArrayHandle).to_string(),
|
||||
ret.type_code.to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::convert::TryInto;
|
||||
|
||||
#[test]
|
||||
fn bytearray() {
|
||||
let w = vec![1u8, 2, 3, 4, 5];
|
||||
let v = TVMByteArray::from(&w);
|
||||
let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap();
|
||||
assert_eq!(tvm.data(), w.iter().map(|e| *e as i8).collect::<Vec<i8>>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ty() {
|
||||
let t = TVMType::from("int32");
|
||||
let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
|
||||
assert_eq!(tvm, t);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ctx() {
|
||||
let c = TVMContext::from("gpu");
|
||||
let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap();
|
||||
assert_eq!(tvm, c);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
/target
|
||||
**/*.rs.bk
|
||||
Cargo.lock
|
||||
*.o
|
||||
*.so
|
||||
*.ptx
|
||||
*.json
|
|
@ -0,0 +1,15 @@
|
|||
[package]
|
||||
name = "basics"
|
||||
version = "0.0.0"
|
||||
authors = ["TVM Contributors"]
|
||||
license = "Apache-2.0"
|
||||
build = "build.rs"
|
||||
|
||||
[dependencies]
|
||||
ndarray = "0.12.1"
|
||||
tvm-frontend = { path = "../../" }
|
||||
|
||||
[features]
|
||||
default = ["cpu"]
|
||||
cpu = []
|
||||
gpu = []
|
|
@ -0,0 +1,27 @@
|
|||
fn main() {
|
||||
let out_dir = std::env::var("OUT_DIR").unwrap();
|
||||
|
||||
let output = std::process::Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add.py"))
|
||||
.args(&[
|
||||
if cfg!(feature = "cpu") {
|
||||
"llvm"
|
||||
} else {
|
||||
"cuda"
|
||||
},
|
||||
&std::env::var("OUT_DIR").unwrap(),
|
||||
])
|
||||
.output()
|
||||
.expect("Failed to execute command");
|
||||
assert!(
|
||||
std::path::Path::new(&format!("{}/test_add.so", out_dir)).exists(),
|
||||
"Could not build tvm lib: {}",
|
||||
String::from_utf8(output.stderr)
|
||||
.unwrap()
|
||||
.trim()
|
||||
.split("\n")
|
||||
.last()
|
||||
.unwrap_or("")
|
||||
);
|
||||
|
||||
println!("cargo:rustc-link-search=native={}", out_dir);
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
extern crate ndarray as rust_ndarray;
|
||||
extern crate tvm_frontend as tvm;
|
||||
|
||||
use tvm::*;
|
||||
|
||||
fn main() {
|
||||
let shape = &mut [2];
|
||||
let mut data = vec![3f32, 4.0];
|
||||
|
||||
let (ctx, ctx_name) = if cfg!(feature = "cpu") {
|
||||
(TVMContext::cpu(0), "cpu")
|
||||
} else {
|
||||
(TVMContext::gpu(0), "gpu")
|
||||
};
|
||||
let dtype = TVMType::from("float32");
|
||||
let mut arr = NDArray::empty(shape, ctx, dtype);
|
||||
arr.copy_from_buffer(data.as_mut_slice());
|
||||
let mut ret = NDArray::empty(shape, ctx, dtype);
|
||||
let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap();
|
||||
if !fadd.enabled(ctx_name) {
|
||||
return;
|
||||
}
|
||||
if cfg!(feature = "gpu") {
|
||||
fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap());
|
||||
}
|
||||
function::Builder::from(&mut fadd)
|
||||
.arg(&arr)
|
||||
.arg(&arr)
|
||||
.set_output(&mut ret)
|
||||
.unwrap()
|
||||
.invoke()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
import tvm
|
||||
from tvm.contrib import cc
|
||||
|
||||
|
||||
def main(target, out_dir):
|
||||
n = tvm.var('n')
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
B = tvm.placeholder((n,), name='B')
|
||||
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
if target == 'cuda':
|
||||
bx, tx = s[C].split(C.op.axis[0], factor=64)
|
||||
s[C].bind(bx, tvm.thread_axis('blockIdx.x'))
|
||||
s[C].bind(tx, tvm.thread_axis('threadIdx.x'))
|
||||
|
||||
fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd')
|
||||
|
||||
fadd.save(osp.join(out_dir, 'test_add.o'))
|
||||
if target == 'cuda':
|
||||
fadd.imported_modules[0].save(os.path.join(out_dir, 'test_add.ptx'))
|
||||
cc.create_shared(
|
||||
osp.join(out_dir, 'test_add.so'), [osp.join(out_dir, 'test_add.o')])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1], sys.argv[2])
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
[package]
|
||||
name = "callback"
|
||||
version = "0.0.0"
|
||||
authors = ["TVM Contributors"]
|
||||
|
||||
[dependencies]
|
||||
ndarray = "0.12.1"
|
||||
tvm-frontend = { path = "../../" }
|
|
@ -0,0 +1,44 @@
|
|||
#![feature(extern_crate_item_prelude, try_from)]
|
||||
#![allow(unused_imports)]
|
||||
|
||||
extern crate ndarray as rust_ndarray;
|
||||
#[macro_use]
|
||||
extern crate tvm_frontend as tvm;
|
||||
|
||||
use rust_ndarray::ArrayD;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
||||
use tvm::*;
|
||||
|
||||
fn main() {
|
||||
register_global_func! {
|
||||
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
let mut ret = 0f32;
|
||||
let shape = &mut [2];
|
||||
for arg in args.iter() {
|
||||
let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
|
||||
let arg: NDArray = arg.try_into()?;
|
||||
let arr = arg.copy_to_ndarray(e)?;
|
||||
let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?;
|
||||
ret += rnd.scalar_sum();
|
||||
}
|
||||
Ok(TVMRetValue::from(ret))
|
||||
}
|
||||
}
|
||||
|
||||
let shape = &mut [2];
|
||||
let mut data = vec![3f32, 4.0];
|
||||
let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
|
||||
arr.copy_from_buffer(data.as_mut_slice());
|
||||
|
||||
let mut registered = function::Builder::default();
|
||||
let ret: f32 = registered
|
||||
.get_function("sum", true)
|
||||
.arg(&arr)
|
||||
.arg(&arr)
|
||||
.invoke()
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
assert_eq!(ret, 14f32);
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
#![feature(extern_crate_item_prelude, panic_info_message)]
|
||||
#![allow(unused_imports)]
|
||||
|
||||
use std::panic;
|
||||
|
||||
#[macro_use]
|
||||
extern crate tvm_frontend as tvm;
|
||||
|
||||
use tvm::*;
|
||||
|
||||
fn main() {
|
||||
register_global_func! {
|
||||
fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
Err(ErrorKind::TypeMismatch(
|
||||
format!("{}", "i64".to_string()),
|
||||
format!("{}", "f64".to_string()),
|
||||
).into())
|
||||
}
|
||||
}
|
||||
|
||||
let mut registered = function::Builder::default();
|
||||
registered.get_function("error", true);
|
||||
assert!(registered.func.is_some());
|
||||
registered.args(&[10, 20]);
|
||||
|
||||
println!("expected error message is:");
|
||||
panic::set_hook(Box::new(|panic_info| {
|
||||
if let Some(msg) = panic_info.message() {
|
||||
println!("{:?}", msg);
|
||||
}
|
||||
if let Some(location) = panic_info.location() {
|
||||
println!(
|
||||
"panic occurred in file '{}' at line {}",
|
||||
location.file(),
|
||||
location.line()
|
||||
);
|
||||
} else {
|
||||
println!("panic occurred but can't get location information");
|
||||
}
|
||||
}));
|
||||
|
||||
let _result = registered.invoke();
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
#![feature(extern_crate_item_prelude, try_from)]
|
||||
#![allow(unused_imports)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate tvm_frontend as tvm;
|
||||
|
||||
use std::convert::TryInto;
|
||||
use tvm::*;
|
||||
|
||||
fn main() {
|
||||
register_global_func! {
|
||||
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
let mut ret = 0.0;
|
||||
for arg in args.iter() {
|
||||
let val: f64 = arg.try_into()?;
|
||||
ret += val;
|
||||
}
|
||||
Ok(TVMRetValue::from(&ret))
|
||||
}
|
||||
}
|
||||
|
||||
let mut registered = function::Builder::default();
|
||||
registered.get_function("sum", true);
|
||||
assert!(registered.func.is_some());
|
||||
let ret: f64 = registered
|
||||
.args(&[10.0f64, 20.0, 30.0])
|
||||
.invoke()
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
assert_eq!(ret, 60f64);
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
#![feature(extern_crate_item_prelude, try_from)]
|
||||
#![allow(unused_imports)]
|
||||
|
||||
extern crate tvm_frontend as tvm;
|
||||
|
||||
use std::convert::TryInto;
|
||||
use tvm::*;
|
||||
|
||||
fn main() {
|
||||
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
let mut ret = 0i64;
|
||||
for arg in args.iter() {
|
||||
let val: i64 = arg.try_into()?;
|
||||
ret += val;
|
||||
}
|
||||
Ok(TVMRetValue::from(&ret))
|
||||
}
|
||||
|
||||
tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
|
||||
|
||||
let mut registered = function::Builder::default();
|
||||
registered.get_function("mysum", true);
|
||||
assert!(registered.func.is_some());
|
||||
let ret: i64 = registered
|
||||
.args(&[10, 20, 30])
|
||||
.invoke()
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
assert_eq!(ret, 60);
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
#![feature(extern_crate_item_prelude, try_from)]
|
||||
#![allow(unused_imports)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate tvm_frontend as tvm;
|
||||
use std::convert::TryInto;
|
||||
use tvm::*;
|
||||
|
||||
// FIXME
|
||||
fn main() {
|
||||
register_global_func! {
|
||||
fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
let mut ret = "".to_string();
|
||||
for arg in args.iter() {
|
||||
let val: String = arg.try_into()?;
|
||||
ret += val.as_str();
|
||||
}
|
||||
Ok(TVMRetValue::from(ret))
|
||||
}
|
||||
}
|
||||
let mut registered = function::Builder::default();
|
||||
registered.get_function("concate_str", true);
|
||||
assert!(registered.func.is_some());
|
||||
let a = "a".to_string();
|
||||
let b = "b".to_string();
|
||||
let c = "c".to_string();
|
||||
let ret: String = registered
|
||||
.args(&[a, b, c])
|
||||
.invoke()
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
assert_eq!(ret, "abc".to_owned());
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
language: rust
|
||||
rust:
|
||||
- nightly
|
||||
matrix:
|
||||
fast_finish: true
|
|
@ -0,0 +1,29 @@
|
|||
[package]
|
||||
name = "tvm-runtime"
|
||||
version = "0.1.0"
|
||||
license = "Apache-2.0"
|
||||
description = "A static TVM runtime"
|
||||
repository = "https://github.com/dmlc/tvm"
|
||||
readme = "README.md"
|
||||
keywords = ["tvm", "nnvm"]
|
||||
categories = ["api-bindings", "science"]
|
||||
authors = ["TVM Contributors"]
|
||||
|
||||
[features]
|
||||
default = ["nom/std"]
|
||||
sgx = ["nom/alloc"]
|
||||
|
||||
[dependencies]
|
||||
bounded-spsc-queue = "0.4.0"
|
||||
error-chain = { version = "0.12.0", default-features = false }
|
||||
itertools = "0.7.8"
|
||||
lazy_static = "1.1.0"
|
||||
ndarray = "0.11.2"
|
||||
nom = {version = "4.0.0", default-features = false }
|
||||
serde = "1.0.59"
|
||||
serde_derive = "1.0.79"
|
||||
serde_json = "1.0.17"
|
||||
tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] }
|
||||
|
||||
[target.'cfg(not(target_env = "sgx"))'.dependencies]
|
||||
num_cpus = "1.8.0"
|
|
@ -0,0 +1,52 @@
|
|||
#[cfg(target_env = "sgx")]
|
||||
use alloc::alloc::{self, Layout};
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
use std::alloc::{self, Layout};
|
||||
|
||||
use crate::errors::*;
|
||||
|
||||
const DEFAULT_ALIGN_BYTES: usize = 4;
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct Allocation {
|
||||
layout: Layout,
|
||||
ptr: *mut u8,
|
||||
}
|
||||
|
||||
impl Allocation {
|
||||
/// Allocates a chunk of memory of `size` bytes with optional alignment.
|
||||
pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
|
||||
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
|
||||
let layout = Layout::from_size_align(size, alignment)?;
|
||||
let ptr = unsafe { alloc::alloc(layout.clone()) };
|
||||
if ptr.is_null() {
|
||||
alloc::handle_alloc_error(layout);
|
||||
}
|
||||
Ok(Self {
|
||||
ptr: ptr,
|
||||
layout: layout,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn as_mut_ptr(&self) -> *mut u8 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
/// Returns the size of the Allocation in bytes.
|
||||
pub fn size(&self) -> usize {
|
||||
self.layout.size()
|
||||
}
|
||||
|
||||
/// Returns the byte alignment of the Allocation.
|
||||
pub fn align(&self) -> usize {
|
||||
self.layout.align()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Allocation {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
alloc::dealloc(self.ptr, self.layout.clone());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,507 @@
|
|||
use std::{
|
||||
any::TypeId,
|
||||
convert::TryFrom,
|
||||
mem,
|
||||
ops::{Deref, DerefMut},
|
||||
os::raw::{c_int, c_void},
|
||||
ptr, slice,
|
||||
};
|
||||
|
||||
use ndarray;
|
||||
|
||||
use crate::{
|
||||
allocator::Allocation,
|
||||
errors::*,
|
||||
ffi::runtime::{
|
||||
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt,
|
||||
DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor,
|
||||
},
|
||||
};
|
||||
|
||||
/// A `Storage` is a container which holds `Tensor` data.
|
||||
#[derive(PartialEq)]
|
||||
pub enum Storage<'a> {
|
||||
/// A `Storage` which owns its contained bytes.
|
||||
Owned(Allocation),
|
||||
|
||||
/// A view of an existing `Storage`.
|
||||
View(&'a mut [u8], usize), // ptr, align
|
||||
}
|
||||
|
||||
impl<'a> Storage<'a> {
|
||||
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
|
||||
Ok(Storage::Owned(Allocation::new(size, align)?))
|
||||
}
|
||||
|
||||
pub fn as_mut_ptr(&self) -> *mut u8 {
|
||||
match self {
|
||||
Storage::Owned(alloc) => alloc.as_mut_ptr(),
|
||||
Storage::View(slice, _) => slice.as_ptr() as *mut u8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn size(&self) -> usize {
|
||||
match self {
|
||||
Storage::Owned(alloc) => alloc.size(),
|
||||
Storage::View(slice, _) => slice.len(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn align(&self) -> usize {
|
||||
match self {
|
||||
Storage::Owned(alloc) => alloc.align(),
|
||||
Storage::View(_, align) => *align,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.as_mut_ptr() as *const _
|
||||
}
|
||||
|
||||
/// Returns a `Storage::View` which points to an owned `Storage::Owned`.
|
||||
pub fn view(&self) -> Storage<'a> {
|
||||
match self {
|
||||
Storage::Owned(alloc) => Storage::View(
|
||||
unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
|
||||
self.align(),
|
||||
),
|
||||
Storage::View(slice, _) => Storage::View(
|
||||
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
|
||||
self.align(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_owned(&self) -> bool {
|
||||
match self {
|
||||
Storage::Owned(_) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an owned version of this storage via cloning.
|
||||
pub fn to_owned(&self) -> Storage<'static> {
|
||||
let s = Storage::new(self.size(), Some(self.align())).unwrap();
|
||||
unsafe {
|
||||
s.as_mut_ptr()
|
||||
.copy_from_nonoverlapping(self.as_ptr(), self.size());
|
||||
}
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
|
||||
fn from(data: &'d [T]) -> Self {
|
||||
let data = unsafe {
|
||||
slice::from_raw_parts_mut(
|
||||
data.as_ptr() as *const u8 as *mut u8,
|
||||
data.len() * mem::size_of::<T>() as usize,
|
||||
)
|
||||
};
|
||||
Storage::View(data, mem::align_of::<T>())
|
||||
}
|
||||
}
|
||||
|
||||
/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
|
||||
/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
|
||||
/// converted to `ndarray::Array` for non-TVM processing.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// extern crate ndarray;
|
||||
///
|
||||
/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
|
||||
/// let mut a: Tensor = a_nd.into();
|
||||
/// let mut a_dl: DLTensor = (&mut t).into();
|
||||
/// call_packed!(tvm_fn, &mut a_dl);
|
||||
///
|
||||
/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
|
||||
/// let mut a_nd = ndarray::Array::try_from(&a).unwrap();
|
||||
/// ```
|
||||
#[derive(PartialEq)]
|
||||
pub struct Tensor<'a> {
|
||||
/// The bytes which contain the data this `Tensor` represents.
|
||||
pub(crate) data: Storage<'a>,
|
||||
pub(crate) ctx: TVMContext,
|
||||
pub(crate) dtype: DataType,
|
||||
pub(crate) shape: Vec<i64>,
|
||||
// ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
|
||||
/// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
|
||||
pub(crate) strides: Option<Vec<usize>>,
|
||||
pub(crate) byte_offset: isize,
|
||||
/// The number of elements in the `Tensor`.
|
||||
pub(crate) size: usize,
|
||||
}
|
||||
|
||||
unsafe impl<'a> Send for Tensor<'a> {}
|
||||
|
||||
impl<'a> Tensor<'a> {
|
||||
pub fn shape(&self) -> Vec<i64> {
|
||||
self.shape.clone()
|
||||
}
|
||||
|
||||
/// Returns the data of this `Tensor` as a `Vec`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
|
||||
pub fn to_vec<T: 'static + std::fmt::Debug + Clone>(&self) -> Vec<T> {
|
||||
assert!(self.is_contiguous());
|
||||
assert!(self.dtype.is_type::<T>());
|
||||
unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.size).to_vec() }
|
||||
}
|
||||
|
||||
/// Returns `true` iff this `Tensor` is represented by a contiguous region of memory.
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
match self.strides {
|
||||
None => true,
|
||||
Some(ref strides) => {
|
||||
// check that stride for each dimension is the
|
||||
// product of all trailing dimensons' shapes
|
||||
self.shape
|
||||
.iter()
|
||||
.zip(strides)
|
||||
.rfold(
|
||||
(true, 1),
|
||||
|(is_contig, expected_stride), (shape, stride)| {
|
||||
(
|
||||
is_contig && *stride == expected_stride,
|
||||
expected_stride * (*shape as usize),
|
||||
)
|
||||
},
|
||||
)
|
||||
.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a clone of this `Tensor`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
|
||||
pub fn copy(&mut self, other: &Tensor) {
|
||||
assert!(
|
||||
self.dtype == other.dtype && self.size == other.size,
|
||||
"Tensor shape/dtype mismatch."
|
||||
);
|
||||
assert!(
|
||||
self.is_contiguous() && other.is_contiguous(),
|
||||
"copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
|
||||
self.strides,
|
||||
other.strides
|
||||
);
|
||||
unsafe {
|
||||
self.data
|
||||
.as_mut_ptr()
|
||||
.offset(self.byte_offset as isize)
|
||||
.copy_from_nonoverlapping(
|
||||
other.data.as_mut_ptr().offset(other.byte_offset),
|
||||
other.size * other.dtype.itemsize(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an owned version of this `Tensor` via cloning.
|
||||
pub fn to_owned(&self) -> Tensor<'static> {
|
||||
let t = Tensor {
|
||||
data: self.data.to_owned(),
|
||||
ctx: self.ctx.clone(),
|
||||
dtype: self.dtype.clone(),
|
||||
size: self.size.clone(),
|
||||
shape: self.shape.clone(),
|
||||
strides: None,
|
||||
byte_offset: 0,
|
||||
};
|
||||
unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
|
||||
}
|
||||
|
||||
fn from_array_storage<'s, T, D: ndarray::Dimension>(
|
||||
arr: &ndarray::Array<T, D>,
|
||||
storage: Storage<'s>,
|
||||
type_code: usize,
|
||||
) -> Tensor<'s> {
|
||||
let type_width = mem::size_of::<T>() as usize;
|
||||
Tensor {
|
||||
data: storage,
|
||||
ctx: TVMContext::default(),
|
||||
dtype: DataType {
|
||||
code: type_code,
|
||||
bits: 8 * type_width,
|
||||
lanes: 1,
|
||||
},
|
||||
size: arr.len(),
|
||||
shape: arr.shape().iter().map(|&v| v as i64).collect(),
|
||||
strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
|
||||
byte_offset: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
|
||||
macro_rules! impl_ndarray_try_from_tensor {
|
||||
($type:ty, $dtype:expr) => {
|
||||
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
|
||||
type Error = Error;
|
||||
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
|
||||
ensure!(
|
||||
tensor.dtype == $dtype,
|
||||
"Cannot convert Tensor with dtype {:?} to ndarray",
|
||||
tensor.dtype
|
||||
);
|
||||
Ok(ndarray::Array::from_shape_vec(
|
||||
tensor
|
||||
.shape
|
||||
.iter()
|
||||
.map(|s| *s as usize)
|
||||
.collect::<Vec<usize>>(),
|
||||
tensor.to_vec::<$type>(),
|
||||
)?)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
|
||||
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
|
||||
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
|
||||
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
|
||||
|
||||
pub struct DLTensor {
|
||||
pub(crate) inner: _DLTensor,
|
||||
}
|
||||
|
||||
impl Deref for DLTensor {
|
||||
type Target = _DLTensor;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for DLTensor {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl DLTensor {
|
||||
pub(crate) fn new(raw: _DLTensor) -> Self {
|
||||
Self { inner: raw }
|
||||
}
|
||||
|
||||
pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
|
||||
assert!(!flatten || tensor.is_contiguous());
|
||||
Self {
|
||||
inner: _DLTensor {
|
||||
data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
|
||||
ctx: DLContext::from(&tensor.ctx),
|
||||
ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
|
||||
dtype: DLDataType::from(&tensor.dtype),
|
||||
shape: if flatten {
|
||||
&tensor.size as *const _ as *mut i64
|
||||
} else {
|
||||
tensor.shape.as_ptr()
|
||||
} as *mut i64,
|
||||
strides: if flatten || tensor.is_contiguous() {
|
||||
ptr::null_mut()
|
||||
} else {
|
||||
tensor.strides.as_ref().unwrap().as_ptr()
|
||||
} as *mut i64,
|
||||
byte_offset: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
|
||||
fn from(tensor: &'a Tensor<'t>) -> Self {
|
||||
DLTensor::from_tensor(tensor, false /* flatten */)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
|
||||
fn from(tensor: &'a mut Tensor<'t>) -> Self {
|
||||
DLTensor::from_tensor(tensor, false /* flatten */)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct DataType {
|
||||
pub(crate) code: usize,
|
||||
pub(crate) bits: usize,
|
||||
pub(crate) lanes: usize,
|
||||
}
|
||||
|
||||
impl DataType {
|
||||
/// Returns the number of bytes occupied by an element of this `DataType`.
|
||||
pub fn itemsize(&self) -> usize {
|
||||
(self.bits * self.lanes) >> 3
|
||||
}
|
||||
|
||||
/// Returns whether this `DataType` represents primitive type `T`.
|
||||
pub fn is_type<T: 'static>(&self) -> bool {
|
||||
if self.lanes != 1 {
|
||||
return false;
|
||||
}
|
||||
let typ = TypeId::of::<T>();
|
||||
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
|
||||
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
|
||||
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
|
||||
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
|
||||
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
|
||||
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a DataType> for DLDataType {
|
||||
fn from(dtype: &'a DataType) -> Self {
|
||||
Self {
|
||||
code: dtype.code as u8,
|
||||
bits: dtype.bits as u8,
|
||||
lanes: dtype.lanes as u16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DLDataType> for DataType {
|
||||
fn from(dtype: DLDataType) -> Self {
|
||||
Self {
|
||||
code: dtype.code as usize,
|
||||
bits: dtype.bits as usize,
|
||||
lanes: dtype.lanes as usize,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! make_dtype_const {
|
||||
($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
|
||||
const $name: DataType = DataType {
|
||||
code: $code as usize,
|
||||
bits: $bits,
|
||||
lanes: $lanes,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
|
||||
make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
|
||||
// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
|
||||
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
|
||||
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct TVMContext {
|
||||
pub(crate) device_type: usize,
|
||||
pub(crate) device_id: usize,
|
||||
}
|
||||
|
||||
impl<'a> From<&'a TVMContext> for DLContext {
|
||||
fn from(ctx: &'a TVMContext) -> Self {
|
||||
Self {
|
||||
device_type: ctx.device_type as u32,
|
||||
device_id: ctx.device_id as i32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TVMContext {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
device_type: DLDeviceType_kDLCPU as usize,
|
||||
device_id: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<DLTensor> for Tensor<'a> {
|
||||
fn from(dlt: DLTensor) -> Self {
|
||||
unsafe {
|
||||
let dtype = DataType::from(dlt.dtype);
|
||||
let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
|
||||
let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
|
||||
let storage = Storage::from(slice::from_raw_parts(
|
||||
dlt.data as *const u8,
|
||||
dtype.itemsize() * size,
|
||||
));
|
||||
Self {
|
||||
data: storage,
|
||||
ctx: TVMContext::default(),
|
||||
dtype: dtype,
|
||||
size: size,
|
||||
shape: shape,
|
||||
strides: if dlt.strides == ptr::null_mut() {
|
||||
None
|
||||
} else {
|
||||
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
|
||||
},
|
||||
byte_offset: dlt.byte_offset as isize,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the ndarray is not contiguous.
|
||||
macro_rules! impl_tensor_from_ndarray {
|
||||
($type:ty, $typecode:expr) => {
|
||||
impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
|
||||
fn from(arr: ndarray::Array<$type, D>) -> Self {
|
||||
let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
|
||||
Tensor::from_array_storage(&arr, storage.to_owned(), $typecode as usize)
|
||||
}
|
||||
}
|
||||
impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
|
||||
fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
|
||||
let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
|
||||
Tensor::from_array_storage(arr, storage, $typecode as usize)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// `From` conversions to `DLTensor` for `ndarray::Array`.
|
||||
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
|
||||
macro_rules! impl_dltensor_from_ndarray {
|
||||
($type:ty, $typecode:expr) => {
|
||||
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
|
||||
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
|
||||
DLTensor {
|
||||
inner: _DLTensor {
|
||||
data: arr.as_mut_ptr() as *mut c_void,
|
||||
ctx: DLContext {
|
||||
device_type: DLDeviceType_kDLCPU,
|
||||
device_id: 0,
|
||||
},
|
||||
ndim: arr.ndim() as c_int,
|
||||
dtype: DLDataType {
|
||||
code: $typecode as u8,
|
||||
bits: 8 * mem::size_of::<$type>() as u8,
|
||||
lanes: 1,
|
||||
},
|
||||
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
|
||||
strides: arr.strides().as_ptr() as *const isize as *mut i64,
|
||||
byte_offset: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
|
||||
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
|
||||
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
|
||||
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
|
||||
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
|
||||
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
|
||||
|
||||
impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
|
||||
impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
|
||||
impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
|
||||
impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
|
||||
impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
|
||||
impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
|
|
@ -4,16 +4,12 @@ use alloc::alloc;
|
|||
use std::alloc;
|
||||
use std::num;
|
||||
|
||||
use crate::common::errors as common_errors;
|
||||
use ndarray;
|
||||
use serde_json;
|
||||
|
||||
error_chain! {
|
||||
errors {
|
||||
TryFromTVMRetValueError(expected: String, actual: i64) {
|
||||
description("mismatched types while downcasting TVMRetValue")
|
||||
display("invalid downcast: expected `{}` but was `{}`", expected, actual)
|
||||
}
|
||||
|
||||
GraphFormatError(msg: String) {
|
||||
description("unable to load graph")
|
||||
display("could not load graph json: {}", msg)
|
||||
|
@ -29,11 +25,12 @@ error_chain! {
|
|||
GraphDeserialize(serde_json::Error);
|
||||
ParseInt(num::ParseIntError);
|
||||
ShapeError(ndarray::ShapeError);
|
||||
CommonError(common_errors::Error);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<alloc::LayoutErr> for Error {
|
||||
fn from(_err: alloc::LayoutErr) -> Error {
|
||||
Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
|
||||
}
|
||||
fn from(_err: alloc::LayoutErr) -> Error {
|
||||
Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,473 @@
|
|||
use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
|
||||
|
||||
use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
|
||||
use serde;
|
||||
use serde_json;
|
||||
|
||||
use super::{DLTensor, DataType, Module, Storage, TVMContext, Tensor};
|
||||
use crate::{
|
||||
common::value::TVMArgValue,
|
||||
errors::{Error, ErrorKind, Result},
|
||||
ffi::runtime::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt},
|
||||
};
|
||||
|
||||
// @see `kTVMNDArrayMagic` in `ndarray.h`
|
||||
const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
|
||||
// @see `kTVMNDArrayListMagic` in `graph_runtime.h`
|
||||
const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7;
|
||||
|
||||
/// A TVM computation graph.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// let graph_json = fs::read_to_string("graph.json")).unwrap();
|
||||
/// let graph = Graph::try_from(&graph_json).unwrap();
|
||||
/// ```
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Graph {
|
||||
pub nodes: Vec<Node>,
|
||||
pub arg_nodes: Vec<usize>,
|
||||
pub heads: Vec<Entry>,
|
||||
pub node_row_ptr: Option<Vec<usize>>,
|
||||
pub attrs: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Entry {
|
||||
pub id: usize,
|
||||
pub index: usize,
|
||||
pub version: usize,
|
||||
}
|
||||
|
||||
impl Graph {
|
||||
fn entry_index(&self, entry: &Entry) -> Result<usize> {
|
||||
self.node_row_ptr
|
||||
.as_ref()
|
||||
.map(|nrp| nrp[entry.id] + entry.index)
|
||||
.ok_or("Missing node_row_ptr.".into())
|
||||
}
|
||||
|
||||
/// Attempt to deserialize a JSON attribute to a type `T`.
|
||||
fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
|
||||
Ok(serde_json::from_value::<T>(
|
||||
self.attrs
|
||||
.as_ref()
|
||||
.ok_or(ErrorKind::GraphFormatError(
|
||||
"Missing graph attrs".to_string(),
|
||||
))?
|
||||
.get(attr)
|
||||
.ok_or(ErrorKind::GraphFormatError(format!(
|
||||
"Missing {} attr",
|
||||
attr
|
||||
)))?
|
||||
.to_owned(),
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node {
|
||||
pub op: String,
|
||||
pub name: String,
|
||||
pub inputs: Vec<Entry>,
|
||||
pub attrs: Option<HashMap<String, String>>,
|
||||
pub control_deps: Option<Vec<Entry>>,
|
||||
}
|
||||
|
||||
struct NodeAttrs {
|
||||
func_name: String,
|
||||
num_outputs: usize,
|
||||
flatten_data: bool,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
fn parse_attrs(&self) -> Result<NodeAttrs> {
|
||||
let attrs = self
|
||||
.attrs
|
||||
.as_ref()
|
||||
.ok_or(format!("Missing node.attrs for `{}`", self.name))?;
|
||||
let func_name = attrs
|
||||
.get("func_name")
|
||||
.ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
|
||||
.to_string();
|
||||
let num_outputs = attrs
|
||||
.get("num_outputs")
|
||||
.ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
|
||||
.parse::<usize>()?;
|
||||
let flatten_data = attrs
|
||||
.get("flatten_data")
|
||||
.ok_or(format!(
|
||||
"Node `{}` is missing attrs.flatten_data",
|
||||
self.name
|
||||
))?
|
||||
.parse::<u8>()?
|
||||
== 1;
|
||||
Ok(NodeAttrs {
|
||||
func_name,
|
||||
num_outputs,
|
||||
flatten_data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a String> for Graph {
|
||||
type Error = Error;
|
||||
fn try_from(graph_json: &String) -> Result<Self> {
|
||||
let graph = serde_json::from_str(graph_json)?;
|
||||
Ok(graph)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a str> for Graph {
|
||||
type Error = Error;
|
||||
fn try_from(graph_json: &'a str) -> Result<Self> {
|
||||
let graph = serde_json::from_str(graph_json)?;
|
||||
Ok(graph)
|
||||
}
|
||||
}
|
||||
|
||||
/// A executor for a TVM computation graph.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use ndarray::Array;
|
||||
///
|
||||
/// let syslib = SystemLibModule::default(); // a provider of TVM functions
|
||||
///
|
||||
/// let mut params_bytes = Vec::new();
|
||||
/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
|
||||
/// let params = tvm::runtime::load_param_dict(¶ms_bytes).unwrap();
|
||||
///
|
||||
/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
|
||||
///
|
||||
/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
|
||||
/// exec.load_params(params);
|
||||
///
|
||||
/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
|
||||
/// exec.set_input("data", x.into());
|
||||
/// exec.run();
|
||||
/// let output = exec.get_output(0).unwrap();
|
||||
///
|
||||
/// println!("{:#?}", Array::try_from(output).unwrap());
|
||||
/// ```
|
||||
pub struct GraphExecutor<'m, 't> {
|
||||
graph: Graph,
|
||||
op_execs: Vec<Box<Fn() + 'm>>,
|
||||
tensors: Vec<Tensor<'t>>,
|
||||
}
|
||||
|
||||
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
|
||||
|
||||
impl<'m, 't> GraphExecutor<'m, 't> {
|
||||
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
|
||||
let tensors = Self::setup_storages(&graph)?;
|
||||
Ok(GraphExecutor {
|
||||
op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
|
||||
tensors: tensors,
|
||||
graph: graph,
|
||||
})
|
||||
}
|
||||
|
||||
/// Runs the computation graph.
|
||||
pub fn run(&self) {
|
||||
self.op_execs.iter().for_each(|op_exec| {
|
||||
op_exec();
|
||||
});
|
||||
}
|
||||
|
||||
/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
|
||||
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
|
||||
let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
|
||||
let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
|
||||
let dtypes = graph
|
||||
.get_attr::<(String, Vec<String>)>("dltype")?
|
||||
.1
|
||||
.iter()
|
||||
.map(|dltype| {
|
||||
if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
|
||||
Ok(dtype)
|
||||
} else {
|
||||
Err(ErrorKind::GraphFormatError(
|
||||
format!("Invalid dltype: {}", dltype).to_string(),
|
||||
)
|
||||
.into())
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<DataType>>>()?;
|
||||
|
||||
let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
|
||||
let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
|
||||
for (i, &storage_id) in storage_ids.iter().enumerate() {
|
||||
let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
|
||||
let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
|
||||
storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
|
||||
}
|
||||
|
||||
let mut storages: Vec<Storage> = storage_num_bytes
|
||||
.into_iter()
|
||||
.map(|nbytes| Storage::new(nbytes, align))
|
||||
.collect::<Result<Vec<Storage>>>()?;
|
||||
|
||||
let tensors = izip!(storage_ids, shapes, dtypes)
|
||||
.map(|(storage_id, shape, dtype)| {
|
||||
let storage = storages[storage_id].view();
|
||||
Tensor {
|
||||
data: mem::replace(&mut storages[storage_id], storage),
|
||||
ctx: TVMContext::default(),
|
||||
dtype: dtype,
|
||||
size: shape.iter().product::<i64>() as usize,
|
||||
shape: shape,
|
||||
strides: None,
|
||||
byte_offset: 0,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(tensors)
|
||||
}
|
||||
|
||||
/// Creates closures which represent the computation performed by this graph.
|
||||
fn setup_op_execs<M: 'm + Module>(
|
||||
graph: &Graph,
|
||||
lib: &'m M,
|
||||
tensors: &Vec<Tensor<'t>>,
|
||||
) -> Result<Vec<Box<Fn() + 'm>>> {
|
||||
ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
|
||||
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
|
||||
|
||||
let mut op_execs = Vec::new();
|
||||
for (i, node) in graph.nodes.iter().enumerate() {
|
||||
if node.op == "null" {
|
||||
continue;
|
||||
}
|
||||
ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
|
||||
ensure!(node.attrs.is_some(), "Missing node attrs.");
|
||||
|
||||
let attrs = node.parse_attrs()?;
|
||||
|
||||
if attrs.func_name == "__nop" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let func = lib
|
||||
.get_function(&attrs.func_name)
|
||||
.ok_or(format!("Missing function {}", attrs.func_name))?;
|
||||
let arg_indices = node
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|entry| graph.entry_index(entry))
|
||||
.chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi)));
|
||||
|
||||
let dl_tensors = arg_indices
|
||||
.map(|idx| {
|
||||
let tensor = &tensors[idx?];
|
||||
Ok(if attrs.flatten_data {
|
||||
DLTensor::from_tensor(tensor, true /* flatten */)
|
||||
} else {
|
||||
DLTensor::from(tensor)
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<DLTensor>>>()
|
||||
.unwrap();
|
||||
let op: Box<Fn()> = box move || {
|
||||
let args = dl_tensors
|
||||
.iter()
|
||||
.map(|t| t.into())
|
||||
.collect::<Vec<TVMArgValue>>();
|
||||
func(args.as_slice());
|
||||
};
|
||||
op_execs.push(op);
|
||||
}
|
||||
Ok(op_execs)
|
||||
}
|
||||
|
||||
pub fn load_params(&mut self, params: HashMap<String, Tensor>) {
|
||||
params.into_iter().for_each(|(name, param)| {
|
||||
self.set_input(name, param);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_input<S: AsRef<str>>(&mut self, name: S, value: Tensor) {
|
||||
if let Some(idx) = self.get_input_index(name.as_ref()) {
|
||||
// TODO: consider `new_with_params` to avoid ever allocating
|
||||
let ptr = self.tensors[idx].data.as_ptr();
|
||||
let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr);
|
||||
let mut owner = to_replace.nth(0).unwrap();
|
||||
if value.data.is_owned() {
|
||||
// FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
|
||||
// mem::replace(&mut (*owner), value);
|
||||
// to_replace.for_each(|t| {
|
||||
// panic!("replacing");
|
||||
// t.data = owner.data.view();
|
||||
// });
|
||||
owner.copy(&value);
|
||||
} else {
|
||||
owner.copy(&value);
|
||||
}
|
||||
} else {
|
||||
println!("Unexpected input `{}`", name.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the graph input with name `name`, if it exists.
|
||||
pub fn get_input<S: AsRef<str>>(&mut self, name: S) -> Option<&Tensor> {
|
||||
self.get_input_index(name.as_ref())
|
||||
.and_then(move |idx| Some(&self.tensors[idx]))
|
||||
}
|
||||
|
||||
/// Returns the graph output with index `index`, if it exists.
|
||||
pub fn get_output(&self, idx: usize) -> Option<&Tensor> {
|
||||
let graph = &self.graph;
|
||||
graph.heads.get(idx).and_then(|entry| {
|
||||
graph
|
||||
.entry_index(entry)
|
||||
.map(|idx| self.tensors.get(idx))
|
||||
.unwrap_or(None)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the index for graph input with name `name`, if it exists.
|
||||
pub fn get_input_index<S: AsRef<str>>(&self, name: S) -> Option<usize> {
|
||||
let graph = &self.graph;
|
||||
(0..graph.nodes.len())
|
||||
.skip_while(|&i| graph.nodes[i].name != name.as_ref())
|
||||
.nth(0)
|
||||
.and_then(|i| {
|
||||
if graph.arg_nodes.iter().any(|&id| id == i) {
|
||||
graph.node_row_ptr.as_ref().map(|nrp| nrp[i])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
|
||||
named!(
|
||||
tvm_str_to_type<CompleteStr, DataType>,
|
||||
do_parse!(
|
||||
type_name: alpha1 >>
|
||||
bits: digit1 >>
|
||||
lanes: opt!(tuple!(tag!("x"), digit1)) >>
|
||||
(DataType {
|
||||
code: match type_name {
|
||||
CompleteStr("int") => DLDataTypeCode_kDLInt,
|
||||
CompleteStr("uint") => DLDataTypeCode_kDLUInt,
|
||||
CompleteStr("float") => DLDataTypeCode_kDLFloat,
|
||||
_ => DLDataTypeCode_kDLFloat,
|
||||
} as usize,
|
||||
bits: bits.parse::<u8>().unwrap() as usize,
|
||||
lanes: match lanes {
|
||||
Some(lanes) => lanes.1.parse::<u16>().unwrap() as usize,
|
||||
None => 1,
|
||||
},
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
/// Converts a bytes to String.
|
||||
named!(
|
||||
name<String>,
|
||||
map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
|
||||
b.to_vec()
|
||||
))
|
||||
);
|
||||
|
||||
/// Parses a TVMContext
|
||||
named!(
|
||||
tvm_ctx<&[u8], TVMContext>,
|
||||
do_parse!(
|
||||
device_type: le_u32 >>
|
||||
device_id: le_i32 >>
|
||||
(TVMContext { device_type: device_type as usize, device_id: device_id as usize })
|
||||
)
|
||||
);
|
||||
|
||||
/// Parses a DataType
|
||||
named!(
|
||||
data_type<&[u8], DataType>,
|
||||
do_parse!(
|
||||
code: le_u8 >>
|
||||
bits: le_u8 >>
|
||||
lanes: le_u16 >>
|
||||
(DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize })
|
||||
)
|
||||
);
|
||||
|
||||
/// Parses a Tensor from a TVM array file.
|
||||
named!(
|
||||
tensor<Tensor>,
|
||||
do_parse!(
|
||||
take!(8)
|
||||
>> bits!(tag_bits!(u64, 64, 0))
|
||||
>> ctx: tvm_ctx
|
||||
>> ndim: le_u32
|
||||
>> dtype: data_type
|
||||
>> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize)
|
||||
>> length: le_i64
|
||||
>> data: take!(length)
|
||||
>> (Tensor {
|
||||
data: Storage::from(data),
|
||||
ctx: ctx,
|
||||
dtype: dtype,
|
||||
size: shape.iter().product::<i64>() as usize,
|
||||
shape: shape,
|
||||
strides: None,
|
||||
byte_offset: 0,
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
/// Parses a graph params dict from a params binary file.
|
||||
named!(
|
||||
parse_param_dict<HashMap<String, Tensor>>,
|
||||
do_parse!(
|
||||
take!(8)
|
||||
>> bits!(tag_bits!(u64, 64, 0))
|
||||
>> names: length_count!(le_u64, name)
|
||||
>> tensors: length_count!(le_u64, tensor)
|
||||
>> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter())))
|
||||
)
|
||||
);
|
||||
|
||||
/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
|
||||
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
|
||||
if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
|
||||
if remaining_bytes.len() > 0 {
|
||||
bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
|
||||
} else {
|
||||
Ok(param_dict)
|
||||
}
|
||||
} else {
|
||||
bail!(ErrorKind::LoadGraphParamsError(
|
||||
"invalid parameters file".to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_str_to_type() {
|
||||
assert_eq!(
|
||||
tvm_str_to_type(CompleteStr("float24")).unwrap().1,
|
||||
DataType {
|
||||
code: DLDataTypeCode_kDLFloat as usize,
|
||||
bits: 24,
|
||||
lanes: 1
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1,
|
||||
DataType {
|
||||
code: DLDataTypeCode_kDLUInt as usize,
|
||||
bits: 111,
|
||||
lanes: 44
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
|
@ -10,13 +10,13 @@
|
|||
//! For examples of use, please refer to the multi-file tests in the `tests` directory.
|
||||
|
||||
#![feature(
|
||||
alloc,
|
||||
allocator_api,
|
||||
box_syntax,
|
||||
fn_traits,
|
||||
try_from,
|
||||
unboxed_closures,
|
||||
vec_remove_item
|
||||
alloc,
|
||||
allocator_api,
|
||||
box_syntax,
|
||||
fn_traits,
|
||||
try_from,
|
||||
unboxed_closures,
|
||||
vec_remove_item
|
||||
)]
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
|
@ -39,29 +39,36 @@ extern crate serde;
|
|||
#[macro_use]
|
||||
extern crate serde_derive;
|
||||
extern crate serde_json;
|
||||
extern crate tvm_common as common;
|
||||
|
||||
pub mod ffi {
|
||||
#![allow(
|
||||
non_camel_case_types,
|
||||
non_snake_case,
|
||||
non_upper_case_globals,
|
||||
unused
|
||||
)]
|
||||
|
||||
pub mod runtime {
|
||||
use std::os::raw::{c_char, c_int, c_void};
|
||||
|
||||
include!(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/src/runtime/c_runtime_api.rs"
|
||||
));
|
||||
|
||||
pub type BackendPackedCFunc =
|
||||
extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
|
||||
}
|
||||
}
|
||||
|
||||
mod allocator;
|
||||
mod array;
|
||||
pub mod errors;
|
||||
pub mod runtime;
|
||||
mod module;
|
||||
#[macro_use]
|
||||
mod packed_func;
|
||||
mod graph;
|
||||
#[cfg(target_env = "sgx")]
|
||||
#[macro_use]
|
||||
pub mod sgx;
|
||||
mod threading;
|
||||
mod workspace;
|
||||
|
||||
pub use errors::*;
|
||||
pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue};
|
||||
|
||||
pub use self::{
|
||||
array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*,
|
||||
};
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
use self::sgx::ocall_packed_func;
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
unsafe {
|
||||
panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
|
||||
}
|
||||
#[cfg(target_env = "sgx")]
|
||||
ocall_packed!("__sgx_set_last_error__", cmsg);
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
use std::{
|
||||
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
ffi::runtime::BackendPackedCFunc,
|
||||
packed_func::{wrap_backend_packed_func, PackedFunc},
|
||||
};
|
||||
|
||||
pub trait Module {
|
||||
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
|
||||
}
|
||||
|
||||
pub struct SystemLibModule;
|
||||
|
||||
lazy_static! {
|
||||
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
|
||||
Mutex::new(HashMap::new());
|
||||
}
|
||||
|
||||
impl Module for SystemLibModule {
|
||||
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
|
||||
SYSTEM_LIB_FUNCTIONS
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(name.as_ref())
|
||||
.map(|func| wrap_backend_packed_func(func.to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SystemLibModule {
|
||||
fn default() -> Self {
|
||||
SystemLibModule {}
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
|
||||
cname: *const c_char,
|
||||
func: BackendPackedCFunc,
|
||||
) -> i32 {
|
||||
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
|
||||
SYSTEM_LIB_FUNCTIONS
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(name.to_string(), func);
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void};
|
||||
|
||||
use super::Tensor;
|
||||
use crate::ffi::runtime::{
|
||||
BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle,
|
||||
TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue,
|
||||
};
|
||||
|
||||
use super::DLTensor;
|
||||
use crate::{
|
||||
common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue},
|
||||
errors::*,
|
||||
};
|
||||
|
||||
pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
|
||||
|
||||
/// Calls a packed function and returns a `TVMRetValue`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
|
||||
#[macro_export]
|
||||
macro_rules! call_packed {
|
||||
($fn:expr, $($args:expr),+) => {
|
||||
$fn(&[$($args.into(),)+])
|
||||
};
|
||||
($fn:expr) => {
|
||||
$fn(&Vec::new())
|
||||
};
|
||||
}
|
||||
|
||||
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
|
||||
fn from(arr: &'a DLTensor) -> Self {
|
||||
let raw = _TVMValue {
|
||||
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
|
||||
};
|
||||
TVMArgValue {
|
||||
value: TVMValue::new(raw),
|
||||
type_code: TVMTypeCode::kArrayHandle,
|
||||
lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
|
||||
fn from(arr: &'a mut DLTensor) -> Self {
|
||||
let raw = _TVMValue {
|
||||
v_handle: arr as *mut _ as *mut c_void,
|
||||
};
|
||||
TVMArgValue {
|
||||
value: TVMValue::new(raw),
|
||||
type_code: TVMTypeCode::kArrayHandle,
|
||||
lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
|
||||
type Error = Error;
|
||||
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
|
||||
ensure!(
|
||||
val.type_code == TVMTypeCode::kArrayHandle
|
||||
|| val.type_code == TVMTypeCode::kNDArrayContainer,
|
||||
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
|
||||
TVMTypeCode::kArrayHandle,
|
||||
TVMTypeCode::kNDArrayContainer,
|
||||
val.type_code,
|
||||
);
|
||||
|
||||
let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) };
|
||||
Ok(DLTensor::new(dlt).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
|
||||
fn from(val: &'t Tensor<'a>) -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: 0,
|
||||
box_value: box DLTensor::from(val),
|
||||
type_code: TVMTypeCode::kNDArrayContainer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<Self> {
|
||||
ensure!(
|
||||
ret.type_code == TVMTypeCode::kArrayHandle
|
||||
|| ret.type_code == TVMTypeCode::kNDArrayContainer,
|
||||
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
|
||||
TVMTypeCode_kArrayHandle,
|
||||
TVMTypeCode_kNDArrayContainer,
|
||||
ret.type_code,
|
||||
);
|
||||
|
||||
let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) };
|
||||
Ok(DLTensor::new(dlt).into())
|
||||
}
|
||||
}
|
||||
|
||||
// @see `WrapPackedFunc` in `llvm_module.cc`.
|
||||
pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
|
||||
box move |args: &[TVMArgValue]| {
|
||||
func(
|
||||
args.iter()
|
||||
.map(|ref arg| arg.value.inner)
|
||||
.collect::<Vec<_TVMValue>>()
|
||||
.as_ptr(),
|
||||
args.iter()
|
||||
.map(|ref arg| arg.type_code as i32)
|
||||
.collect::<Vec<i32>>()
|
||||
.as_ptr() as *const i32,
|
||||
args.len() as i32,
|
||||
);
|
||||
TVMRetValue::default()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
use std::{
|
||||
ffi::CString,
|
||||
os::raw::{c_char, c_int},
|
||||
};
|
||||
|
||||
use errors::Result;
|
||||
use ffi::runtime::TVMValue;
|
||||
use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
|
||||
|
||||
pub use runtime::threading::tvm_run_worker as run_worker;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tvm_ocall {
|
||||
($func: expr) => {
|
||||
match $func {
|
||||
0 => Ok(()),
|
||||
err => Err(format!("SGX error: {}", err)),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub type SgxStatus = u32;
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
extern "C" {
|
||||
fn tvm_ocall_packed_func(
|
||||
name: *const c_char,
|
||||
arg_values: *const TVMValue,
|
||||
type_codes: *const c_int,
|
||||
num_args: c_int,
|
||||
ret_val: *mut TVMValue,
|
||||
ret_type_code: *mut c_int,
|
||||
) -> SgxStatus;
|
||||
}
|
||||
|
||||
pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
let mut ret_val = TVMValue { v_int64: 0 };
|
||||
let ret_type_code = 0i64;
|
||||
unsafe {
|
||||
tvm_ocall!(tvm_ocall_packed_func(
|
||||
CString::new(fn_name.as_ref()).unwrap().as_ptr(),
|
||||
args.iter()
|
||||
.map(|ref arg| arg.value)
|
||||
.collect::<Vec<TVMValue>>()
|
||||
.as_ptr(),
|
||||
args.iter()
|
||||
.map(|ref arg| arg.type_code as i32)
|
||||
.collect::<Vec<i32>>()
|
||||
.as_ptr() as *const i32,
|
||||
args.len() as i32,
|
||||
&mut ret_val as *mut TVMValue,
|
||||
&mut (ret_type_code as i32) as *mut c_int,
|
||||
))?;
|
||||
}
|
||||
Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! ocall_packed {
|
||||
($fn_name:expr, $($args:expr),+) => {
|
||||
ocall_packed_func($fn_name, &[$($args.into(),)+])
|
||||
.expect(concat!("Error calling `", $fn_name, "`"))
|
||||
};
|
||||
($fn_name:expr) => {
|
||||
ocall_packed_func($fn_name, &Vec::new())
|
||||
.expect(concat!("Error calling `", $fn_name, "`"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shutdown() {
|
||||
if env!("TVM_NUM_THREADS") != "0" {
|
||||
sgx_join_threads()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SystemLibModule {
|
||||
fn drop(&mut self) {
|
||||
shutdown()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,336 @@
|
|||
use std::{
|
||||
os::raw::{c_int, c_void},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
|
||||
Arc, Barrier,
|
||||
},
|
||||
};
|
||||
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
use num_cpus;
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
use std::{
|
||||
env,
|
||||
thread::{self, JoinHandle},
|
||||
};
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
use std::{collections::VecDeque, ptr, sync::Mutex};
|
||||
|
||||
use bounded_spsc_queue::{self, Producer};
|
||||
|
||||
use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv};
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
|
||||
|
||||
type FTVMParallelLambda =
|
||||
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
|
||||
|
||||
/// Holds a parallel job request made by a TVM library function.
|
||||
struct Job {
|
||||
cb: FTVMParallelLambda,
|
||||
cdata: *const c_void,
|
||||
req_num_tasks: usize,
|
||||
pending: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl Job {
|
||||
/// Splits this job into a number of `Task`s which can be scheduled.
|
||||
fn tasks(&self, num_workers: usize) -> Vec<Task> {
|
||||
let num_tasks = if self.req_num_tasks == 0 {
|
||||
num_workers
|
||||
} else {
|
||||
self.req_num_tasks.min(num_workers)
|
||||
};
|
||||
self.pending.store(num_tasks, Ordering::SeqCst);
|
||||
|
||||
let barrier = Arc::new(Barrier::new(num_tasks));
|
||||
|
||||
(0..num_tasks)
|
||||
.map(move |i| Task {
|
||||
id: i,
|
||||
flambda: self.cb,
|
||||
penv: TVMParallelGroupEnv {
|
||||
sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void,
|
||||
num_task: num_tasks as i32,
|
||||
},
|
||||
cdata: self.cdata,
|
||||
pending: Arc::clone(&self.pending),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Waits for all tasks in this `Job` to be completed.
|
||||
fn wait(&self) -> Result<()> {
|
||||
while self.pending.load(Ordering::Acquire) > 0 {
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
thread::yield_now();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A chunk of work requested by a TVM function.
|
||||
struct Task {
|
||||
id: usize,
|
||||
flambda: FTVMParallelLambda,
|
||||
penv: TVMParallelGroupEnv,
|
||||
cdata: *const c_void,
|
||||
pending: Arc<AtomicUsize>,
|
||||
}
|
||||
unsafe impl Send for Task {}
|
||||
unsafe impl Sync for Task {}
|
||||
|
||||
impl FnOnce<()> for Task {
|
||||
type Output = i32;
|
||||
extern "rust-call" fn call_once(self, _args: ()) -> Self::Output {
|
||||
let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
|
||||
self.pending.fetch_sub(1, Ordering::AcqRel);
|
||||
status
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Threads {
|
||||
#[allow(unused)]
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
handles: Vec<JoinHandle<()>>,
|
||||
queues: Vec<Producer<Task>>,
|
||||
}
|
||||
|
||||
impl<'a> Threads {
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
|
||||
num_threads: usize,
|
||||
cb: F,
|
||||
) -> Self {
|
||||
let (handles, queues) = (0..num_threads)
|
||||
.map(|_| {
|
||||
let (p, c) = bounded_spsc_queue::make(2);
|
||||
let handle = thread::spawn(move || cb(c.into()));
|
||||
(handle, p)
|
||||
})
|
||||
.unzip();
|
||||
Threads {
|
||||
handles: handles,
|
||||
queues: queues,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
|
||||
num_threads: usize,
|
||||
_cb: F,
|
||||
) -> Self {
|
||||
let mut consumer_queues = SGX_QUEUES.lock().unwrap();
|
||||
let queues = (0..num_threads)
|
||||
.map(|_| {
|
||||
let (p, c) = bounded_spsc_queue::make(2);
|
||||
consumer_queues.push_back(c.into());
|
||||
p
|
||||
})
|
||||
.collect();
|
||||
ocall_packed!("__sgx_thread_group_launch__", num_threads as u64);
|
||||
Threads { queues: queues }
|
||||
}
|
||||
}
|
||||
|
||||
struct ThreadPool {
|
||||
num_workers: usize,
|
||||
#[allow(unused)]
|
||||
threads: Threads,
|
||||
}
|
||||
|
||||
thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new());
|
||||
|
||||
impl ThreadPool {
|
||||
fn new() -> Self {
|
||||
let num_workers = max_concurrency();
|
||||
ThreadPool {
|
||||
num_workers: num_workers,
|
||||
threads: Threads::launch(num_workers, ThreadPool::run_worker),
|
||||
}
|
||||
}
|
||||
|
||||
fn launch(&self, job: Job) {
|
||||
let mut tasks = job.tasks(self.num_workers + 1);
|
||||
|
||||
for (i, task) in tasks.split_off(1).into_iter().enumerate() {
|
||||
self.threads.queues[i].push(task);
|
||||
}
|
||||
|
||||
tasks.pop().unwrap()();
|
||||
job.wait().unwrap();
|
||||
}
|
||||
|
||||
fn run_worker(queue: Consumer<Task>) {
|
||||
loop {
|
||||
let task = queue.pop();
|
||||
let result = task();
|
||||
if result == <i32>::min_value() {
|
||||
break;
|
||||
} else if result != 0 {
|
||||
panic!("Error running task.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send + Sync wrapper for bounded_spsc_queue::Consumer
|
||||
struct Consumer<T> {
|
||||
consumer: bounded_spsc_queue::Consumer<T>,
|
||||
}
|
||||
impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> {
|
||||
fn from(c: bounded_spsc_queue::Consumer<T>) -> Self {
|
||||
Consumer { consumer: c }
|
||||
}
|
||||
}
|
||||
impl<T> Consumer<T> {
|
||||
fn pop(&self) -> T {
|
||||
self.consumer.pop()
|
||||
}
|
||||
}
|
||||
unsafe impl<T> Send for Consumer<T> {}
|
||||
unsafe impl<T> Sync for Consumer<T> {}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
lazy_static! {
|
||||
/// Holds tasks for untrusted threads which re-enter the enclave to execute.
|
||||
static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new());
|
||||
}
|
||||
|
||||
#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))]
|
||||
fn max_concurrency() -> usize {
|
||||
if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) {
|
||||
if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
|
||||
return threads;
|
||||
}
|
||||
}
|
||||
num_cpus::get_physical()
|
||||
}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
fn max_concurrency() -> usize {
|
||||
usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn max_concurrency() -> usize {
|
||||
0 // wasm doesn't support threads yet
|
||||
}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue {
|
||||
let q = {
|
||||
let mut qs = SGX_QUEUES.lock().unwrap();
|
||||
qs.pop_front()
|
||||
// `qs: MutexGuard` needs to be dropped here since `run_worker` won't return
|
||||
};
|
||||
if let Some(q) = q {
|
||||
ThreadPool::run_worker(q);
|
||||
}
|
||||
TVMRetValue::default()
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendParallelLaunch(
|
||||
cb: FTVMParallelLambda,
|
||||
cdata: *const c_void,
|
||||
num_task: usize,
|
||||
) -> c_int {
|
||||
if max_concurrency() == 0 {
|
||||
let penv = TVMParallelGroupEnv {
|
||||
sync_handle: 0 as *mut c_void,
|
||||
num_task: 1,
|
||||
};
|
||||
cb(0, &penv as *const _, cdata);
|
||||
} else {
|
||||
THREAD_POOL.with(|pool| {
|
||||
pool.launch(Job {
|
||||
cb: cb,
|
||||
cdata: cdata,
|
||||
req_num_tasks: num_task,
|
||||
pending: Arc::new(ATOMIC_USIZE_INIT),
|
||||
});
|
||||
});
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
pub(crate) fn sgx_join_threads() {
|
||||
extern "C" fn poison_pill(
|
||||
_task_id: usize,
|
||||
_penv: *const TVMParallelGroupEnv,
|
||||
_cdata: *const c_void,
|
||||
) -> i32 {
|
||||
<i32>::min_value()
|
||||
}
|
||||
|
||||
THREAD_POOL.with(|pool| {
|
||||
pool.launch(Job {
|
||||
cb: poison_pill,
|
||||
cdata: ptr::null(),
|
||||
req_num_tasks: 0,
|
||||
pending: Arc::new(ATOMIC_USIZE_INIT),
|
||||
});
|
||||
});
|
||||
ocall_packed!("__sgx_thread_group_join__", 0);
|
||||
}
|
||||
|
||||
// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) {
|
||||
let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) };
|
||||
barrier.wait();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{ptr, thread, time::Duration};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_max_concurrency() {
|
||||
env::set_var("TVM_NUM_THREADS", "42");
|
||||
env::set_var("OMP_NUM_THREADS", "24");
|
||||
assert_eq!(max_concurrency(), 42);
|
||||
env::remove_var("TVM_NUM_THREADS");
|
||||
assert_eq!(max_concurrency(), 24);
|
||||
}
|
||||
|
||||
extern "C" fn flambda(
|
||||
task_id: usize,
|
||||
penv: *const TVMParallelGroupEnv,
|
||||
cdata: *const c_void,
|
||||
) -> i32 {
|
||||
if cdata == ptr::null() {
|
||||
return 0;
|
||||
}
|
||||
unsafe {
|
||||
let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize));
|
||||
thread::sleep(Duration::from_millis(50 * task_id as u64));
|
||||
counter.fetch_add(1, Ordering::SeqCst);
|
||||
task_ids_sum.fetch_add(task_id, Ordering::SeqCst);
|
||||
assert_eq!((*penv).num_task, 3);
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_launch() {
|
||||
TVMBackendParallelLaunch(flambda, ptr::null(), 6);
|
||||
let counter = ATOMIC_USIZE_INIT;
|
||||
let task_ids_sum = ATOMIC_USIZE_INIT;
|
||||
let cdata = (counter, task_ids_sum);
|
||||
let num_tasks = 3;
|
||||
TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
|
||||
assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks);
|
||||
assert_eq!(
|
||||
cdata.1.load(Ordering::SeqCst),
|
||||
(0..num_tasks).sum::<usize>()
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
use std::{
|
||||
cell::RefCell,
|
||||
os::raw::{c_int, c_void},
|
||||
ptr,
|
||||
};
|
||||
|
||||
use super::allocator::Allocation;
|
||||
use crate::errors::*;
|
||||
|
||||
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
|
||||
|
||||
struct WorkspacePool {
|
||||
workspaces: Vec<Allocation>,
|
||||
free: Vec<usize>,
|
||||
in_use: Vec<usize>,
|
||||
}
|
||||
|
||||
impl WorkspacePool {
|
||||
fn new() -> Self {
|
||||
WorkspacePool {
|
||||
workspaces: Vec::new(),
|
||||
free: Vec::new(),
|
||||
in_use: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
|
||||
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
|
||||
self.in_use.push(self.workspaces.len() - 1);
|
||||
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
|
||||
}
|
||||
|
||||
fn alloc(&mut self, size: usize) -> Result<*mut u8> {
|
||||
if self.free.len() == 0 {
|
||||
return self.alloc_new(size);
|
||||
}
|
||||
let idx = self
|
||||
.free
|
||||
.iter()
|
||||
.fold(None, |cur_ws_idx: Option<usize>, &idx| {
|
||||
let ws_size = self.workspaces[idx].size();
|
||||
if !ws_size >= size {
|
||||
return cur_ws_idx;
|
||||
}
|
||||
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
|
||||
let cur_size = self.workspaces[cur_idx].size();
|
||||
Some(match ws_size <= cur_size {
|
||||
true => idx,
|
||||
false => cur_idx,
|
||||
})
|
||||
})
|
||||
});
|
||||
match idx {
|
||||
Some(idx) => {
|
||||
self.free.remove_item(&idx).unwrap();
|
||||
self.in_use.push(idx);
|
||||
Ok(self.workspaces[idx].as_mut_ptr())
|
||||
}
|
||||
None => self.alloc_new(size),
|
||||
}
|
||||
}
|
||||
|
||||
fn free(&mut self, ptr: *mut u8) -> Result<()> {
|
||||
let mut ws_idx = None;
|
||||
for i in 0..self.in_use.len() {
|
||||
let idx = self.in_use[i];
|
||||
if self.workspaces[idx].as_mut_ptr() == ptr {
|
||||
self.in_use.remove(i);
|
||||
ws_idx = Some(idx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(self
|
||||
.free
|
||||
.push(ws_idx.ok_or("Tried to free nonexistent workspace.")?))
|
||||
}
|
||||
}
|
||||
|
||||
thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
|
||||
|
||||
const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendAllocWorkspace(
|
||||
_device_type: c_int,
|
||||
_device_id: c_int,
|
||||
size: u64,
|
||||
_dtype_code_hint: c_int,
|
||||
_dtype_bits_hint: c_int,
|
||||
) -> *mut c_void {
|
||||
let nbytes = if size == 0 {
|
||||
WORKSPACE_PAGE_SIZE
|
||||
} else {
|
||||
size as usize
|
||||
};
|
||||
WORKSPACE_POOL.with(|pool_cell| {
|
||||
pool_cell
|
||||
.borrow_mut()
|
||||
.alloc(nbytes as usize)
|
||||
.unwrap_or(ptr::null_mut()) as *mut c_void
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendFreeWorkspace(
|
||||
_device_type: c_int,
|
||||
_device_id: c_int,
|
||||
ptr: *mut c_void,
|
||||
) -> c_int {
|
||||
WORKSPACE_POOL.with(|pool_cell| {
|
||||
(match pool_cell.borrow_mut().free(ptr as *mut u8) {
|
||||
Ok(()) => 0,
|
||||
Err(_) => -1,
|
||||
}) as c_int
|
||||
});
|
||||
return 0;
|
||||
}
|
|
@ -1,3 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""Builds a simple NNVM graph for testing."""
|
||||
|
||||
from os import path as osp
|
|
@ -0,0 +1,39 @@
|
|||
#![feature(try_from)]
|
||||
|
||||
extern crate serde;
|
||||
extern crate serde_json;
|
||||
|
||||
extern crate tvm_runtime;
|
||||
|
||||
use std::{convert::TryFrom, fs, io::Read};
|
||||
|
||||
use tvm_runtime::Graph;
|
||||
|
||||
#[test]
|
||||
fn test_load_graph() {
|
||||
let mut params_bytes = Vec::new();
|
||||
fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
|
||||
.expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
|
||||
.read_to_end(&mut params_bytes)
|
||||
.unwrap();
|
||||
let _params = tvm_runtime::load_param_dict(¶ms_bytes);
|
||||
|
||||
let graph = Graph::try_from(
|
||||
&fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(graph.nodes[3].op, "tvm_op");
|
||||
assert_eq!(
|
||||
graph.nodes[3]
|
||||
.attrs
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.get("func_name")
|
||||
.unwrap(),
|
||||
"fuse_dense"
|
||||
);
|
||||
assert_eq!(graph.nodes[5].inputs[0].index, 0);
|
||||
assert_eq!(graph.nodes[6].inputs[0].index, 1);
|
||||
assert_eq!(graph.heads.len(), 2);
|
||||
}
|
|
@ -2,13 +2,13 @@
|
|||
name = "test-nnvm"
|
||||
version = "0.0.0"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Nick Hynes <nhynes@berkeley.edu>"]
|
||||
authors = ["TVM Contributors"]
|
||||
|
||||
[dependencies]
|
||||
ndarray = "0.11.2"
|
||||
tvm = { path = "../../" }
|
||||
serde = "1.0.59"
|
||||
serde_json = "1.0.17"
|
||||
tvm-runtime = { path = "../../" }
|
||||
|
||||
[build-dependencies]
|
||||
ar = "0.6.0"
|
|
@ -0,0 +1,33 @@
|
|||
extern crate ar;
|
||||
|
||||
use std::{env, fs::File, path::Path, process::Command};
|
||||
|
||||
use ar::Builder;
|
||||
|
||||
fn main() {
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
|
||||
let output = Command::new(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/src/build_test_graph.py"
|
||||
))
|
||||
.arg(&out_dir)
|
||||
.output()
|
||||
.expect("Failed to execute command");
|
||||
assert!(
|
||||
Path::new(&format!("{}/graph.o", out_dir)).exists(),
|
||||
"Could not build graph lib: {}",
|
||||
String::from_utf8(output.stderr)
|
||||
.unwrap()
|
||||
.trim()
|
||||
.split("\n")
|
||||
.last()
|
||||
.unwrap_or("")
|
||||
);
|
||||
|
||||
let mut builder = Builder::new(File::create(format!("{}/libgraph.a", out_dir)).unwrap());
|
||||
builder.append_path(format!("{}/graph.o", out_dir)).unwrap();
|
||||
|
||||
println!("cargo:rustc-link-lib=static=graph");
|
||||
println!("cargo:rustc-link-search=native={}", out_dir);
|
||||
}
|
|
@ -23,6 +23,7 @@ def _get_model(dshape):
|
|||
def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
|
||||
if isinstance(graph, sym.Symbol):
|
||||
graph = nnvm.graph.create(graph)
|
||||
|
||||
ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
|
||||
param_shapes = dict(zip(graph.index.input_names, ishapes))
|
||||
np.random.seed(seed)
|
||||
|
@ -40,6 +41,7 @@ def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
|
|||
initializer(param, init_value)
|
||||
# init_value /= init_value.sum() + 1e-10
|
||||
params[param] = tvm.nd.array(init_value)
|
||||
|
||||
return params
|
||||
|
||||
def main():
|
||||
|
@ -56,6 +58,7 @@ def main():
|
|||
lib.save(osp.join(sys.argv[1], 'graph.o'))
|
||||
with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet:
|
||||
f_resnet.write(graph.json())
|
||||
|
||||
with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params:
|
||||
f_params.write(nnvm.compiler.save_param_dict(params))
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
#![feature(try_from)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate ndarray;
|
||||
extern crate serde;
|
||||
extern crate serde_json;
|
||||
|
||||
extern crate tvm_runtime;
|
||||
use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
|
||||
|
||||
use ndarray::Array;
|
||||
use tvm_runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
|
||||
|
||||
const BATCH_SIZE: usize = 4;
|
||||
const IN_DIM: usize = 8;
|
||||
|
||||
macro_rules! check_sum {
|
||||
($e:expr, $a:ident, $b:ident) => {
|
||||
let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
|
||||
check_sum!(a, $b);
|
||||
};
|
||||
($e:expr, $a:expr, $b:ident) => {
|
||||
let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
|
||||
check_sum!(a, $b);
|
||||
};
|
||||
($a:ident, $b:ident) => {
|
||||
let a_sum: f32 = $a.scalar_sum();
|
||||
let b_sum: f32 = $b.scalar_sum();
|
||||
assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
|
||||
};
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let syslib = SystemLibModule::default();
|
||||
|
||||
let mut params_bytes = Vec::new();
|
||||
fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
|
||||
.unwrap()
|
||||
.read_to_end(&mut params_bytes)
|
||||
.unwrap();
|
||||
let params = tvm_runtime::load_param_dict(¶ms_bytes)
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.to_owned()))
|
||||
.collect::<HashMap<String, Tensor<'static>>>();
|
||||
|
||||
let graph =
|
||||
Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap())
|
||||
.unwrap();
|
||||
let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
|
||||
|
||||
let x = Array::from_shape_vec(
|
||||
(BATCH_SIZE, IN_DIM),
|
||||
(0..BATCH_SIZE * IN_DIM)
|
||||
.map(|x| x as f32)
|
||||
.collect::<Vec<f32>>(),
|
||||
)
|
||||
.unwrap();
|
||||
let w = Array::try_from(params.get("dense0_weight").unwrap())
|
||||
.unwrap()
|
||||
.into_shape((IN_DIM * 2, IN_DIM))
|
||||
.unwrap();
|
||||
let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap();
|
||||
let dense = x.dot(&w.t()) + &b;
|
||||
let left = dense.slice(s![.., 0..IN_DIM]);
|
||||
let right = dense.slice(s![.., IN_DIM..]);
|
||||
let expected_o0 = &left + 1f32;
|
||||
let expected_o1 = &right - 1f32;
|
||||
|
||||
exec.load_params(params);
|
||||
exec.set_input("data", (&x).into());
|
||||
|
||||
check_sum!(exec, data, x);
|
||||
check_sum!(exec, dense0_weight, w);
|
||||
check_sum!(exec, dense0_bias, b);
|
||||
|
||||
exec.run();
|
||||
|
||||
check_sum!(exec, 0, expected_o0);
|
||||
check_sum!(exec, 1, expected_o1);
|
||||
check_sum!(exec, 2, dense);
|
||||
}
|
|
@ -2,11 +2,11 @@
|
|||
name = "test-tvm-basic"
|
||||
version = "0.0.0"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Nick Hynes <nhynes@berkeley.edu>"]
|
||||
authors = ["TVM Contributors"]
|
||||
|
||||
[dependencies]
|
||||
ndarray = "0.11.2"
|
||||
tvm = { path = "../../" }
|
||||
tvm-runtime = { path = "../../" }
|
||||
|
||||
[build-dependencies]
|
||||
ar = "0.6.0"
|
|
@ -0,0 +1,34 @@
|
|||
extern crate ar;
|
||||
|
||||
use std::{env, path::Path, process::Command};
|
||||
|
||||
use ar::Builder;
|
||||
use std::fs::File;
|
||||
|
||||
fn main() {
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
|
||||
let output = Command::new(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/src/build_test_lib.py"
|
||||
))
|
||||
.arg(&out_dir)
|
||||
.output()
|
||||
.expect("Failed to execute command");
|
||||
assert!(
|
||||
Path::new(&format!("{}/test.o", out_dir)).exists(),
|
||||
"Could not build tvm lib: {}",
|
||||
String::from_utf8(output.stderr)
|
||||
.unwrap()
|
||||
.trim()
|
||||
.split("\n")
|
||||
.last()
|
||||
.unwrap_or("")
|
||||
);
|
||||
|
||||
let mut builder = Builder::new(File::create(format!("{}/libtest.a", out_dir)).unwrap());
|
||||
builder.append_path(format!("{}/test.o", out_dir)).unwrap();
|
||||
|
||||
println!("cargo:rustc-link-lib=static=test");
|
||||
println!("cargo:rustc-link-search=native={}", out_dir);
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
extern crate ndarray;
|
||||
#[macro_use]
|
||||
extern crate tvm_runtime;
|
||||
|
||||
use ndarray::Array;
|
||||
use tvm_runtime::{DLTensor, Module, SystemLibModule};
|
||||
|
||||
fn main() {
|
||||
let syslib = SystemLibModule::default();
|
||||
let add = syslib
|
||||
.get_function("default_function")
|
||||
.expect("main function not found");
|
||||
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
|
||||
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
|
||||
let mut c = Array::from_vec(vec![0f32; 4]);
|
||||
let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
|
||||
let mut a_dl: DLTensor = (&mut a).into();
|
||||
let mut b_dl: DLTensor = (&mut b).into();
|
||||
let mut c_dl: DLTensor = (&mut c).into();
|
||||
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
|
||||
assert!(c.all_close(&e, 1e-8f32));
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
#[cfg(target_env = "sgx")]
|
||||
use alloc::alloc::{self, Layout};
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
use std::alloc::{self, Layout};
|
||||
|
||||
use errors::*;
|
||||
|
||||
const DEFAULT_ALIGN_BYTES: usize = 4;
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct Allocation {
|
||||
layout: Layout,
|
||||
ptr: *mut u8,
|
||||
}
|
||||
|
||||
impl Allocation {
|
||||
/// Allocates a chunk of memory of `size` bytes with optional alignment.
|
||||
pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
|
||||
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
|
||||
let layout = Layout::from_size_align(size, alignment)?;
|
||||
let ptr = unsafe { alloc::alloc(layout.clone()) };
|
||||
if ptr.is_null() {
|
||||
alloc::handle_alloc_error(layout);
|
||||
}
|
||||
Ok(Self {
|
||||
ptr: ptr,
|
||||
layout: layout,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn as_mut_ptr(&self) -> *mut u8 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
/// Returns the size of the Allocation in bytes.
|
||||
pub fn size(&self) -> usize {
|
||||
self.layout.size()
|
||||
}
|
||||
|
||||
/// Returns the byte alignment of the Allocation.
|
||||
pub fn align(&self) -> usize {
|
||||
self.layout.align()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Allocation {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
alloc::dealloc(self.ptr, self.layout.clone());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,500 +0,0 @@
|
|||
use std::{
|
||||
any::TypeId,
|
||||
convert::TryFrom,
|
||||
mem,
|
||||
os::raw::{c_int, c_void},
|
||||
ptr, slice,
|
||||
};
|
||||
|
||||
use ndarray;
|
||||
|
||||
use super::allocator::Allocation;
|
||||
use errors::*;
|
||||
use ffi::runtime::{
|
||||
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
|
||||
DLDeviceType_kDLCPU, DLTensor,
|
||||
};
|
||||
|
||||
/// A `Storage` is a container which holds `Tensor` data.
|
||||
#[derive(PartialEq)]
|
||||
pub enum Storage<'a> {
|
||||
/// A `Storage` which owns its contained bytes.
|
||||
Owned(Allocation),
|
||||
|
||||
/// A view of an existing `Storage`.
|
||||
View(&'a mut [u8], usize), // ptr, align
|
||||
}
|
||||
|
||||
impl<'a> Storage<'a> {
|
||||
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
|
||||
Ok(Storage::Owned(Allocation::new(size, align)?))
|
||||
}
|
||||
|
||||
pub fn as_mut_ptr(&self) -> *mut u8 {
|
||||
match self {
|
||||
Storage::Owned(alloc) => alloc.as_mut_ptr(),
|
||||
Storage::View(slice, _) => slice.as_ptr() as *mut u8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn size(&self) -> usize {
|
||||
match self {
|
||||
Storage::Owned(alloc) => alloc.size(),
|
||||
Storage::View(slice, _) => slice.len(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn align(&self) -> usize {
|
||||
match self {
|
||||
Storage::Owned(alloc) => alloc.align(),
|
||||
Storage::View(_, align) => *align,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.as_mut_ptr() as *const _
|
||||
}
|
||||
|
||||
/// Returns a `Storage::View` which points to an owned `Storage::Owned`.
|
||||
pub fn view(&self) -> Storage<'a> {
|
||||
match self {
|
||||
Storage::Owned(alloc) => Storage::View(
|
||||
unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
|
||||
self.align(),
|
||||
),
|
||||
Storage::View(slice, _) => Storage::View(
|
||||
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
|
||||
self.align(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_owned(&self) -> bool {
|
||||
match self {
|
||||
Storage::Owned(_) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an owned version of this storage via cloning.
|
||||
pub fn to_owned(&self) -> Storage<'static> {
|
||||
let s = Storage::new(self.size(), Some(self.align())).unwrap();
|
||||
unsafe {
|
||||
s.as_mut_ptr()
|
||||
.copy_from_nonoverlapping(self.as_ptr(), self.size())
|
||||
}
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a [T]> for Storage<'a> {
|
||||
fn from(data: &'a [T]) -> Self {
|
||||
let data = unsafe {
|
||||
slice::from_raw_parts_mut(
|
||||
data.as_ptr() as *const u8 as *mut u8,
|
||||
data.len() * mem::size_of::<T>() as usize,
|
||||
)
|
||||
};
|
||||
Storage::View(data, mem::align_of::<T>())
|
||||
}
|
||||
}
|
||||
|
||||
/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
|
||||
/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
|
||||
/// converted to `ndarray::Array` for non-TVM processing.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// extern crate ndarray;
|
||||
///
|
||||
/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
|
||||
/// let mut a: Tensor = a_nd.into();
|
||||
/// let mut a_dl: DLTensor = (&mut t).into();
|
||||
/// call_packed!(tvm_fn, &mut a_dl);
|
||||
///
|
||||
/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
|
||||
/// let mut a_nd = ndarray::Array::try_from(&a).unwrap();
|
||||
/// ```
|
||||
#[derive(PartialEq)]
|
||||
pub struct Tensor<'a> {
|
||||
/// The bytes which contain the data this `Tensor` represents.
|
||||
pub(super) data: Storage<'a>,
|
||||
pub(super) ctx: TVMContext,
|
||||
pub(super) dtype: DataType,
|
||||
pub(super) shape: Vec<i64>, // not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
|
||||
/// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
|
||||
pub(super) strides: Option<Vec<usize>>,
|
||||
pub(super) byte_offset: isize,
|
||||
/// The number of elements in the `Tensor`.
|
||||
pub(super) size: usize,
|
||||
}
|
||||
|
||||
unsafe impl<'a> Send for Tensor<'a> {}
|
||||
|
||||
impl<'a> Tensor<'a> {
|
||||
pub fn shape(&self) -> Vec<i64> {
|
||||
self.shape.clone()
|
||||
}
|
||||
|
||||
/// Returns the data of this `Tensor` as a `Vec`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
|
||||
pub fn to_vec<T: 'static>(&self) -> Vec<T> {
|
||||
assert!(self.is_contiguous());
|
||||
assert!(self.dtype.is_type::<T>());
|
||||
let mut vec: Vec<T> = Vec::with_capacity(self.size * self.dtype.itemsize());
|
||||
unsafe {
|
||||
vec.as_mut_ptr().copy_from_nonoverlapping(
|
||||
self.data.as_ptr().offset(self.byte_offset) as *const T,
|
||||
self.size,
|
||||
);
|
||||
vec.set_len(self.size);
|
||||
}
|
||||
vec
|
||||
}
|
||||
|
||||
/// Returns `true` iff this `Tensor` is represented by a contiguous region of memory.
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
match self.strides {
|
||||
None => true,
|
||||
Some(ref strides) => {
|
||||
// check that stride for each dimension is the product of all trailing dimensons' shapes
|
||||
self
|
||||
.shape
|
||||
.iter()
|
||||
.zip(strides)
|
||||
.rfold(
|
||||
(true, 1),
|
||||
|(is_contig, expected_stride), (shape, stride)| {
|
||||
(
|
||||
is_contig && *stride == expected_stride,
|
||||
expected_stride * (*shape as usize),
|
||||
)
|
||||
},
|
||||
)
|
||||
.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a clone of this `Tensor`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
|
||||
pub fn copy(&mut self, other: &Tensor) {
|
||||
assert!(
|
||||
self.dtype == other.dtype && self.size == other.size,
|
||||
"Tensor shape/dtype mismatch."
|
||||
);
|
||||
assert!(
|
||||
self.is_contiguous() && other.is_contiguous(),
|
||||
"copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
|
||||
self.strides,
|
||||
other.strides
|
||||
);
|
||||
unsafe {
|
||||
self
|
||||
.data
|
||||
.as_mut_ptr()
|
||||
.offset(self.byte_offset as isize)
|
||||
.copy_from_nonoverlapping(
|
||||
other.data.as_mut_ptr().offset(other.byte_offset),
|
||||
other.size * other.dtype.itemsize(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an owned version of this `Tensor` via cloning.
|
||||
pub fn to_owned(&self) -> Tensor<'static> {
|
||||
let t = Tensor {
|
||||
data: self.data.to_owned(),
|
||||
ctx: self.ctx.clone(),
|
||||
dtype: self.dtype.clone(),
|
||||
size: self.size.clone(),
|
||||
shape: self.shape.clone(),
|
||||
strides: None,
|
||||
byte_offset: 0,
|
||||
};
|
||||
unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
|
||||
}
|
||||
|
||||
fn from_array_storage<'s, T, D: ndarray::Dimension>(
|
||||
arr: &ndarray::Array<T, D>,
|
||||
storage: Storage<'s>,
|
||||
type_code: usize,
|
||||
) -> Tensor<'s> {
|
||||
let type_width = mem::size_of::<T>() as usize;
|
||||
Tensor {
|
||||
data: storage,
|
||||
ctx: TVMContext::default(),
|
||||
dtype: DataType {
|
||||
code: type_code,
|
||||
bits: 8 * type_width,
|
||||
lanes: 1,
|
||||
},
|
||||
size: arr.len(),
|
||||
shape: arr.shape().iter().map(|&v| v as i64).collect(),
|
||||
strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
|
||||
byte_offset: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
|
||||
macro_rules! impl_ndarray_try_from_tensor {
|
||||
($type:ty, $dtype:expr) => {
|
||||
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
|
||||
type Error = Error;
|
||||
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
|
||||
ensure!(
|
||||
tensor.dtype == $dtype,
|
||||
"Cannot convert Tensor with dtype {:?} to ndarray",
|
||||
tensor.dtype
|
||||
);
|
||||
Ok(ndarray::Array::from_shape_vec(
|
||||
tensor
|
||||
.shape
|
||||
.iter()
|
||||
.map(|s| *s as usize)
|
||||
.collect::<Vec<usize>>(),
|
||||
tensor.to_vec::<$type>(),
|
||||
)?)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
|
||||
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
|
||||
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
|
||||
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
|
||||
|
||||
impl DLTensor {
|
||||
pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
|
||||
assert!(!flatten || tensor.is_contiguous());
|
||||
Self {
|
||||
data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
|
||||
ctx: DLContext::from(&tensor.ctx),
|
||||
ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
|
||||
dtype: DLDataType::from(&tensor.dtype),
|
||||
shape: if flatten {
|
||||
&tensor.size as *const _ as *mut i64
|
||||
} else {
|
||||
tensor.shape.as_ptr()
|
||||
} as *mut i64,
|
||||
strides: if flatten || tensor.is_contiguous() {
|
||||
ptr::null_mut()
|
||||
} else {
|
||||
tensor.strides.as_ref().unwrap().as_ptr()
|
||||
} as *mut i64,
|
||||
byte_offset: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
|
||||
fn from(tensor: &'a Tensor<'t>) -> Self {
|
||||
DLTensor::from_tensor(tensor, false /* flatten */)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
|
||||
fn from(tensor: &'a mut Tensor<'t>) -> Self {
|
||||
DLTensor::from_tensor(tensor, false /* flatten */)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct DataType {
|
||||
pub(super) code: usize,
|
||||
pub(super) bits: usize,
|
||||
pub(super) lanes: usize,
|
||||
}
|
||||
|
||||
impl DataType {
|
||||
/// Returns the number of bytes occupied by an element of this `DataType`.
|
||||
pub fn itemsize(&self) -> usize {
|
||||
(self.bits * self.lanes) >> 3
|
||||
}
|
||||
|
||||
/// Returns whether this `DataType` represents primitive type `T`.
|
||||
pub fn is_type<T: 'static>(&self) -> bool {
|
||||
if self.lanes != 1 {
|
||||
return false;
|
||||
}
|
||||
let typ = TypeId::of::<T>();
|
||||
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
|
||||
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
|
||||
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
|
||||
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
|
||||
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
|
||||
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a DataType> for DLDataType {
|
||||
fn from(dtype: &'a DataType) -> Self {
|
||||
Self {
|
||||
code: dtype.code as u8,
|
||||
bits: dtype.bits as u8,
|
||||
lanes: dtype.lanes as u16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DLDataType> for DataType {
|
||||
fn from(dtype: DLDataType) -> Self {
|
||||
Self {
|
||||
code: dtype.code as usize,
|
||||
bits: dtype.bits as usize,
|
||||
lanes: dtype.lanes as usize,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! make_dtype_const {
|
||||
($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
|
||||
const $name: DataType = DataType {
|
||||
code: $code as usize,
|
||||
bits: $bits,
|
||||
lanes: $lanes,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
|
||||
make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
|
||||
// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
|
||||
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
|
||||
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
|
||||
|
||||
impl Default for DLContext {
|
||||
fn default() -> Self {
|
||||
DLContext {
|
||||
device_type: DLDeviceType_kDLCPU,
|
||||
device_id: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct TVMContext {
|
||||
pub(super) device_type: usize,
|
||||
pub(super) device_id: usize,
|
||||
}
|
||||
|
||||
impl<'a> From<&'a TVMContext> for DLContext {
|
||||
fn from(ctx: &'a TVMContext) -> Self {
|
||||
Self {
|
||||
device_type: ctx.device_type as u32,
|
||||
device_id: ctx.device_id as i32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TVMContext {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
device_type: DLDeviceType_kDLCPU as usize,
|
||||
device_id: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<DLTensor> for Tensor<'a> {
|
||||
fn from(dlt: DLTensor) -> Self {
|
||||
unsafe {
|
||||
let dtype = DataType::from(dlt.dtype);
|
||||
let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
|
||||
let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
|
||||
let storage = Storage::from(slice::from_raw_parts(
|
||||
dlt.data as *const u8,
|
||||
dtype.itemsize() * size,
|
||||
));
|
||||
Self {
|
||||
data: storage,
|
||||
ctx: TVMContext::default(),
|
||||
dtype: dtype,
|
||||
size: size,
|
||||
shape: shape,
|
||||
strides: if dlt.strides == ptr::null_mut() {
|
||||
None
|
||||
} else {
|
||||
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
|
||||
},
|
||||
byte_offset: dlt.byte_offset as isize,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the ndarray is not contiguous.
|
||||
macro_rules! impl_tensor_from_ndarray {
|
||||
($type:ty, $typecode:expr) => {
|
||||
impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
|
||||
fn from(arr: ndarray::Array<$type, D>) -> Self {
|
||||
assert!(arr.is_standard_layout(), "Array must be contiguous.");
|
||||
let size = arr.len() * mem::size_of::<$type>() as usize;
|
||||
let storage =
|
||||
Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) });
|
||||
Tensor::from_array_storage(&arr, storage, $typecode as usize)
|
||||
}
|
||||
}
|
||||
impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
|
||||
fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
|
||||
assert!(arr.is_standard_layout(), "Array must be contiguous.");
|
||||
Tensor::from_array_storage(
|
||||
arr,
|
||||
Storage::from(arr.as_slice().unwrap()),
|
||||
$typecode as usize,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// `From` conversions to `DLTensor` for `ndarray::Array`.
|
||||
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
|
||||
macro_rules! impl_dltensor_from_ndarray {
|
||||
($type:ty, $typecode:expr) => {
|
||||
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
|
||||
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
|
||||
DLTensor {
|
||||
data: arr.as_mut_ptr() as *mut c_void,
|
||||
ctx: DLContext::default(),
|
||||
ndim: arr.ndim() as c_int,
|
||||
dtype: DLDataType {
|
||||
code: $typecode as u8,
|
||||
bits: 8 * mem::size_of::<$type>() as u8,
|
||||
lanes: 1,
|
||||
},
|
||||
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
|
||||
strides: arr.strides().as_ptr() as *const isize as *mut i64,
|
||||
byte_offset: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
|
||||
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
|
||||
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
|
||||
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
|
||||
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
|
||||
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
|
||||
|
||||
impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
|
||||
impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
|
||||
impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
|
||||
impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
|
||||
impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
|
||||
impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
|
|
@ -1,472 +0,0 @@
|
|||
use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
|
||||
|
||||
use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
|
||||
use serde;
|
||||
use serde_json;
|
||||
|
||||
use super::{DataType, Module, Storage, TVMArgValue, TVMContext, Tensor};
|
||||
use errors::{Error, ErrorKind, Result};
|
||||
use ffi::runtime::{
|
||||
DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor,
|
||||
};
|
||||
|
||||
// Magic number for NDArray file. @see `kTVMNDArrayMagic` in `ndarray.h`
|
||||
const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
|
||||
// Magic number for NDArray list file. @see `kTVMNDArrayListMagic` in `graph_runtime.h`
|
||||
const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7;
|
||||
|
||||
/// A TVM computation graph.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// let graph_json = fs::read_to_string("graph.json")).unwrap();
|
||||
/// let graph = Graph::try_from(&graph_json).unwrap();
|
||||
/// ```
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Graph {
|
||||
pub nodes: Vec<Node>,
|
||||
pub arg_nodes: Vec<usize>,
|
||||
pub heads: Vec<Entry>,
|
||||
pub node_row_ptr: Option<Vec<usize>>,
|
||||
pub attrs: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Entry {
|
||||
pub id: usize,
|
||||
pub index: usize,
|
||||
pub version: usize,
|
||||
}
|
||||
|
||||
impl Graph {
|
||||
fn entry_index(&self, entry: &Entry) -> Result<usize> {
|
||||
self
|
||||
.node_row_ptr
|
||||
.as_ref()
|
||||
.map(|nrp| nrp[entry.id] + entry.index)
|
||||
.ok_or("Missing node_row_ptr.".into())
|
||||
}
|
||||
|
||||
/// Attempt to deserialize a JSON attribute to a type `T`.
|
||||
fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
|
||||
Ok(serde_json::from_value::<T>(
|
||||
self
|
||||
.attrs
|
||||
.as_ref()
|
||||
.ok_or(ErrorKind::GraphFormatError(
|
||||
"Missing graph attrs".to_string(),
|
||||
))?
|
||||
.get(attr)
|
||||
.ok_or(ErrorKind::GraphFormatError(format!(
|
||||
"Missing {} attr",
|
||||
attr
|
||||
)))?
|
||||
.to_owned(),
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node {
|
||||
pub op: String,
|
||||
pub name: String,
|
||||
pub inputs: Vec<Entry>,
|
||||
pub attrs: Option<HashMap<String, String>>,
|
||||
pub control_deps: Option<Vec<Entry>>,
|
||||
}
|
||||
|
||||
struct NodeAttrs {
|
||||
func_name: String,
|
||||
num_outputs: usize,
|
||||
flatten_data: bool,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
fn parse_attrs(&self) -> Result<NodeAttrs> {
|
||||
let attrs = self
|
||||
.attrs
|
||||
.as_ref()
|
||||
.ok_or(format!("Missing node.attrs for `{}`", self.name))?;
|
||||
let func_name = attrs
|
||||
.get("func_name")
|
||||
.ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
|
||||
.to_string();
|
||||
let num_outputs = attrs
|
||||
.get("num_outputs")
|
||||
.ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
|
||||
.parse::<usize>()?;
|
||||
let flatten_data = attrs
|
||||
.get("flatten_data")
|
||||
.ok_or(format!(
|
||||
"Node `{}` is missing attrs.flatten_data",
|
||||
self.name
|
||||
))?
|
||||
.parse::<u8>()?
|
||||
== 1;
|
||||
Ok(NodeAttrs {
|
||||
func_name,
|
||||
num_outputs,
|
||||
flatten_data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a String> for Graph {
|
||||
type Error = Error;
|
||||
fn try_from(graph_json: &String) -> Result<Self> {
|
||||
let graph = serde_json::from_str(graph_json)?;
|
||||
Ok(graph)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a str> for Graph {
|
||||
type Error = Error;
|
||||
fn try_from(graph_json: &'a str) -> Result<Self> {
|
||||
let graph = serde_json::from_str(graph_json)?;
|
||||
Ok(graph)
|
||||
}
|
||||
}
|
||||
|
||||
/// A executor for a TVM computation graph.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use ndarray::Array;
|
||||
///
|
||||
/// let syslib = SystemLibModule::default(); // a provider of TVM functions
|
||||
///
|
||||
/// let mut params_bytes = Vec::new();
|
||||
/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
|
||||
/// let params = tvm::runtime::load_param_dict(¶ms_bytes).unwrap();
|
||||
///
|
||||
/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
|
||||
///
|
||||
/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
|
||||
/// exec.load_params(params);
|
||||
///
|
||||
/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
|
||||
/// exec.set_input("data", x.into());
|
||||
/// exec.run();
|
||||
/// let output = exec.get_output(0).unwrap();
|
||||
///
|
||||
/// println!("{:#?}", Array::try_from(output).unwrap());
|
||||
/// ```
|
||||
pub struct GraphExecutor<'m, 't> {
|
||||
graph: Graph,
|
||||
op_execs: Vec<Box<Fn() + 'm>>,
|
||||
tensors: Vec<Tensor<'t>>,
|
||||
}
|
||||
|
||||
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
|
||||
|
||||
impl<'m, 't> GraphExecutor<'m, 't> {
|
||||
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
|
||||
let tensors = Self::setup_storages(&graph)?;
|
||||
Ok(GraphExecutor {
|
||||
op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
|
||||
tensors: tensors,
|
||||
graph: graph,
|
||||
})
|
||||
}
|
||||
|
||||
/// Runs the computation graph.
|
||||
pub fn run(&self) {
|
||||
self.op_execs.iter().for_each(|op_exec| {
|
||||
op_exec();
|
||||
});
|
||||
}
|
||||
|
||||
/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
|
||||
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
|
||||
let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
|
||||
let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
|
||||
let dtypes = graph
|
||||
.get_attr::<(String, Vec<String>)>("dltype")?
|
||||
.1
|
||||
.iter()
|
||||
.map(|dltype| {
|
||||
if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
|
||||
Ok(dtype)
|
||||
} else {
|
||||
Err(ErrorKind::GraphFormatError(format!("Invalid dltype: {}", dltype).to_string()).into())
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<DataType>>>()?;
|
||||
|
||||
let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
|
||||
let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
|
||||
for (i, &storage_id) in storage_ids.iter().enumerate() {
|
||||
let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
|
||||
let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
|
||||
storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
|
||||
}
|
||||
|
||||
let mut storages: Vec<Storage> = storage_num_bytes
|
||||
.into_iter()
|
||||
.map(|nbytes| Storage::new(nbytes, align))
|
||||
.collect::<Result<Vec<Storage>>>()?;
|
||||
|
||||
let tensors = izip!(storage_ids, shapes, dtypes)
|
||||
.map(|(storage_id, shape, dtype)| {
|
||||
let storage = storages[storage_id].view();
|
||||
Tensor {
|
||||
data: mem::replace(&mut storages[storage_id], storage),
|
||||
ctx: TVMContext::default(),
|
||||
dtype: dtype,
|
||||
size: shape.iter().product::<i64>() as usize,
|
||||
shape: shape,
|
||||
strides: None,
|
||||
byte_offset: 0,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(tensors)
|
||||
}
|
||||
|
||||
/// Creates closures which represent the computation performed by this graph.
|
||||
fn setup_op_execs<M: 'm + Module>(
|
||||
graph: &Graph,
|
||||
lib: &'m M,
|
||||
tensors: &Vec<Tensor<'t>>,
|
||||
) -> Result<Vec<Box<Fn() + 'm>>> {
|
||||
ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
|
||||
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
|
||||
|
||||
let mut op_execs = Vec::new();
|
||||
for (i, node) in graph.nodes.iter().enumerate() {
|
||||
if node.op == "null" {
|
||||
continue;
|
||||
}
|
||||
ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
|
||||
ensure!(node.attrs.is_some(), "Missing node attrs.");
|
||||
|
||||
let attrs = node.parse_attrs()?;
|
||||
|
||||
if attrs.func_name == "__nop" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let func = lib
|
||||
.get_function(&attrs.func_name)
|
||||
.ok_or(format!("Missing function {}", attrs.func_name))?;
|
||||
let arg_indices = node
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|entry| graph.entry_index(entry))
|
||||
.chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi)));
|
||||
|
||||
let dl_tensors = arg_indices
|
||||
.map(|idx| {
|
||||
let tensor = &tensors[idx?];
|
||||
Ok(if attrs.flatten_data {
|
||||
DLTensor::from_tensor(tensor, true /* flatten */)
|
||||
} else {
|
||||
DLTensor::from(tensor)
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<DLTensor>>>()
|
||||
.unwrap();
|
||||
let op: Box<Fn()> = box move || {
|
||||
let args = dl_tensors
|
||||
.iter()
|
||||
.map(|t| t.into())
|
||||
.collect::<Vec<TVMArgValue>>();
|
||||
func(args.as_slice());
|
||||
};
|
||||
op_execs.push(op);
|
||||
}
|
||||
Ok(op_execs)
|
||||
}
|
||||
|
||||
pub fn load_params(&mut self, params: HashMap<String, Tensor<'t>>) {
|
||||
params.into_iter().for_each(|(name, param)| {
|
||||
self.set_input(name, param);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_input<S: AsRef<str>>(&mut self, name: S, value: Tensor<'t>) {
|
||||
if let Some(idx) = self.get_input_index(name.as_ref()) {
|
||||
// TODO: consider `new_with_params` to avoid ever allocating
|
||||
let ptr = self.tensors[idx].data.as_ptr();
|
||||
let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr);
|
||||
let mut owner = to_replace.nth(0).unwrap();
|
||||
if value.data.is_owned() {
|
||||
// FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
|
||||
// mem::replace(&mut (*owner), value);
|
||||
// to_replace.for_each(|t| {
|
||||
// panic!("replacing");
|
||||
// t.data = owner.data.view();
|
||||
// });
|
||||
owner.copy(&value);
|
||||
} else {
|
||||
owner.copy(&value);
|
||||
}
|
||||
} else {
|
||||
println!("Unexpected input `{}`", name.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the graph input with name `name`, if it exists.
|
||||
pub fn get_input<S: AsRef<str>>(&mut self, name: S) -> Option<&Tensor> {
|
||||
self
|
||||
.get_input_index(name.as_ref())
|
||||
.and_then(move |idx| Some(&self.tensors[idx]))
|
||||
}
|
||||
|
||||
/// Returns the graph output with index `index`, if it exists.
|
||||
pub fn get_output(&self, idx: usize) -> Option<&Tensor> {
|
||||
let graph = &self.graph;
|
||||
graph.heads.get(idx).and_then(|entry| {
|
||||
graph
|
||||
.entry_index(entry)
|
||||
.map(|idx| self.tensors.get(idx))
|
||||
.unwrap_or(None)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the index for graph input with name `name`, if it exists.
|
||||
pub fn get_input_index<S: AsRef<str>>(&self, name: S) -> Option<usize> {
|
||||
let graph = &self.graph;
|
||||
(0..graph.nodes.len())
|
||||
.skip_while(|&i| graph.nodes[i].name != name.as_ref())
|
||||
.nth(0)
|
||||
.and_then(|i| {
|
||||
if graph.arg_nodes.iter().any(|&id| id == i) {
|
||||
graph.node_row_ptr.as_ref().map(|nrp| nrp[i])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
|
||||
named!(
|
||||
tvm_str_to_type<CompleteStr, DataType>,
|
||||
do_parse!(
|
||||
type_name: alpha1 >>
|
||||
bits: digit1 >>
|
||||
lanes: opt!(tuple!(tag!("x"), digit1)) >>
|
||||
(DataType {
|
||||
code: match type_name {
|
||||
CompleteStr("int") => DLDataTypeCode_kDLInt,
|
||||
CompleteStr("uint") => DLDataTypeCode_kDLUInt,
|
||||
CompleteStr("float") => DLDataTypeCode_kDLFloat,
|
||||
_ => DLDataTypeCode_kDLFloat,
|
||||
} as usize,
|
||||
bits: bits.parse::<u8>().unwrap() as usize,
|
||||
lanes: match lanes {
|
||||
Some(lanes) => lanes.1.parse::<u16>().unwrap() as usize,
|
||||
None => 1,
|
||||
},
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
/// Converts a bytes to String.
|
||||
named!(
|
||||
name<String>,
|
||||
map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
|
||||
b.to_vec()
|
||||
))
|
||||
);
|
||||
|
||||
/// Parses a TVMContext
|
||||
named!(
|
||||
tvm_ctx<&[u8], TVMContext>,
|
||||
do_parse!(
|
||||
device_type: le_u32 >>
|
||||
device_id: le_i32 >>
|
||||
(TVMContext { device_type: device_type as usize, device_id: device_id as usize })
|
||||
)
|
||||
);
|
||||
|
||||
/// Parses a DataType
|
||||
named!(
|
||||
data_type<&[u8], DataType>,
|
||||
do_parse!(
|
||||
code: le_u8 >>
|
||||
bits: le_u8 >>
|
||||
lanes: le_u16 >>
|
||||
(DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize })
|
||||
)
|
||||
);
|
||||
|
||||
/// Parses a Tensor from a TVM array file.
|
||||
named!(
|
||||
tensor<Tensor>,
|
||||
do_parse!(
|
||||
take!(8)
|
||||
>> bits!(tag_bits!(u64, 64, 0))
|
||||
>> ctx: tvm_ctx
|
||||
>> ndim: le_u32
|
||||
>> dtype: data_type
|
||||
>> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize)
|
||||
>> length: le_i64
|
||||
>> data: take!(length)
|
||||
>> (Tensor {
|
||||
data: Storage::from(data),
|
||||
ctx: ctx,
|
||||
dtype: dtype,
|
||||
size: shape.iter().product::<i64>() as usize,
|
||||
shape: shape,
|
||||
strides: None,
|
||||
byte_offset: 0,
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
/// Parses a graph params dict from a params binary file.
|
||||
named!(
|
||||
parse_param_dict<HashMap<String, Tensor>>,
|
||||
do_parse!(
|
||||
take!(8)
|
||||
>> bits!(tag_bits!(u64, 64, 0))
|
||||
>> names: length_count!(le_u64, name)
|
||||
>> tensors: length_count!(le_u64, tensor)
|
||||
>> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter())))
|
||||
)
|
||||
);
|
||||
|
||||
/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
|
||||
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
|
||||
if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
|
||||
if remaining_bytes.len() > 0 {
|
||||
bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
|
||||
} else {
|
||||
Ok(param_dict)
|
||||
}
|
||||
} else {
|
||||
bail!(ErrorKind::LoadGraphParamsError(
|
||||
"invalid parameters file".to_string()
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_str_to_type() {
|
||||
assert_eq!(
|
||||
tvm_str_to_type(CompleteStr("float24")).unwrap().1,
|
||||
DataType {
|
||||
code: DLDataTypeCode_kDLFloat as usize,
|
||||
bits: 24,
|
||||
lanes: 1
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1,
|
||||
DataType {
|
||||
code: DLDataTypeCode_kDLUInt as usize,
|
||||
bits: 111,
|
||||
lanes: 44
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
mod allocator;
|
||||
mod array;
|
||||
mod module;
|
||||
#[macro_use]
|
||||
mod packed_func;
|
||||
mod graph;
|
||||
#[cfg(target_env = "sgx")]
|
||||
#[macro_use]
|
||||
pub mod sgx;
|
||||
mod threading;
|
||||
mod workspace;
|
||||
|
||||
use std::os::raw::c_char;
|
||||
|
||||
pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*};
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
use self::sgx::ocall_packed_func;
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) {
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
unsafe {
|
||||
panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
|
||||
}
|
||||
#[cfg(target_env = "sgx")]
|
||||
ocall_packed!("__sgx_set_last_error__", cmsg);
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
use std::{
|
||||
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
|
||||
};
|
||||
|
||||
use ffi::runtime::BackendPackedCFunc;
|
||||
use runtime::packed_func::{wrap_backend_packed_func, PackedFunc};
|
||||
|
||||
pub trait Module {
|
||||
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
|
||||
}
|
||||
|
||||
pub struct SystemLibModule;
|
||||
|
||||
lazy_static! {
|
||||
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
|
||||
Mutex::new(HashMap::new());
|
||||
}
|
||||
|
||||
impl Module for SystemLibModule {
|
||||
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
|
||||
SYSTEM_LIB_FUNCTIONS
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(name.as_ref())
|
||||
.map(|func| wrap_backend_packed_func(func.to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SystemLibModule {
|
||||
fn default() -> Self {
|
||||
SystemLibModule {}
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
|
||||
cname: *const c_char,
|
||||
func: BackendPackedCFunc,
|
||||
) -> i32 {
|
||||
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
|
||||
SYSTEM_LIB_FUNCTIONS
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(name.to_string(), func);
|
||||
return 0;
|
||||
}
|
|
@ -1,342 +0,0 @@
|
|||
use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
|
||||
|
||||
use super::Tensor;
|
||||
use ffi::runtime::{
|
||||
BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor,
|
||||
TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMTypeCode_kNDArrayContainer, TVMValue,
|
||||
};
|
||||
|
||||
use errors::*;
|
||||
|
||||
pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
|
||||
|
||||
/// Calls a packed function and returns a `TVMRetValue`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
|
||||
#[macro_export]
|
||||
macro_rules! call_packed {
|
||||
($fn:expr, $($args:expr),+) => {
|
||||
$fn(&[$($args.into(),)+])
|
||||
};
|
||||
($fn:expr) => {
|
||||
$fn(&Vec::new())
|
||||
};
|
||||
}
|
||||
|
||||
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
|
||||
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct TVMArgValue<'a> {
|
||||
_lifetime: PhantomData<&'a ()>,
|
||||
pub(crate) value: TVMValue,
|
||||
pub(crate) type_code: i64,
|
||||
}
|
||||
|
||||
impl<'a> TVMArgValue<'a> {
|
||||
pub fn new(value: TVMValue, type_code: i64) -> Self {
|
||||
TVMArgValue {
|
||||
_lifetime: PhantomData,
|
||||
value: value,
|
||||
type_code: type_code,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
|
||||
macro_rules! impl_prim_tvm_arg {
|
||||
($type:ty, $field:ident, $code:expr, $as:ty) => {
|
||||
impl<'a> From<$type> for TVMArgValue<'a> {
|
||||
fn from(val: $type) -> Self {
|
||||
TVMArgValue {
|
||||
value: TVMValue { $field: val as $as },
|
||||
type_code: $code as i64,
|
||||
_lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a> TryFrom<TVMArgValue<'a>> for $type {
|
||||
type Error = Error;
|
||||
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
|
||||
ensure!(
|
||||
val.type_code == $code as i64,
|
||||
"Could not downcast arg. Expected `{}`, got `{}`",
|
||||
$code,
|
||||
val.type_code
|
||||
);
|
||||
Ok(unsafe { val.value.$field as $type })
|
||||
}
|
||||
}
|
||||
};
|
||||
($type:ty, $field:ident, $code:expr) => {
|
||||
impl_prim_tvm_arg!($type, $field, $code, $type);
|
||||
};
|
||||
($type:ty,v_int64) => {
|
||||
impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64);
|
||||
};
|
||||
($type:ty,v_float64) => {
|
||||
impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64);
|
||||
};
|
||||
}
|
||||
|
||||
impl_prim_tvm_arg!(f32, v_float64);
|
||||
impl_prim_tvm_arg!(f64, v_float64);
|
||||
impl_prim_tvm_arg!(i8, v_int64);
|
||||
impl_prim_tvm_arg!(u8, v_int64);
|
||||
impl_prim_tvm_arg!(i32, v_int64);
|
||||
impl_prim_tvm_arg!(u32, v_int64);
|
||||
impl_prim_tvm_arg!(i64, v_int64);
|
||||
impl_prim_tvm_arg!(u64, v_int64);
|
||||
|
||||
/// Creates a conversion to a `TVMArgValue` for an object handle.
|
||||
impl<'a, T> From<*const T> for TVMArgValue<'a> {
|
||||
fn from(ptr: *const T) -> Self {
|
||||
TVMArgValue {
|
||||
value: TVMValue {
|
||||
v_handle: ptr as *mut T as *mut c_void,
|
||||
},
|
||||
type_code: TVMTypeCode_kArrayHandle as i64,
|
||||
_lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
|
||||
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
|
||||
fn from(ptr: *mut T) -> Self {
|
||||
TVMArgValue {
|
||||
value: TVMValue {
|
||||
v_handle: ptr as *mut c_void,
|
||||
},
|
||||
type_code: TVMTypeCode_kHandle as i64,
|
||||
_lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
|
||||
fn from(arr: &'a mut DLTensor) -> Self {
|
||||
TVMArgValue {
|
||||
value: TVMValue {
|
||||
v_handle: arr as *mut _ as *mut c_void,
|
||||
},
|
||||
type_code: TVMTypeCode_kArrayHandle as i64,
|
||||
_lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
|
||||
fn from(arr: &'a DLTensor) -> Self {
|
||||
TVMArgValue {
|
||||
value: TVMValue {
|
||||
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
|
||||
},
|
||||
type_code: TVMTypeCode_kArrayHandle as i64,
|
||||
_lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
|
||||
type Error = Error;
|
||||
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
|
||||
ensure!(
|
||||
val.type_code == TVMTypeCode_kArrayHandle as i64
|
||||
|| val.type_code == TVMTypeCode_kNDArrayContainer as i64,
|
||||
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
|
||||
TVMTypeCode_kArrayHandle,
|
||||
TVMTypeCode_kNDArrayContainer,
|
||||
val.type_code,
|
||||
);
|
||||
|
||||
let dlt = unsafe { *(val.value.v_handle as *mut DLTensor as *const DLTensor) };
|
||||
Ok(dlt.into())
|
||||
}
|
||||
}
|
||||
|
||||
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
|
||||
/// Can be downcasted using `try_from` if it contains the desired type.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// let a = 42u32;
|
||||
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
|
||||
///
|
||||
/// let s = "hello, world!";
|
||||
/// let t: TVMRetValue = s.into();
|
||||
/// assert_eq!(String::try_from(t).unwrap(), s);
|
||||
/// ```
|
||||
pub struct TVMRetValue {
|
||||
/// A primitive return value, if any.
|
||||
prim_value: u64,
|
||||
/// An object return value, if any.
|
||||
box_value: Box<Any>,
|
||||
/// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use.
|
||||
type_code: i64,
|
||||
}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
impl TVMRetValue {
|
||||
pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
|
||||
unsafe {
|
||||
Self {
|
||||
prim_value: match type_code {
|
||||
0 | 1 => value.v_int64 as u64,
|
||||
2 => value.v_float64 as u64,
|
||||
3 | 7 | 8 | 9 | 10 => value.v_handle as u64,
|
||||
11 | 12 => value.v_str as u64,
|
||||
_ => 0,
|
||||
} as u64,
|
||||
box_value: box (),
|
||||
type_code: type_code,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_tvm_value(self) -> (TVMValue, i64) {
|
||||
let val = match self.type_code {
|
||||
0 | 1 => TVMValue {
|
||||
v_int64: self.prim_value.clone() as i64,
|
||||
},
|
||||
2 => TVMValue {
|
||||
v_float64: self.prim_value.clone() as f64,
|
||||
},
|
||||
3 | 7 | 8 | 9 | 10 | 13 => TVMValue {
|
||||
v_handle: Box::into_raw(self.box_value) as *mut c_void,
|
||||
},
|
||||
11 | 12 => TVMValue {
|
||||
v_str: Box::into_raw(self.box_value) as *const _,
|
||||
},
|
||||
_ => unreachable!(),
|
||||
};
|
||||
(val, self.type_code)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TVMRetValue {
|
||||
fn default() -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: 0,
|
||||
box_value: box (),
|
||||
type_code: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_prim_ret_value {
|
||||
($type:ty, $code:expr) => {
|
||||
impl From<$type> for TVMRetValue {
|
||||
fn from(val: $type) -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: val as u64,
|
||||
box_value: box (),
|
||||
type_code: $code,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl TryFrom<TVMRetValue> for $type {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<$type> {
|
||||
if ret.type_code == $code {
|
||||
Ok(ret.prim_value as $type)
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMRetValueError(
|
||||
stringify!($type).to_string(),
|
||||
ret.type_code
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! impl_boxed_ret_value {
|
||||
($type:ty, $code:expr) => {
|
||||
impl From<$type> for TVMRetValue {
|
||||
fn from(val: $type) -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: 0,
|
||||
box_value: box val,
|
||||
type_code: $code,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl TryFrom<TVMRetValue> for $type {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<$type> {
|
||||
if let Ok(val) = ret.box_value.downcast::<$type>() {
|
||||
Ok(*val)
|
||||
} else {
|
||||
bail!(ErrorKind::TryFromTVMRetValueError(
|
||||
stringify!($type).to_string(),
|
||||
ret.type_code
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_prim_ret_value!(i8, 0);
|
||||
impl_prim_ret_value!(u8, 1);
|
||||
impl_prim_ret_value!(i16, 0);
|
||||
impl_prim_ret_value!(u16, 1);
|
||||
impl_prim_ret_value!(i32, 0);
|
||||
impl_prim_ret_value!(u32, 1);
|
||||
impl_prim_ret_value!(f32, 2);
|
||||
impl_prim_ret_value!(i64, 0);
|
||||
impl_prim_ret_value!(u64, 1);
|
||||
impl_prim_ret_value!(f64, 2);
|
||||
impl_prim_ret_value!(isize, 0);
|
||||
impl_prim_ret_value!(usize, 1);
|
||||
impl_boxed_ret_value!(String, 11);
|
||||
|
||||
impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
|
||||
fn from(val: &'t Tensor<'a>) -> Self {
|
||||
TVMRetValue {
|
||||
prim_value: 0,
|
||||
box_value: box DLTensor::from(val),
|
||||
type_code: TVMTypeCode_kNDArrayContainer as i64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
|
||||
type Error = Error;
|
||||
fn try_from(ret: TVMRetValue) -> Result<Self> {
|
||||
ensure!(
|
||||
ret.type_code == TVMTypeCode_kArrayHandle as i64
|
||||
|| ret.type_code == TVMTypeCode_kNDArrayContainer as i64,
|
||||
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
|
||||
TVMTypeCode_kArrayHandle,
|
||||
TVMTypeCode_kNDArrayContainer,
|
||||
ret.type_code,
|
||||
);
|
||||
|
||||
let dlt = unsafe { *(ret.prim_value as *mut DLTensor as *const DLTensor) };
|
||||
Ok(dlt.into())
|
||||
}
|
||||
}
|
||||
|
||||
// @see `WrapPackedFunc` in `llvm_module.cc`.
|
||||
pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
|
||||
box move |args: &[TVMArgValue]| {
|
||||
func(
|
||||
args
|
||||
.iter()
|
||||
.map(|ref arg| arg.value)
|
||||
.collect::<Vec<TVMValue>>()
|
||||
.as_ptr(),
|
||||
args
|
||||
.iter()
|
||||
.map(|ref arg| arg.type_code as i32)
|
||||
.collect::<Vec<i32>>()
|
||||
.as_ptr() as *const i32,
|
||||
args.len() as i32,
|
||||
);
|
||||
TVMRetValue::default()
|
||||
}
|
||||
}
|
|
@ -1,82 +0,0 @@
|
|||
use std::{
|
||||
ffi::CString,
|
||||
os::raw::{c_char, c_int},
|
||||
};
|
||||
|
||||
use errors::Result;
|
||||
use ffi::runtime::TVMValue;
|
||||
use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
|
||||
|
||||
pub use runtime::threading::tvm_run_worker as run_worker;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tvm_ocall {
|
||||
($func: expr) => {
|
||||
match $func {
|
||||
0 => Ok(()),
|
||||
err => Err(format!("SGX error: {}", err)),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub type SgxStatus = u32;
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
extern "C" {
|
||||
fn tvm_ocall_packed_func(
|
||||
name: *const c_char,
|
||||
arg_values: *const TVMValue,
|
||||
type_codes: *const c_int,
|
||||
num_args: c_int,
|
||||
ret_val: *mut TVMValue,
|
||||
ret_type_code: *mut c_int,
|
||||
) -> SgxStatus;
|
||||
}
|
||||
|
||||
pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
|
||||
let mut ret_val = TVMValue { v_int64: 0 };
|
||||
let ret_type_code = 0i64;
|
||||
unsafe {
|
||||
tvm_ocall!(tvm_ocall_packed_func(
|
||||
CString::new(fn_name.as_ref()).unwrap().as_ptr(),
|
||||
args
|
||||
.iter()
|
||||
.map(|ref arg| arg.value)
|
||||
.collect::<Vec<TVMValue>>()
|
||||
.as_ptr(),
|
||||
args
|
||||
.iter()
|
||||
.map(|ref arg| arg.type_code as i32)
|
||||
.collect::<Vec<i32>>()
|
||||
.as_ptr() as *const i32,
|
||||
args.len() as i32,
|
||||
&mut ret_val as *mut TVMValue,
|
||||
&mut (ret_type_code as i32) as *mut c_int,
|
||||
))?;
|
||||
}
|
||||
Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! ocall_packed {
|
||||
($fn_name:expr, $($args:expr),+) => {
|
||||
ocall_packed_func($fn_name, &[$($args.into(),)+])
|
||||
.expect(concat!("Error calling `", $fn_name, "`"))
|
||||
};
|
||||
($fn_name:expr) => {
|
||||
ocall_packed_func($fn_name, &Vec::new())
|
||||
.expect(concat!("Error calling `", $fn_name, "`"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shutdown() {
|
||||
if env!("TVM_NUM_THREADS") != "0" {
|
||||
sgx_join_threads()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SystemLibModule {
|
||||
fn drop(&mut self) {
|
||||
shutdown()
|
||||
}
|
||||
}
|
|
@ -1,337 +0,0 @@
|
|||
use std::{
|
||||
os::raw::{c_int, c_void},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
|
||||
Arc, Barrier,
|
||||
},
|
||||
};
|
||||
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
use num_cpus;
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
use std::{
|
||||
env,
|
||||
thread::{self, JoinHandle},
|
||||
};
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
use std::{collections::VecDeque, ptr, sync::Mutex};
|
||||
|
||||
use bounded_spsc_queue::{self, Producer};
|
||||
|
||||
use super::super::errors::*;
|
||||
use ffi::runtime::TVMParallelGroupEnv;
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
|
||||
|
||||
type FTVMParallelLambda =
|
||||
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
|
||||
|
||||
/// Holds a parallel job request made by a TVM library function.
|
||||
struct Job {
|
||||
cb: FTVMParallelLambda,
|
||||
cdata: *const c_void,
|
||||
req_num_tasks: usize,
|
||||
pending: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl Job {
|
||||
/// Splits this job into a number of `Task`s which can be scheduled.
|
||||
fn tasks(&self, num_workers: usize) -> Vec<Task> {
|
||||
let num_tasks = if self.req_num_tasks == 0 {
|
||||
num_workers
|
||||
} else {
|
||||
self.req_num_tasks.min(num_workers)
|
||||
};
|
||||
self.pending.store(num_tasks, Ordering::SeqCst);
|
||||
|
||||
let barrier = Arc::new(Barrier::new(num_tasks));
|
||||
|
||||
(0..num_tasks)
|
||||
.map(move |i| Task {
|
||||
id: i,
|
||||
flambda: self.cb,
|
||||
penv: TVMParallelGroupEnv {
|
||||
sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void,
|
||||
num_task: num_tasks as i32,
|
||||
},
|
||||
cdata: self.cdata,
|
||||
pending: Arc::clone(&self.pending),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Waits for all tasks in this `Job` to be completed.
|
||||
fn wait(&self) -> Result<()> {
|
||||
while self.pending.load(Ordering::Acquire) > 0 {
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
thread::yield_now();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A chunk of work requested by a TVM function.
|
||||
struct Task {
|
||||
id: usize,
|
||||
flambda: FTVMParallelLambda,
|
||||
penv: TVMParallelGroupEnv,
|
||||
cdata: *const c_void,
|
||||
pending: Arc<AtomicUsize>,
|
||||
}
|
||||
unsafe impl Send for Task {}
|
||||
unsafe impl Sync for Task {}
|
||||
|
||||
impl FnOnce<()> for Task {
|
||||
type Output = i32;
|
||||
extern "rust-call" fn call_once(self, _args: ()) -> Self::Output {
|
||||
let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
|
||||
self.pending.fetch_sub(1, Ordering::AcqRel);
|
||||
status
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Threads {
|
||||
#[allow(unused)]
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
handles: Vec<JoinHandle<()>>,
|
||||
queues: Vec<Producer<Task>>,
|
||||
}
|
||||
|
||||
impl<'a> Threads {
|
||||
#[cfg(not(target_env = "sgx"))]
|
||||
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
|
||||
num_threads: usize,
|
||||
cb: F,
|
||||
) -> Self {
|
||||
let (handles, queues) = (0..num_threads)
|
||||
.map(|_| {
|
||||
let (p, c) = bounded_spsc_queue::make(2);
|
||||
let handle = thread::spawn(move || cb(c.into()));
|
||||
(handle, p)
|
||||
})
|
||||
.unzip();
|
||||
Threads {
|
||||
handles: handles,
|
||||
queues: queues,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
|
||||
num_threads: usize,
|
||||
_cb: F,
|
||||
) -> Self {
|
||||
let mut consumer_queues = SGX_QUEUES.lock().unwrap();
|
||||
let queues = (0..num_threads)
|
||||
.map(|_| {
|
||||
let (p, c) = bounded_spsc_queue::make(2);
|
||||
consumer_queues.push_back(c.into());
|
||||
p
|
||||
})
|
||||
.collect();
|
||||
ocall_packed!("__sgx_thread_group_launch__", num_threads as u64);
|
||||
Threads { queues: queues }
|
||||
}
|
||||
}
|
||||
|
||||
struct ThreadPool {
|
||||
num_workers: usize,
|
||||
#[allow(unused)]
|
||||
threads: Threads,
|
||||
}
|
||||
|
||||
thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new());
|
||||
|
||||
impl ThreadPool {
|
||||
fn new() -> Self {
|
||||
let num_workers = max_concurrency();
|
||||
ThreadPool {
|
||||
num_workers: num_workers,
|
||||
threads: Threads::launch(num_workers, ThreadPool::run_worker),
|
||||
}
|
||||
}
|
||||
|
||||
fn launch(&self, job: Job) {
|
||||
let mut tasks = job.tasks(self.num_workers + 1);
|
||||
|
||||
for (i, task) in tasks.split_off(1).into_iter().enumerate() {
|
||||
self.threads.queues[i].push(task);
|
||||
}
|
||||
|
||||
tasks.pop().unwrap()();
|
||||
job.wait().unwrap();
|
||||
}
|
||||
|
||||
fn run_worker(queue: Consumer<Task>) {
|
||||
loop {
|
||||
let task = queue.pop();
|
||||
let result = task();
|
||||
if result == <i32>::min_value() {
|
||||
break;
|
||||
} else if result != 0 {
|
||||
panic!("Error running task.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send + Sync wrapper for bounded_spsc_queue::Consumer
|
||||
struct Consumer<T> {
|
||||
consumer: bounded_spsc_queue::Consumer<T>,
|
||||
}
|
||||
impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> {
|
||||
fn from(c: bounded_spsc_queue::Consumer<T>) -> Self {
|
||||
Consumer { consumer: c }
|
||||
}
|
||||
}
|
||||
impl<T> Consumer<T> {
|
||||
fn pop(&self) -> T {
|
||||
self.consumer.pop()
|
||||
}
|
||||
}
|
||||
unsafe impl<T> Send for Consumer<T> {}
|
||||
unsafe impl<T> Sync for Consumer<T> {}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
lazy_static! {
|
||||
/// Holds tasks for untrusted threads which re-enter the enclave to execute.
|
||||
static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new());
|
||||
}
|
||||
|
||||
#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))]
|
||||
fn max_concurrency() -> usize {
|
||||
if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) {
|
||||
if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
|
||||
return threads;
|
||||
}
|
||||
}
|
||||
num_cpus::get_physical()
|
||||
}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
fn max_concurrency() -> usize {
|
||||
usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn max_concurrency() -> usize {
|
||||
0 // wasm doesn't support threads yet
|
||||
}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue {
|
||||
let q = {
|
||||
let mut qs = SGX_QUEUES.lock().unwrap();
|
||||
qs.pop_front()
|
||||
// `qs: MutexGuard` needs to be dropped here since `run_worker` won't return
|
||||
};
|
||||
if let Some(q) = q {
|
||||
ThreadPool::run_worker(q);
|
||||
}
|
||||
TVMRetValue::default()
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendParallelLaunch(
|
||||
cb: FTVMParallelLambda,
|
||||
cdata: *const c_void,
|
||||
num_task: usize,
|
||||
) -> c_int {
|
||||
if max_concurrency() == 0 {
|
||||
let penv = TVMParallelGroupEnv {
|
||||
sync_handle: 0 as *mut c_void,
|
||||
num_task: 1,
|
||||
};
|
||||
cb(0, &penv as *const _, cdata);
|
||||
} else {
|
||||
THREAD_POOL.with(|pool| {
|
||||
pool.launch(Job {
|
||||
cb: cb,
|
||||
cdata: cdata,
|
||||
req_num_tasks: num_task,
|
||||
pending: Arc::new(ATOMIC_USIZE_INIT),
|
||||
});
|
||||
});
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
#[cfg(target_env = "sgx")]
|
||||
pub(crate) fn sgx_join_threads() {
|
||||
extern "C" fn poison_pill(
|
||||
_task_id: usize,
|
||||
_penv: *const TVMParallelGroupEnv,
|
||||
_cdata: *const c_void,
|
||||
) -> i32 {
|
||||
<i32>::min_value()
|
||||
}
|
||||
|
||||
THREAD_POOL.with(|pool| {
|
||||
pool.launch(Job {
|
||||
cb: poison_pill,
|
||||
cdata: ptr::null(),
|
||||
req_num_tasks: 0,
|
||||
pending: Arc::new(ATOMIC_USIZE_INIT),
|
||||
});
|
||||
});
|
||||
ocall_packed!("__sgx_thread_group_join__", 0);
|
||||
}
|
||||
|
||||
// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) {
|
||||
let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) };
|
||||
barrier.wait();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{ptr, thread, time::Duration};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_max_concurrency() {
|
||||
env::set_var("TVM_NUM_THREADS", "42");
|
||||
env::set_var("OMP_NUM_THREADS", "24");
|
||||
assert_eq!(max_concurrency(), 42);
|
||||
env::remove_var("TVM_NUM_THREADS");
|
||||
assert_eq!(max_concurrency(), 24);
|
||||
}
|
||||
|
||||
extern "C" fn flambda(
|
||||
task_id: usize,
|
||||
penv: *const TVMParallelGroupEnv,
|
||||
cdata: *const c_void,
|
||||
) -> i32 {
|
||||
if cdata == ptr::null() {
|
||||
return 0;
|
||||
}
|
||||
unsafe {
|
||||
let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize));
|
||||
thread::sleep(Duration::from_millis(50 * task_id as u64));
|
||||
counter.fetch_add(1, Ordering::SeqCst);
|
||||
task_ids_sum.fetch_add(task_id, Ordering::SeqCst);
|
||||
assert_eq!((*penv).num_task, 3);
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_launch() {
|
||||
TVMBackendParallelLaunch(flambda, ptr::null(), 6);
|
||||
let counter = ATOMIC_USIZE_INIT;
|
||||
let task_ids_sum = ATOMIC_USIZE_INIT;
|
||||
let cdata = (counter, task_ids_sum);
|
||||
let num_tasks = 3;
|
||||
TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
|
||||
assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks);
|
||||
assert_eq!(
|
||||
cdata.1.load(Ordering::SeqCst),
|
||||
(0..num_tasks).sum::<usize>()
|
||||
);
|
||||
}
|
||||
}
|
|
@ -1,119 +0,0 @@
|
|||
use std::{
|
||||
cell::RefCell,
|
||||
os::raw::{c_int, c_void},
|
||||
ptr,
|
||||
};
|
||||
|
||||
use super::allocator::Allocation;
|
||||
use errors::*;
|
||||
|
||||
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
|
||||
|
||||
struct WorkspacePool {
|
||||
workspaces: Vec<Allocation>,
|
||||
free: Vec<usize>,
|
||||
in_use: Vec<usize>,
|
||||
}
|
||||
|
||||
impl WorkspacePool {
|
||||
fn new() -> Self {
|
||||
WorkspacePool {
|
||||
workspaces: Vec::new(),
|
||||
free: Vec::new(),
|
||||
in_use: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
|
||||
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
|
||||
self.in_use.push(self.workspaces.len() - 1);
|
||||
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
|
||||
}
|
||||
|
||||
fn alloc(&mut self, size: usize) -> Result<*mut u8> {
|
||||
if self.free.len() == 0 {
|
||||
return self.alloc_new(size);
|
||||
}
|
||||
let idx = self
|
||||
.free
|
||||
.iter()
|
||||
.fold(None, |cur_ws_idx: Option<usize>, &idx| {
|
||||
let ws_size = self.workspaces[idx].size();
|
||||
if !ws_size >= size {
|
||||
return cur_ws_idx;
|
||||
}
|
||||
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
|
||||
let cur_size = self.workspaces[cur_idx].size();
|
||||
Some(match ws_size <= cur_size {
|
||||
true => idx,
|
||||
false => cur_idx,
|
||||
})
|
||||
})
|
||||
});
|
||||
match idx {
|
||||
Some(idx) => {
|
||||
self.free.remove_item(&idx).unwrap();
|
||||
self.in_use.push(idx);
|
||||
Ok(self.workspaces[idx].as_mut_ptr())
|
||||
}
|
||||
None => self.alloc_new(size),
|
||||
}
|
||||
}
|
||||
|
||||
fn free(&mut self, ptr: *mut u8) -> Result<()> {
|
||||
let mut ws_idx = None;
|
||||
for i in 0..self.in_use.len() {
|
||||
let idx = self.in_use[i];
|
||||
if self.workspaces[idx].as_mut_ptr() == ptr {
|
||||
self.in_use.remove(i);
|
||||
ws_idx = Some(idx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(
|
||||
self
|
||||
.free
|
||||
.push(ws_idx.ok_or("Tried to free nonexistent workspace.")?),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
|
||||
|
||||
const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendAllocWorkspace(
|
||||
_device_type: c_int,
|
||||
_device_id: c_int,
|
||||
size: u64,
|
||||
_dtype_code_hint: c_int,
|
||||
_dtype_bits_hint: c_int,
|
||||
) -> *mut c_void {
|
||||
let nbytes = if size == 0 {
|
||||
WORKSPACE_PAGE_SIZE
|
||||
} else {
|
||||
size as usize
|
||||
};
|
||||
WORKSPACE_POOL.with(|pool_cell| {
|
||||
pool_cell
|
||||
.borrow_mut()
|
||||
.alloc(nbytes as usize)
|
||||
.unwrap_or(ptr::null_mut()) as *mut c_void
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendFreeWorkspace(
|
||||
_device_type: c_int,
|
||||
_device_id: c_int,
|
||||
ptr: *mut c_void,
|
||||
) -> c_int {
|
||||
WORKSPACE_POOL.with(|pool_cell| {
|
||||
(match pool_cell.borrow_mut().free(ptr as *mut u8) {
|
||||
Ok(()) => 0,
|
||||
Err(_) => -1,
|
||||
}) as c_int
|
||||
});
|
||||
return 0;
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
#![feature(try_from)]
|
||||
|
||||
extern crate serde;
|
||||
extern crate serde_json;
|
||||
|
||||
extern crate tvm;
|
||||
|
||||
use std::{convert::TryFrom, fs, io::Read};
|
||||
|
||||
use tvm::runtime::Graph;
|
||||
|
||||
#[test]
|
||||
fn test_load_graph() {
|
||||
let mut params_bytes = Vec::new();
|
||||
fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
|
||||
.expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
|
||||
.read_to_end(&mut params_bytes)
|
||||
.unwrap();
|
||||
let _params = tvm::runtime::load_param_dict(¶ms_bytes);
|
||||
|
||||
let graph = Graph::try_from(
|
||||
&fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(graph.nodes[3].op, "tvm_op");
|
||||
assert_eq!(
|
||||
graph.nodes[3]
|
||||
.attrs
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.get("func_name")
|
||||
.unwrap(),
|
||||
"fuse_dense"
|
||||
);
|
||||
assert_eq!(graph.nodes[5].inputs[0].index, 0);
|
||||
assert_eq!(graph.nodes[6].inputs[0].index, 1);
|
||||
assert_eq!(graph.heads.len(), 2);
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
extern crate ar;
|
||||
|
||||
use std::{
|
||||
env,
|
||||
fs::File,
|
||||
path::{Path, PathBuf},
|
||||
process::Command,
|
||||
};
|
||||
|
||||
use ar::Builder;
|
||||
|
||||
fn main() {
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
|
||||
let output = Command::new(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/src/build_test_graph.py"
|
||||
))
|
||||
.arg(&out_dir)
|
||||
.output()
|
||||
.expect("Failed to execute command");
|
||||
assert!(
|
||||
Path::new(&format!("{}/graph.o", out_dir)).exists(),
|
||||
"Could not build graph lib: {}",
|
||||
String::from_utf8(output.stderr)
|
||||
.unwrap()
|
||||
.trim()
|
||||
.split("\n")
|
||||
.last()
|
||||
.unwrap_or("")
|
||||
);
|
||||
|
||||
let in_path: PathBuf = [&out_dir, "graph.o"].iter().collect();
|
||||
let out_path: PathBuf = [&out_dir, "libgraph.a"].iter().collect();
|
||||
let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
|
||||
builder.append_path(in_path.to_str().unwrap()).unwrap();
|
||||
|
||||
println!("cargo:rustc-link-lib=static=graph");
|
||||
println!("cargo:rustc-link-search=native={}", out_dir);
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
#![feature(try_from)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate ndarray;
|
||||
extern crate serde;
|
||||
extern crate serde_json;
|
||||
|
||||
extern crate tvm;
|
||||
use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
|
||||
|
||||
use ndarray::Array;
|
||||
use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
|
||||
|
||||
const BATCH_SIZE: usize = 4;
|
||||
const IN_DIM: usize = 8;
|
||||
|
||||
macro_rules! check_sum {
|
||||
($e:expr, $a:ident, $b:ident) => {
|
||||
let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
|
||||
check_sum!(a, $b);
|
||||
};
|
||||
($e:expr, $a:expr, $b:ident) => {
|
||||
let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
|
||||
check_sum!(a, $b);
|
||||
};
|
||||
($a:ident, $b:ident) => {
|
||||
let a_sum: f32 = $a.scalar_sum();
|
||||
let b_sum: f32 = $b.scalar_sum();
|
||||
assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
|
||||
};
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let syslib = SystemLibModule::default();
|
||||
|
||||
let mut params_bytes = Vec::new();
|
||||
fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
|
||||
.unwrap()
|
||||
.read_to_end(&mut params_bytes)
|
||||
.unwrap();
|
||||
let params = tvm::runtime::load_param_dict(¶ms_bytes)
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.to_owned()))
|
||||
.collect::<HashMap<String, Tensor<'static>>>();
|
||||
|
||||
let graph =
|
||||
Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()).unwrap();
|
||||
let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
|
||||
|
||||
let x = Array::from_shape_vec(
|
||||
(BATCH_SIZE, IN_DIM),
|
||||
(0..BATCH_SIZE * IN_DIM)
|
||||
.map(|x| x as f32)
|
||||
.collect::<Vec<f32>>(),
|
||||
).unwrap();
|
||||
let w = Array::try_from(params.get("dense0_weight").unwrap())
|
||||
.unwrap()
|
||||
.into_shape((IN_DIM * 2, IN_DIM))
|
||||
.unwrap();
|
||||
let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap();
|
||||
let dense = x.dot(&w.t()) + &b;
|
||||
let left = dense.slice(s![.., 0..IN_DIM]);
|
||||
let right = dense.slice(s![.., IN_DIM..]);
|
||||
let expected_o0 = &left + 1f32;
|
||||
let expected_o1 = &right - 1f32;
|
||||
|
||||
exec.load_params(params);
|
||||
exec.set_input("data", x.clone().into());
|
||||
|
||||
check_sum!(exec, data, x);
|
||||
check_sum!(exec, dense0_weight, w);
|
||||
check_sum!(exec, dense0_bias, b);
|
||||
|
||||
exec.run();
|
||||
|
||||
check_sum!(exec, 0, expected_o0);
|
||||
check_sum!(exec, 1, expected_o1);
|
||||
check_sum!(exec, 2, dense);
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
extern crate ar;
|
||||
|
||||
use std::{env, path::PathBuf, process::Command};
|
||||
|
||||
use ar::Builder;
|
||||
use std::fs::File;
|
||||
|
||||
fn main() {
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
|
||||
let output = Command::new(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/src/build_test_lib.py"
|
||||
)).arg(&out_dir)
|
||||
.output()
|
||||
.expect("Failed to execute command");
|
||||
if output.stderr.len() > 0 {
|
||||
panic!(String::from_utf8(output.stderr).unwrap());
|
||||
}
|
||||
|
||||
let in_path: PathBuf = [&out_dir, "test.o"].iter().collect();
|
||||
let out_path: PathBuf = [&out_dir, "libtest.a"].iter().collect();
|
||||
let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
|
||||
builder.append_path(in_path.to_str().unwrap()).unwrap();
|
||||
|
||||
println!("cargo:rustc-link-lib=static=test");
|
||||
println!("cargo:rustc-link-search=native={}", out_dir);
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
extern crate ndarray;
|
||||
#[macro_use]
|
||||
extern crate tvm;
|
||||
|
||||
use ndarray::Array;
|
||||
use tvm::{
|
||||
ffi::runtime::DLTensor,
|
||||
runtime::{Module, SystemLibModule},
|
||||
};
|
||||
|
||||
fn main() {
|
||||
let syslib = SystemLibModule::default();
|
||||
let add = syslib
|
||||
.get_function("default_function")
|
||||
.expect("main function not found");
|
||||
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
|
||||
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
|
||||
let mut c = Array::from_vec(vec![0f32; 4]);
|
||||
let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
|
||||
let mut a_dl: DLTensor = (&mut a).into();
|
||||
let mut b_dl: DLTensor = (&mut b).into();
|
||||
let mut c_dl: DLTensor = (&mut c).into();
|
||||
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
|
||||
assert!(c.all_close(&e, 1e-8f32));
|
||||
}
|
|
@ -2,24 +2,60 @@
|
|||
|
||||
set -e
|
||||
|
||||
export LD_LIBRARY_PATH=lib:$LD_LIBRARY_PATH
|
||||
export TVM_HOME="$(git rev-parse --show-toplevel)"
|
||||
|
||||
tvm_root="$(git rev-parse --show-toplevel)"
|
||||
export PYTHONPATH="$tvm_root/python":"$tvm_root/nnvm/python":"$tvm_root/topi/python"
|
||||
export LD_LIBRARY_PATH="$TVM_HOME/lib":"$TVM_HOME/build":"$TVM_HOME/nnvm":$LD_LIBRARY_PATH
|
||||
export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/nnvm/python":"$TVM_HOME/topi/python"
|
||||
export RUST_DIR="$TVM_HOME/rust"
|
||||
|
||||
#cd rust
|
||||
#cargo fmt -- --check
|
||||
cd $RUST_DIR
|
||||
cargo fmt -- --check
|
||||
|
||||
# test common
|
||||
cd $RUST_DIR/common
|
||||
cargo build --features runtime
|
||||
cargo test --features runtime --tests
|
||||
|
||||
cargo build --features frontend
|
||||
cargo test --features frontend --tests
|
||||
|
||||
# test runtime
|
||||
cd $RUST_DIR/runtime
|
||||
|
||||
# run basic tests
|
||||
#python3 tests/build_model.py
|
||||
#cargo test --tests
|
||||
python3 tests/build_model.py
|
||||
cargo test --tests
|
||||
|
||||
# run TVM module test
|
||||
#cd tests/test_tvm_basic
|
||||
#cargo run
|
||||
#cd -
|
||||
cd tests/test_tvm_basic
|
||||
cargo run
|
||||
cd -
|
||||
|
||||
# run NNVM graph test
|
||||
#cd tests/test_nnvm
|
||||
#cargo run
|
||||
#cd -
|
||||
cd tests/test_nnvm
|
||||
cargo run
|
||||
cd -
|
||||
|
||||
# test frontend
|
||||
cd $RUST_DIR/frontend
|
||||
|
||||
cargo test --tests -- --test-threads=1
|
||||
|
||||
# run basic tests on cpu
|
||||
cd tests/basics
|
||||
cargo build --features cpu
|
||||
cargo run --features cpu
|
||||
# uncomment when have more CI resources
|
||||
# cargo build --features gpu
|
||||
# cargo run --features gpu
|
||||
# fi
|
||||
cd -
|
||||
|
||||
# run callback tests separately: https://discuss.tvm.ai/t/are-global-functions-need-to-be-accessed-in-separate-processes/1075
|
||||
cd tests/callback
|
||||
cargo build
|
||||
cargo run --bin int
|
||||
cargo run --bin float
|
||||
cargo run --bin array
|
||||
cargo run --bin string
|
||||
cd -
|
||||
|
|
Загрузка…
Ссылка в новой задаче