onnxruntime-tvm/jvm
eqy fde98f0ede [RPC] Android RPC Performance Regression Fix, Update Android RPC to use Tracker (#1457) 2018-07-19 20:45:50 -07:00
..
assembly [BUILD] Switch to CMake only Infra (#1254) 2018-06-10 22:00:33 -07:00
conf [WIP][Frontend] Scala/Java package (#176) 2017-07-14 23:04:39 -07:00
core [RPC] Android RPC Performance Regression Fix, Update Android RPC to use Tracker (#1457) 2018-07-19 20:45:50 -07:00
native JNI Crash fix (#1357) 2018-06-30 07:56:26 -07:00
README.md [DOCS] Detailed contributor guide, doc refactor (#1220) 2018-06-01 16:38:58 -07:00
pom.xml [WIP][Frontend] Scala/Java package (#176) 2017-07-14 23:04:39 -07:00

README.md

TVM4J - Java Frontend for TVM Runtime

This folder contains the Java interface for TVM runtime. It brings TVM runtime to Java virtual machine.

  • It enables you to construct NDArray from Java native array and vice versa.
  • You can register and convert Java native functions to TVM functions.
  • It enables you to load shared libraries created by Python and C++.
  • It provides a simple interface for RPC server and client.

Installation

Requirements

  • JDK 1.6+. Oracle JDK and OpenJDK are well tested.
  • Maven 3 for build.

Modules

TVM4J contains three modules:

  • core
    • It contains all the Java interfaces.
  • native
    • The JNI native library is compiled in this module. It does not link TVM runtime library (libtvm_runtime.so for Linux and libtvm_runtime.dylib for OSX). Instead, you have to specify libtvm.so.path which contains the TVM runtime library as Java system property.
  • assembly
    • It assembles Java interfaces (core), JNI library (native) and TVM runtime library together. The simplest way to integrate tvm4j in your project is to rely on this module. It automatically extracts the native library to a tempfile and load it.

Build

First please refer to Installation Guide and build runtime shared library from the C++ codes (libtvm_runtime.so for Linux and libtvm_runtime.dylib for OSX).

Then you can compile tvm4j by

make jvmpkg

(Optional) run unit test by

make jvmpkg JVM_TEST_ARGS="-DskipTests=false"

After it is compiled and packaged, you can install tvm4j in your local maven repository,

make jvminstall

Convert and Register Java Function as TVM Function

It is easy to define a Java function and call it from TVM. The following snippet demonstrate how to concatenate Java strings.

Function func = Function.convertFunc(new Function.Callback() {
      @Override public Object invoke(TVMValue... args) {
        StringBuilder res = new StringBuilder();
        for (TVMValue arg : args) {
          res.append(arg.asString());
        }
        return res.toString();
      }
    });
TVMValue res = func.pushArg("Hello").pushArg(" ").pushArg("World!").invoke();
assertEquals("Hello World!", res.asString());
res.release();
func.release();

It is your job to verify the types of callback arguments, as well as the type of returned result.

You can register the Java function by Function.register and use Function.getFunction to get the registered function later.

Use TVM to Generate Shared Library

There's nothing special for this part. The following Python snippet generate add_cpu.so which add two vectors on CPU.

import os
import tvm
from tvm.contrib import cc, util

def test_add(target_dir):
    n = tvm.var("n")
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((n,), name='B')
    C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
    s = tvm.create_schedule(C.op)
    fadd = tvm.build(s, [A, B, C], "llvm", target_host="llvm", name="myadd")

    fadd.save(os.path.join(target_dir, "add_cpu.o"))
    cc.create_shared(os.path.join(target_dir, "add_cpu.so"),
            [os.path.join(target_dir, "add_cpu.o")])

if __name__ == "__main__":
    import sys
    if len(sys.argv) != 2:
        sys.exit(-1)
    test_add(sys.argv[1])

Run the Generated Shared Library

The following code snippet demonstrate how to load generated shared library (add_cpu.so).

import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.NDArray;
import ml.dmlc.tvm.TVMContext;

import java.io.File;
import java.util.Arrays;

public class LoadAddFunc {
  public static void main(String[] args) {
    String loadingDir = args[0];
    Module fadd = Module.load(loadingDir + File.separator + "add_cpu.so");

    TVMContext ctx = TVMContext.cpu();

    long[] shape = new long[]{2};
    NDArray arr = NDArray.empty(shape, ctx);
    arr.copyFrom(new float[]{3f, 4f});
    NDArray res = NDArray.empty(shape, ctx);

    fadd.entryFunc().pushArg(arr).pushArg(arr).pushArg(res).invoke();
    System.out.println(Arrays.toString(res.asFloatArray()));

    arr.release();
    res.release();
    fadd.release();
  }
}

RPC Server

There are two ways to start an RPC server on JVM. A standalone server can be started by

Server server = new Server(port);
server.start();

This will open a socket and wait for remote requests. You can use Java, Python, or any other frontend to make an RPC call. Here's an example for calling remote function test.rpc.strcat in Java.

RPCSession client = Client.connect("localhost", port.value);
Function func = client.getFunction("test.rpc.strcat");
String result = func.call("abc", 11L).asString();

Another way is to start a proxy, make server and client communicate with each other through the proxy. The following snippet shows how to start a server which connects to a proxy.

Server server = new Server(proxyHost, proxyPort, "key");
server.start();

You can also use StandaloneServerProcessor and ConnectProxyServerProcessor to build your own RPC server. Refer to Android RPC Server for more details.