[Refactor] Introduce target generic dispatch system (#556)
* [TVM] Introduce target generic dispatch system * fix target warning
This commit is contained in:
Родитель
c3cac46465
Коммит
eb761f3630
|
@ -8,6 +8,7 @@ Python API
|
|||
intrin
|
||||
tensor
|
||||
schedule
|
||||
target
|
||||
build
|
||||
module
|
||||
ndarray
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
tvm.target
|
||||
----------
|
||||
.. automodule:: tvm.target
|
||||
|
||||
.. autofunction:: tvm.target.generic_func
|
||||
|
||||
.. autoclass:: tvm.target.Target
|
||||
:members:
|
||||
|
||||
.. autofunction:: tvm.target.cuda
|
||||
.. autofunction:: tvm.target.rocm
|
||||
.. autofunction:: tvm.target.rasp
|
||||
.. autofunction:: tvm.target.create
|
|
@ -37,13 +37,11 @@ Index
|
|||
|
||||
.. autosummary::
|
||||
|
||||
topi.cuda.schedule_conv2d_nchw
|
||||
topi.cuda.schedule_conv2d_hwcn
|
||||
topi.cuda.schedule_depthwise_conv2d_nchw
|
||||
topi.cuda.schedule_depthwise_conv2d_nhwc
|
||||
topi.cuda.schedule_reduce
|
||||
topi.cuda.schedule_broadcast
|
||||
topi.cuda.schedule_injective
|
||||
topi.generic.schedule_conv2d_nchw
|
||||
topi.generic.schedule_depthwise_conv2d_nchw
|
||||
topi.generic.schedule_reduce
|
||||
topi.generic.schedule_broadcast
|
||||
topi.generic.schedule_injective
|
||||
|
||||
topi
|
||||
~~~~
|
||||
|
@ -75,14 +73,12 @@ topi.nn
|
|||
.. autofunction:: topi.nn.depthwise_conv2d_nhwc
|
||||
|
||||
|
||||
topi.cuda
|
||||
~~~~~~~~~
|
||||
.. automodule:: topi.cuda
|
||||
topi.generic
|
||||
~~~~~~~~~~~~
|
||||
.. automodule:: topi.generic
|
||||
|
||||
.. autofunction:: topi.cuda.schedule_conv2d_nchw
|
||||
.. autofunction:: topi.cuda.schedule_conv2d_hwcn
|
||||
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nchw
|
||||
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nhwc
|
||||
.. autofunction:: topi.cuda.schedule_reduce
|
||||
.. autofunction:: topi.cuda.schedule_broadcast
|
||||
.. autofunction:: topi.cuda.schedule_injective
|
||||
.. autofunction:: topi.generic.schedule_conv2d_nchw
|
||||
.. autofunction:: topi.generic.schedule_depthwise_conv2d_nchw
|
||||
.. autofunction:: topi.generic.schedule_reduce
|
||||
.. autofunction:: topi.generic.schedule_broadcast
|
||||
.. autofunction:: topi.generic.schedule_injective
|
||||
|
|
|
@ -56,11 +56,7 @@ def context(dev_type, dev_id=0):
|
|||
assert tvm.context("cuda", 0) == tvm.gpu(0)
|
||||
"""
|
||||
if isinstance(dev_type, string_types):
|
||||
if dev_type not in TVMContext.STR2MASK:
|
||||
if dev_type.find("nvptx") != -1:
|
||||
dev_type = "cuda"
|
||||
if dev_type.find("rocm") != -1:
|
||||
dev_type = "rocm"
|
||||
dev_type = dev_type.split()[0]
|
||||
if dev_type not in TVMContext.STR2MASK:
|
||||
raise ValueError("Unknown device type %s" % dev_type)
|
||||
dev_type = TVMContext.STR2MASK[dev_type]
|
||||
|
|
|
@ -100,9 +100,12 @@ class TVMContext(ctypes.Structure):
|
|||
12: 'ext_dev',
|
||||
}
|
||||
STR2MASK = {
|
||||
'llvm': 1,
|
||||
'stackvm': 1,
|
||||
'cpu': 1,
|
||||
'gpu': 2,
|
||||
'cuda': 2,
|
||||
'nvptx': 2,
|
||||
'cl': 4,
|
||||
'opencl': 4,
|
||||
'metal': 8,
|
||||
|
|
|
@ -15,6 +15,7 @@ from . import container
|
|||
from . import module
|
||||
from . import codegen
|
||||
from . import ndarray
|
||||
from . import target as _target
|
||||
|
||||
class BuildConfig(object):
|
||||
"""Configuration scope to set a build config option.
|
||||
|
@ -238,7 +239,7 @@ def lower(sch,
|
|||
|
||||
def build(sch,
|
||||
args=None,
|
||||
target="llvm",
|
||||
target=None,
|
||||
target_host=None,
|
||||
name="default_function",
|
||||
binds=None):
|
||||
|
@ -252,36 +253,10 @@ def build(sch,
|
|||
args : list of Buffer or Tensor or Var, optional
|
||||
The argument lists to the function.
|
||||
|
||||
target : str, optional
|
||||
target : str or :any:`tvm.target.Target`, optional
|
||||
The target and option of the compilation.
|
||||
When the target is llvm, you can set options like:
|
||||
|
||||
- **-mtriple=<target triple>** or **-target**
|
||||
|
||||
Specify the target triple, which is useful for cross
|
||||
compilation.
|
||||
|
||||
- **-mcpu=<cpuname>**
|
||||
|
||||
Specify a specific chip in the current architecture to
|
||||
generate code for. By default this is infered from the
|
||||
target triple and autodetected to the current architecture.
|
||||
|
||||
- **-mattr=a1,+a2,-a3,...**
|
||||
|
||||
Override or control specific attributes of the target,
|
||||
such as whether SIMD operations are enabled or not. The
|
||||
default set of attributes is set by the current CPU.
|
||||
|
||||
- **-system-lib**
|
||||
|
||||
Build TVM system library module. System lib is a global module that contains
|
||||
self registered functions in program startup. User can get the module using
|
||||
:any:`tvm.module.system_lib`.
|
||||
It is useful in environments where dynamic loading api like dlopen is banned.
|
||||
The system lib will be available as long as the result code is linked by the program.
|
||||
|
||||
target_host : str, optional
|
||||
target_host : str or :any:`tvm.target.Target` optional
|
||||
Host compilation target, if target is device.
|
||||
When TVM compiles device specific program such as CUDA,
|
||||
we also need host(CPU) side code to interact with the driver
|
||||
|
@ -301,6 +276,10 @@ def build(sch,
|
|||
-------
|
||||
f : Function, or pair of functions
|
||||
The result function.
|
||||
|
||||
Note
|
||||
----
|
||||
See the note on :any:`tvm.target` on target string format.
|
||||
"""
|
||||
if isinstance(sch, schedule.Schedule):
|
||||
if args is None:
|
||||
|
@ -325,6 +304,9 @@ def build(sch,
|
|||
if x.name in fname_set:
|
||||
raise ValueError("Duplicate function name %s" % x.name)
|
||||
|
||||
target = _target.current_target() if target is None else target
|
||||
target = _target.create(target) if target else _target.create("llvm")
|
||||
|
||||
fhost = []
|
||||
fdevice = []
|
||||
for func in flist:
|
||||
|
@ -332,7 +314,7 @@ def build(sch,
|
|||
if BuildConfig.current.detect_global_barrier:
|
||||
func = ir_pass.ThreadSync(func, "global")
|
||||
func = ir_pass.ThreadSync(func, "shared")
|
||||
warp_size = 32 if target == "cuda" else 1
|
||||
warp_size = target.thread_warp_size
|
||||
func = ir_pass.LowerThreadAllreduce(func, warp_size)
|
||||
fsplits = [s for s in ir_pass.SplitHostDevice(func)]
|
||||
fhost.append(fsplits[0])
|
||||
|
@ -345,29 +327,28 @@ def build(sch,
|
|||
else:
|
||||
raise ValueError("unknown function type %d" % func.func_type)
|
||||
|
||||
if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice:
|
||||
if "gpu" in target.keys and not fdevice:
|
||||
warnings.warn(
|
||||
"Specified target %s, but cannot find device code, did you do bind?" % target)
|
||||
|
||||
device = "cpu" if target.startswith("llvm") or target == "stackvm" else target
|
||||
device_type = ndarray.context(device, 0).device_type
|
||||
device_type = ndarray.context(target.target_name, 0).device_type
|
||||
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
|
||||
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
|
||||
|
||||
if not target_host:
|
||||
if device == "cpu":
|
||||
if device_type == ndarray.cpu(0).device_type:
|
||||
target_host = target
|
||||
assert not fdevice
|
||||
else:
|
||||
target_host = "llvm" if module.enabled("llvm") else "stackvm"
|
||||
|
||||
target_host = _target.create(target_host)
|
||||
target_device = target
|
||||
fdevice = [ir_pass.LowerIntrin(x, target_device) for x in fdevice]
|
||||
fhost = [ir_pass.LowerIntrin(x, target_host) for x in fhost]
|
||||
fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice]
|
||||
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
|
||||
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
|
||||
mhost = codegen.build_module(fhost, target_host)
|
||||
mhost = codegen.build_module(fhost, str(target_host))
|
||||
|
||||
if fdevice:
|
||||
mdev = codegen.build_module(fdevice, target_device)
|
||||
mdev = codegen.build_module(fdevice, str(target_device))
|
||||
mhost.import_module(mdev)
|
||||
return mhost
|
||||
|
|
|
@ -1,63 +1,311 @@
|
|||
"""Target management API of tvm"""
|
||||
"""Target management API of TVM.
|
||||
|
||||
TVM's target string is in fomat ``<target_name> [-option=value]...``.
|
||||
|
||||
Note
|
||||
----
|
||||
The list of options include:
|
||||
|
||||
- **-device=<device name>**
|
||||
|
||||
The device name.
|
||||
|
||||
- **-mtriple=<target triple>** or **-target**
|
||||
|
||||
Specify the target triple, which is useful for cross
|
||||
compilation.
|
||||
|
||||
- **-mcpu=<cpuname>**
|
||||
|
||||
Specify a specific chip in the current architecture to
|
||||
generate code for. By default this is infered from the
|
||||
target triple and autodetected to the current architecture.
|
||||
|
||||
- **-mattr=a1,+a2,-a3,...**
|
||||
|
||||
Override or control specific attributes of the target,
|
||||
such as whether SIMD operations are enabled or not. The
|
||||
default set of attributes is set by the current CPU.
|
||||
|
||||
- **-system-lib**
|
||||
|
||||
Build TVM system library module. System lib is a global module that contains
|
||||
self registered functions in program startup. User can get the module using
|
||||
:any:`tvm.module.system_lib`.
|
||||
It is useful in environments where dynamic loading api like dlopen is banned.
|
||||
The system lib will be available as long as the result code is linked by the program.
|
||||
|
||||
We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
|
||||
We can also use other specific function in this module to create specific targets.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import warnings
|
||||
from ._ffi.base import _LIB_NAME
|
||||
|
||||
try:
|
||||
from decorator import decorate
|
||||
except ImportError as err_msg:
|
||||
# Allow decorator to be missing in runtime
|
||||
if _LIB_NAME != "libtvm_runtime.so":
|
||||
raise err_msg
|
||||
|
||||
|
||||
def _merge_opts(opts, new_opts):
|
||||
"""Helper function to merge options"""
|
||||
if isinstance(new_opts, str):
|
||||
new_opts = new_opts.split()
|
||||
if new_opts:
|
||||
return opts + new_opts
|
||||
return opts
|
||||
|
||||
|
||||
class Target(object):
|
||||
"""A Target describes the target type on which computation should be carried on"""
|
||||
default_target = None
|
||||
str2type = {'x86': 1, 'cuda': 2, 'rasp': 3}
|
||||
type2str = {1: 'x86', 2: 'cuda', 3: 'rasp'}
|
||||
def __init__(self, target_type):
|
||||
"""Constructs a context."""
|
||||
if isinstance(target_type, Target):
|
||||
self.target_typeid = target_type.target_typeid
|
||||
"""Target device information, use through TVM API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_name : {"llvm", "cuda", "opencl", "metal", "rocm", "stackvm", "ext_dev"}
|
||||
The major target name.
|
||||
|
||||
options : list of str, optional
|
||||
Additional arguments appended to the target.
|
||||
|
||||
Note
|
||||
----
|
||||
Do not use class constructor, you can create target using the following functions
|
||||
|
||||
- :any:`tvm.target.create` create target from string
|
||||
- :any:`tvm.target.rasp` create raspberry pi target
|
||||
- :any:`tvm.target.cuda` create CUDA target
|
||||
- :any:`tvm.target.rocm` create ROCM target
|
||||
"""
|
||||
current = None
|
||||
|
||||
def __init__(self,
|
||||
target_name,
|
||||
options=None):
|
||||
self.target_name = target_name
|
||||
self.options = _merge_opts([], options)
|
||||
self.device_name = ""
|
||||
# Parse device option
|
||||
for item in self.options:
|
||||
if item.startswith("-device="):
|
||||
self.device_name = item.split("=")[1]
|
||||
# Target query searchs device name first
|
||||
if self.device_name:
|
||||
self.keys = (self.device_name,)
|
||||
else:
|
||||
self.target_typeid = Target.str2type[target_type]
|
||||
|
||||
@property
|
||||
def target_type(self):
|
||||
"""Returns the target type of current target."""
|
||||
return Target.type2str[self.target_typeid]
|
||||
|
||||
def __hash__(self):
|
||||
"""Compute hash value of target for dictionary lookup"""
|
||||
return hash(self.target_typeid)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Compares two targets. Two targets are equal if they
|
||||
have the same target type.
|
||||
"""
|
||||
return isinstance(other, Target) and \
|
||||
self.target_typeid == other.target_typeid
|
||||
self.keys = ()
|
||||
# Target configuration handling
|
||||
self.thread_warp_size = 1
|
||||
if target_name in ("llvm", ):
|
||||
self.keys += ("cpu",)
|
||||
elif target_name in ("cuda", "nvptx"):
|
||||
self.keys += ("cuda", "gpu")
|
||||
self.max_num_threads = 512
|
||||
self.thread_warp_size = 32
|
||||
elif target_name in ("rocm", "opencl"):
|
||||
# For now assume rocm schedule for opencl
|
||||
self.keys += ("rocm", "gpu")
|
||||
self.max_num_threads = 256
|
||||
elif target_name in ("metal",):
|
||||
self.keys += ("gpu",)
|
||||
self.max_num_threads = 256
|
||||
elif target_name in ("stackvm", "ext_dev"):
|
||||
# Do not now class for stacvm or ext_dev
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unknown target name %s" % target_name)
|
||||
|
||||
def __str__(self):
|
||||
return '%s' % (self.target_type)
|
||||
return " ".join([self.target_name] + self.options)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __enter__(self):
|
||||
self._old_target = Target.default_target
|
||||
Target.default_target = self
|
||||
self._old_target = Target.current
|
||||
if self._old_target is not None and str(self) != str(self._old_target):
|
||||
warnings.warn(
|
||||
"Override target '%s' with new target scope '%s'" % (
|
||||
self._old_target, self))
|
||||
Target.current = self
|
||||
return self
|
||||
|
||||
def __exit__(self, ptype, value, trace):
|
||||
Target.default_target = self._old_target
|
||||
Target.current = self._old_target
|
||||
|
||||
Target.default_target = Target('x86')
|
||||
|
||||
def x86():
|
||||
"""Returns a x86 target."""
|
||||
return Target('x86')
|
||||
def generic_func(fdefault):
|
||||
"""Wrap a target generic function.
|
||||
|
||||
def cuda():
|
||||
"""Returns a cuda target."""
|
||||
return Target('cuda')
|
||||
Generic function allows registeration of further functions
|
||||
that can be dispatched on current target context.
|
||||
If no registered dispatch is matched, the fdefault will be called.
|
||||
|
||||
def rasp():
|
||||
"""Returns a rasp target."""
|
||||
return Target('rasp')
|
||||
Parameters
|
||||
----------
|
||||
fdefault : function
|
||||
The default function.
|
||||
|
||||
def current_target():
|
||||
"""Returns the current target."""
|
||||
return Target.default_target
|
||||
Returns
|
||||
-------
|
||||
fgeneric : function
|
||||
A wrapped generic function.
|
||||
|
||||
Example
|
||||
-------
|
||||
.. code-block:: python
|
||||
|
||||
import tvm
|
||||
# wrap function as target generic
|
||||
@tvm.target.generic_func
|
||||
def my_func(a):
|
||||
return a + 1
|
||||
# register specialization of my_func under target cuda
|
||||
@my_func.register("cuda")
|
||||
def my_func_cuda(a):
|
||||
return a + 2
|
||||
# displays 3, because my_func is called
|
||||
print(my_func(2))
|
||||
# displays 4, because my_func_cuda is called
|
||||
with tvm.target.cuda():
|
||||
print(my_func(2))
|
||||
"""
|
||||
dispatch_dict = {}
|
||||
func_name = fdefault.__name__
|
||||
|
||||
def register(key, func=None, override=False):
|
||||
"""Register function to be the dispatch function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str or list of str
|
||||
The key to be registered.
|
||||
|
||||
func : function
|
||||
The function to be registered.
|
||||
|
||||
override : bool
|
||||
Whether override existing registeration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The register function is necessary.
|
||||
"""
|
||||
def _do_reg(myf):
|
||||
key_list = [key] if isinstance(key, str) else key
|
||||
for k in key_list:
|
||||
if k in dispatch_dict and not override:
|
||||
raise ValueError(
|
||||
"Key is already registered for %s" % func_name)
|
||||
dispatch_dict[k] = myf
|
||||
return myf
|
||||
if func:
|
||||
return _do_reg(myf)
|
||||
return _do_reg
|
||||
|
||||
def dispatch_func(func, *args, **kwargs):
|
||||
"""The wrapped dispath function"""
|
||||
target = current_target()
|
||||
if target is None:
|
||||
return func(*args, **kwargs)
|
||||
for k in target.keys:
|
||||
if k in dispatch_dict:
|
||||
return dispatch_dict[k](*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
fdecorate = decorate(fdefault, dispatch_func)
|
||||
fdecorate.register = register
|
||||
return fdecorate
|
||||
|
||||
|
||||
def cuda(options=None):
|
||||
"""Returns a cuda target.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
options : list of str
|
||||
Additional options
|
||||
"""
|
||||
return Target("cuda", options)
|
||||
|
||||
|
||||
def rocm(options=None):
|
||||
"""Returns a ROCM target.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
options : list of str
|
||||
Additional options
|
||||
"""
|
||||
return Target("rocm", options)
|
||||
|
||||
|
||||
def rasp(options=None):
|
||||
"""Returns a rasp target.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
options : list of str
|
||||
Additional options
|
||||
"""
|
||||
opts = ["-device=rasp",
|
||||
"-mtriple=armv7l-none-linux-gnueabihf",
|
||||
"-mcpu=cortex-a53",
|
||||
"-mattr=+neon"]
|
||||
opts = _merge_opts(opts, options)
|
||||
return Target("llvm", opts)
|
||||
|
||||
|
||||
def create(target_str):
|
||||
"""Get a target given target string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_str : str
|
||||
The target string.
|
||||
|
||||
Returns
|
||||
-------
|
||||
target : Target
|
||||
The target object
|
||||
|
||||
Note
|
||||
----
|
||||
See the note on :any:`tvm.target` on target string format.
|
||||
"""
|
||||
if isinstance(target_str, Target):
|
||||
return target_str
|
||||
if not isinstance(target_str, str):
|
||||
raise ValueError("target_str has to be string type")
|
||||
arr = target_str.split()
|
||||
# Parse device option
|
||||
device_name = ""
|
||||
for item in arr[1:]:
|
||||
if item.startswith("-device="):
|
||||
device_name = item.split("=")[1]
|
||||
if device_name == "rasp":
|
||||
return rasp(arr[1:])
|
||||
return Target(arr[0], arr[1:])
|
||||
|
||||
|
||||
def current_target(allow_none=True):
|
||||
"""Returns the current target.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
allow_none : bool
|
||||
Whether allow the current target to be none
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError if current target is not set.
|
||||
"""
|
||||
if Target.current:
|
||||
return Target.current
|
||||
if not allow_none:
|
||||
raise RuntimeError(
|
||||
"Requires a current target in generic function, but it is not set. "
|
||||
"Please set it using `with TargetObject:`")
|
||||
return Target.current
|
||||
|
|
|
@ -82,6 +82,8 @@ GetLLVMTargetMachine(const std::string& target_str,
|
|||
} else {
|
||||
LOG(FATAL) << "invalid -mfloat-abi option " << value;
|
||||
}
|
||||
} else if (key == "-device") {
|
||||
// pass
|
||||
} else {
|
||||
LOG(FATAL) << "unknown option " << key;
|
||||
}
|
||||
|
|
|
@ -68,7 +68,8 @@ def test_gemm():
|
|||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
|
||||
f = tvm.build(s, [A, B, C], device)
|
||||
with tvm.target.create(device):
|
||||
f = tvm.build(s, [A, B, C])
|
||||
ctx = tvm.context(device, 0)
|
||||
# launch the kernel.
|
||||
n = nn
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
import tvm
|
||||
|
||||
@tvm.target.generic_func
|
||||
def mygeneric(data):
|
||||
# default generic function
|
||||
return data + 1
|
||||
|
||||
@mygeneric.register(["cuda", "gpu"])
|
||||
def cuda_func(data):
|
||||
return data + 2
|
||||
|
||||
@mygeneric.register("rocm")
|
||||
def rocm_func(data):
|
||||
return data + 3
|
||||
|
||||
@mygeneric.register("cpu")
|
||||
def rocm_func(data):
|
||||
return data + 10
|
||||
|
||||
|
||||
def test_target_dispatch():
|
||||
with tvm.target.cuda():
|
||||
assert mygeneric(1) == 3
|
||||
|
||||
with tvm.target.rocm():
|
||||
assert mygeneric(1) == 4
|
||||
|
||||
with tvm.target.create("cuda"):
|
||||
assert mygeneric(1) == 3
|
||||
|
||||
with tvm.target.rasp():
|
||||
assert mygeneric(1) == 11
|
||||
|
||||
with tvm.target.create("metal"):
|
||||
assert mygeneric(1) == 3
|
||||
|
||||
try:
|
||||
mygeneric(0)
|
||||
raise RuntimeError("not reached")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_target_dispatch()
|
|
@ -3,6 +3,7 @@
|
|||
import tvm
|
||||
from .. import util
|
||||
from .. import tag
|
||||
from .. import generic
|
||||
|
||||
def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
|
||||
"""Schedule conv2d for specific feature_in_out_filter pattern"""
|
||||
|
@ -483,6 +484,8 @@ def schedule_conv2d_small_batch(outs):
|
|||
traverse(outs[0].op)
|
||||
return s
|
||||
|
||||
|
||||
@generic.schedule_conv2d_nchw.register(["cuda", "gpu"])
|
||||
def schedule_conv2d_nchw(outs):
|
||||
"""Schedule for conv2d_nchw.
|
||||
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
import tvm
|
||||
from .. import tag
|
||||
from .. import generic
|
||||
|
||||
@generic.schedule_dense.register(["cuda", "gpu"])
|
||||
def schedule_dense(outs):
|
||||
"""Schedule for dense operator.
|
||||
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
import tvm
|
||||
from ..util import get_const_tuple
|
||||
from .. import tag
|
||||
from .. import generic
|
||||
|
||||
@generic.schedule_depthwise_conv2d_nchw.register(["cuda", "gpu"])
|
||||
def schedule_depthwise_conv2d_nchw(outs):
|
||||
"""Schedule for depthwise_conv2d nchw forward.
|
||||
|
||||
|
|
|
@ -1,17 +1,21 @@
|
|||
# pylint: disable=invalid-name, unused-variable,
|
||||
"""Schedule for composition of injective operator"""
|
||||
import tvm
|
||||
from .. import generic
|
||||
|
||||
def _schedule_injective(op, sch):
|
||||
x = op.output(0)
|
||||
fused = sch[x].fuse(*sch[x].op.axis)
|
||||
num_thread = 512
|
||||
target = tvm.target.current_target()
|
||||
target = target if target else tvm.target.cuda()
|
||||
num_thread = target.max_num_threads
|
||||
bx, tx = sch[x].split(fused, factor=num_thread)
|
||||
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
return sch
|
||||
|
||||
|
||||
@generic.schedule_injective.register(["cuda", "gpu"])
|
||||
def schedule_injective(outs):
|
||||
"""Schedule for injective op.
|
||||
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
"""Schedule for pooling operators"""
|
||||
import tvm
|
||||
from .. import tag
|
||||
from .. import generic
|
||||
|
||||
@generic.schedule_global_pool.register(["cuda", "gpu"])
|
||||
def schedule_global_pool(outs):
|
||||
"""Schedule for global_pool.
|
||||
|
||||
|
@ -63,6 +65,7 @@ def schedule_global_pool(outs):
|
|||
return s
|
||||
|
||||
|
||||
@generic.schedule_pool.register(["cuda", "gpu"])
|
||||
def schedule_pool(outs):
|
||||
"""Schedule for pool.
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
import tvm
|
||||
from .. import tag
|
||||
from .. import generic
|
||||
|
||||
def _schedule_reduce(op, sch, is_idx_reduce=False):
|
||||
if is_idx_reduce:
|
||||
|
@ -62,6 +63,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
|
|||
return sch
|
||||
|
||||
|
||||
@generic.schedule_reduce.register(["cuda", "gpu"])
|
||||
def schedule_reduce(outs):
|
||||
"""Schedule for inject->reduce->bcast ops.
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
|
||||
"""Schedule for softmax operator"""
|
||||
import tvm
|
||||
from .. import generic
|
||||
|
||||
@generic.schedule_softmax.register(["cuda", "gpu"])
|
||||
def schedule_softmax(outs):
|
||||
"""Schedule for softmax op.
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# pylint: disable=wildcard-import
|
||||
"""Generic declaration and schedules.
|
||||
|
||||
This is a recommended way of using TOPI API.
|
||||
To use the generic schedule function, user must set
|
||||
the current target scope using with block. See also :any:`tvm.target`
|
||||
|
||||
Example
|
||||
-------
|
||||
.. code-block:: python
|
||||
|
||||
# create schedule that dispatches to topi.cuda.schedule_injective
|
||||
with tvm.target.create("cuda"):
|
||||
s = tvm.generic.schedule_injective(outs)
|
||||
"""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
from .nn import *
|
||||
from .injective import *
|
|
@ -0,0 +1,32 @@
|
|||
# pylint: disable=invalid-name
|
||||
"""generic declaration and schedules."""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
import tvm
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_injective(outs):
|
||||
"""Schedule for injective op.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of reduce in the format
|
||||
of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
target = tvm.target.current_target(allow_none=False)
|
||||
if target.target_name != "llvm":
|
||||
raise RuntimeError("schedule_injective not registered for '%s'" % target)
|
||||
x = outs[0]
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
tvm.schedule.AutoInlineInjective(s)
|
||||
s[x].fuse(s[x].op.axis)
|
||||
return s
|
||||
|
||||
schedule_elemwise = schedule_injective
|
||||
schedule_broadcast = schedule_injective
|
|
@ -0,0 +1,142 @@
|
|||
"""Generic nn operators"""
|
||||
from __future__ import absolute_import as _abs
|
||||
import tvm
|
||||
|
||||
|
||||
def _default_schedule(outs, auto_inline):
|
||||
"""Default schedule for llvm."""
|
||||
target = tvm.target.current_target(allow_none=False)
|
||||
if target.target_name != "llvm":
|
||||
raise RuntimeError("schedule_pool not registered for '%s'" % target)
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
if auto_inline:
|
||||
x = outs[0]
|
||||
tvm.schedule.AutoInlineInjective(s)
|
||||
s[x].fuse(s[x].op.axis)
|
||||
return s
|
||||
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_conv2d_nchw(outs):
|
||||
"""Schedule for conv2d nchow
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of reduce in the format
|
||||
of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
||||
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_depthwise_conv2d_nchw(outs):
|
||||
"""Schedule for conv2d nchow
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of reduce in the format
|
||||
of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
||||
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_reduce(outs):
|
||||
"""Schedule for reduction
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of reduce in the format
|
||||
of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, True)
|
||||
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_softmax(outs):
|
||||
"""Schedule for softmax
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of reduce in the format
|
||||
of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
||||
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_dense(outs):
|
||||
"""Schedule for dense
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of reduce in the format
|
||||
of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
||||
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_pool(outs):
|
||||
"""Schedule for pool
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of reduce in the format
|
||||
of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
||||
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_global_pool(outs):
|
||||
"""Schedule for global pool
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of reduce in the format
|
||||
of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
|
@ -1,9 +1,8 @@
|
|||
# pylint: disable=invalid-name, unused-variable, too-many-locals
|
||||
# pylint: disable=invalid-name, unused-variable, too-many-locals, unused-argument
|
||||
"""Conv2D operators"""
|
||||
from __future__ import absolute_import as _abs
|
||||
from collections import namedtuple
|
||||
import tvm
|
||||
from tvm import target as _target
|
||||
from .pad import pad
|
||||
from .util import get_pad_tuple
|
||||
from ..util import simplify
|
||||
|
@ -51,9 +50,7 @@ _WORKLOADS = [
|
|||
# platform specific schedule
|
||||
_CONV_SCHEDULE = {}
|
||||
|
||||
# platform specific declaration
|
||||
_CONV_DECLARATION = {}
|
||||
|
||||
@tvm.target.generic_func
|
||||
def conv2d(data, kernel, stride, padding, layout='NCHW'):
|
||||
"""Conv2D operator.
|
||||
|
||||
|
@ -80,10 +77,6 @@ def conv2d(data, kernel, stride, padding, layout='NCHW'):
|
|||
4-D with shape [batch, out_channel, out_height, out_width]
|
||||
"""
|
||||
# search platform specific declaration first
|
||||
target = _target.current_target()
|
||||
if target in _CONV_DECLARATION:
|
||||
return _CONV_DECLARATION[target](data, kernel, stride, padding, layout)
|
||||
|
||||
# default declaration
|
||||
if layout == 'NCHW':
|
||||
return conv2d_nchw(data, kernel, stride, padding)
|
||||
|
@ -105,15 +98,15 @@ def _get_workload(data, kernel, stride, padding):
|
|||
return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
|
||||
|
||||
|
||||
def _get_schedule(wkl, target=None):
|
||||
@tvm.target.generic_func
|
||||
def _get_schedule(wkl):
|
||||
# pylint: disable=unreachable
|
||||
""" Get the platform specific schedule. """
|
||||
if target is None:
|
||||
target = _target.current_target()
|
||||
else:
|
||||
target = _target.Target(target)
|
||||
assert target in _CONV_SCHEDULE, "no schedule for such target: {}".format(target)
|
||||
return _CONV_SCHEDULE[target](wkl)
|
||||
|
||||
target = tvm.target.current_target()
|
||||
raise RuntimeError(
|
||||
"No schedule for current target:{}".format(target))
|
||||
# This return has no use, merely to supress pylint warning
|
||||
return wkl
|
||||
|
||||
def _spatial_pack(data, kernel, stride, padding):
|
||||
""" Compute convolution with pack on spatial axes. """
|
||||
|
|
|
@ -4,11 +4,12 @@ from __future__ import absolute_import as _abs
|
|||
import tvm
|
||||
from tvm import target as _target
|
||||
from .. import tag
|
||||
from ..nn.conv2d import conv2d, _get_schedule
|
||||
from ..nn.conv2d import SpatialPack, Im2ColPack
|
||||
from ..nn.conv2d import _CONV_DECLARATION, _CONV_SCHEDULE
|
||||
from ..nn.conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC
|
||||
from ..nn.conv2d import _get_workload, _get_schedule
|
||||
from ..nn.conv2d import _get_workload
|
||||
from ..nn.util import infer_pad, infer_stride
|
||||
from .. import generic
|
||||
|
||||
_SCHEDULES = [
|
||||
SpatialPack(1, 8, 4, 1, 4, True),
|
||||
|
@ -36,6 +37,7 @@ _SCHEDULES = [
|
|||
Im2ColPack(7, 4, 1, 4, True),
|
||||
]
|
||||
|
||||
@_get_schedule.register("rasp")
|
||||
def _schedule_conv2d(wkl):
|
||||
if wkl not in _WORKLOADS:
|
||||
raise ValueError("no schedule for such workload: {}".format(wkl))
|
||||
|
@ -43,8 +45,8 @@ def _schedule_conv2d(wkl):
|
|||
sch = _SCHEDULES[idx]
|
||||
return sch
|
||||
|
||||
_CONV_SCHEDULE[_target.rasp()] = _schedule_conv2d
|
||||
|
||||
@conv2d.register("rasp")
|
||||
def _declaration_conv2d(data, kernel, stride, padding, layout):
|
||||
assert layout == 'NCHW', "only support NCHW convolution on rasp"
|
||||
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
|
||||
|
@ -52,7 +54,6 @@ def _declaration_conv2d(data, kernel, stride, padding, layout):
|
|||
sch = _get_schedule(wkl)
|
||||
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding)
|
||||
|
||||
_CONV_DECLARATION[_target.rasp()] = _declaration_conv2d
|
||||
|
||||
def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
|
||||
kernel, kernel_vec,
|
||||
|
@ -64,7 +65,9 @@ def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
|
|||
else:
|
||||
stride = infer_stride(data_pad, kernel, output)
|
||||
wkl = _get_workload(data, kernel, stride, padding)
|
||||
sch = _get_schedule(wkl, 'rasp')
|
||||
|
||||
with tvm.target.rasp():
|
||||
sch = _get_schedule(wkl)
|
||||
|
||||
H, W = wkl.height, wkl.width
|
||||
CI, CO = wkl.in_filter, wkl.out_filter
|
||||
|
@ -170,7 +173,9 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
|
|||
else:
|
||||
stride = infer_stride(data_pad, kernel, output)
|
||||
wkl = _get_workload(data, kernel, stride, padding)
|
||||
sch = _get_schedule(wkl, 'rasp')
|
||||
|
||||
with _target.rasp():
|
||||
sch = _get_schedule(wkl)
|
||||
|
||||
H, W = wkl.height, wkl.width
|
||||
CI = wkl.in_filter
|
||||
|
@ -275,6 +280,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
|
|||
|
||||
return s
|
||||
|
||||
@generic.schedule_conv2d_nchw.register(["cpu", "rasp"])
|
||||
def schedule_conv2d(outs):
|
||||
"""Create schedule for tensors"""
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
|
|
|
@ -5,7 +5,7 @@ from collections import namedtuple
|
|||
import tvm
|
||||
from .. import tag
|
||||
from ..nn.util import infer_pad, infer_stride, get_pad_tuple
|
||||
|
||||
from .. import generic
|
||||
|
||||
_Workload = namedtuple('Workload',
|
||||
['height', 'width', 'channel', 'multiplier',
|
||||
|
@ -145,7 +145,7 @@ def _schedule(s, data, data_pad, kernel, output, last):
|
|||
return s
|
||||
|
||||
|
||||
|
||||
@generic.schedule_depthwise_conv2d_nchw.register(["cpu", "rasp"])
|
||||
def schedule_depthwise_conv2d(outs):
|
||||
"""Schedule for depthwise_conv2d nchw forward.
|
||||
|
||||
|
|
|
@ -8,16 +8,16 @@ def verify_broadcast_to_ele(in_shape, out_shape):
|
|||
# Build the logic and compile the function
|
||||
A = tvm.placeholder(shape=in_shape, name="A")
|
||||
B = topi.broadcast_to(A, out_shape)
|
||||
s = topi.cuda.schedule_broadcast(B)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_broadcast(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="broadcast_to")
|
||||
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
|
||||
out_npy = np.broadcast_to(data_npy, out_shape)
|
||||
|
||||
data_nd = tvm.nd.array(data_npy, ctx)
|
||||
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
|
||||
for _ in range(1):
|
||||
|
@ -48,11 +48,12 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
|
|||
C = topi.broadcast_minimum(A, B)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
s = topi.cuda.schedule_broadcast(C)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_broadcast(C)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
|
||||
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
|
||||
|
|
|
@ -14,8 +14,8 @@ def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, paddin
|
|||
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
|
||||
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
|
||||
B = topi.nn.conv2d(A, W, stride, padding)
|
||||
s = topi.generic.schedule_conv2d_nchw([B])
|
||||
|
||||
s = topi.rasp.schedule_conv2d([B])
|
||||
a_shape = get_const_tuple(A.shape)
|
||||
w_shape = get_const_tuple(W.shape)
|
||||
dtype = A.dtype
|
||||
|
|
|
@ -14,8 +14,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
|
|||
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
|
||||
B = topi.nn.conv2d_nchw(A, W, stride, padding)
|
||||
C = topi.nn.relu(B)
|
||||
s1 = topi.cuda.schedule_conv2d_nchw([B])
|
||||
s2 = topi.cuda.schedule_conv2d_nchw([C])
|
||||
|
||||
a_shape = get_const_tuple(A.shape)
|
||||
w_shape = get_const_tuple(W.shape)
|
||||
|
@ -35,6 +33,9 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
|
|||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s1 = topi.generic.schedule_conv2d_nchw([B])
|
||||
s2 = topi.generic.schedule_conv2d_nchw([C])
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
w = tvm.nd.array(w_np, ctx)
|
||||
|
|
|
@ -12,7 +12,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
|
|||
C = tvm.placeholder((out_dim,), name='C')
|
||||
D = topi.nn.dense(A, B, C if use_bias else None)
|
||||
D = topi.nn.relu(D)
|
||||
s = topi.cuda.schedule_dense(D)
|
||||
dtype = A.dtype
|
||||
|
||||
# use memoize to pickle the test data for next time use
|
||||
|
@ -33,6 +32,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
|
|||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_dense(D)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(b_np, ctx)
|
||||
|
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
from scipy import signal
|
||||
from topi.util import get_const_tuple
|
||||
from tvm.contrib.pickle_memoize import memoize
|
||||
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
|
||||
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc
|
||||
|
||||
|
||||
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
|
||||
|
@ -21,15 +21,18 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
|
|||
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=[stride_h, stride_w], padding=padding)
|
||||
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
|
||||
Relu = topi.nn.relu(ScaleShift)
|
||||
# schedule
|
||||
s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
|
||||
s2 = schedule_depthwise_conv2d_nchw(ScaleShift)
|
||||
s3 = schedule_depthwise_conv2d_nchw(Relu)
|
||||
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
# schedule
|
||||
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
|
||||
s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
|
||||
s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
|
||||
|
||||
ctx = tvm.context(device, 0)
|
||||
# build the kernels
|
||||
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
|
||||
|
@ -88,7 +91,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
|
|||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
|
||||
|
||||
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
|
||||
in_width = in_height
|
||||
filter_channel = in_channel
|
||||
|
|
|
@ -12,7 +12,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
|
|||
A = tvm.placeholder((n, ic, ih, iw), name='A')
|
||||
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type)
|
||||
B = topi.nn.relu(B)
|
||||
s = topi.cuda.schedule_pool(B)
|
||||
dtype = A.dtype
|
||||
|
||||
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
|
||||
|
@ -36,6 +35,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
|
|||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_pool(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
|
||||
|
@ -57,7 +58,6 @@ def verify_global_pool(n, c, h, w, pool_type):
|
|||
A = tvm.placeholder((n, c, h, w), name='A')
|
||||
B = topi.nn.global_pool(A, pool_type=pool_type)
|
||||
B = topi.nn.relu(B)
|
||||
s = topi.cuda.schedule_global_pool(B)
|
||||
|
||||
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
|
||||
if pool_type == 'avg':
|
||||
|
@ -70,6 +70,8 @@ def verify_global_pool(n, c, h, w, pool_type):
|
|||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_global_pool(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
|
|
|
@ -45,11 +45,13 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
|
|||
out_dtype = "int32"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
s = topi.cuda.schedule_reduce(B)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_reduce(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="sum")
|
||||
# Test
|
||||
|
|
|
@ -12,8 +12,6 @@ def verify_softmax(m, n):
|
|||
s = tvm.create_schedule([B.op])
|
||||
tvm.lower(s, [A, B], simple_mode=True)
|
||||
|
||||
s = topi.cuda.schedule_softmax(B)
|
||||
|
||||
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
|
||||
b_np = topi.testing.softmax_python(a_np)
|
||||
|
||||
|
@ -21,6 +19,8 @@ def verify_softmax(m, n):
|
|||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_softmax(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
|
@ -43,7 +43,6 @@ def verify_log_softmax(m, n):
|
|||
s = tvm.create_schedule([B.op])
|
||||
tvm.lower(s, [A, B], simple_mode=True)
|
||||
|
||||
s = topi.cuda.schedule_softmax(B)
|
||||
|
||||
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
|
||||
b_np = topi.testing.log_softmax_python(a_np)
|
||||
|
@ -52,6 +51,8 @@ def verify_log_softmax(m, n):
|
|||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_softmax(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
|
|
|
@ -6,11 +6,12 @@ import topi
|
|||
def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
|
||||
A = tvm.placeholder(shape=in_shape, name="A")
|
||||
B = topi.expand_dims(A, axis, num_newaxis)
|
||||
s = topi.cuda.schedule_broadcast(B)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_broadcast(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="expand_dims")
|
||||
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
|
||||
|
@ -23,17 +24,18 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
|
|||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("rocm")
|
||||
|
||||
|
||||
def verify_tranpose(in_shape, axes):
|
||||
A = tvm.placeholder(shape=in_shape, name="A")
|
||||
B = topi.transpose(A, axes)
|
||||
s = topi.cuda.schedule_injective(B)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="tranpose")
|
||||
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
|
||||
|
@ -46,16 +48,17 @@ def verify_tranpose(in_shape, axes):
|
|||
check_device("cuda")
|
||||
check_device("opencl")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("rocm")
|
||||
|
||||
def verify_reshape(src_shape, dst_shape):
|
||||
A = tvm.placeholder(shape=src_shape, name="A")
|
||||
B = topi.reshape(A, dst_shape)
|
||||
s = topi.cuda.schedule_injective(B)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="reshape")
|
||||
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
|
||||
|
@ -68,16 +71,17 @@ def verify_reshape(src_shape, dst_shape):
|
|||
check_device("cuda")
|
||||
check_device("opencl")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("rocm")
|
||||
|
||||
def verify_squeeze(src_shape, axis):
|
||||
A = tvm.placeholder(shape=src_shape, name="A")
|
||||
B = topi.squeeze(A, axis=axis)
|
||||
s = topi.cuda.schedule_injective(B)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="squeeze")
|
||||
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
|
||||
|
@ -94,18 +98,19 @@ def verify_squeeze(src_shape, axis):
|
|||
check_device("cuda")
|
||||
check_device("opencl")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("rocm")
|
||||
|
||||
def verify_concatenate(shapes, axis):
|
||||
tensor_l = []
|
||||
for i, shape in enumerate(shapes):
|
||||
tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
|
||||
out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
|
||||
s = topi.cuda.schedule_injective(out_tensor)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(out_tensor)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
|
||||
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
|
||||
|
@ -118,16 +123,17 @@ def verify_concatenate(shapes, axis):
|
|||
check_device("cuda")
|
||||
check_device("opencl")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("rocm")
|
||||
|
||||
def verify_split(src_shape, indices_or_sections, axis):
|
||||
A = tvm.placeholder(shape=src_shape, name="A")
|
||||
tensor_l = topi.split(A, indices_or_sections, axis=axis)
|
||||
s = topi.cuda.schedule_injective(tensor_l)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(tensor_l)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A] + tensor_l, device, name="split")
|
||||
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
|
||||
|
@ -142,7 +148,7 @@ def verify_split(src_shape, indices_or_sections, axis):
|
|||
check_device("opencl")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
|
||||
|
||||
def test_expand_dims():
|
||||
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
|
||||
verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
|
||||
|
@ -190,4 +196,3 @@ if __name__ == "__main__":
|
|||
test_squeeze()
|
||||
test_concatenate()
|
||||
test_split()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче