[RUST] Rust DSO module (#2976)
This commit is contained in:
Родитель
05f7fa9b05
Коммит
a479432d90
|
@ -20,6 +20,7 @@ members = [
|
||||||
"common",
|
"common",
|
||||||
"runtime",
|
"runtime",
|
||||||
"runtime/tests/test_tvm_basic",
|
"runtime/tests/test_tvm_basic",
|
||||||
|
"runtime/tests/test_tvm_dso",
|
||||||
"runtime/tests/test_nnvm",
|
"runtime/tests/test_nnvm",
|
||||||
"frontend",
|
"frontend",
|
||||||
"frontend/tests/basics",
|
"frontend/tests/basics",
|
||||||
|
|
|
@ -22,23 +22,30 @@ extern crate bindgen;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
|
||||||
|
let tvm_home = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||||
|
.canonicalize()
|
||||||
|
.unwrap();
|
||||||
|
tvm_home
|
||||||
|
.parent()
|
||||||
|
.unwrap()
|
||||||
|
.parent()
|
||||||
|
.unwrap()
|
||||||
|
.to_str()
|
||||||
|
.unwrap()
|
||||||
|
.to_string()
|
||||||
|
});
|
||||||
if cfg!(feature = "bindings") {
|
if cfg!(feature = "bindings") {
|
||||||
println!("cargo:rerun-if-env-changed=TVM_HOME");
|
println!("cargo:rerun-if-env-changed=TVM_HOME");
|
||||||
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
|
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
|
||||||
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
|
println!("cargo:rustc-link-search={}/build", tvm_home);
|
||||||
}
|
}
|
||||||
|
|
||||||
// @see rust-bindgen#550 for `blacklist_type`
|
// @see rust-bindgen#550 for `blacklist_type`
|
||||||
bindgen::Builder::default()
|
bindgen::Builder::default()
|
||||||
.header(format!(
|
.header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
|
||||||
"{}/include/tvm/runtime/c_runtime_api.h",
|
.header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
|
||||||
env!("TVM_HOME")
|
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
|
||||||
))
|
|
||||||
.header(format!(
|
|
||||||
"{}/include/tvm/runtime/c_backend_api.h",
|
|
||||||
env!("TVM_HOME")
|
|
||||||
))
|
|
||||||
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
|
|
||||||
.blacklist_type("max_align_t")
|
.blacklist_type("max_align_t")
|
||||||
.layout_tests(false)
|
.layout_tests(false)
|
||||||
.derive_partialeq(true)
|
.derive_partialeq(true)
|
||||||
|
|
|
@ -45,3 +45,6 @@ tvm-common = { version = "0.1.0", path = "../common/" }
|
||||||
|
|
||||||
[target.'cfg(not(target_env = "sgx"))'.dependencies]
|
[target.'cfg(not(target_env = "sgx"))'.dependencies]
|
||||||
num_cpus = "1.8.0"
|
num_cpus = "1.8.0"
|
||||||
|
|
||||||
|
[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
|
||||||
|
libloading = "0.5"
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
cell::RefCell,
|
||||||
|
collections::HashMap,
|
||||||
|
ffi::CStr,
|
||||||
|
os::raw::{c_char, c_int, c_void},
|
||||||
|
pin::Pin,
|
||||||
|
};
|
||||||
|
|
||||||
|
use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
threading::{TVMBackendParallelBarrier, TVMBackendParallelLaunch},
|
||||||
|
workspace::{TVMBackendAllocWorkspace, TVMBackendFreeWorkspace},
|
||||||
|
TVMAPISetLastError,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::Module;
|
||||||
|
|
||||||
|
const TVM_MAIN: &'static [u8] = b"__tvm_main__";
|
||||||
|
const TVM_MODULE_CTX: &'static [u8] = b"__tvm_module_ctx";
|
||||||
|
|
||||||
|
/// A module backed by a Dynamic Shared Object (dylib).
|
||||||
|
pub struct DsoModule<'a> {
|
||||||
|
lib: libloading::Library,
|
||||||
|
packed_funcs: RefCell<HashMap<String, &'a (dyn PackedFunc)>>,
|
||||||
|
_pin: std::marker::PhantomPinned,
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! init_context_func {
|
||||||
|
($lib:ident, $( ($fn:ident, $sig:ty) ),+ $(,)?) => {
|
||||||
|
unsafe {
|
||||||
|
$(
|
||||||
|
let fn_ptr = $lib.get::<*mut $sig>(concat!("__", stringify!($fn)).as_bytes());
|
||||||
|
if let Ok(fn_ptr) = fn_ptr {
|
||||||
|
**fn_ptr = $fn;
|
||||||
|
}
|
||||||
|
)+
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> DsoModule<'a> {
|
||||||
|
pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, failure::Error> {
|
||||||
|
let lib = libloading::Library::new(filename)?;
|
||||||
|
|
||||||
|
init_context_func!(
|
||||||
|
lib,
|
||||||
|
(TVMAPISetLastError, extern "C" fn(*const i8)),
|
||||||
|
(
|
||||||
|
TVMBackendAllocWorkspace,
|
||||||
|
extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void
|
||||||
|
),
|
||||||
|
(
|
||||||
|
TVMBackendFreeWorkspace,
|
||||||
|
extern "C" fn(c_int, c_int, *mut c_void) -> c_int
|
||||||
|
),
|
||||||
|
(
|
||||||
|
TVMBackendParallelLaunch,
|
||||||
|
extern "C" fn(crate::threading::FTVMParallelLambda, *const c_void, usize) -> c_int
|
||||||
|
),
|
||||||
|
(
|
||||||
|
TVMBackendParallelBarrier,
|
||||||
|
extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv)
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Pin the module in memory so that `ctx` pointer (below) is stable.
|
||||||
|
let dso_mod = Box::pin(Self {
|
||||||
|
lib,
|
||||||
|
packed_funcs: RefCell::new(HashMap::new()),
|
||||||
|
_pin: std::marker::PhantomPinned,
|
||||||
|
});
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
if let Ok(ctx) = dso_mod.lib.get::<*mut *const c_void>(TVM_MODULE_CTX) {
|
||||||
|
**ctx = &dso_mod as *const _ as *const c_void;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(dso_mod)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Module for DsoModule<'a> {
|
||||||
|
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
|
||||||
|
let name = name.as_ref();
|
||||||
|
let func = match unsafe {
|
||||||
|
self.lib
|
||||||
|
.get::<BackendPackedCFunc>(if name.as_bytes() == TVM_MAIN {
|
||||||
|
// If __tvm_main__ is present, it contains the name of the
|
||||||
|
// actual main function.
|
||||||
|
match self
|
||||||
|
.lib
|
||||||
|
.get::<*const c_char>(TVM_MAIN)
|
||||||
|
.map(|p| CStr::from_ptr(*p))
|
||||||
|
{
|
||||||
|
Ok(m) => m.to_bytes(),
|
||||||
|
_ => return None,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
name.as_bytes()
|
||||||
|
})
|
||||||
|
} {
|
||||||
|
Ok(func) => unsafe { func.into_raw() },
|
||||||
|
Err(_) => return None,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.packed_funcs.borrow_mut().insert(
|
||||||
|
name.to_string(),
|
||||||
|
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)),
|
||||||
|
);
|
||||||
|
|
||||||
|
self.packed_funcs.borrow().get(name).map(|f| *f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Drop for DsoModule<'a> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.packed_funcs
|
||||||
|
.replace(HashMap::new())
|
||||||
|
.into_iter()
|
||||||
|
.map(|(_name, f)| unsafe { Box::from_raw(f as *const _ as *mut (dyn PackedFunc)) })
|
||||||
|
.for_each(std::mem::drop);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
|
||||||
|
mod dso;
|
||||||
|
mod syslib;
|
||||||
|
|
||||||
|
use tvm_common::{
|
||||||
|
ffi::BackendPackedCFunc,
|
||||||
|
packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
|
||||||
|
pub use dso::DsoModule;
|
||||||
|
pub use syslib::SystemLibModule;
|
||||||
|
|
||||||
|
pub trait Module {
|
||||||
|
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
|
||||||
|
}
|
||||||
|
|
||||||
|
// @see `WrapPackedFunc` in `llvm_module.cc`.
|
||||||
|
fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<dyn PackedFunc> {
|
||||||
|
box move |args: &[TVMArgValue]| {
|
||||||
|
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
|
||||||
|
.into_iter()
|
||||||
|
.map(|arg| {
|
||||||
|
let (val, code) = arg.to_tvm_value();
|
||||||
|
(val, code as i32)
|
||||||
|
})
|
||||||
|
.unzip();
|
||||||
|
let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
|
||||||
|
if exit_code == 0 {
|
||||||
|
Ok(TVMRetValue::default())
|
||||||
|
} else {
|
||||||
|
Err(tvm_common::errors::FuncCallError::get_with_context(
|
||||||
|
func_name.clone(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -21,14 +21,9 @@ use std::{
|
||||||
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
|
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
|
||||||
};
|
};
|
||||||
|
|
||||||
use tvm_common::{
|
use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
|
||||||
ffi::BackendPackedCFunc,
|
|
||||||
packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub trait Module {
|
use super::Module;
|
||||||
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct SystemLibModule;
|
pub struct SystemLibModule;
|
||||||
|
|
||||||
|
@ -53,30 +48,6 @@ impl Default for SystemLibModule {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// @see `WrapPackedFunc` in `llvm_module.cc`.
|
|
||||||
pub(super) fn wrap_backend_packed_func(
|
|
||||||
func_name: String,
|
|
||||||
func: BackendPackedCFunc,
|
|
||||||
) -> Box<dyn PackedFunc> {
|
|
||||||
box move |args: &[TVMArgValue]| {
|
|
||||||
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
|
|
||||||
.into_iter()
|
|
||||||
.map(|arg| {
|
|
||||||
let (val, code) = arg.to_tvm_value();
|
|
||||||
(val, code as i32)
|
|
||||||
})
|
|
||||||
.unzip();
|
|
||||||
let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
|
|
||||||
if exit_code == 0 {
|
|
||||||
Ok(TVMRetValue::default())
|
|
||||||
} else {
|
|
||||||
Err(tvm_common::errors::FuncCallError::get_with_context(
|
|
||||||
func_name.clone(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
|
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
|
||||||
cname: *const c_char,
|
cname: *const c_char,
|
||||||
|
@ -85,7 +56,7 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
|
||||||
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
|
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
|
||||||
SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
|
SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
|
||||||
name.to_string(),
|
name.to_string(),
|
||||||
&*Box::leak(wrap_backend_packed_func(name.to_string(), func)),
|
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)),
|
||||||
);
|
);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
|
@ -42,7 +42,7 @@ use tvm_common::ffi::TVMParallelGroupEnv;
|
||||||
#[cfg(target_env = "sgx")]
|
#[cfg(target_env = "sgx")]
|
||||||
use super::{TVMArgValue, TVMRetValue};
|
use super::{TVMArgValue, TVMRetValue};
|
||||||
|
|
||||||
type FTVMParallelLambda =
|
pub(crate) type FTVMParallelLambda =
|
||||||
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
|
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
|
||||||
|
|
||||||
/// Holds a parallel job request made by a TVM library function.
|
/// Holds a parallel job request made by a TVM library function.
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
[package]
|
||||||
|
name = "test-tvm-dso"
|
||||||
|
version = "0.0.0"
|
||||||
|
license = "Apache-2.0"
|
||||||
|
authors = ["TVM Contributors"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
ndarray="0.12"
|
||||||
|
tvm-runtime = { path = "../../" }
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
use std::{env, path::Path, process::Command};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let out_dir = env::var("OUT_DIR").unwrap();
|
||||||
|
|
||||||
|
let output = Command::new(concat!(
|
||||||
|
env!("CARGO_MANIFEST_DIR"),
|
||||||
|
"/src/build_test_lib.py"
|
||||||
|
))
|
||||||
|
.arg(&out_dir)
|
||||||
|
.output()
|
||||||
|
.expect("Failed to execute command");
|
||||||
|
assert!(
|
||||||
|
Path::new(&format!("{}/test.so", out_dir)).exists(),
|
||||||
|
"Could not build tvm lib: {}",
|
||||||
|
String::from_utf8(output.stderr)
|
||||||
|
.unwrap()
|
||||||
|
.trim()
|
||||||
|
.split("\n")
|
||||||
|
.last()
|
||||||
|
.unwrap_or("")
|
||||||
|
);
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""Prepares a simple TVM library for testing."""
|
||||||
|
|
||||||
|
from os import path as osp
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import tvm
|
||||||
|
from tvm.contrib import cc
|
||||||
|
|
||||||
|
def main():
|
||||||
|
n = tvm.var('n')
|
||||||
|
A = tvm.placeholder((n,), name='A')
|
||||||
|
B = tvm.placeholder((n,), name='B')
|
||||||
|
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
|
||||||
|
s = tvm.create_schedule(C.op)
|
||||||
|
s[C].parallel(s[C].op.axis[0])
|
||||||
|
print(tvm.lower(s, [A, B, C], simple_mode=True))
|
||||||
|
obj_file = osp.join(sys.argv[1], 'test.o')
|
||||||
|
tvm.build(s, [A, B, C], 'llvm').save(obj_file)
|
||||||
|
cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
extern crate ndarray;
|
||||||
|
#[macro_use]
|
||||||
|
extern crate tvm_runtime;
|
||||||
|
|
||||||
|
use ndarray::Array;
|
||||||
|
use tvm_runtime::{DLTensor, DsoModule, Module};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
tvm_runtime::TVMGetLastError();
|
||||||
|
let module = DsoModule::new(concat!(env!("OUT_DIR"), "/test.so")).unwrap();
|
||||||
|
let add = module
|
||||||
|
.get_function("__tvm_main__")
|
||||||
|
.expect("main function not found");
|
||||||
|
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
|
||||||
|
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
|
||||||
|
let mut c = Array::from_vec(vec![0f32; 4]);
|
||||||
|
let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
|
||||||
|
let mut a_dl: DLTensor = (&mut a).into();
|
||||||
|
let mut b_dl: DLTensor = (&mut b).into();
|
||||||
|
let mut c_dl: DLTensor = (&mut c).into();
|
||||||
|
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
|
||||||
|
assert!(c.all_close(&e, 1e-8f32));
|
||||||
|
}
|
|
@ -48,6 +48,10 @@ cd tests/test_tvm_basic
|
||||||
cargo run
|
cargo run
|
||||||
cd -
|
cd -
|
||||||
|
|
||||||
|
cd tests/test_tvm_dso
|
||||||
|
cargo run
|
||||||
|
cd -
|
||||||
|
|
||||||
# run NNVM graph test
|
# run NNVM graph test
|
||||||
cd tests/test_nnvm
|
cd tests/test_nnvm
|
||||||
cargo run
|
cargo run
|
||||||
|
|
Загрузка…
Ссылка в новой задаче