[tvm4j] support kNDArrayContainer (#1510)

This commit is contained in:
Yizhi Liu 2018-07-30 12:58:30 -07:00 коммит произвёл Tianqi Chen
Родитель f4b2c29300
Коммит feabd40690
3 изменённых файлов: 12 добавлений и 7 удалений

Просмотреть файл

@ -187,7 +187,8 @@ public class Function extends TVMValue {
* @return this
*/
public Function pushArg(NDArrayBase arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.ARRAY_HANDLE.id);
int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(arg.handle, id);
return this;
}
@ -247,7 +248,9 @@ public class Function extends TVMValue {
} else if (arg instanceof byte[]) {
Base._LIB.tvmFuncPushArgBytes((byte[]) arg);
} else if (arg instanceof NDArrayBase) {
Base._LIB.tvmFuncPushArgHandle(((NDArrayBase) arg).handle, TypeCode.ARRAY_HANDLE.id);
NDArrayBase nd = (NDArrayBase) arg;
int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(nd.handle, id);
} else if (arg instanceof Module) {
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id);
} else if (arg instanceof Function) {

Просмотреть файл

@ -21,7 +21,7 @@ package ml.dmlc.tvm;
public enum TypeCode {
INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5),
TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9),
FUNC_HANDLE(10), STR(11), BYTES(12);
FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13);
public final int id;

Просмотреть файл

@ -134,10 +134,10 @@ jobject newFunction(JNIEnv *env, jlong value) {
return object;
}
jobject newNDArray(JNIEnv *env, jlong value) {
jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) {
jclass cls = env->FindClass("ml/dmlc/tvm/NDArrayBase");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
jmethodID constructor = env->GetMethodID(cls, "<init>", "(JZ)V");
jobject object = env->NewObject(cls, constructor, handle, isview);
env->DeleteLocalRef(cls);
return object;
}
@ -181,7 +181,9 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
case kFuncHandle:
return newFunction(env, reinterpret_cast<jlong>(value.v_handle));
case kArrayHandle:
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle));
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), true);
case kNDArrayContainer:
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), false);
case kStr:
return newTVMValueString(env, value.v_str);
case kBytes: