[NNPACK] Add argument nthreads (#631)
This commit is contained in:
Родитель
3548530772
Коммит
182a7852de
|
@ -16,7 +16,7 @@ def config(nthreads):
|
|||
"""
|
||||
_Config(nthreads)
|
||||
|
||||
def fully_connected_inference(lhs, rhs):
|
||||
def fully_connected_inference(lhs, rhs, nthreads=1):
|
||||
"""Create an extern op that compute fully connected of 1D tensor lhs and
|
||||
2D tensor rhs with nnpack.
|
||||
|
||||
|
@ -37,9 +37,9 @@ def fully_connected_inference(lhs, rhs):
|
|||
(m, ), [lhs, rhs],
|
||||
lambda ins, outs: _intrin.call_packed(
|
||||
"tvm.contrib.nnpack.fully_connected_inference",
|
||||
ins[0], ins[1], outs[0]), name="C")
|
||||
ins[0], ins[1], outs[0], nthreads), name="C")
|
||||
|
||||
def fully_connected_output(lhs, rhs):
|
||||
def fully_connected_output(lhs, rhs, nthreads=1):
|
||||
"""Create an extern op that compute fully connected of 2D tensor lhs and
|
||||
2D tensor rhs with nnpack.
|
||||
|
||||
|
@ -61,9 +61,9 @@ def fully_connected_output(lhs, rhs):
|
|||
(n, m), [lhs, rhs],
|
||||
lambda ins, outs: _intrin.call_packed(
|
||||
"tvm.contrib.nnpack.fully_connected_output",
|
||||
ins[0], ins[1], outs[0]), name="C")
|
||||
ins[0], ins[1], outs[0], nthreads), name="C")
|
||||
|
||||
def convolution_inference(data, kernel, bias, padding, stride):
|
||||
def convolution_inference(data, kernel, bias, padding, stride, nthreads=1):
|
||||
"""Create an extern op to do inference convolution of 3D tensor data and
|
||||
4D tensor kernel and 1D tensor bias with nnpack.
|
||||
|
||||
|
@ -104,9 +104,9 @@ def convolution_inference(data, kernel, bias, padding, stride):
|
|||
lambda ins, outs: _intrin.call_packed(
|
||||
"tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2],
|
||||
outs[0], padding[0], padding[1], padding[2], padding[3],
|
||||
stride[0], stride[1]), name="C")
|
||||
stride[0], stride[1], nthreads), name="C")
|
||||
|
||||
def convolution_output(data, kernel, bias, padding):
|
||||
def convolution_output(data, kernel, bias, padding, nthreads=1):
|
||||
"""Create an extern op to compute convolution of 4D tensor data and
|
||||
4D tensor kernel and 1D tensor bias with nnpack.
|
||||
|
||||
|
@ -142,6 +142,6 @@ def convolution_output(data, kernel, bias, padding):
|
|||
(batch, output_channels, output_height, output_width), [data, kernel, bias],
|
||||
lambda ins, outs: _intrin.call_packed(
|
||||
"tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2],
|
||||
outs[0], padding[0], padding[1], padding[2], padding[3]), name="C")
|
||||
outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C")
|
||||
|
||||
_init_api("tvm.contrib.nnpack")
|
||||
|
|
|
@ -24,6 +24,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
|
|||
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
|
||||
uint64_t stride_width = args[8], stride_height = args[9];
|
||||
nnp_size stride_size{stride_width, stride_height};
|
||||
NNPackConfig(args[10]);
|
||||
|
||||
CHECK_EQ(input->ndim, 3);
|
||||
CHECK_EQ(kernel->ndim, 4);
|
||||
|
@ -80,6 +81,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
|
|||
DLTensor* output = args[3];
|
||||
uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7];
|
||||
nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
|
||||
NNPackConfig(args[8]);
|
||||
|
||||
CHECK_EQ(input->ndim, 4);
|
||||
CHECK_EQ(kernel->ndim, 4);
|
||||
|
|
|
@ -21,6 +21,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
|
|||
DLTensor* A = args[0];
|
||||
DLTensor* B = args[1];
|
||||
DLTensor* C = args[2];
|
||||
NNPackConfig(args[3]);
|
||||
|
||||
CHECK_EQ(A->ndim, 1);
|
||||
CHECK_EQ(B->ndim, 2);
|
||||
CHECK_EQ(C->ndim, 1);
|
||||
|
@ -49,6 +51,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
|
|||
DLTensor* A = args[0];
|
||||
DLTensor* B = args[1];
|
||||
DLTensor* C = args[2];
|
||||
NNPackConfig(args[3]);
|
||||
|
||||
CHECK_EQ(A->ndim, 2);
|
||||
CHECK_EQ(B->ndim, 2);
|
||||
CHECK_EQ(C->ndim, 2);
|
||||
|
|
|
@ -14,18 +14,23 @@ NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() {
|
|||
return NNPackThreadLocalStore::Get();
|
||||
}
|
||||
|
||||
bool NNPackConfig(uint64_t nthreads) {
|
||||
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
|
||||
if (entry->threadpool != NULL &&
|
||||
pthreadpool_get_threads_count(entry->threadpool) != nthreads) {
|
||||
pthreadpool_destroy(entry->threadpool);
|
||||
entry->threadpool = NULL;
|
||||
}
|
||||
if (entry->threadpool == NULL) {
|
||||
entry->threadpool = pthreadpool_create(nthreads);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
TVM_REGISTER_GLOBAL("contrib.nnpack._Config")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
|
||||
size_t nthreads = args[0].operator uint64_t();
|
||||
if (entry->threadpool != NULL &&
|
||||
pthreadpool_get_threads_count(entry->threadpool) != nthreads) {
|
||||
pthreadpool_destroy(entry->threadpool);
|
||||
entry->threadpool = NULL;
|
||||
}
|
||||
if (entry->threadpool == NULL) {
|
||||
entry->threadpool = pthreadpool_create(nthreads);
|
||||
}
|
||||
CHECK(NNPackConfig(args[0]));
|
||||
});
|
||||
} // namespace contrib
|
||||
} // namespace tvm
|
||||
|
|
|
@ -18,6 +18,8 @@ struct NNPackThreadLocalEntry {
|
|||
pthreadpool_t threadpool{NULL};
|
||||
static NNPackThreadLocalEntry* ThreadLocal();
|
||||
};
|
||||
|
||||
bool NNPackConfig(uint64_t nthreads);
|
||||
} // namespace contrib
|
||||
} // namespace tvm
|
||||
#endif // TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_
|
||||
|
|
Загрузка…
Ссылка в новой задаче