[ROCM] MIOpen contrib for convolution kernels (#722)
* fist working miopen support * do FindFwdAlgo during build time * fix lint * update doc string * import topi after checking if rocm is enabled * add miopen namespace * fixed descriptor overwrite bug * add use_miopen option * fix lint * better miopen option handling * fix typo * fix options handling
This commit is contained in:
Родитель
5d37be6259
Коммит
3b9f16523b
1
Makefile
1
Makefile
|
@ -134,6 +134,7 @@ include make/contrib/cblas.mk
|
|||
include make/contrib/random.mk
|
||||
include make/contrib/nnpack.mk
|
||||
include make/contrib/cudnn.mk
|
||||
include make/contrib/miopen.mk
|
||||
include make/contrib/mps.mk
|
||||
|
||||
ifdef ADD_CFLAGS
|
||||
|
|
|
@ -72,5 +72,8 @@ USE_NNPACK = 0
|
|||
# Whether use CuDNN
|
||||
USE_CUDNN = 0
|
||||
|
||||
# Whether use MIOpen
|
||||
USE_MIOPEN = 0
|
||||
|
||||
# Whether use MPS
|
||||
USE_MPS = 0
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
MIOPEN_CONTRIB_SRC = $(wildcard src/contrib/miopen/*.cc)
|
||||
MIOPEN_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(MIOPEN_CONTRIB_SRC))
|
||||
|
||||
ifeq ($(USE_MIOPEN), 1)
|
||||
CFLAGS += -DTVM_USE_MIOPEN=1
|
||||
ADD_LDFLAGS += -lMIOpen
|
||||
RUNTIME_DEP += $(MIOPEN_CONTRIB_OBJ)
|
||||
endif
|
|
@ -0,0 +1,102 @@
|
|||
"""External function interface to MIOpen library."""
|
||||
# pylint: disable-msg=C0103
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from .. import api as _api
|
||||
from .. import intrin as _intrin
|
||||
from .. import get_global_func as _get_global_func
|
||||
|
||||
|
||||
def _get_np_int32_array_handle(arr):
|
||||
"""Return a void_p handle for a numpy array
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arr: numpy.NDArray
|
||||
source numpy array
|
||||
|
||||
Returns
|
||||
-------
|
||||
ptr: ctypes.c_void_p
|
||||
pointer to the data
|
||||
"""
|
||||
assert arr.dtype == np.int32
|
||||
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
|
||||
return ctypes.cast(ptr, ctypes.c_void_p)
|
||||
|
||||
|
||||
def conv2d_forward(x,
|
||||
w,
|
||||
stride_h=1,
|
||||
stride_w=1,
|
||||
pad_h=0,
|
||||
pad_w=0,
|
||||
dilation_h=1,
|
||||
dilation_w=1,
|
||||
conv_mode=0):
|
||||
"""Create an extern op that compute 2D convolution with MIOpen
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: Tensor
|
||||
input feature map
|
||||
w: Tensor
|
||||
convolution weight
|
||||
stride_h: int
|
||||
height stride
|
||||
stride_w: int
|
||||
width stride
|
||||
pad_h: int
|
||||
height pad
|
||||
pad_w: int
|
||||
weight pad
|
||||
dilation_h: int
|
||||
height dilation
|
||||
dilation_w: int
|
||||
width dilation
|
||||
conv_mode: int
|
||||
0: miopenConvolution
|
||||
1: miopenTranspose
|
||||
|
||||
Returns
|
||||
-------
|
||||
y: Tensor
|
||||
The result tensor
|
||||
"""
|
||||
assert conv_mode == 0, "Transpose convolutions not supported yet."
|
||||
oshape = np.zeros((len(x.shape)), dtype=np.int32)
|
||||
xshape = x.shape
|
||||
wshape = w.shape
|
||||
setup_func = _get_global_func("tvm.contrib.miopen.conv2d.setup")
|
||||
algo = setup_func(conv_mode,
|
||||
pad_h,
|
||||
pad_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
dilation_h,
|
||||
dilation_w,
|
||||
xshape[0].value,
|
||||
xshape[1].value,
|
||||
xshape[2].value,
|
||||
xshape[3].value,
|
||||
wshape[0].value,
|
||||
wshape[1].value,
|
||||
wshape[2].value,
|
||||
wshape[3].value,
|
||||
_get_np_int32_array_handle(oshape))
|
||||
|
||||
return _api.extern(
|
||||
list(oshape), [x, w],
|
||||
lambda ins, outs: _intrin.call_packed(
|
||||
"tvm.contrib.miopen.conv2d.forward",
|
||||
conv_mode,
|
||||
pad_h,
|
||||
pad_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
dilation_h,
|
||||
dilation_w,
|
||||
algo,
|
||||
ins[0],
|
||||
ins[1],
|
||||
outs[0]), name="y")
|
|
@ -88,12 +88,17 @@ class Target(object):
|
|||
target_name,
|
||||
options=None):
|
||||
self.target_name = target_name
|
||||
self.options = _merge_opts([], options)
|
||||
self.options = []
|
||||
self.device_name = ""
|
||||
self.libs = []
|
||||
# Parse device option
|
||||
for item in self.options:
|
||||
if item.startswith("-device="):
|
||||
for item in _merge_opts([], options):
|
||||
if item.startswith("-libs="):
|
||||
self.libs.append(item.split("=")[1])
|
||||
continue
|
||||
elif item.startswith("-device="):
|
||||
self.device_name = item.split("=")[1]
|
||||
self.options.append(item)
|
||||
# Target query searchs device name first
|
||||
if self.device_name:
|
||||
self.keys = (self.device_name,)
|
||||
|
|
|
@ -0,0 +1,221 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file Use external miopen utils function
|
||||
*/
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include <tvm/runtime/util.h>
|
||||
#include <tvm/runtime/device_api.h>
|
||||
#include "miopen_utils.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace contrib {
|
||||
namespace miopen {
|
||||
|
||||
using namespace runtime;
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
const int mode = args[0];
|
||||
const int pad_h = args[1];
|
||||
const int pad_w = args[2];
|
||||
const int stride_h = args[3];
|
||||
const int stride_w = args[4];
|
||||
const int dilation_h = args[5];
|
||||
const int dilation_w = args[6];
|
||||
const int x_dim0 = args[7];
|
||||
const int x_dim1 = args[8];
|
||||
const int x_dim2 = args[9];
|
||||
const int x_dim3 = args[10];
|
||||
const int w_dim0 = args[11];
|
||||
const int w_dim1 = args[12];
|
||||
const int w_dim2 = args[13];
|
||||
const int w_dim3 = args[14];
|
||||
void *out_shape = args[15];
|
||||
|
||||
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
|
||||
// Set Mode
|
||||
entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
|
||||
// Set Ctx
|
||||
entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0};
|
||||
// Set Data Type
|
||||
entry_ptr->conv_entry.data_type = miopenFloat; // MIOpen only suppports fp32
|
||||
// Set Desc
|
||||
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
|
||||
entry_ptr->conv_entry.mode,
|
||||
pad_h,
|
||||
pad_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
dilation_h,
|
||||
dilation_w));
|
||||
// Set Filter
|
||||
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
|
||||
entry_ptr->conv_entry.data_type,
|
||||
w_dim0,
|
||||
w_dim1,
|
||||
w_dim2,
|
||||
w_dim3));
|
||||
// Set Input
|
||||
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc,
|
||||
entry_ptr->conv_entry.data_type,
|
||||
x_dim0,
|
||||
x_dim1,
|
||||
x_dim2,
|
||||
x_dim3));
|
||||
|
||||
// Set Output shape
|
||||
MIOPEN_CALL(miopenGetConvolutionForwardOutputDim(entry_ptr->conv_entry.conv_desc,
|
||||
entry_ptr->conv_entry.input_desc,
|
||||
entry_ptr->conv_entry.filter_desc,
|
||||
static_cast<int*>(out_shape),
|
||||
static_cast<int*>(out_shape) + 1,
|
||||
static_cast<int*>(out_shape) + 2,
|
||||
static_cast<int*>(out_shape) + 3));
|
||||
|
||||
const int *oshape = static_cast<int*>(out_shape);
|
||||
// Set Output
|
||||
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc,
|
||||
entry_ptr->conv_entry.data_type,
|
||||
oshape[0],
|
||||
oshape[1],
|
||||
oshape[2],
|
||||
oshape[3]));
|
||||
|
||||
// Set workspace
|
||||
size_t workspace_size = 0;
|
||||
MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize(entry_ptr->handle,
|
||||
entry_ptr->conv_entry.filter_desc,
|
||||
entry_ptr->conv_entry.input_desc,
|
||||
entry_ptr->conv_entry.conv_desc,
|
||||
entry_ptr->conv_entry.output_desc,
|
||||
&workspace_size));
|
||||
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
|
||||
|
||||
const size_t input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3;
|
||||
const size_t filter_size = w_dim0 * w_dim1 * w_dim2 * w_dim3;
|
||||
const size_t output_size = oshape[0] * oshape[1] * oshape[2] * oshape[3];
|
||||
|
||||
runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry.rocm_api;
|
||||
float* input_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
|
||||
input_size * sizeof(float)));
|
||||
float* filter_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
|
||||
filter_size * sizeof(float)));
|
||||
float* output_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
|
||||
output_size * sizeof(float)));
|
||||
|
||||
const int request_algo_count = 4;
|
||||
const bool exhaustive_search = false;
|
||||
int returned_algo_count = 0;
|
||||
miopenConvAlgoPerf_t perfs[4];
|
||||
|
||||
MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm(entry_ptr->handle,
|
||||
entry_ptr->conv_entry.input_desc,
|
||||
input_buf,
|
||||
entry_ptr->conv_entry.filter_desc,
|
||||
filter_buf,
|
||||
entry_ptr->conv_entry.conv_desc,
|
||||
entry_ptr->conv_entry.output_desc,
|
||||
output_buf,
|
||||
request_algo_count,
|
||||
&returned_algo_count,
|
||||
perfs,
|
||||
entry_ptr->conv_entry.workspace,
|
||||
workspace_size,
|
||||
exhaustive_search));
|
||||
|
||||
rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, input_buf);
|
||||
rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, filter_buf);
|
||||
rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, output_buf);
|
||||
|
||||
const std::vector<std::string> fwd_algo_names{
|
||||
"miopenConvolutionFwdAlgoGEMM",
|
||||
"miopenConvolutionFwdAlgoDirect",
|
||||
"miopenConvolutionFwdAlgoFFT",
|
||||
"miopenConvolutionFwdAlgoWinograd",
|
||||
};
|
||||
const auto best_algo = perfs[0].fwd_algo;
|
||||
LOG(INFO) << "\tMIOpen Found " << returned_algo_count
|
||||
<< " fwd algorithms, choosing " << fwd_algo_names[best_algo];
|
||||
for (int i = 0; i < returned_algo_count; ++i) {
|
||||
LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perfs[i].fwd_algo]
|
||||
<< " - time: " << perfs[i].time << " ms"
|
||||
<< ", Memory: " << perfs[i].memory;
|
||||
}
|
||||
// Set Algo
|
||||
ret[0] = static_cast<int>(best_algo);
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
const int mode = args[0];
|
||||
const int pad_h = args[1];
|
||||
const int pad_w = args[2];
|
||||
const int stride_h = args[3];
|
||||
const int stride_w = args[4];
|
||||
const int dilation_h = args[5];
|
||||
const int dilation_w = args[6];
|
||||
const int algo = args[7];
|
||||
const DLTensor *x = args[8];
|
||||
const DLTensor *w = args[9];
|
||||
const DLTensor *y = args[10];
|
||||
|
||||
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
|
||||
entry_ptr->conv_entry.fwd_algo = static_cast<miopenConvFwdAlgorithm_t>(algo);
|
||||
// Set Mode
|
||||
entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
|
||||
// Set Ctx
|
||||
entry_ptr->conv_entry.ctx = x->ctx;
|
||||
// Set Data Type
|
||||
entry_ptr->conv_entry.data_type = miopenFloat; // MIOpen only suppports fp32
|
||||
// Set Desc
|
||||
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
|
||||
entry_ptr->conv_entry.mode,
|
||||
pad_h,
|
||||
pad_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
dilation_h,
|
||||
dilation_w));
|
||||
// Set Filter
|
||||
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
|
||||
entry_ptr->conv_entry.data_type,
|
||||
w->shape[0],
|
||||
w->shape[1],
|
||||
w->shape[2],
|
||||
w->shape[3]));
|
||||
// Set Input
|
||||
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc,
|
||||
entry_ptr->conv_entry.data_type,
|
||||
x->shape[0],
|
||||
x->shape[1],
|
||||
x->shape[2],
|
||||
x->shape[3]));
|
||||
// Set Output
|
||||
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc,
|
||||
entry_ptr->conv_entry.data_type,
|
||||
y->shape[0],
|
||||
y->shape[1],
|
||||
y->shape[2],
|
||||
y->shape[3]));
|
||||
|
||||
const float alpha = 1.f;
|
||||
const float beta = 0.f;
|
||||
MIOPEN_CALL(miopenConvolutionForward(entry_ptr->handle,
|
||||
&alpha,
|
||||
entry_ptr->conv_entry.input_desc,
|
||||
x->data,
|
||||
entry_ptr->conv_entry.filter_desc,
|
||||
w->data,
|
||||
entry_ptr->conv_entry.conv_desc,
|
||||
entry_ptr->conv_entry.fwd_algo,
|
||||
&beta,
|
||||
entry_ptr->conv_entry.output_desc,
|
||||
y->data,
|
||||
entry_ptr->conv_entry.workspace,
|
||||
entry_ptr->conv_entry.workspace_size));
|
||||
});
|
||||
|
||||
} // namespace miopen
|
||||
} // namespace contrib
|
||||
} // namespace tvm
|
|
@ -0,0 +1,78 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file Use external miopen utils function
|
||||
*/
|
||||
#include "miopen_utils.h"
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace tvm {
|
||||
namespace contrib {
|
||||
namespace miopen {
|
||||
|
||||
std::string miopenGetErrorString(int error_code) {
|
||||
const std::vector<std::string> mio_err{
|
||||
"StatusSuccess ", "StatusNotInitialized ", "StatusInvalidValue ",
|
||||
"StatusBadParm ", "StatusAllocFailed ", "StatusInternalError ",
|
||||
"StatusNotImplemented ", "StatusUnknownError "};
|
||||
return mio_err[error_code];
|
||||
}
|
||||
|
||||
// MiopenThreadEntry
|
||||
MIOpenThreadEntry::MIOpenThreadEntry() {
|
||||
auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream;
|
||||
auto func = runtime::Registry::Get("device_api.rocm");
|
||||
void *ret = (*func)();
|
||||
rocm_api = static_cast<runtime::DeviceAPI*>(ret);
|
||||
MIOPEN_CALL(miopenCreate(&handle));
|
||||
MIOPEN_CALL(miopenSetStream(handle, stream));
|
||||
conv_entry.rocm_api = rocm_api;
|
||||
}
|
||||
|
||||
MIOpenThreadEntry::~MIOpenThreadEntry() {
|
||||
MIOPEN_CALL(miopenDestroy(handle));
|
||||
}
|
||||
|
||||
typedef dmlc::ThreadLocalStore<MIOpenThreadEntry> MIOpenThreadStore;
|
||||
|
||||
MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() {
|
||||
return MIOpenThreadStore::Get();
|
||||
}
|
||||
|
||||
// ConvEntry
|
||||
|
||||
ConvEntry::ConvEntry() {
|
||||
MIOPEN_CALL(miopenCreateConvolutionDescriptor(&conv_desc));
|
||||
MIOPEN_CALL(miopenCreateTensorDescriptor(&filter_desc));
|
||||
MIOPEN_CALL(miopenCreateTensorDescriptor(&input_desc));
|
||||
MIOPEN_CALL(miopenCreateTensorDescriptor(&output_desc));
|
||||
}
|
||||
|
||||
ConvEntry::~ConvEntry() {
|
||||
MIOPEN_CALL(miopenDestroyConvolutionDescriptor(conv_desc));
|
||||
MIOPEN_CALL(miopenDestroyTensorDescriptor(filter_desc));
|
||||
MIOPEN_CALL(miopenDestroyTensorDescriptor(input_desc));
|
||||
MIOPEN_CALL(miopenDestroyTensorDescriptor(output_desc));
|
||||
CleanWorkspace();
|
||||
}
|
||||
|
||||
void ConvEntry::UpdateWorkspace(const size_t wsize) {
|
||||
if (workspace_size < wsize) {
|
||||
if (workspace != nullptr) {
|
||||
CleanWorkspace();
|
||||
}
|
||||
workspace_size = wsize;
|
||||
workspace = rocm_api->AllocWorkspace(ctx, workspace_size);
|
||||
}
|
||||
}
|
||||
|
||||
void ConvEntry::CleanWorkspace() {
|
||||
if (workspace) rocm_api->FreeWorkspace(ctx, workspace);
|
||||
workspace_size = 0;
|
||||
}
|
||||
|
||||
} // namespace miopen
|
||||
} // namespace contrib
|
||||
} // namespace tvm
|
|
@ -0,0 +1,59 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file Use external miopen utils function
|
||||
*/
|
||||
|
||||
#ifndef TVM_CONTRIB_MIOPEN_MIOPEN_UTILS_H_
|
||||
#define TVM_CONTRIB_MIOPEN_MIOPEN_UTILS_H_
|
||||
|
||||
#include <dmlc/logging.h>
|
||||
#include <miopen/miopen.h>
|
||||
#include <tvm/runtime/device_api.h>
|
||||
#include <string>
|
||||
#include "../../runtime/rocm/rocm_common.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace contrib {
|
||||
namespace miopen {
|
||||
|
||||
std::string miopenGetErrorString(int error_code);
|
||||
|
||||
#define MIOPEN_CALL(func) \
|
||||
{ \
|
||||
miopenStatus_t e = (func); \
|
||||
CHECK_EQ(e, miopenStatusSuccess) \
|
||||
<< "miopen error: " << miopenGetErrorString(e); \
|
||||
}
|
||||
|
||||
struct ConvEntry {
|
||||
miopenConvolutionDescriptor_t conv_desc;
|
||||
miopenConvolutionMode_t mode{miopenConvolution};
|
||||
miopenTensorDescriptor_t filter_desc;
|
||||
miopenDataType_t data_type{miopenFloat};
|
||||
miopenTensorDescriptor_t input_desc;
|
||||
miopenTensorDescriptor_t output_desc;
|
||||
miopenConvFwdAlgorithm_t fwd_algo;
|
||||
TVMContext ctx;
|
||||
runtime::DeviceAPI *rocm_api;
|
||||
void *workspace{nullptr};
|
||||
size_t workspace_size{0};
|
||||
ConvEntry();
|
||||
~ConvEntry();
|
||||
void UpdateWorkspace(const size_t wsize);
|
||||
void CleanWorkspace();
|
||||
}; // ConvThreadEntry
|
||||
|
||||
struct MIOpenThreadEntry {
|
||||
MIOpenThreadEntry();
|
||||
~MIOpenThreadEntry();
|
||||
miopenHandle_t handle{nullptr};
|
||||
ConvEntry conv_entry;
|
||||
runtime::DeviceAPI *rocm_api{nullptr};
|
||||
static MIOpenThreadEntry *ThreadLocal();
|
||||
}; // MIOpenThreadEntry
|
||||
|
||||
} // namespace miopen
|
||||
} // namespace contrib
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_CONTRIB_MIOPEN_MIOPEN_UTILS_H_
|
|
@ -0,0 +1,64 @@
|
|||
import tvm
|
||||
from tvm.contrib import miopen
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_conv2d():
|
||||
in_channel = 64
|
||||
out_channel = 128
|
||||
filter_h = 3
|
||||
filter_w = 3
|
||||
pad_h = 1
|
||||
pad_w = 1
|
||||
stride_h = 1
|
||||
stride_w = 1
|
||||
dilation_h = 1
|
||||
dilation_w = 1
|
||||
|
||||
xshape = [1, in_channel, 64, 64]
|
||||
if not tvm.module.enabled("rocm"):
|
||||
print("skip because rocm is not enabled...")
|
||||
return
|
||||
if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True):
|
||||
print("skip because miopen is not enabled...")
|
||||
return
|
||||
wshape = (out_channel, in_channel, filter_h, filter_w)
|
||||
|
||||
X = tvm.placeholder(xshape, name='X')
|
||||
W = tvm.placeholder(wshape, name='W')
|
||||
Y = miopen.conv2d_forward(X,
|
||||
W,
|
||||
stride_h,
|
||||
stride_w,
|
||||
pad_h,
|
||||
pad_w,
|
||||
dilation_h,
|
||||
dilation_w,
|
||||
conv_mode=0)
|
||||
|
||||
yshape = [x.value for x in Y.shape]
|
||||
s = tvm.create_schedule(Y.op)
|
||||
|
||||
def verify():
|
||||
ctx = tvm.rocm(0)
|
||||
f = tvm.build(s, [X, W, Y], "rocm", target_host="llvm", name="conv2d")
|
||||
x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(np.float32), ctx)
|
||||
w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(np.float32), ctx)
|
||||
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
|
||||
f(x, w, y)
|
||||
|
||||
import topi
|
||||
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w))
|
||||
with tvm.target.rocm():
|
||||
s_ref = topi.generic.schedule_conv2d_nchw([Y_ref])
|
||||
f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm")
|
||||
y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
|
||||
f_ref(x, w, y_ref)
|
||||
print("Max abs diff:", np.max(np.abs(y.asnumpy() - y_ref.asnumpy())))
|
||||
np.testing.assert_allclose(y.asnumpy(), y_ref.asnumpy(), atol=1e-3)
|
||||
|
||||
verify()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv2d()
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
export PYTHONPATH=python
|
||||
export PYTHONPATH=python:topi/python
|
||||
|
||||
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче