[tvm4j] support kNDArrayContainer (#1510)
This commit is contained in:
Родитель
f4b2c29300
Коммит
feabd40690
|
@ -187,7 +187,8 @@ public class Function extends TVMValue {
|
|||
* @return this
|
||||
*/
|
||||
public Function pushArg(NDArrayBase arg) {
|
||||
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.ARRAY_HANDLE.id);
|
||||
int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
|
||||
Base._LIB.tvmFuncPushArgHandle(arg.handle, id);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -247,7 +248,9 @@ public class Function extends TVMValue {
|
|||
} else if (arg instanceof byte[]) {
|
||||
Base._LIB.tvmFuncPushArgBytes((byte[]) arg);
|
||||
} else if (arg instanceof NDArrayBase) {
|
||||
Base._LIB.tvmFuncPushArgHandle(((NDArrayBase) arg).handle, TypeCode.ARRAY_HANDLE.id);
|
||||
NDArrayBase nd = (NDArrayBase) arg;
|
||||
int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
|
||||
Base._LIB.tvmFuncPushArgHandle(nd.handle, id);
|
||||
} else if (arg instanceof Module) {
|
||||
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id);
|
||||
} else if (arg instanceof Function) {
|
||||
|
|
|
@ -21,7 +21,7 @@ package ml.dmlc.tvm;
|
|||
public enum TypeCode {
|
||||
INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5),
|
||||
TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9),
|
||||
FUNC_HANDLE(10), STR(11), BYTES(12);
|
||||
FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13);
|
||||
|
||||
public final int id;
|
||||
|
||||
|
|
|
@ -134,10 +134,10 @@ jobject newFunction(JNIEnv *env, jlong value) {
|
|||
return object;
|
||||
}
|
||||
|
||||
jobject newNDArray(JNIEnv *env, jlong value) {
|
||||
jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) {
|
||||
jclass cls = env->FindClass("ml/dmlc/tvm/NDArrayBase");
|
||||
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
|
||||
jobject object = env->NewObject(cls, constructor, value);
|
||||
jmethodID constructor = env->GetMethodID(cls, "<init>", "(JZ)V");
|
||||
jobject object = env->NewObject(cls, constructor, handle, isview);
|
||||
env->DeleteLocalRef(cls);
|
||||
return object;
|
||||
}
|
||||
|
@ -181,7 +181,9 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
|
|||
case kFuncHandle:
|
||||
return newFunction(env, reinterpret_cast<jlong>(value.v_handle));
|
||||
case kArrayHandle:
|
||||
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle));
|
||||
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), true);
|
||||
case kNDArrayContainer:
|
||||
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), false);
|
||||
case kStr:
|
||||
return newTVMValueString(env, value.v_str);
|
||||
case kBytes:
|
||||
|
|
Загрузка…
Ссылка в новой задаче