Fix rust compile issues and add GH action to run build validations and tests (#18346)
### Description This PR gets the onnxruntime Rust bindings to a foundation where they can be extended and validated as the onnxruntime progresses. Specifically, the PR does the following. - fixes some of the existing compilation issues due to missing some enums output tensor data types. - introduces a `just vendor` task that will vendor the source code from the onnxruntime to enable a common base directory within the crate directory rather than using a relative parent path. This enables `crate package` to be able to archive the onnxruntime native code, which will enable consumers of the onnxruntime-sys crate to be able to compile on their target. - introduces a GH action to lint the Rust code (rustfmt, clippy), build the library, validate through tests, and validate crate can package correctly. TODOs: - [x] This PR is based on #18200 and will need to be rebased once that PR is merged. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This is the first step to getting new onnxruntime Rust crates published through this project, which will unblock community Rust projects which would like to take a dependency on onnxruntime Rust. Follow up work to enable publication of onnxruntime Rust crates: - change name of the crates to be published (onnxruntime-rs and onnxruntime-sys are already taken and we'll need new names) - update authors / license to reflect contributions from previous maintainer(s) and new maintainers - introduce a crate publish GH action or ADO pipeline --------- Signed-off-by: David Justice <david@devigned.com>
This commit is contained in:
Родитель
8d50313816
Коммит
2c22b49876
|
@ -0,0 +1,44 @@
|
||||||
|
# yaml-language-server: $schema=https://json.schemastore.org/github-action.json
|
||||||
|
|
||||||
|
name: 'Rust toolchain setup'
|
||||||
|
description: 'Common setup steps for GitHub workflows for Rust projects'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: composite
|
||||||
|
steps:
|
||||||
|
- uses: dtolnay/rust-toolchain@1.71.0
|
||||||
|
with:
|
||||||
|
components: clippy, rustfmt
|
||||||
|
- uses: extractions/setup-just@v1
|
||||||
|
with:
|
||||||
|
just-version: '1.15.0' # optional semver specification, otherwise latest
|
||||||
|
|
||||||
|
###
|
||||||
|
### Linux setup
|
||||||
|
###
|
||||||
|
- name: rustup
|
||||||
|
# We need to use the nightly rust tool change to enable registry-auth / to connect to ADO feeds.
|
||||||
|
if: ${{ (runner.os == 'Linux') }}
|
||||||
|
run: |
|
||||||
|
rustup set profile minimal
|
||||||
|
rustup install
|
||||||
|
shell: bash
|
||||||
|
# - name: Cargo login
|
||||||
|
# if: ${{ (runner.os == 'Linux') }}
|
||||||
|
# run: just cargo-login-ci
|
||||||
|
# shell: bash
|
||||||
|
|
||||||
|
###
|
||||||
|
### Windows setup
|
||||||
|
###
|
||||||
|
- name: rustup
|
||||||
|
# We need to use the nightly rust tool change to enable registry-auth / to connect to ADO feeds.
|
||||||
|
if: ${{ (runner.os == 'Windows') }}
|
||||||
|
run: |
|
||||||
|
rustup set profile minimal
|
||||||
|
rustup install
|
||||||
|
shell: pwsh
|
||||||
|
# - name: Cargo login
|
||||||
|
# if: ${{ (runner.os == 'Windows') }}
|
||||||
|
# run: just cargo-login-ci-windows
|
||||||
|
# shell: pwsh
|
|
@ -0,0 +1,132 @@
|
||||||
|
name: Rust
|
||||||
|
|
||||||
|
on: [pull_request]
|
||||||
|
|
||||||
|
env:
|
||||||
|
CARGO_TERM_COLOR: always
|
||||||
|
RUST_LOG: onnxruntime=debug,onnxruntime-sys=debug
|
||||||
|
RUST_BACKTRACE: 1
|
||||||
|
MANIFEST_PATH: ${{ github.workspace }}/rust/Cargo.toml
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
fmt:
|
||||||
|
name: Rustfmt
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: ./.github/actions/rust-toolchain-setup
|
||||||
|
- name: vendor onnxruntime source
|
||||||
|
run: just vendor
|
||||||
|
- name: fmt
|
||||||
|
run: cargo fmt --all -- --check
|
||||||
|
|
||||||
|
download:
|
||||||
|
name: Download prebuilt ONNX Runtime archive from build.rs
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
ORT_RUST_STRATEGY=download
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: ./.github/actions/rust-toolchain-setup
|
||||||
|
- run: rustup target install x86_64-unknown-linux-gnu
|
||||||
|
- run: rustup target install x86_64-apple-darwin
|
||||||
|
- run: rustup target install i686-pc-windows-msvc
|
||||||
|
- run: rustup target install x86_64-pc-windows-msvc
|
||||||
|
# ******************************************************************
|
||||||
|
- name: Download prebuilt archive (CPU, x86_64-unknown-linux-gnu)
|
||||||
|
run: cargo build --target x86_64-unknown-linux-gnu --manifest-path ${{ env.MANIFEST_PATH }}
|
||||||
|
- name: Verify prebuilt archive downloaded (CPU, x86_64-unknown-linux-gnu)
|
||||||
|
run: ls -lh target/x86_64-unknown-linux-gnu/debug/build/onnxruntime-sys-*/out/onnxruntime-linux-x64-1.*.tgz
|
||||||
|
# ******************************************************************
|
||||||
|
- name: Download prebuilt archive (CPU, x86_64-apple-darwin)
|
||||||
|
run: cargo build --target x86_64-apple-darwin --manifest-path ${{ env.MANIFEST_PATH }}
|
||||||
|
- name: Verify prebuilt archive downloaded (CPU, x86_64-apple-darwin)
|
||||||
|
run: ls -lh target/x86_64-apple-darwin/debug/build/onnxruntime-sys-*/out/onnxruntime-osx-x64-1.*.tgz
|
||||||
|
# ******************************************************************
|
||||||
|
- name: Download prebuilt archive (CPU, i686-pc-windows-msvc)
|
||||||
|
run: cargo build --target i686-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }}
|
||||||
|
- name: Verify prebuilt archive downloaded (CPU, i686-pc-windows-msvc)
|
||||||
|
run: ls -lh target/i686-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-x86-1.*.zip
|
||||||
|
# ******************************************************************
|
||||||
|
- name: Download prebuilt archive (CPU, x86_64-pc-windows-msvc)
|
||||||
|
run: cargo build --target x86_64-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }}
|
||||||
|
- name: Verify prebuilt archive downloaded (CPU, x86_64-pc-windows-msvc)
|
||||||
|
run: ls -lh target/x86_64-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-x64-1.*.zip
|
||||||
|
# ******************************************************************
|
||||||
|
- name: Download prebuilt archive (GPU, x86_64-unknown-linux-gnu)
|
||||||
|
env:
|
||||||
|
ORT_USE_CUDA: "yes"
|
||||||
|
run: cargo build --target x86_64-unknown-linux-gnu --manifest-path ${{ env.MANIFEST_PATH }}
|
||||||
|
- name: Verify prebuilt archive downloaded (GPU, x86_64-unknown-linux-gnu)
|
||||||
|
run: ls -lh target/x86_64-unknown-linux-gnu/debug/build/onnxruntime-sys-*/out/onnxruntime-linux-x64-gpu-1.*.tgz
|
||||||
|
# ******************************************************************
|
||||||
|
- name: Download prebuilt archive (GPU, x86_64-pc-windows-msvc)
|
||||||
|
env:
|
||||||
|
ORT_USE_CUDA: "yes"
|
||||||
|
run: cargo build --target x86_64-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }}
|
||||||
|
- name: Verify prebuilt archive downloaded (GPU, x86_64-pc-windows-msvc)
|
||||||
|
run: ls -lh target/x86_64-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-gpu-x64-1.*.zip
|
||||||
|
|
||||||
|
test:
|
||||||
|
name: Test Suite
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
target:
|
||||||
|
[
|
||||||
|
x86_64-unknown-linux-gnu,
|
||||||
|
x86_64-apple-darwin,
|
||||||
|
x86_64-pc-windows-msvc,
|
||||||
|
i686-pc-windows-msvc,
|
||||||
|
]
|
||||||
|
include:
|
||||||
|
- target: x86_64-unknown-linux-gnu
|
||||||
|
os: ubuntu-latest
|
||||||
|
- target: x86_64-apple-darwin
|
||||||
|
os: macos-latest
|
||||||
|
- target: x86_64-pc-windows-msvc
|
||||||
|
os: windows-latest
|
||||||
|
- target: i686-pc-windows-msvc
|
||||||
|
os: windows-latest
|
||||||
|
env:
|
||||||
|
CARGO_BUILD_TARGET: ${{ matrix.target }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: ./.github/actions/rust-toolchain-setup
|
||||||
|
- name: vendor onnxruntime source
|
||||||
|
run: just vendor
|
||||||
|
- run: rustup target install ${{ matrix.target }}
|
||||||
|
- name: Install additional packages (macOS)
|
||||||
|
if: contains(matrix.target, 'x86_64-apple-darwin')
|
||||||
|
run: brew install libomp
|
||||||
|
- name: Build (cargo build)
|
||||||
|
run: cargo build --all --manifest-path ${{ env.MANIFEST_PATH }}
|
||||||
|
- name: Build tests (cargo test)
|
||||||
|
run: cargo test --no-run --manifest-path ${{ env.MANIFEST_PATH }}
|
||||||
|
- name: Build onnxruntime with 'model-fetching' feature
|
||||||
|
run: cargo build --manifest-path ${{ env.MANIFEST_PATH }} --features model-fetching
|
||||||
|
- name: Test onnxruntime-sys
|
||||||
|
run: cargo build --package onnxruntime-sys -- --test-threads=1 --nocapture
|
||||||
|
- name: Test onnxruntime
|
||||||
|
run: cargo test --manifest-path ${{ env.MANIFEST_PATH }} --features model-fetching -- --test-threads=1 --nocapture
|
||||||
|
|
||||||
|
clippy:
|
||||||
|
name: Clippy
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: ./.github/actions/rust-toolchain-setup
|
||||||
|
- name: vendor onnxruntime source
|
||||||
|
run: just vendor
|
||||||
|
- run: clippy --all-features --manifest-path ${{ env.MANIFEST_PATH }} -- -D warnings
|
||||||
|
|
||||||
|
package-sys:
|
||||||
|
name: Package onnxruntime-sys
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: ./.github/actions/rust-toolchain-setup
|
||||||
|
- name: vendor onnxruntime source
|
||||||
|
run: just vendor
|
||||||
|
- run: cargo package --allow-dirty --package onnxruntime-sys
|
|
@ -0,0 +1,13 @@
|
||||||
|
|
||||||
|
|
||||||
|
vendor:
|
||||||
|
mkdir -p ./onnxruntime-sys/vendor/onnxruntime-src
|
||||||
|
cp -rf ../onnxruntime ./onnxruntime-sys/vendor/onnxruntime-src
|
||||||
|
cp -rf ../cmake ./onnxruntime-sys/vendor/onnxruntime-src
|
||||||
|
rm -rf ./onnxruntime-sys/vendor/onnxruntime-src/cmake/external/onnx
|
||||||
|
cp -rf ../include ./onnxruntime-sys/vendor/onnxruntime-src
|
||||||
|
mkdir -p ./onnxruntime-sys/vendor/onnxruntime-src/tools
|
||||||
|
cp -rf ../tools/ci_build ./onnxruntime-sys/vendor/onnxruntime-src/tools
|
||||||
|
cp -rf ../samples ./onnxruntime-sys/vendor/onnxruntime-src
|
||||||
|
cp -f ../requirements.txt.in ./onnxruntime-sys/vendor/onnxruntime-src
|
||||||
|
cp -f ../VERSION_NUMBER ./onnxruntime-sys/vendor/onnxruntime-src
|
|
@ -0,0 +1 @@
|
||||||
|
vendor
|
|
@ -3,18 +3,16 @@ authors = ["Nicolas Bigaouette <nbigaouette@elementai.com>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
name = "onnxruntime-sys"
|
name = "onnxruntime-sys"
|
||||||
version = "0.0.14"
|
version = "0.0.14"
|
||||||
|
|
||||||
links = "onnxruntime"
|
links = "onnxruntime"
|
||||||
|
|
||||||
description = "Unsafe wrapper around Microsoft's ONNX Runtime"
|
description = "Unsafe wrapper around Microsoft's ONNX Runtime"
|
||||||
documentation = "https://docs.rs/onnxruntime-sys"
|
documentation = "https://docs.rs/onnxruntime-sys"
|
||||||
homepage = "https://github.com/microsoft/onnxruntime"
|
homepage = "https://github.com/microsoft/onnxruntime"
|
||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
readme = "../README.md"
|
readme = "../README.md"
|
||||||
repository = "https://github.com/microsoft/onnxruntime"
|
repository = "https://github.com/microsoft/onnxruntime"
|
||||||
|
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
keywords = ["neuralnetworks", "onnx", "bindings"]
|
keywords = ["neuralnetworks", "onnx", "bindings"]
|
||||||
|
include = ["src", "example", "vendor", "build.rs"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
libloading = "0.7"
|
libloading = "0.7"
|
||||||
|
@ -22,6 +20,7 @@ libloading = "0.7"
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
bindgen = "0.63"
|
bindgen = "0.63"
|
||||||
cmake = "0.1"
|
cmake = "0.1"
|
||||||
|
anyhow = "1.0"
|
||||||
|
|
||||||
# Used on unix
|
# Used on unix
|
||||||
flate2 = "1.0"
|
flate2 = "1.0"
|
||||||
|
|
|
@ -8,12 +8,16 @@ use std::{
|
||||||
str::FromStr,
|
str::FromStr,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// use cmake::build;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Context, Result};
|
||||||
|
|
||||||
/// ONNX Runtime version
|
/// ONNX Runtime version
|
||||||
///
|
///
|
||||||
/// WARNING: If version is changed, bindings for all platforms will have to be re-generated.
|
/// WARNING: If version is changed, bindings for all platforms will have to be re-generated.
|
||||||
/// To do so, run this:
|
/// To do so, run this:
|
||||||
/// cargo build --package onnxruntime-sys --features generate-bindings
|
/// cargo build --package onnxruntime-sys --features generate-bindings
|
||||||
const ORT_VERSION: &str = include_str!("../../VERSION_NUMBER");
|
const ORT_VERSION: &str = include_str!("./vendor/onnxruntime-src/VERSION_NUMBER");
|
||||||
|
|
||||||
/// Base Url from which to download pre-built releases/
|
/// Base Url from which to download pre-built releases/
|
||||||
const ORT_RELEASE_BASE_URL: &str = "https://github.com/microsoft/onnxruntime/releases/download";
|
const ORT_RELEASE_BASE_URL: &str = "https://github.com/microsoft/onnxruntime/releases/download";
|
||||||
|
@ -34,8 +38,8 @@ const ORT_RUST_ENV_GPU: &str = "ORT_RUST_USE_CUDA";
|
||||||
/// Subdirectory (of the 'target' directory) into which to extract the prebuilt library.
|
/// Subdirectory (of the 'target' directory) into which to extract the prebuilt library.
|
||||||
const ORT_PREBUILT_EXTRACT_DIR: &str = "onnxruntime";
|
const ORT_PREBUILT_EXTRACT_DIR: &str = "onnxruntime";
|
||||||
|
|
||||||
fn main() {
|
fn main() -> Result<()> {
|
||||||
let libort_install_dir = prepare_libort_dir();
|
let libort_install_dir = prepare_libort_dir().context("preparing libort directory")?;
|
||||||
|
|
||||||
let include_dir = libort_install_dir.join("include");
|
let include_dir = libort_install_dir.join("include");
|
||||||
let lib_dir = libort_install_dir.join("lib");
|
let lib_dir = libort_install_dir.join("lib");
|
||||||
|
@ -55,6 +59,7 @@ fn main() {
|
||||||
);
|
);
|
||||||
|
|
||||||
generate_bindings(&include_dir);
|
generate_bindings(&include_dir);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_bindings(include_dir: &Path) {
|
fn generate_bindings(include_dir: &Path) {
|
||||||
|
@ -70,11 +75,7 @@ fn generate_bindings(include_dir: &Path) {
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let path = include_dir
|
let path = include_dir.join("onnxruntime").join("onnxruntime_c_api.h");
|
||||||
.join("onnxruntime")
|
|
||||||
.join("core")
|
|
||||||
.join("session")
|
|
||||||
.join("onnxruntime_c_api.h");
|
|
||||||
|
|
||||||
// The bindgen::Builder is the main entry point
|
// The bindgen::Builder is the main entry point
|
||||||
// to bindgen, and lets you build up options for
|
// to bindgen, and lets you build up options for
|
||||||
|
@ -106,7 +107,7 @@ fn generate_bindings(include_dir: &Path) {
|
||||||
|
|
||||||
let generated_file = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs");
|
let generated_file = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs");
|
||||||
bindings
|
bindings
|
||||||
.write_to_file(&generated_file)
|
.write_to_file(generated_file)
|
||||||
.expect("Couldn't write bindings!");
|
.expect("Couldn't write bindings!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,7 +145,7 @@ fn extract_archive(filename: &Path, output: &Path) {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_tgz(filename: &Path, output: &Path) {
|
fn extract_tgz(filename: &Path, output: &Path) {
|
||||||
let file = fs::File::open(&filename).unwrap();
|
let file = fs::File::open(filename).unwrap();
|
||||||
let buf = io::BufReader::new(file);
|
let buf = io::BufReader::new(file);
|
||||||
let tar = flate2::read::GzDecoder::new(buf);
|
let tar = flate2::read::GzDecoder::new(buf);
|
||||||
let mut archive = tar::Archive::new(tar);
|
let mut archive = tar::Archive::new(tar);
|
||||||
|
@ -152,7 +153,7 @@ fn extract_tgz(filename: &Path, output: &Path) {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_zip(filename: &Path, outpath: &Path) {
|
fn extract_zip(filename: &Path, outpath: &Path) {
|
||||||
let file = fs::File::open(&filename).unwrap();
|
let file = fs::File::open(filename).unwrap();
|
||||||
let buf = io::BufReader::new(file);
|
let buf = io::BufReader::new(file);
|
||||||
let mut archive = zip::ZipArchive::new(buf).unwrap();
|
let mut archive = zip::ZipArchive::new(buf).unwrap();
|
||||||
for i in 0..archive.len() {
|
for i in 0..archive.len() {
|
||||||
|
@ -168,7 +169,7 @@ fn extract_zip(filename: &Path, outpath: &Path) {
|
||||||
);
|
);
|
||||||
if let Some(p) = outpath.parent() {
|
if let Some(p) = outpath.parent() {
|
||||||
if !p.exists() {
|
if !p.exists() {
|
||||||
fs::create_dir_all(&p).unwrap();
|
fs::create_dir_all(p).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let mut outfile = fs::File::create(&outpath).unwrap();
|
let mut outfile = fs::File::create(&outpath).unwrap();
|
||||||
|
@ -190,15 +191,15 @@ enum Architecture {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FromStr for Architecture {
|
impl FromStr for Architecture {
|
||||||
type Err = String;
|
type Err = anyhow::Error;
|
||||||
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
fn from_str(s: &str) -> Result<Self> {
|
||||||
match s.to_lowercase().as_str() {
|
match s.to_lowercase().as_str() {
|
||||||
"x86" => Ok(Architecture::X86),
|
"x86" => Ok(Architecture::X86),
|
||||||
"x86_64" => Ok(Architecture::X86_64),
|
"x86_64" => Ok(Architecture::X86_64),
|
||||||
"arm" => Ok(Architecture::Arm),
|
"arm" => Ok(Architecture::Arm),
|
||||||
"aarch64" => Ok(Architecture::Arm64),
|
"aarch64" => Ok(Architecture::Arm64),
|
||||||
_ => Err(format!("Unsupported architecture: {}", s)),
|
_ => Err(anyhow!("Unsupported architecture: {s}")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -233,14 +234,14 @@ impl Os {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FromStr for Os {
|
impl FromStr for Os {
|
||||||
type Err = String;
|
type Err = anyhow::Error;
|
||||||
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
fn from_str(s: &str) -> Result<Self> {
|
||||||
match s.to_lowercase().as_str() {
|
match s.to_lowercase().as_str() {
|
||||||
"windows" => Ok(Os::Windows),
|
"windows" => Ok(Os::Windows),
|
||||||
"macos" => Ok(Os::MacOs),
|
"macos" => Ok(Os::MacOs),
|
||||||
"linux" => Ok(Os::Linux),
|
"linux" => Ok(Os::Linux),
|
||||||
_ => Err(format!("Unsupported os: {}", s)),
|
_ => Err(anyhow!("Unsupported os: {s}")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -262,9 +263,9 @@ enum Accelerator {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FromStr for Accelerator {
|
impl FromStr for Accelerator {
|
||||||
type Err = String;
|
type Err = anyhow::Error;
|
||||||
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
fn from_str(s: &str) -> Result<Self> {
|
||||||
match s.to_lowercase().as_str() {
|
match s.to_lowercase().as_str() {
|
||||||
"1" | "yes" | "true" | "on" => Ok(Accelerator::Cuda),
|
"1" | "yes" | "true" | "on" => Ok(Accelerator::Cuda),
|
||||||
_ => Ok(Accelerator::Cpu),
|
_ => Ok(Accelerator::Cpu),
|
||||||
|
@ -393,36 +394,37 @@ fn prepare_libort_dir_prebuilt() -> PathBuf {
|
||||||
extract_dir.join(prebuilt_archive.file_stem().unwrap())
|
extract_dir.join(prebuilt_archive.file_stem().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepare_libort_dir() -> PathBuf {
|
fn prepare_libort_dir() -> Result<PathBuf> {
|
||||||
let strategy = env::var(ORT_RUST_ENV_STRATEGY);
|
let strategy = env::var(ORT_RUST_ENV_STRATEGY);
|
||||||
println!(
|
println!(
|
||||||
"strategy: {:?}",
|
"strategy: {:?}",
|
||||||
strategy.as_ref().map_or_else(|_| "unknown", String::as_str)
|
strategy.as_ref().map_or_else(|_| "unknown", String::as_str)
|
||||||
);
|
);
|
||||||
match strategy.as_ref().map(String::as_str) {
|
match strategy.as_ref().map(String::as_str) {
|
||||||
Ok("download") => prepare_libort_dir_prebuilt(),
|
Ok("download") => Ok(prepare_libort_dir_prebuilt()),
|
||||||
Ok("system") => PathBuf::from(match env::var(ORT_RUST_ENV_SYSTEM_LIB_LOCATION) {
|
Ok("system") => {
|
||||||
Ok(p) => p,
|
let location = env::var(ORT_RUST_ENV_SYSTEM_LIB_LOCATION).context(format!(
|
||||||
Err(e) => {
|
"Could not get value of environment variable {:?}",
|
||||||
panic!(
|
ORT_RUST_ENV_SYSTEM_LIB_LOCATION
|
||||||
"Could not get value of environment variable {:?}: {:?}",
|
))?;
|
||||||
ORT_RUST_ENV_SYSTEM_LIB_LOCATION, e
|
Ok(PathBuf::from(location))
|
||||||
);
|
}
|
||||||
}
|
|
||||||
}),
|
|
||||||
Ok("compile") | Err(_) => prepare_libort_dir_compiled(),
|
Ok("compile") | Err(_) => prepare_libort_dir_compiled(),
|
||||||
_ => panic!("Unknown value for {:?}", ORT_RUST_ENV_STRATEGY),
|
_ => Err(anyhow!("Unknown value for {:?}", ORT_RUST_ENV_STRATEGY)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepare_libort_dir_compiled() -> PathBuf {
|
fn prepare_libort_dir_compiled() -> Result<PathBuf> {
|
||||||
let mut config = cmake::Config::new("../../cmake");
|
let manifest_dir_string = env::var("CARGO_MANIFEST_DIR").unwrap();
|
||||||
|
let mut config = cmake::Config::new(format!(
|
||||||
|
"{manifest_dir_string}/vendor/onnxruntime-src/cmake"
|
||||||
|
));
|
||||||
|
|
||||||
config.define("onnxruntime_BUILD_SHARED_LIB", "ON");
|
config.define("onnxruntime_BUILD_SHARED_LIB", "ON");
|
||||||
|
|
||||||
if env::var(ORT_RUST_ENV_GPU).unwrap_or_default().parse() == Ok(Accelerator::Cuda) {
|
if let Ok(Accelerator::Cuda) = env::var(ORT_RUST_ENV_GPU).unwrap_or_default().parse() {
|
||||||
config.define("onnxruntime_USE_CUDA", "ON");
|
config.define("onnxruntime_USE_CUDA", "ON");
|
||||||
}
|
};
|
||||||
|
|
||||||
config.build()
|
Ok(config.build())
|
||||||
}
|
}
|
||||||
|
|
|
@ -307,7 +307,7 @@ fn main() {
|
||||||
|
|
||||||
let output_node_names_cstring: Vec<std::ffi::CString> = output_node_names
|
let output_node_names_cstring: Vec<std::ffi::CString> = output_node_names
|
||||||
.iter()
|
.iter()
|
||||||
.map(|n| std::ffi::CString::new(n.clone()).unwrap())
|
.map(|n| std::ffi::CString::new(*n).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
let output_node_names_ptr: Vec<*const i8> = output_node_names_cstring
|
let output_node_names_ptr: Vec<*const i8> = output_node_names_cstring
|
||||||
.iter()
|
.iter()
|
||||||
|
|
|
@ -290,9 +290,6 @@ impl<'a> TryFrom<OrtOutputTensor> for OrtOutput<'a> {
|
||||||
.unwrap()(shape_info);
|
.unwrap()(shape_info);
|
||||||
|
|
||||||
match element_type {
|
match element_type {
|
||||||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED => {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => {
|
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => {
|
||||||
WithOutputTensor::try_from(value).map(OrtOutput::Float)
|
WithOutputTensor::try_from(value).map(OrtOutput::Float)
|
||||||
}
|
}
|
||||||
|
@ -317,12 +314,6 @@ impl<'a> TryFrom<OrtOutputTensor> for OrtOutput<'a> {
|
||||||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => {
|
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => {
|
||||||
WithOutputTensor::try_from(value).map(OrtOutput::String)
|
WithOutputTensor::try_from(value).map(OrtOutput::String)
|
||||||
}
|
}
|
||||||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => {
|
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => {
|
||||||
WithOutputTensor::try_from(value).map(OrtOutput::Double)
|
WithOutputTensor::try_from(value).map(OrtOutput::Double)
|
||||||
}
|
}
|
||||||
|
@ -332,14 +323,18 @@ impl<'a> TryFrom<OrtOutputTensor> for OrtOutput<'a> {
|
||||||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => {
|
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => {
|
||||||
WithOutputTensor::try_from(value).map(OrtOutput::UInt64)
|
WithOutputTensor::try_from(value).map(OrtOutput::UInt64)
|
||||||
}
|
}
|
||||||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 => {
|
// Unimplemented output tensor data types
|
||||||
unimplemented!()
|
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
|
||||||
}
|
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
|
||||||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 => {
|
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
|
||||||
unimplemented!()
|
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
|
||||||
}
|
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
|
||||||
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => {
|
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
|
||||||
unimplemented!()
|
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN
|
||||||
|
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ
|
||||||
|
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ
|
||||||
|
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 => {
|
||||||
|
unimplemented!("{:?}", element_type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Загрузка…
Ссылка в новой задаче