[RUST] Rust DSO module (#2976)
This commit is contained in:
Родитель
05f7fa9b05
Коммит
a479432d90
|
@ -20,6 +20,7 @@ members = [
|
|||
"common",
|
||||
"runtime",
|
||||
"runtime/tests/test_tvm_basic",
|
||||
"runtime/tests/test_tvm_dso",
|
||||
"runtime/tests/test_nnvm",
|
||||
"frontend",
|
||||
"frontend/tests/basics",
|
||||
|
|
|
@ -22,23 +22,30 @@ extern crate bindgen;
|
|||
use std::path::PathBuf;
|
||||
|
||||
fn main() {
|
||||
let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
|
||||
let tvm_home = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.canonicalize()
|
||||
.unwrap();
|
||||
tvm_home
|
||||
.parent()
|
||||
.unwrap()
|
||||
.parent()
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_string()
|
||||
});
|
||||
if cfg!(feature = "bindings") {
|
||||
println!("cargo:rerun-if-env-changed=TVM_HOME");
|
||||
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
|
||||
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
|
||||
println!("cargo:rustc-link-search={}/build", tvm_home);
|
||||
}
|
||||
|
||||
// @see rust-bindgen#550 for `blacklist_type`
|
||||
bindgen::Builder::default()
|
||||
.header(format!(
|
||||
"{}/include/tvm/runtime/c_runtime_api.h",
|
||||
env!("TVM_HOME")
|
||||
))
|
||||
.header(format!(
|
||||
"{}/include/tvm/runtime/c_backend_api.h",
|
||||
env!("TVM_HOME")
|
||||
))
|
||||
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
|
||||
.header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
|
||||
.header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
|
||||
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
|
||||
.blacklist_type("max_align_t")
|
||||
.layout_tests(false)
|
||||
.derive_partialeq(true)
|
||||
|
|
|
@ -45,3 +45,6 @@ tvm-common = { version = "0.1.0", path = "../common/" }
|
|||
|
||||
[target.'cfg(not(target_env = "sgx"))'.dependencies]
|
||||
num_cpus = "1.8.0"
|
||||
|
||||
[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
|
||||
libloading = "0.5"
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::HashMap,
|
||||
ffi::CStr,
|
||||
os::raw::{c_char, c_int, c_void},
|
||||
pin::Pin,
|
||||
};
|
||||
|
||||
use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
|
||||
|
||||
use crate::{
|
||||
threading::{TVMBackendParallelBarrier, TVMBackendParallelLaunch},
|
||||
workspace::{TVMBackendAllocWorkspace, TVMBackendFreeWorkspace},
|
||||
TVMAPISetLastError,
|
||||
};
|
||||
|
||||
use super::Module;
|
||||
|
||||
const TVM_MAIN: &'static [u8] = b"__tvm_main__";
|
||||
const TVM_MODULE_CTX: &'static [u8] = b"__tvm_module_ctx";
|
||||
|
||||
/// A module backed by a Dynamic Shared Object (dylib).
|
||||
pub struct DsoModule<'a> {
|
||||
lib: libloading::Library,
|
||||
packed_funcs: RefCell<HashMap<String, &'a (dyn PackedFunc)>>,
|
||||
_pin: std::marker::PhantomPinned,
|
||||
}
|
||||
|
||||
macro_rules! init_context_func {
|
||||
($lib:ident, $( ($fn:ident, $sig:ty) ),+ $(,)?) => {
|
||||
unsafe {
|
||||
$(
|
||||
let fn_ptr = $lib.get::<*mut $sig>(concat!("__", stringify!($fn)).as_bytes());
|
||||
if let Ok(fn_ptr) = fn_ptr {
|
||||
**fn_ptr = $fn;
|
||||
}
|
||||
)+
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl<'a> DsoModule<'a> {
|
||||
pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, failure::Error> {
|
||||
let lib = libloading::Library::new(filename)?;
|
||||
|
||||
init_context_func!(
|
||||
lib,
|
||||
(TVMAPISetLastError, extern "C" fn(*const i8)),
|
||||
(
|
||||
TVMBackendAllocWorkspace,
|
||||
extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void
|
||||
),
|
||||
(
|
||||
TVMBackendFreeWorkspace,
|
||||
extern "C" fn(c_int, c_int, *mut c_void) -> c_int
|
||||
),
|
||||
(
|
||||
TVMBackendParallelLaunch,
|
||||
extern "C" fn(crate::threading::FTVMParallelLambda, *const c_void, usize) -> c_int
|
||||
),
|
||||
(
|
||||
TVMBackendParallelBarrier,
|
||||
extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv)
|
||||
),
|
||||
);
|
||||
|
||||
// Pin the module in memory so that `ctx` pointer (below) is stable.
|
||||
let dso_mod = Box::pin(Self {
|
||||
lib,
|
||||
packed_funcs: RefCell::new(HashMap::new()),
|
||||
_pin: std::marker::PhantomPinned,
|
||||
});
|
||||
|
||||
unsafe {
|
||||
if let Ok(ctx) = dso_mod.lib.get::<*mut *const c_void>(TVM_MODULE_CTX) {
|
||||
**ctx = &dso_mod as *const _ as *const c_void;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dso_mod)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Module for DsoModule<'a> {
|
||||
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
|
||||
let name = name.as_ref();
|
||||
let func = match unsafe {
|
||||
self.lib
|
||||
.get::<BackendPackedCFunc>(if name.as_bytes() == TVM_MAIN {
|
||||
// If __tvm_main__ is present, it contains the name of the
|
||||
// actual main function.
|
||||
match self
|
||||
.lib
|
||||
.get::<*const c_char>(TVM_MAIN)
|
||||
.map(|p| CStr::from_ptr(*p))
|
||||
{
|
||||
Ok(m) => m.to_bytes(),
|
||||
_ => return None,
|
||||
}
|
||||
} else {
|
||||
name.as_bytes()
|
||||
})
|
||||
} {
|
||||
Ok(func) => unsafe { func.into_raw() },
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
self.packed_funcs.borrow_mut().insert(
|
||||
name.to_string(),
|
||||
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)),
|
||||
);
|
||||
|
||||
self.packed_funcs.borrow().get(name).map(|f| *f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for DsoModule<'a> {
|
||||
fn drop(&mut self) {
|
||||
self.packed_funcs
|
||||
.replace(HashMap::new())
|
||||
.into_iter()
|
||||
.map(|(_name, f)| unsafe { Box::from_raw(f as *const _ as *mut (dyn PackedFunc)) })
|
||||
.for_each(std::mem::drop);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
|
||||
mod dso;
|
||||
mod syslib;
|
||||
|
||||
use tvm_common::{
|
||||
ffi::BackendPackedCFunc,
|
||||
packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
|
||||
};
|
||||
|
||||
#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
|
||||
pub use dso::DsoModule;
|
||||
pub use syslib::SystemLibModule;
|
||||
|
||||
pub trait Module {
|
||||
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
|
||||
}
|
||||
|
||||
// @see `WrapPackedFunc` in `llvm_module.cc`.
|
||||
fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<dyn PackedFunc> {
|
||||
box move |args: &[TVMArgValue]| {
|
||||
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
|
||||
.into_iter()
|
||||
.map(|arg| {
|
||||
let (val, code) = arg.to_tvm_value();
|
||||
(val, code as i32)
|
||||
})
|
||||
.unzip();
|
||||
let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
|
||||
if exit_code == 0 {
|
||||
Ok(TVMRetValue::default())
|
||||
} else {
|
||||
Err(tvm_common::errors::FuncCallError::get_with_context(
|
||||
func_name.clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -21,14 +21,9 @@ use std::{
|
|||
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
|
||||
};
|
||||
|
||||
use tvm_common::{
|
||||
ffi::BackendPackedCFunc,
|
||||
packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
|
||||
};
|
||||
use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
|
||||
|
||||
pub trait Module {
|
||||
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
|
||||
}
|
||||
use super::Module;
|
||||
|
||||
pub struct SystemLibModule;
|
||||
|
||||
|
@ -53,30 +48,6 @@ impl Default for SystemLibModule {
|
|||
}
|
||||
}
|
||||
|
||||
// @see `WrapPackedFunc` in `llvm_module.cc`.
|
||||
pub(super) fn wrap_backend_packed_func(
|
||||
func_name: String,
|
||||
func: BackendPackedCFunc,
|
||||
) -> Box<dyn PackedFunc> {
|
||||
box move |args: &[TVMArgValue]| {
|
||||
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
|
||||
.into_iter()
|
||||
.map(|arg| {
|
||||
let (val, code) = arg.to_tvm_value();
|
||||
(val, code as i32)
|
||||
})
|
||||
.unzip();
|
||||
let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
|
||||
if exit_code == 0 {
|
||||
Ok(TVMRetValue::default())
|
||||
} else {
|
||||
Err(tvm_common::errors::FuncCallError::get_with_context(
|
||||
func_name.clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
|
||||
cname: *const c_char,
|
||||
|
@ -85,7 +56,7 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
|
|||
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
|
||||
SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
|
||||
name.to_string(),
|
||||
&*Box::leak(wrap_backend_packed_func(name.to_string(), func)),
|
||||
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)),
|
||||
);
|
||||
return 0;
|
||||
}
|
|
@ -42,7 +42,7 @@ use tvm_common::ffi::TVMParallelGroupEnv;
|
|||
#[cfg(target_env = "sgx")]
|
||||
use super::{TVMArgValue, TVMRetValue};
|
||||
|
||||
type FTVMParallelLambda =
|
||||
pub(crate) type FTVMParallelLambda =
|
||||
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
|
||||
|
||||
/// Holds a parallel job request made by a TVM library function.
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
[package]
|
||||
name = "test-tvm-dso"
|
||||
version = "0.0.0"
|
||||
license = "Apache-2.0"
|
||||
authors = ["TVM Contributors"]
|
||||
|
||||
[dependencies]
|
||||
ndarray="0.12"
|
||||
tvm-runtime = { path = "../../" }
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
use std::{env, path::Path, process::Command};
|
||||
|
||||
fn main() {
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
|
||||
let output = Command::new(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/src/build_test_lib.py"
|
||||
))
|
||||
.arg(&out_dir)
|
||||
.output()
|
||||
.expect("Failed to execute command");
|
||||
assert!(
|
||||
Path::new(&format!("{}/test.so", out_dir)).exists(),
|
||||
"Could not build tvm lib: {}",
|
||||
String::from_utf8(output.stderr)
|
||||
.unwrap()
|
||||
.trim()
|
||||
.split("\n")
|
||||
.last()
|
||||
.unwrap_or("")
|
||||
);
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Prepares a simple TVM library for testing."""
|
||||
|
||||
from os import path as osp
|
||||
import sys
|
||||
|
||||
import tvm
|
||||
from tvm.contrib import cc
|
||||
|
||||
def main():
|
||||
n = tvm.var('n')
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
B = tvm.placeholder((n,), name='B')
|
||||
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
s[C].parallel(s[C].op.axis[0])
|
||||
print(tvm.lower(s, [A, B, C], simple_mode=True))
|
||||
obj_file = osp.join(sys.argv[1], 'test.o')
|
||||
tvm.build(s, [A, B, C], 'llvm').save(obj_file)
|
||||
cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file])
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
extern crate ndarray;
|
||||
#[macro_use]
|
||||
extern crate tvm_runtime;
|
||||
|
||||
use ndarray::Array;
|
||||
use tvm_runtime::{DLTensor, DsoModule, Module};
|
||||
|
||||
fn main() {
|
||||
tvm_runtime::TVMGetLastError();
|
||||
let module = DsoModule::new(concat!(env!("OUT_DIR"), "/test.so")).unwrap();
|
||||
let add = module
|
||||
.get_function("__tvm_main__")
|
||||
.expect("main function not found");
|
||||
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
|
||||
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
|
||||
let mut c = Array::from_vec(vec![0f32; 4]);
|
||||
let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
|
||||
let mut a_dl: DLTensor = (&mut a).into();
|
||||
let mut b_dl: DLTensor = (&mut b).into();
|
||||
let mut c_dl: DLTensor = (&mut c).into();
|
||||
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
|
||||
assert!(c.all_close(&e, 1e-8f32));
|
||||
}
|
|
@ -48,6 +48,10 @@ cd tests/test_tvm_basic
|
|||
cargo run
|
||||
cd -
|
||||
|
||||
cd tests/test_tvm_dso
|
||||
cargo run
|
||||
cd -
|
||||
|
||||
# run NNVM graph test
|
||||
cd tests/test_nnvm
|
||||
cargo run
|
||||
|
|
Загрузка…
Ссылка в новой задаче