[CONTRIB/BLAS] Add CBLAS Example to contrib (#120)
* [CONTRIB/BLAS] Add CBLAS Example to contrib * Update makefile
This commit is contained in:
Родитель
8a5b6c21ad
Коммит
f364d563c2
|
@ -7,7 +7,7 @@ endif()
|
|||
|
||||
include(cmake/Util.cmake)
|
||||
tvm_option(USE_CUDA "Build with CUDA" ON)
|
||||
tvm_option(USE_OPENCL "Build with OpenCL" ON)
|
||||
tvm_option(USE_OPENCL "Build with OpenCL" OFF)
|
||||
tvm_option(USE_LLVM "Build with LLVM" OFF)
|
||||
tvm_option(USE_RTTI "Build with RTTI" OFF)
|
||||
tvm_option(USE_MSVC_MT "Build with MT" OFF)
|
||||
|
|
56
Makefile
56
Makefile
|
@ -10,47 +10,52 @@ endif
|
|||
|
||||
include $(config)
|
||||
|
||||
# specify tensor path
|
||||
.PHONY: clean all test doc pylint cpplint lint verilog cython cython2 cython3
|
||||
|
||||
all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a
|
||||
|
||||
LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a
|
||||
# The source code dependencies
|
||||
LIB_HALIDEIR = HalideIR/lib/libHalideIR.a
|
||||
|
||||
SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
|
||||
CC_SRC = $(filter-out src/contrib/%.cc src/runtime/%.cc,\
|
||||
$(wildcard src/*/*.cc src/*/*/*.cc))
|
||||
METAL_SRC = $(wildcard src/runtime/metal/*.mm)
|
||||
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
|
||||
METAL_OBJ = $(patsubst src/%.mm, build/%.o, $(METAL_SRC))
|
||||
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
|
||||
|
||||
RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc)
|
||||
RUNTIME_DEP = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))
|
||||
|
||||
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
|
||||
# Objectives
|
||||
METAL_OBJ = $(patsubst src/%.mm, build/%.o, $(METAL_SRC))
|
||||
CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC))
|
||||
RUNTIME_OBJ = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))
|
||||
CONTRIB_OBJ =
|
||||
|
||||
export LDFLAGS = -pthread -lm
|
||||
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
|
||||
-Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
|
||||
export OBJCFLAGS= -fno-objc-arc
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
# Deps
|
||||
ALL_DEP = $(CC_OBJ) $(CONTRIB_OBJ) $(LIB_HALIDEIR)
|
||||
RUNTIME_DEP = $(RUNTIME_OBJ)
|
||||
|
||||
# The flags
|
||||
LDFLAGS = -pthread -lm
|
||||
CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
|
||||
-Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
|
||||
FRAMEWORKS =
|
||||
OBJCFLAGS = -fno-objc-arc
|
||||
|
||||
# Dependency specific rules
|
||||
ifdef CUDA_PATH
|
||||
NVCC=$(CUDA_PATH)/bin/nvcc
|
||||
CFLAGS += -I$(CUDA_PATH)/include
|
||||
LDFLAGS += -L$(CUDA_PATH)/lib64
|
||||
endif
|
||||
|
||||
ifeq ($(ENABLE_CUDA), 1)
|
||||
ifeq ($(USE_CUDA), 1)
|
||||
CFLAGS += -DTVM_CUDA_RUNTIME=1
|
||||
LDFLAGS += -lcuda -lcudart -lnvrtc
|
||||
else
|
||||
CFLAGS += -DTVM_CUDA_RUNTIME=0
|
||||
endif
|
||||
|
||||
FRAMEWORKS=
|
||||
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
ifeq ($(ENABLE_OPENCL), 1)
|
||||
ifeq ($(USE_OPENCL), 1)
|
||||
CFLAGS += -DTVM_OPENCL_RUNTIME=1
|
||||
ifeq ($(UNAME_S), Darwin)
|
||||
FRAMEWORKS += -framework OpenCL
|
||||
|
@ -61,10 +66,9 @@ else
|
|||
CFLAGS += -DTVM_OPENCL_RUNTIME=0
|
||||
endif
|
||||
|
||||
ifeq ($(ENABLE_METAL), 1)
|
||||
ifeq ($(USE_METAL), 1)
|
||||
CFLAGS += -DTVM_METAL_RUNTIME=1
|
||||
LDFLAGS += -lObjc
|
||||
ALL_DEP += $(METAL_OBJ)
|
||||
RUNTIME_DEP += $(METAL_OBJ)
|
||||
FRAMEWORKS += -framework Metal -framework Foundation
|
||||
else
|
||||
|
@ -74,13 +78,15 @@ endif
|
|||
# llvm configuration
|
||||
LLVM_CONFIG=llvm-config
|
||||
|
||||
ifeq ($(ENABLE_LLVM), 1)
|
||||
ifeq ($(USE_LLVM), 1)
|
||||
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
|
||||
LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags))
|
||||
LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs)
|
||||
CFLAGS += $(LLVM_INCLUDE) -DTVM_LLVM_VERSION=$(LLVM_VERSION)
|
||||
endif
|
||||
|
||||
include make/contrib/cblas.mk
|
||||
|
||||
ifdef ADD_CFLAGS
|
||||
CFLAGS += $(ADD_CFLAGS)
|
||||
endif
|
||||
|
@ -106,7 +112,7 @@ build/%.o: src/%.mm
|
|||
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
|
||||
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
|
||||
|
||||
lib/libtvm.so: $(ALL_DEP)
|
||||
lib/libtvm.so: $(ALL_DEP) $(RUNTIME_DEP)
|
||||
@mkdir -p $(@D)
|
||||
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
|
||||
|
||||
|
@ -114,11 +120,11 @@ lib/libtvm_runtime.so: $(RUNTIME_DEP)
|
|||
@mkdir -p $(@D)
|
||||
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
|
||||
|
||||
lib/libtvm.a: $(ALL_DEP)
|
||||
lib/libtvm.a: $(ALL_DEP) $(RUNTIME_DEP)
|
||||
@mkdir -p $(@D)
|
||||
ar crv $@ $(filter %.o, $?)
|
||||
|
||||
$(LIB_HALIDE_IR): LIBHALIDEIR
|
||||
$(LIB_HALIDEIR): LIBHALIDEIR
|
||||
|
||||
LIBHALIDEIR:
|
||||
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)
|
||||
|
|
13
NEWS.md
13
NEWS.md
|
@ -4,8 +4,11 @@ TVM Change Log
|
|||
This file records the changes in TVM library in reverse chronological order.
|
||||
|
||||
## Initial version (0.1rc)
|
||||
- CUDA/OpenCL codegen
|
||||
- LLVM codegen
|
||||
- AOT and module system
|
||||
- External function call
|
||||
- Beta verilog codegen
|
||||
- External function and contrib libraries
|
||||
- Metal backend
|
||||
- OpenCL backend
|
||||
- CUDA backend
|
||||
- LLVM backend
|
||||
- DLPack integration support
|
||||
- AOT and module system
|
||||
- Basic code structure ready.
|
|
@ -346,7 +346,7 @@ TVM_DLL int TVMFuncRegisterGlobal(
|
|||
* \brief Get a global function.
|
||||
*
|
||||
* \param name The name of the function.
|
||||
* \param out the result function pointer.
|
||||
* \param out the result function pointer, NULL if it does not exist.
|
||||
*
|
||||
* \note The function handle of global function is managed by TVM runtime,
|
||||
* So TVMFuncFree is should not be called when it get deleted.
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file util.h
|
||||
* \brief Useful runtime util.
|
||||
*/
|
||||
#ifndef TVM_RUNTIME_UTIL_H_
|
||||
#define TVM_RUNTIME_UTIL_H_
|
||||
|
||||
#include "./c_runtime_api.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
||||
/*!
|
||||
* \brief Check whether type matches the given spec.
|
||||
* \param t The type
|
||||
* \param code The type code.
|
||||
* \param bits The number of bits to be matched.
|
||||
* \param lanes The number of lanes sin the type.
|
||||
*/
|
||||
inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) {
|
||||
return t.code == code && t.bits == bits && t.lanes == lanes;
|
||||
}
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_RUNTIME_UTIL_H_
|
|
@ -16,11 +16,6 @@
|
|||
# $ make -j8
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
#---------------------
|
||||
# choice of compiler
|
||||
#--------------------
|
||||
export NVCC = nvcc
|
||||
|
||||
# whether compile with debug
|
||||
DEBUG = 0
|
||||
|
||||
|
@ -31,22 +26,27 @@ ADD_LDFLAGS =
|
|||
ADD_CFLAGS =
|
||||
|
||||
#---------------------------------------------
|
||||
# matrix computation libraries for CPU/GPU
|
||||
# Backend runtimes.
|
||||
#---------------------------------------------
|
||||
|
||||
# whether enable CUDA during compile
|
||||
ENABLE_CUDA = 1
|
||||
USE_CUDA = 1
|
||||
|
||||
# whether enable OpenCL during compile
|
||||
ENABLE_OPENCL = 0
|
||||
USE_OPENCL = 0
|
||||
|
||||
# whether enable Metal during compile
|
||||
ENABLE_METAL = 0
|
||||
USE_METAL = 0
|
||||
|
||||
# whether build with LLVM support
|
||||
# This requires llvm-config to be in your PATH
|
||||
# Requires LLVM version >= 4.0
|
||||
ENABLE_LLVM = 0
|
||||
USE_LLVM = 0
|
||||
|
||||
#---------------------------------------------
|
||||
# Contrib optional libraries.
|
||||
#---------------------------------------------
|
||||
# Whether use BLAS, choices: openblas, atlas, blas, apple
|
||||
USE_BLAS = none
|
||||
|
||||
# add the path to CUDA library to link and compile flag
|
||||
# if you have already add them to environment variable.
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
CBLAS_CONTRIB_SRC = $(wildcard src/contrib/cblas/*.cc)
|
||||
CBLAS_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(CBLAS_CONTRIB_SRC))
|
||||
|
||||
ifeq ($(USE_BLAS), openblas)
|
||||
ADD_LDFLAGS += -lopenblas
|
||||
RUNTIME_DEP += $(CBLAS_CONTRIB_OBJ)
|
||||
else ifeq ($(USE_BLAS), atlas)
|
||||
ADD_LDFLAGS += -lcblas
|
||||
RUNTIME_DEP += $(CBLAS_CONTRIB_OBJ)
|
||||
else ifeq ($(USE_BLAS), blas)
|
||||
ADD_LDFLAGS += -lblas
|
||||
RUNTIME_DEP += $(CBLAS_CONTRIB_OBJ)
|
||||
else ifeq ($(USE_BLAS), apple)
|
||||
ADD_CFLAGS += -I/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Versions/Current/Headers/
|
||||
FRAMEWORKS += -framework Accelerate
|
||||
RUNTIME_DEP += $(CBLAS_CONTRIB_OBJ)
|
||||
endif
|
|
@ -196,7 +196,7 @@ def register_func(func_name, f=None, override=False):
|
|||
return register
|
||||
|
||||
|
||||
def get_global_func(name):
|
||||
def get_global_func(name, allow_missing=False):
|
||||
"""Get a global function by name
|
||||
|
||||
Parameters
|
||||
|
@ -204,14 +204,24 @@ def get_global_func(name):
|
|||
name : str
|
||||
The name of the global function
|
||||
|
||||
allow_missing : bool
|
||||
Whether allow missing function or raise an error.
|
||||
|
||||
Returns
|
||||
-------
|
||||
func : tvm.Function
|
||||
The function to be returned.
|
||||
The function to be returned, None if function is missing.
|
||||
"""
|
||||
handle = FunctionHandle()
|
||||
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
|
||||
return Function(handle, False)
|
||||
if handle.value:
|
||||
return Function(handle, False)
|
||||
else:
|
||||
if allow_missing:
|
||||
return None
|
||||
else:
|
||||
raise ValueError("Cannot find global function %s" % name)
|
||||
|
||||
|
||||
|
||||
def list_global_func_names():
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
"""External function interface to BLAS libraroes."""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
from .. import api as _api
|
||||
from .. import intrin as _intrin
|
||||
|
||||
def matmul(lhs, rhs, transa=False, transb=False):
|
||||
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
|
||||
|
||||
This function serves as an example on how to calle external libraries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lhs : Tensor
|
||||
The left matrix operand
|
||||
rhs : Tensor
|
||||
The right matrix operand
|
||||
transa : bool
|
||||
Whether transpose lhs
|
||||
transb : bool
|
||||
Whether transpose rhs
|
||||
|
||||
Returns
|
||||
-------
|
||||
C : Tensor
|
||||
The result tensor.
|
||||
"""
|
||||
n = lhs.shape[1] if transa else lhs.shape[0]
|
||||
m = rhs.shape[0] if transb else rhs.shape[1]
|
||||
return _api.extern(
|
||||
(n, m), [lhs, rhs],
|
||||
lambda ins, outs: _intrin.call_packed(
|
||||
"tvm.contrib.cblas.matmul",
|
||||
ins[0], ins[1], outs[0], transa, transb), name="C")
|
|
@ -106,6 +106,12 @@ class Schedule(NodeBase):
|
|||
|
||||
include_inputs : boolean, optional
|
||||
Whether include input operations in the group if they are used by outputs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
group : Stage
|
||||
A virtual stage represents the group, user can use compute_at to move
|
||||
the attachment point of the group.
|
||||
"""
|
||||
if isinstance(outputs, _tensor.Tensor):
|
||||
outputs = [outputs]
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Header files in include are public APIs that share across modules.
|
||||
There can be internal header files within each module that sit in src.
|
||||
|
||||
The current code modules in src.
|
||||
## Modules
|
||||
- common Internal common utilities.
|
||||
- api API function registration
|
||||
- lang The definition of DSL related data structure
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file Use external cblas library call.
|
||||
*/
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include <tvm/runtime/util.h>
|
||||
#include <dmlc/logging.h>
|
||||
|
||||
extern "C" {
|
||||
#include <cblas.h>
|
||||
}
|
||||
|
||||
namespace tvm {
|
||||
namespace contrib {
|
||||
|
||||
using namespace runtime;
|
||||
|
||||
// matrix multiplication for row major
|
||||
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
DLTensor* A = args[0];
|
||||
DLTensor* B = args[1];
|
||||
DLTensor* C = args[2];
|
||||
bool transa = args[3];
|
||||
bool transb = args[4];
|
||||
// call gemm for simple compact code.
|
||||
CHECK_EQ(A->ndim, 2);
|
||||
CHECK_EQ(B->ndim, 2);
|
||||
CHECK_EQ(C->ndim, 2);
|
||||
CHECK(C->strides == nullptr);
|
||||
CHECK(B->strides == nullptr);
|
||||
CHECK(A->strides == nullptr);
|
||||
CHECK(TypeMatch(A->dtype, kFloat, 32));
|
||||
CHECK(TypeMatch(B->dtype, kFloat, 32));
|
||||
CHECK(TypeMatch(C->dtype, kFloat, 32));
|
||||
cblas_sgemm(CblasColMajor,
|
||||
transb ? CblasTrans : CblasNoTrans,
|
||||
transa ? CblasTrans : CblasNoTrans,
|
||||
transb ? B->shape[0] : B->shape[1],
|
||||
transa ? A->shape[1] : A->shape[0],
|
||||
transa ? B->shape[1] : B->shape[0],
|
||||
1.0f,
|
||||
static_cast<float*>(B->data), B->shape[1],
|
||||
static_cast<float*>(A->data), A->shape[1],
|
||||
0.0f,
|
||||
static_cast<float*>(C->data), C->shape[1]);
|
||||
});
|
||||
} // namespace contrib
|
||||
} // namespace tvm
|
|
@ -105,9 +105,11 @@ int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
|
|||
API_BEGIN();
|
||||
const tvm::runtime::PackedFunc* fp =
|
||||
tvm::runtime::Registry::Get(name);
|
||||
CHECK(fp != nullptr)
|
||||
<< "Cannot find global function " << name;
|
||||
*out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*)
|
||||
if (fp != nullptr) {
|
||||
*out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*)
|
||||
} else {
|
||||
*out = nullptr;
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
from tvm.contrib import cblas
|
||||
|
||||
def test_matmul_add():
|
||||
n = 1024
|
||||
l = 128
|
||||
m = 235
|
||||
bias = tvm.var('bias', dtype=tvm.float32)
|
||||
A = tvm.placeholder((n, l), name='A')
|
||||
B = tvm.placeholder((l, m), name='B')
|
||||
C = cblas.matmul(A, B)
|
||||
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
|
||||
s = tvm.create_schedule(D.op)
|
||||
|
||||
def verify(target="llvm"):
|
||||
if not tvm.module.enabled(target):
|
||||
print("skip because %s is not enabled..." % target)
|
||||
return
|
||||
if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
|
||||
print("skip because extern function is not avalable")
|
||||
return
|
||||
ctx = tvm.cpu(0)
|
||||
f = tvm.build(s, [A, B, D, bias], target)
|
||||
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
|
||||
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
|
||||
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
|
||||
bb = 10.0
|
||||
f(a, b, d, bb)
|
||||
np.testing.assert_allclose(
|
||||
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb)
|
||||
verify()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_matmul_add()
|
|
@ -18,17 +18,17 @@ if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
|
|||
fi
|
||||
|
||||
cp make/config.mk config.mk
|
||||
echo "ENABLE_CUDA=0" >> config.mk
|
||||
echo "USE_CUDA=0" >> config.mk
|
||||
|
||||
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
|
||||
echo "ENABLE_OPENCL=1" >> config.mk
|
||||
echo "ENABLE_METAL=1" >> config.mk
|
||||
echo "USE_OPENCL=1" >> config.mk
|
||||
echo "USE_METAL=1" >> config.mk
|
||||
else
|
||||
# use g++-4.8 for linux
|
||||
if [ ${CXX} == "g++" ]; then
|
||||
export CXX=g++-4.8
|
||||
fi
|
||||
echo "ENABLE_OPENCL=0" >> config.mk
|
||||
echo "USE_OPENCL=0" >> config.mk
|
||||
fi
|
||||
|
||||
if [ ${TASK} == "verilog_test" ] || [ ${TASK} == "all_test" ]; then
|
||||
|
|
|
@ -22,6 +22,7 @@ import numpy as np
|
|||
#
|
||||
# A **Schedule** is a set of transformation of computation that
|
||||
# transforms the loop of computations in the program.
|
||||
#
|
||||
|
||||
# declare some variables for use later
|
||||
n = tvm.var('n')
|
||||
|
@ -50,7 +51,7 @@ print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
|
|||
|
||||
######################################################################
|
||||
# split
|
||||
# --------------------------
|
||||
# -----
|
||||
# :code:`split` can split a specified axis into two axises by
|
||||
# :code:`factor`.
|
||||
A = tvm.placeholder((m,), name='A')
|
||||
|
@ -72,7 +73,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
|
|||
|
||||
######################################################################
|
||||
# tile
|
||||
# --------------------------
|
||||
# ----
|
||||
# :code:`tile` help you execute the computation tile by tile over two
|
||||
# axises.
|
||||
A = tvm.placeholder((m, n), name='A')
|
||||
|
@ -84,7 +85,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
|
|||
|
||||
######################################################################
|
||||
# fuse
|
||||
# --------------------------
|
||||
# ----
|
||||
# :code:`fuse` can fuse two consecutive axises of one computation.
|
||||
A = tvm.placeholder((m, n), name='A')
|
||||
B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')
|
||||
|
@ -98,7 +99,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
|
|||
|
||||
######################################################################
|
||||
# reorder
|
||||
# --------------------------
|
||||
# -------
|
||||
# :code:`reorder` can reorder the axises in the specified order.
|
||||
A = tvm.placeholder((m, n), name='A')
|
||||
B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')
|
||||
|
@ -112,7 +113,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
|
|||
|
||||
######################################################################
|
||||
# bind
|
||||
# --------------------------
|
||||
# ----
|
||||
# :code:`bind` can bind a specified axis with a thread axis, often used
|
||||
# in gpu programming.
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
|
@ -126,7 +127,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
|
|||
|
||||
######################################################################
|
||||
# compute_at
|
||||
# --------------------------
|
||||
# ----------
|
||||
# For a schedule consists of multiple operators, tvm will compute
|
||||
# tensors at the root separately by default.
|
||||
A = tvm.placeholder((m,), name='A')
|
||||
|
@ -149,7 +150,7 @@ print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
|
|||
|
||||
######################################################################
|
||||
# compute_inline
|
||||
# --------------------------
|
||||
# --------------
|
||||
# :code:`compute_inline` can mark one stage as inline, then the body of
|
||||
# computation will be expanded and inserted at the address where the
|
||||
# tensor is required.
|
||||
|
@ -163,7 +164,7 @@ print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
|
|||
|
||||
######################################################################
|
||||
# compute_root
|
||||
# --------------------------
|
||||
# ------------
|
||||
# :code:`compute_root` can move computation of one stage to the root.
|
||||
A = tvm.placeholder((m,), name='A')
|
||||
B = tvm.compute((m,), lambda i: A[i]+1, name='B')
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Verilog Code Guidline
|
||||
|
||||
The verilog backend is still at early alpha and not yet ready to use.
|
||||
|
||||
- Use ```my_port_name``` for variable naming.
|
||||
- Always use suffix to indicate certain usage.
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче