[CODEGEN] Enable cross compile of AMDGPU without rocm, update rpc (#1154)
This commit is contained in:
Родитель
11c7b6cfaf
Коммит
51c40b4f8b
|
@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""):
|
|||
def _connect(key):
|
||||
conn = yield websocket.websocket_connect(url)
|
||||
on_message = create_on_message(conn)
|
||||
temp = _server_env()
|
||||
temp = _server_env(None)
|
||||
# Start connecton
|
||||
conn.write_message(struct.pack('@i', base.RPC_MAGIC), binary=True)
|
||||
key = "server:" + key
|
||||
|
|
|
@ -11,6 +11,7 @@ Server is TCP based with the following protocol:
|
|||
from __future__ import absolute_import
|
||||
|
||||
import os
|
||||
import ctypes
|
||||
import socket
|
||||
import select
|
||||
import struct
|
||||
|
@ -21,12 +22,13 @@ import time
|
|||
|
||||
from ..._ffi.function import register_func
|
||||
from ..._ffi.base import py_str
|
||||
from ..._ffi.libinfo import find_lib_path
|
||||
from ...module import load as _load_module
|
||||
from .. import util
|
||||
from . import base
|
||||
from . base import TrackerCode
|
||||
|
||||
def _server_env():
|
||||
def _server_env(load_library):
|
||||
"""Server environment function return temp dir"""
|
||||
temp = util.tempdir()
|
||||
# pylint: disable=unused-variable
|
||||
|
@ -41,13 +43,21 @@ def _server_env():
|
|||
m = _load_module(path)
|
||||
logging.info("load_module %s", path)
|
||||
return m
|
||||
|
||||
libs = []
|
||||
load_library = load_library.split(":") if load_library else []
|
||||
for file_name in load_library:
|
||||
file_name = find_lib_path(file_name)[0]
|
||||
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
|
||||
logging.info("Load additional library %s", file_name)
|
||||
temp.libs = libs
|
||||
return temp
|
||||
|
||||
|
||||
def _serve_loop(sock, addr):
|
||||
def _serve_loop(sock, addr, load_library):
|
||||
"""Server loop"""
|
||||
sockfd = sock.fileno()
|
||||
temp = _server_env()
|
||||
temp = _server_env(load_library)
|
||||
base._ServerLoop(sockfd)
|
||||
temp.remove()
|
||||
logging.info("Finish serving %s", addr)
|
||||
|
@ -62,7 +72,7 @@ def _parse_server_opt(opts):
|
|||
return ret
|
||||
|
||||
|
||||
def _listen_loop(sock, port, rpc_key, tracker_addr):
|
||||
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
|
||||
"""Lisenting loop of the server master."""
|
||||
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
|
||||
"""Accept connection from the other places.
|
||||
|
@ -162,7 +172,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
|
|||
|
||||
# step 3: serving
|
||||
logging.info("RPCServer: connection from %s", addr)
|
||||
server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
|
||||
server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr, load_library))
|
||||
server_proc.deamon = True
|
||||
server_proc.start()
|
||||
# close from our side.
|
||||
|
@ -174,7 +184,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
|
|||
server_proc.terminate()
|
||||
|
||||
|
||||
def _connect_proxy_loop(addr, key):
|
||||
def _connect_proxy_loop(addr, key, load_library):
|
||||
key = "server:" + key
|
||||
retry_count = 0
|
||||
max_retry = 5
|
||||
|
@ -198,7 +208,7 @@ def _connect_proxy_loop(addr, key):
|
|||
opts = _parse_server_opt(remote_key.split()[1:])
|
||||
logging.info("RPCProxy connected to %s", str(addr))
|
||||
process = multiprocessing.Process(
|
||||
target=_serve_loop, args=(sock, addr))
|
||||
target=_serve_loop, args=(sock, addr, load_library))
|
||||
process.deamon = True
|
||||
process.start()
|
||||
sock.close()
|
||||
|
@ -256,6 +266,9 @@ class Server(object):
|
|||
|
||||
key : str, optional
|
||||
The key used to identify the server in Proxy connection.
|
||||
|
||||
load_library : str, optional
|
||||
List of additional libraries to be loaded during execution.
|
||||
"""
|
||||
def __init__(self,
|
||||
host,
|
||||
|
@ -264,7 +277,8 @@ class Server(object):
|
|||
is_proxy=False,
|
||||
use_popen=False,
|
||||
tracker_addr=None,
|
||||
key=""):
|
||||
key="",
|
||||
load_library=None):
|
||||
try:
|
||||
if base._ServerLoop is None:
|
||||
raise RuntimeError("Please compile with USE_RPC=1")
|
||||
|
@ -283,6 +297,8 @@ class Server(object):
|
|||
assert key
|
||||
cmd += ["--tracker=%s:%d" % tracker_addr,
|
||||
"--key=%s" % key]
|
||||
if load_library:
|
||||
cmd += ["--load-libary", load_library]
|
||||
self.proc = multiprocessing.Process(
|
||||
target=subprocess.check_call, args=(cmd,))
|
||||
self.proc.deamon = True
|
||||
|
@ -308,12 +324,12 @@ class Server(object):
|
|||
self.sock = sock
|
||||
self.proc = multiprocessing.Process(
|
||||
target=_listen_loop, args=(
|
||||
self.sock, self.port, key, tracker_addr))
|
||||
self.sock, self.port, key, tracker_addr, load_library))
|
||||
self.proc.deamon = True
|
||||
self.proc.start()
|
||||
else:
|
||||
self.proc = multiprocessing.Process(
|
||||
target=_connect_proxy_loop, args=((host, port), key))
|
||||
target=_connect_proxy_loop, args=((host, port), key, load_library))
|
||||
self.proc.deamon = True
|
||||
self.proc.start()
|
||||
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
"""Start an RPC server"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import ctypes
|
||||
from ..contrib import rpc
|
||||
from .._ffi.libinfo import find_lib_path
|
||||
|
||||
def main():
|
||||
"""Main funciton"""
|
||||
|
@ -19,26 +15,12 @@ def main():
|
|||
help='The end search port of the PRC')
|
||||
parser.add_argument('--key', type=str, default="",
|
||||
help="RPC key used to identify the connection type.")
|
||||
parser.add_argument('--with-executor', type=bool, default=False,
|
||||
help="Whether to load executor runtime")
|
||||
parser.add_argument('--load-library', type=str, default="",
|
||||
help="Additional library to load")
|
||||
parser.add_argument('--tracker', type=str, default="",
|
||||
help="Report to RPC tracker")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
load_library = [lib for lib in args.load_library.split(":") if len(lib) != 0]
|
||||
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
|
||||
apps_path = os.path.join(curr_path, "../../../apps/graph_executor/lib/")
|
||||
libs = []
|
||||
if args.with_executor:
|
||||
load_library += ["libtvm_graph_exec.so"]
|
||||
for file_name in load_library:
|
||||
file_name = find_lib_path(file_name, apps_path)[0]
|
||||
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
|
||||
logging.info("Load additional library %s", file_name)
|
||||
|
||||
if args.tracker:
|
||||
url, port = args.tracker.split(":")
|
||||
port = int(port)
|
||||
|
@ -53,8 +35,8 @@ def main():
|
|||
args.port,
|
||||
args.port_end,
|
||||
key=args.key,
|
||||
tracker_addr=tracker_addr)
|
||||
server.libs += libs
|
||||
tracker_addr=tracker_addr,
|
||||
load_library=args.load_library)
|
||||
server.proc.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include <tvm/codegen.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include "../runtime/meta_data.h"
|
||||
|
||||
|
@ -111,17 +112,19 @@ class CodeGenSourceBase {
|
|||
runtime::Module SourceModuleCreate(std::string code, std::string fmt);
|
||||
|
||||
/*!
|
||||
* \brief Create a source module for viewing and limited saving
|
||||
* \param code The code to be viewed.
|
||||
* \brief Create a source module for viewing and limited saving for device.
|
||||
* \param data The code data to be viewed.
|
||||
* \param fmt The code. format.
|
||||
* \param fmap The map function information map of each function.
|
||||
* \param type_key The type_key of the runtime module of this source code
|
||||
* \param fget_source a closure to replace default get source behavior.
|
||||
*/
|
||||
runtime::Module DeviceSourceModuleCreate(
|
||||
std::string code,
|
||||
std::string data,
|
||||
std::string fmt,
|
||||
std::unordered_map<std::string, runtime::FunctionInfo> fmap,
|
||||
std::string type_key);
|
||||
std::string type_key,
|
||||
std::function<std::string(const std::string&)> fget_source = nullptr);
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
|
||||
|
|
|
@ -4,15 +4,18 @@
|
|||
* \brief AMDGPU code generator.
|
||||
*/
|
||||
#ifdef TVM_LLVM_VERSION
|
||||
#if TVM_ROCM_RUNTIME
|
||||
|
||||
#include <tvm/runtime/device_api.h>
|
||||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include "./codegen_llvm.h"
|
||||
#include "../build_common.h"
|
||||
#include "../codegen_source_base.h"
|
||||
#include "../../pass/ir_util.h"
|
||||
|
||||
#if TVM_ROCM_RUNTIME
|
||||
#include "../../runtime/rocm/rocm_module.h"
|
||||
#endif // TVM_ROCM_RUNTIME
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
@ -131,19 +134,27 @@ class CodeGenAMDGPU : public CodeGenLLVM {
|
|||
}
|
||||
};
|
||||
|
||||
inline int DetectROCMComputeVersion() {
|
||||
inline int DetectROCMComputeVersion(const std::string& target) {
|
||||
size_t pos = target.find("=gfx");
|
||||
if (pos != std::string::npos) {
|
||||
int value;
|
||||
std::stringstream is(target.substr(pos + 4));
|
||||
if (is >> value) return value;
|
||||
}
|
||||
TVMContext tvm_ctx;
|
||||
tvm_ctx.device_type = kDLROCM;
|
||||
tvm_ctx.device_id = 0;
|
||||
TVMRetValue val;
|
||||
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
|
||||
tvm_ctx, tvm::runtime::kExist, &val);
|
||||
if (val.operator int() == 1) {
|
||||
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
|
||||
return val.operator int();
|
||||
} else {
|
||||
return 803;
|
||||
tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true);
|
||||
if (api != nullptr) {
|
||||
TVMRetValue val;
|
||||
api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val);
|
||||
if (val.operator int() == 1) {
|
||||
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
|
||||
return val.operator int();
|
||||
}
|
||||
}
|
||||
LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx803";
|
||||
return 803;
|
||||
}
|
||||
|
||||
runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
|
||||
|
@ -151,7 +162,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
|
|||
target.substr(0, 4) == "rocm");
|
||||
std::ostringstream config;
|
||||
config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx"
|
||||
<< DetectROCMComputeVersion()
|
||||
<< DetectROCMComputeVersion(target)
|
||||
<< target.substr(4, target.length() - 4);
|
||||
llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
|
||||
std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
|
||||
|
@ -216,7 +227,19 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
|
|||
std::string hsaco = (*f)(arr);
|
||||
std::string ll(data_ll.begin(), data_ll.end());
|
||||
|
||||
#if TVM_ROCM_RUNTIME
|
||||
return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly);
|
||||
#else
|
||||
LOG(WARNING) << "ROCM runtime is not enabled, return a source module...";
|
||||
auto fget_source = [ll, assembly](const std::string& format) {
|
||||
if (format.length() == 0) return assembly;
|
||||
if (format == "ll" || format == "llvm") return format;
|
||||
if (format == "asm") return assembly;
|
||||
return std::string("");
|
||||
};
|
||||
return DeviceSourceModuleCreate(
|
||||
hsaco, "hsaco", ExtractFuncInfo(funcs), "hsaco", fget_source);
|
||||
#endif // TVM_ROCM_RUNTIME
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("codegen.build_rocm")
|
||||
|
@ -226,5 +249,4 @@ TVM_REGISTER_API("codegen.build_rocm")
|
|||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
#endif // TVM_ROCM_RUNTIME
|
||||
#endif // TVM_LLVM_VERSION
|
||||
|
|
|
@ -54,46 +54,71 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
|
|||
}
|
||||
|
||||
// supports limited save without cross compile
|
||||
class DeviceSourceModuleNode final : public SourceModuleNode {
|
||||
class DeviceSourceModuleNode final : public runtime::ModuleNode {
|
||||
public:
|
||||
DeviceSourceModuleNode(std::string code,
|
||||
DeviceSourceModuleNode(std::string data,
|
||||
std::string fmt,
|
||||
std::unordered_map<std::string, FunctionInfo> fmap,
|
||||
std::string type_key)
|
||||
: SourceModuleNode(code, fmt), fmap_(fmap), type_key_(type_key) {}
|
||||
std::string type_key,
|
||||
std::function<std::string(const std::string&)> fget_source)
|
||||
: data_(data),
|
||||
fmt_(fmt),
|
||||
fmap_(fmap),
|
||||
type_key_(type_key),
|
||||
fget_source_(fget_source) {}
|
||||
|
||||
PackedFunc GetFunction(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
|
||||
LOG(FATAL) << "Source module cannot execute, to get executable module"
|
||||
<< " build TVM with \'" << fmt_ << "\' runtime support";
|
||||
return PackedFunc();
|
||||
}
|
||||
|
||||
std::string GetSource(const std::string& format) final {
|
||||
if (fget_source_ != nullptr) {
|
||||
return fget_source_(format);
|
||||
} else {
|
||||
return data_;
|
||||
}
|
||||
}
|
||||
|
||||
const char* type_key() const {
|
||||
return type_key_.c_str();
|
||||
}
|
||||
|
||||
void SaveToFile(const std::string& file_name,
|
||||
const std::string& format) final {
|
||||
const std::string& format) final {
|
||||
std::string fmt = GetFileFormat(file_name, format);
|
||||
CHECK_EQ(fmt, fmt_)
|
||||
<< "Can only save to format=" << fmt_;
|
||||
std::string meta_file = GetMetaFilePath(file_name);
|
||||
SaveMetaDataToFile(meta_file, fmap_);
|
||||
SaveBinaryToFile(file_name, code_);
|
||||
SaveBinaryToFile(file_name, data_);
|
||||
}
|
||||
|
||||
void SaveToBinary(dmlc::Stream* stream) final {
|
||||
stream->Write(fmt_);
|
||||
stream->Write(fmap_);
|
||||
stream->Write(code_);
|
||||
stream->Write(data_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string data_;
|
||||
std::string fmt_;
|
||||
std::unordered_map<std::string, FunctionInfo> fmap_;
|
||||
std::string type_key_;
|
||||
std::function<std::string(const std::string&)> fget_source_;
|
||||
};
|
||||
|
||||
runtime::Module DeviceSourceModuleCreate(
|
||||
std::string code,
|
||||
std::string data,
|
||||
std::string fmt,
|
||||
std::unordered_map<std::string, FunctionInfo> fmap,
|
||||
std::string type_key) {
|
||||
std::string type_key,
|
||||
std::function<std::string(const std::string&)> fget_source) {
|
||||
std::shared_ptr<DeviceSourceModuleNode> n =
|
||||
std::make_shared<DeviceSourceModuleNode>(code, fmt, fmap, type_key);
|
||||
std::make_shared<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
|
||||
return runtime::Module(n);
|
||||
}
|
||||
|
||||
|
|
|
@ -121,9 +121,9 @@ bool RuntimeEnabled(const std::string& target) {
|
|||
} else if (target == "vpi" || target == "verilog") {
|
||||
f_name = "device_api.vpi";
|
||||
} else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
|
||||
f_name = "codegen.build_nvptx";
|
||||
f_name = "device_api.gpu";
|
||||
} else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
|
||||
f_name = "codegen.build_rocm";
|
||||
f_name = "device_api.rocm";
|
||||
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
|
||||
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
|
||||
if (pf == nullptr) return false;
|
||||
|
|
|
@ -41,13 +41,13 @@ def verify_l2norm(n, c, h, w, eps, axis=None):
|
|||
b_np = l2norm_instance_python(a_np, eps, axis)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_l2norm(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
|
||||
f = tvm.build(s, [A, B], device)
|
||||
|
|
|
@ -70,13 +70,13 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
|
|||
b_np = lrn_python(a_np, size, axis, bias, alpha, beta)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_lrn(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
|
||||
f = tvm.build(s, [A, B], device)
|
||||
|
|
|
@ -29,7 +29,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
|
|||
a_np, b_np, c_np, d_np = get_ref_data()
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
@ -40,7 +41,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
|
|||
s = topi.cpp.rocm.schedule_dense(target, [D])
|
||||
else:
|
||||
s = topi.cpp.cuda.schedule_dense(target, [D])
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(b_np, ctx)
|
||||
c = tvm.nd.array(c_np, ctx)
|
||||
|
|
|
@ -48,7 +48,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
|
|||
b_np = np.maximum(b_np, 0.0)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
@ -57,7 +58,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
|
|||
s = topi.cpp.generic.default_schedule(target, [B], False)
|
||||
else:
|
||||
s = topi.cpp.cuda.schedule_pool(target, [B])
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
|
||||
f = tvm.build(s, [A, B], device)
|
||||
|
|
|
@ -46,7 +46,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
|
|||
raise NotImplementedError
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
@ -56,7 +57,6 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
|
|||
else:
|
||||
s = topi.cpp.cuda.schedule_reduce(target, [B])
|
||||
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="sum")
|
||||
# Test
|
||||
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
|
||||
|
|
|
@ -7,7 +7,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
|
|||
A = tvm.placeholder(shape=in_shape, name="A")
|
||||
B = topi.cpp.expand_dims(A, axis, num_newaxis)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
@ -16,7 +17,6 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
|
|||
s = topi.cpp.generic.schedule_injective(target, [B])
|
||||
else:
|
||||
s = topi.cpp.cuda.schedule_injective(target, [B])
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="expand_dims")
|
||||
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
|
||||
out_npy = data_npy.reshape(out_shape)
|
||||
|
@ -33,7 +33,8 @@ def verify_tranpose(in_shape, axes):
|
|||
A = tvm.placeholder(shape=in_shape, name="A")
|
||||
B = topi.cpp.transpose(A, axes)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
@ -59,7 +60,8 @@ def verify_reshape(src_shape, dst_shape):
|
|||
A = tvm.placeholder(shape=src_shape, name="A")
|
||||
B = topi.cpp.reshape(A, dst_shape)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
@ -68,7 +70,6 @@ def verify_reshape(src_shape, dst_shape):
|
|||
s = topi.cpp.generic.schedule_injective(target, [B])
|
||||
else:
|
||||
s = topi.cpp.cuda.schedule_injective(target, [B])
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="reshape")
|
||||
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
|
||||
out_npy = np.reshape(data_npy, newshape=dst_shape)
|
||||
|
@ -85,7 +86,8 @@ def verify_squeeze(src_shape, axis):
|
|||
A = tvm.placeholder(shape=src_shape, name="A")
|
||||
B = topi.cpp.squeeze(A, axis)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
@ -94,7 +96,6 @@ def verify_squeeze(src_shape, axis):
|
|||
s = topi.cpp.generic.schedule_injective(target, [B])
|
||||
else:
|
||||
s = topi.cpp.cuda.schedule_injective(target, [B])
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="squeeze")
|
||||
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
|
||||
out_npy = np.squeeze(data_npy, axis=axis)
|
||||
|
@ -116,7 +117,8 @@ def verify_concatenate(shapes, axis):
|
|||
tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
|
||||
out_tensor = topi.cpp.concatenate(tensor_l, axis)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
@ -125,7 +127,6 @@ def verify_concatenate(shapes, axis):
|
|||
s = topi.cpp.generic.schedule_injective(target, [out_tensor])
|
||||
else:
|
||||
s = topi.cpp.cuda.schedule_injective(target, [out_tensor])
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
|
||||
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
|
||||
out_npy = np.concatenate(data_npys, axis=axis)
|
||||
|
@ -143,7 +144,8 @@ def verify_split(src_shape, indices_or_sections, axis):
|
|||
tensor_l = topi.cpp.split(A, indices_or_sections, axis)
|
||||
tensor_l = list(tensor_l)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
|
Загрузка…
Ссылка в новой задаче