[CODEGEN] Refactor common codegen, Verilog Codegen (#74)
* [CODEGEN] Refactor common codegen, Verilog Codegen * fix make * fix mk * update enable signal * change function name to at neg edge * Move test to correct place
This commit is contained in:
Родитель
9ebb57b331
Коммит
df6fcc509c
2
HalideIR
2
HalideIR
|
@ -1 +1 @@
|
||||||
Subproject commit 7efe0366e93c053d558415b72f9fe3f6545eb721
|
Subproject commit ce80d58741688b200f498fed8c7b0ea33e0516c8
|
|
@ -148,6 +148,15 @@ struct IntSetNode : public Node {
|
||||||
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
|
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Detect if e can be rewritten as e = base + var * coeff
|
||||||
|
* Where coeff and base are invariant of var.
|
||||||
|
*
|
||||||
|
* \return [base, coeff] if it is possible, empty array if it is not.
|
||||||
|
*/
|
||||||
|
Array<Expr> DetectLinearEquation(Expr e, Var var);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Find an symbolic integer set that contains all possible values of
|
* \brief Find an symbolic integer set that contains all possible values of
|
||||||
* e given the domain of each iteration variables.
|
* e given the domain of each iteration variables.
|
||||||
|
|
|
@ -19,6 +19,19 @@ using ::tvm::Node;
|
||||||
using ::tvm::NodeRef;
|
using ::tvm::NodeRef;
|
||||||
using ::tvm::AttrVisitor;
|
using ::tvm::AttrVisitor;
|
||||||
|
|
||||||
|
/*! \brief Macro to make it easy to define node ref type given node */
|
||||||
|
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
|
||||||
|
class TypeName : public NodeRef { \
|
||||||
|
public: \
|
||||||
|
TypeName() {} \
|
||||||
|
explicit TypeName(std::shared_ptr<Node> n) : NodeRef(n) {} \
|
||||||
|
const NodeName* operator->() const { \
|
||||||
|
return static_cast<const NodeName*>(node_.get()); \
|
||||||
|
} \
|
||||||
|
using ContainerType = NodeName; \
|
||||||
|
}; \
|
||||||
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief save the node as well as all the node it depends on as json.
|
* \brief save the node as well as all the node it depends on as json.
|
||||||
* This can be used to serialize any TVM object
|
* This can be used to serialize any TVM object
|
||||||
|
|
|
@ -35,7 +35,6 @@ struct ChannelNode : public Node {
|
||||||
Var handle_var;
|
Var handle_var;
|
||||||
/*! \brief default data type in read/write */
|
/*! \brief default data type in read/write */
|
||||||
Type dtype;
|
Type dtype;
|
||||||
|
|
||||||
// visit all attributes
|
// visit all attributes
|
||||||
void VisitAttrs(AttrVisitor* v) final {
|
void VisitAttrs(AttrVisitor* v) final {
|
||||||
v->Visit("handle_var", &handle_var);
|
v->Visit("handle_var", &handle_var);
|
||||||
|
|
|
@ -103,10 +103,16 @@ constexpr const char* extern_op_scope = "extern_op_scope";
|
||||||
// Pipeline related attributes
|
// Pipeline related attributes
|
||||||
/*! \brief channel read scope */
|
/*! \brief channel read scope */
|
||||||
constexpr const char* channel_read_scope = "channel_read_scope";
|
constexpr const char* channel_read_scope = "channel_read_scope";
|
||||||
|
/*! \brief Advance step of channel after end of scope */
|
||||||
|
constexpr const char* channel_read_advance = "channel_read_advance";
|
||||||
/*! \brief channel write scope */
|
/*! \brief channel write scope */
|
||||||
constexpr const char* channel_write_scope = "channel_write_scope";
|
constexpr const char* channel_write_scope = "channel_write_scope";
|
||||||
/*! \brief pipeline module scope */
|
/*! \brief Advance step of channel after end of scope */
|
||||||
|
constexpr const char* channel_write_advance = "channel_write_advance";
|
||||||
|
/*! \brief pipeline stage scope, implies always execution */
|
||||||
constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
|
constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
|
||||||
|
/*! \brief pipeline execution scope, implies the scope can be pipelined. */
|
||||||
|
constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
|
||||||
} // namespace attr
|
} // namespace attr
|
||||||
|
|
||||||
/*! \brief namespace of TVM Intrinsic functions */
|
/*! \brief namespace of TVM Intrinsic functions */
|
||||||
|
|
|
@ -55,6 +55,14 @@ bool VerifySSA(const Stmt& ir);
|
||||||
*/
|
*/
|
||||||
bool HasSideEffect(const Expr& e);
|
bool HasSideEffect(const Expr& e);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Whether e expression used var.
|
||||||
|
* \param e The expression to be checked.
|
||||||
|
* \param v The variable.
|
||||||
|
* \return Whether e uses v.
|
||||||
|
*/
|
||||||
|
bool ExprUseVar(const Expr& e, const Var& v);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Convert a IR node to be SSA form.
|
* \brief Convert a IR node to be SSA form.
|
||||||
* \param stmt The source statement to be converted.
|
* \param stmt The source statement to be converted.
|
||||||
|
@ -115,9 +123,17 @@ Stmt RemoveNoOp(Stmt stmt);
|
||||||
/*!
|
/*!
|
||||||
* \brief Split statement into pipeine stages.
|
* \brief Split statement into pipeine stages.
|
||||||
* \param stmt The stmt to be splitted
|
* \param stmt The stmt to be splitted
|
||||||
|
* \param split_load Whether split load into its own stage.
|
||||||
* \return Transformed stmt.
|
* \return Transformed stmt.
|
||||||
*/
|
*/
|
||||||
Stmt SplitPipeline(Stmt stmt);
|
Stmt SplitPipeline(Stmt stmt, bool split_load);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Narrow channel access to smaller range.
|
||||||
|
* \param stmt The stmt to do access rewriting.
|
||||||
|
* \return Transformed stmt.
|
||||||
|
*/
|
||||||
|
Stmt NarrowChannelAccess(Stmt stmt);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief unroll the constant loops
|
* \brief unroll the constant loops
|
||||||
|
|
|
@ -5,6 +5,7 @@ from __future__ import absolute_import
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
from numbers import Number, Integral
|
from numbers import Number, Integral
|
||||||
|
|
||||||
from .._base import _LIB, check_call
|
from .._base import _LIB, check_call
|
||||||
|
@ -46,7 +47,14 @@ def convert_to_tvm_func(pyfunc):
|
||||||
""" ctypes function """
|
""" ctypes function """
|
||||||
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
|
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
|
||||||
pyargs = [C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
|
pyargs = [C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
|
||||||
rv = local_pyfunc(*pyargs)
|
# pylint: disable=broad-except
|
||||||
|
try:
|
||||||
|
rv = local_pyfunc(*pyargs)
|
||||||
|
except Exception:
|
||||||
|
msg = traceback.format_exc()
|
||||||
|
_LIB.TVMAPISetLastError(c_str(msg))
|
||||||
|
return -1
|
||||||
|
|
||||||
if rv is not None:
|
if rv is not None:
|
||||||
if isinstance(rv, tuple):
|
if isinstance(rv, tuple):
|
||||||
raise ValueError("PackedFunction can only support one reurn value")
|
raise ValueError("PackedFunction can only support one reurn value")
|
||||||
|
|
|
@ -4,10 +4,12 @@ from __future__ import absolute_import
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import ctypes
|
||||||
|
|
||||||
from .. import _api_internal
|
from .. import _api_internal
|
||||||
from .._base import string_types
|
from .._base import string_types
|
||||||
from .._ctypes._node import NodeBase, register_node
|
from .._ctypes._node import NodeBase, register_node
|
||||||
|
from .._ctypes._function import register_func
|
||||||
from . import testing
|
from . import testing
|
||||||
|
|
||||||
@register_node
|
@register_node
|
||||||
|
@ -46,7 +48,7 @@ class VPISession(NodeBase):
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return _api_internal._vpi_SessGetHandleByName(self, name)
|
return _api_internal._vpi_SessGetHandleByName(self, name)
|
||||||
|
|
||||||
def yield_until_posedge(self):
|
def yield_until_next_cycle(self):
|
||||||
"""Yield until next posedge"""
|
"""Yield until next posedge"""
|
||||||
for f in self.yield_callbacks:
|
for f in self.yield_callbacks:
|
||||||
f()
|
f()
|
||||||
|
@ -120,7 +122,8 @@ def search_path():
|
||||||
"""Get the search directory."""
|
"""Get the search directory."""
|
||||||
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
|
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
|
||||||
ver_path = [os.path.join(curr_path, '../../../verilog/')]
|
ver_path = [os.path.join(curr_path, '../../../verilog/')]
|
||||||
ver_path += [os.path.join(curr_path, '../../../tests/verilog/')]
|
ver_path += [os.path.join(curr_path, '../../../tests/verilog/unittest/')]
|
||||||
|
ver_path += [os.path.join(curr_path, '../../../tests/verilog/integration/')]
|
||||||
return ver_path
|
return ver_path
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,29 +181,41 @@ def compile_file(file_name, file_target, options=None):
|
||||||
raise ValueError("Compilation error:\n%s" % out)
|
raise ValueError("Compilation error:\n%s" % out)
|
||||||
|
|
||||||
|
|
||||||
def session(file_name):
|
def session(file_names, codes=None):
|
||||||
"""Create a new iverilog session by compile the file.
|
"""Create a new iverilog session by compile the file.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
file_name : str or list of str
|
file_names : str or list of str
|
||||||
The name of the file
|
The name of the file
|
||||||
|
|
||||||
|
codes : str or list of str
|
||||||
|
The code in str.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
sess : VPISession
|
sess : VPISession
|
||||||
The created session.
|
The created session.
|
||||||
"""
|
"""
|
||||||
if isinstance(file_name, string_types):
|
if isinstance(file_names, string_types):
|
||||||
file_name = [file_name]
|
file_names = [file_names]
|
||||||
|
|
||||||
for name in file_name:
|
path = testing.tempdir()
|
||||||
|
|
||||||
|
if codes:
|
||||||
|
if isinstance(codes, (list, tuple)):
|
||||||
|
codes = '\n'.join(codes)
|
||||||
|
fcode = path.relpath("temp_code.v")
|
||||||
|
with open(fcode, "w") as out_file:
|
||||||
|
out_file.write(codes)
|
||||||
|
file_names.append(fcode)
|
||||||
|
|
||||||
|
for name in file_names:
|
||||||
if not os.path.exists(name):
|
if not os.path.exists(name):
|
||||||
raise ValueError("Cannot find file %s" % name)
|
raise ValueError("Cannot find file %s" % name)
|
||||||
|
|
||||||
path = testing.tempdir()
|
target = path.relpath(os.path.basename(file_names[0].rsplit(".", 1)[0]))
|
||||||
target = path.relpath(os.path.basename(file_name[0].rsplit(".", 1)[0]))
|
compile_file(file_names, target)
|
||||||
compile_file(file_name, target)
|
|
||||||
vpi_path = _find_vpi_path()
|
vpi_path = _find_vpi_path()
|
||||||
|
|
||||||
cmd = ["vvp"]
|
cmd = ["vvp"]
|
||||||
|
@ -243,3 +258,43 @@ def session(file_name):
|
||||||
sess.proc = proc
|
sess.proc = proc
|
||||||
sess.execpath = path
|
sess.execpath = path
|
||||||
return sess
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
@register_func
|
||||||
|
def tvm_callback_verilog_simulator(code, *args):
|
||||||
|
"""Callback by TVM runtime to invoke verilog simulator
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
code : str
|
||||||
|
The verilog code to be simulated
|
||||||
|
|
||||||
|
args : list
|
||||||
|
Additional arguments to be set.
|
||||||
|
"""
|
||||||
|
libs = [
|
||||||
|
find_file("tvm_vpi_mmap.v")
|
||||||
|
]
|
||||||
|
sess = session(libs, code)
|
||||||
|
for i, value in enumerate(args):
|
||||||
|
vpi_h = sess.main["tvm_arg%d" % i]
|
||||||
|
if isinstance(value, ctypes.c_void_p):
|
||||||
|
int_value = int(value.value)
|
||||||
|
elif isinstance(value, int):
|
||||||
|
int_value = value
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Do not know how to handle value type %s" % type(value))
|
||||||
|
vpi_h.put_int(int_value)
|
||||||
|
|
||||||
|
rst = sess.main.rst
|
||||||
|
done = sess.main.done
|
||||||
|
# start driving
|
||||||
|
rst.put_int(1)
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
rst.put_int(0)
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
while not done.get_int():
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
sess.shutdown()
|
||||||
|
|
|
@ -26,6 +26,11 @@ TVM_REGISTER_API(_arith_EvalModular)
|
||||||
*ret = EvalModular(args[0], Map<Var, IntSet>());
|
*ret = EvalModular(args[0], Map<Var, IntSet>());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_arith_DetectLinearEquation)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
*ret = DetectLinearEquation(args[0], args[1]);
|
||||||
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API(_arith_DeduceBound)
|
TVM_REGISTER_API(_arith_DeduceBound)
|
||||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
*ret = DeduceBound(args[0], args[1],
|
*ret = DeduceBound(args[0], args[1],
|
||||||
|
|
|
@ -63,6 +63,7 @@ REGISTER_PASS1(CanonicalSimplify);
|
||||||
REGISTER_PASS4(Inline);
|
REGISTER_PASS4(Inline);
|
||||||
REGISTER_PASS2(StorageFlatten);
|
REGISTER_PASS2(StorageFlatten);
|
||||||
REGISTER_PASS1(VectorizeLoop);
|
REGISTER_PASS1(VectorizeLoop);
|
||||||
|
REGISTER_PASS2(ExprUseVar);
|
||||||
REGISTER_PASS2(UnrollLoop);
|
REGISTER_PASS2(UnrollLoop);
|
||||||
REGISTER_PASS2(StorageSync);
|
REGISTER_PASS2(StorageSync);
|
||||||
REGISTER_PASS4(MakeAPI);
|
REGISTER_PASS4(MakeAPI);
|
||||||
|
@ -71,7 +72,7 @@ REGISTER_PASS1(LiftAllocate);
|
||||||
REGISTER_PASS1(InjectVirtualThread);
|
REGISTER_PASS1(InjectVirtualThread);
|
||||||
REGISTER_PASS1(LoopPartition);
|
REGISTER_PASS1(LoopPartition);
|
||||||
REGISTER_PASS1(RemoveNoOp);
|
REGISTER_PASS1(RemoveNoOp);
|
||||||
REGISTER_PASS1(SplitPipeline);
|
REGISTER_PASS2(SplitPipeline);
|
||||||
|
REGISTER_PASS1(NarrowChannelAccess);
|
||||||
} // namespace ir
|
} // namespace ir
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file bound_deducer.cc
|
||||||
|
* \brief Utility to deduce bound of expression
|
||||||
|
*/
|
||||||
|
#include <tvm/expr.h>
|
||||||
|
#include <tvm/ir_pass.h>
|
||||||
|
#include <tvm/ir_visitor.h>
|
||||||
|
#include <tvm/ir_functor_ext.h>
|
||||||
|
#include <tvm/arithmetic.h>
|
||||||
|
#include "./compute_expr.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace arith {
|
||||||
|
|
||||||
|
using namespace ir;
|
||||||
|
|
||||||
|
// Linear equation, the components can be undefined.
|
||||||
|
struct LinearEqEntry {
|
||||||
|
Expr base;
|
||||||
|
Expr coeff;
|
||||||
|
};
|
||||||
|
|
||||||
|
class LinearEqDetector
|
||||||
|
: public ExprFunctor<LinearEqEntry(const Expr&, const Expr &)> {
|
||||||
|
public:
|
||||||
|
explicit LinearEqDetector(Var var)
|
||||||
|
: var_(var) {}
|
||||||
|
|
||||||
|
Array<Expr> Detect(const Expr& e) {
|
||||||
|
LinearEqEntry ret = VisitExpr(e, e);
|
||||||
|
if (fail_) return Array<Expr>();
|
||||||
|
if (!ret.base.defined()) {
|
||||||
|
ret.base = make_zero(var_.type());
|
||||||
|
}
|
||||||
|
if (!ret.coeff.defined()) {
|
||||||
|
ret.coeff = make_zero(var_.type());
|
||||||
|
}
|
||||||
|
return Array<Expr>{ret.base, ret.coeff};
|
||||||
|
}
|
||||||
|
|
||||||
|
LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final {
|
||||||
|
if (fail_) return LinearEqEntry();
|
||||||
|
LinearEqEntry a = VisitExpr(op->a, op->a);
|
||||||
|
LinearEqEntry b = VisitExpr(op->b, op->b);
|
||||||
|
LinearEqEntry ret;
|
||||||
|
ret.base = AddCombine(a.base, b.base);
|
||||||
|
ret.coeff = AddCombine(a.coeff, b.coeff);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final {
|
||||||
|
if (fail_) return LinearEqEntry();
|
||||||
|
LinearEqEntry a = VisitExpr(op->a, op->a);
|
||||||
|
LinearEqEntry b = VisitExpr(op->b, op->b);
|
||||||
|
if (a.coeff.defined()) {
|
||||||
|
std::swap(a, b);
|
||||||
|
}
|
||||||
|
if (a.coeff.defined()) {
|
||||||
|
fail_ = true;
|
||||||
|
return LinearEqEntry();
|
||||||
|
}
|
||||||
|
LinearEqEntry ret;
|
||||||
|
ret.base = MulCombine(a.base, b.base);
|
||||||
|
ret.coeff = MulCombine(a.base, b.coeff);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
LinearEqEntry VisitExpr_(const Variable* op, const Expr& e) final {
|
||||||
|
LinearEqEntry ret;
|
||||||
|
if (op == var_.get()) {
|
||||||
|
ret.coeff = make_const(op->type, 1);
|
||||||
|
} else {
|
||||||
|
ret.base = e;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
LinearEqEntry VisitExprDefault_(const Node* op, const Expr& e) final {
|
||||||
|
if (fail_) return LinearEqEntry();
|
||||||
|
if (ExprUseVar(e, var_)) {
|
||||||
|
fail_ = true;
|
||||||
|
return LinearEqEntry();
|
||||||
|
} else {
|
||||||
|
LinearEqEntry ret;
|
||||||
|
ret.base = e;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Var var_;
|
||||||
|
bool fail_{false};
|
||||||
|
// Combine by add
|
||||||
|
Expr AddCombine(Expr a, Expr b) {
|
||||||
|
if (!a.defined()) return b;
|
||||||
|
if (!b.defined()) return a;
|
||||||
|
return ComputeExpr<Add>(a, b);
|
||||||
|
}
|
||||||
|
Expr MulCombine(Expr a, Expr b) {
|
||||||
|
if (!a.defined()) return a;
|
||||||
|
if (!b.defined()) return b;
|
||||||
|
return ComputeExpr<Mul>(a, b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Array<Expr> DetectLinearEquation(Expr e, Var var) {
|
||||||
|
return LinearEqDetector(var).Detect(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace arith
|
||||||
|
} // namespace tvm
|
|
@ -163,6 +163,7 @@ inline bool MatchPoint(const IntSet& a,
|
||||||
}
|
}
|
||||||
|
|
||||||
IntSet Union(const Array<IntSet>& sets) {
|
IntSet Union(const Array<IntSet>& sets) {
|
||||||
|
if (sets.size() == 0) return IntSet::nothing();
|
||||||
if (sets.size() == 1) return sets[0];
|
if (sets.size() == 1) return sets[0];
|
||||||
Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
|
Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
|
||||||
for (size_t i = 1; i < sets.size(); ++i) {
|
for (size_t i = 1; i < sets.size(); ++i) {
|
||||||
|
|
|
@ -18,11 +18,8 @@ void CodeGenC::Init(bool output_ssa) {
|
||||||
|
|
||||||
void CodeGenC::InitFuncState(LoweredFunc f) {
|
void CodeGenC::InitFuncState(LoweredFunc f) {
|
||||||
alloc_storage_scope_.clear();
|
alloc_storage_scope_.clear();
|
||||||
name_alloc_map_.clear();
|
|
||||||
ssa_assign_map_.clear();
|
|
||||||
var_idmap_.clear();
|
|
||||||
handle_data_type_.clear();
|
handle_data_type_.clear();
|
||||||
scope_mark_.clear();
|
CodeGenSourceBase::ClearFuncState();
|
||||||
}
|
}
|
||||||
void CodeGenC::AddFunction(LoweredFunc f) {
|
void CodeGenC::AddFunction(LoweredFunc f) {
|
||||||
// clear previous generated state.
|
// clear previous generated state.
|
||||||
|
@ -67,30 +64,6 @@ std::string CodeGenC::Finish() {
|
||||||
return stream.str();
|
return stream.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::string CodeGenC::SSAGetID(std::string src, Type t) {
|
|
||||||
if (name_alloc_map_.count(src)) return src;
|
|
||||||
auto it = ssa_assign_map_.find(src);
|
|
||||||
if (it != ssa_assign_map_.end()) {
|
|
||||||
if (scope_mark_.at(it->second.scope_id)) {
|
|
||||||
return it->second.vid;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this->PrintIndent();
|
|
||||||
SSAEntry e;
|
|
||||||
e.vid = GetUniqueName("_");
|
|
||||||
e.scope_id = static_cast<int>(scope_mark_.size() - 1);
|
|
||||||
ssa_assign_map_[src] = e;
|
|
||||||
if (src.length() > 3 &&
|
|
||||||
src[0] == '(' && src[src.length() - 1] == ')') {
|
|
||||||
src = src.substr(1, src.length() - 2);
|
|
||||||
}
|
|
||||||
PrintType(t, stream);
|
|
||||||
stream << ' ' << e.vid << " = " << src << ";\n";
|
|
||||||
return e.vid;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
|
void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
|
||||||
if (print_ssa_form_) {
|
if (print_ssa_form_) {
|
||||||
std::ostringstream temp;
|
std::ostringstream temp;
|
||||||
|
@ -101,88 +74,17 @@ void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string CodeGenC::GetUniqueName(std::string prefix) {
|
void CodeGenC::PrintSSAAssign(
|
||||||
auto it = name_alloc_map_.find(prefix);
|
const std::string& target, const std::string& src, Type t) {
|
||||||
if (it != name_alloc_map_.end()) {
|
PrintType(t, stream);
|
||||||
while (true) {
|
stream << ' ' << target << " = ";
|
||||||
std::ostringstream os;
|
if (src.length() > 3 &&
|
||||||
os << prefix << (++it->second);
|
src[0] == '(' && src[src.length() - 1] == ')') {
|
||||||
std::string name = os.str();
|
stream << src.substr(1, src.length() - 2);
|
||||||
if (name_alloc_map_.count(name) == 0) {
|
|
||||||
prefix = name;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
name_alloc_map_[prefix] = 0;
|
|
||||||
return prefix;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string CodeGenC::AllocVarID(const Variable* v) {
|
|
||||||
CHECK(!var_idmap_.count(v))
|
|
||||||
<< "Need input to be in SSA form dup " << v->name_hint;
|
|
||||||
std::string key = v->name_hint;
|
|
||||||
for (size_t i = 0; i < key.size(); ++i) {
|
|
||||||
if (key[i] == '.') key[i] = '_';
|
|
||||||
}
|
|
||||||
std::string vid = GetUniqueName(key);
|
|
||||||
var_idmap_[v] = vid;
|
|
||||||
return vid;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string CodeGenC::GetVarID(const Variable* v) const {
|
|
||||||
auto it = var_idmap_.find(v);
|
|
||||||
CHECK(it != var_idmap_.end())
|
|
||||||
<< "Find undefined Variable " << v->name_hint;
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
|
|
||||||
auto it = handle_data_type_.find(buf_var);
|
|
||||||
if (it == handle_data_type_.end()) return false;
|
|
||||||
return it->second == t;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) {
|
|
||||||
auto it = handle_data_type_.find(buf_var);
|
|
||||||
if (it == handle_data_type_.end()) {
|
|
||||||
handle_data_type_[buf_var] = t;
|
|
||||||
} else {
|
} else {
|
||||||
CHECK(it->second == t)
|
stream << src;
|
||||||
<< "conflicting buf var type";
|
|
||||||
}
|
}
|
||||||
}
|
stream << ";\n";
|
||||||
|
|
||||||
void CodeGenC::PrintIndent() {
|
|
||||||
for (int i = 0; i < this->indent; ++i) {
|
|
||||||
this->stream << ' ';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CodeGenC::MarkConst(std::string vid) {
|
|
||||||
if (print_ssa_form_) {
|
|
||||||
auto it = ssa_assign_map_.find(vid);
|
|
||||||
if (it == ssa_assign_map_.end()) {
|
|
||||||
SSAEntry e;
|
|
||||||
e.vid = vid;
|
|
||||||
e.scope_id = 0;
|
|
||||||
ssa_assign_map_[vid] = e;
|
|
||||||
} else {
|
|
||||||
CHECK_EQ(it->second.vid, vid);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int CodeGenC::BeginScope() {
|
|
||||||
int sid = static_cast<int>(scope_mark_.size());
|
|
||||||
scope_mark_.push_back(true);
|
|
||||||
indent += 2;
|
|
||||||
return sid;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CodeGenC::EndScope(int scope_id) {
|
|
||||||
scope_mark_[scope_id] = false;
|
|
||||||
indent -= 2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Print a reference expression to a buffer.
|
// Print a reference expression to a buffer.
|
||||||
|
@ -229,6 +131,23 @@ void CodeGenC::PrintBufferRef(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
|
||||||
|
auto it = handle_data_type_.find(buf_var);
|
||||||
|
if (it == handle_data_type_.end()) return false;
|
||||||
|
return it->second == t;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) {
|
||||||
|
auto it = handle_data_type_.find(buf_var);
|
||||||
|
if (it == handle_data_type_.end()) {
|
||||||
|
handle_data_type_[buf_var] = t;
|
||||||
|
} else {
|
||||||
|
CHECK(it->second == t)
|
||||||
|
<< "conflicting buf var type";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void CodeGenC::PrintVecElemLoad(const std::string& vec,
|
void CodeGenC::PrintVecElemLoad(const std::string& vec,
|
||||||
Type t, int i,
|
Type t, int i,
|
||||||
std::ostream& os) { // NOLINT(*)
|
std::ostream& os) { // NOLINT(*)
|
||||||
|
@ -564,29 +483,32 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {
|
||||||
|
|
||||||
void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
|
void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
|
||||||
int lanes = op->type.lanes();
|
int lanes = op->type.lanes();
|
||||||
|
std::string svalue = GetUniqueName("_");
|
||||||
|
// delcare type.
|
||||||
|
this->PrintIndent();
|
||||||
|
this->PrintType(op->type, stream);
|
||||||
|
stream << ' ' << svalue;
|
||||||
if (op->type.lanes() == 1) {
|
if (op->type.lanes() == 1) {
|
||||||
this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, os);
|
stream << " = ";
|
||||||
|
this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, stream);
|
||||||
|
stream << ";\n";
|
||||||
} else {
|
} else {
|
||||||
Expr base;
|
Expr base;
|
||||||
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
|
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
|
||||||
this->PrintVecLoad(op->buffer_var.get(), op->type, base, os);
|
stream << " = ";
|
||||||
|
this->PrintVecLoad(op->buffer_var.get(), op->type, base, stream);
|
||||||
|
stream << ";\n";
|
||||||
} else {
|
} else {
|
||||||
// Load elements seperately
|
// Load elements seperately
|
||||||
|
stream << ";\n";
|
||||||
std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type());
|
std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type());
|
||||||
std::string svalue = GetUniqueName("_");
|
|
||||||
{
|
|
||||||
// delcare type.
|
|
||||||
this->PrintIndent();
|
|
||||||
this->PrintType(op->type, stream);
|
|
||||||
stream << ' ' << svalue << ";\n";
|
|
||||||
}
|
|
||||||
std::string vid = GetVarID(op->buffer_var.get());
|
std::string vid = GetVarID(op->buffer_var.get());
|
||||||
Type elem_type = op->type.element_of();
|
Type elem_type = op->type.element_of();
|
||||||
for (int i = 0; i < lanes; ++i) {
|
for (int i = 0; i < lanes; ++i) {
|
||||||
std::ostringstream value_temp;
|
std::ostringstream value_temp;
|
||||||
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
|
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
|
||||||
value_temp << "((";
|
value_temp << "((";
|
||||||
PrintType(elem_type, os);
|
PrintType(elem_type, value_temp);
|
||||||
value_temp << "*)" << vid << ')';
|
value_temp << "*)" << vid << ')';
|
||||||
} else {
|
} else {
|
||||||
value_temp << vid;
|
value_temp << vid;
|
||||||
|
@ -596,9 +518,9 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
|
||||||
value_temp << ']';
|
value_temp << ']';
|
||||||
PrintVecElemStore(svalue, op->type, i, value_temp.str());
|
PrintVecElemStore(svalue, op->type, i, value_temp.str());
|
||||||
}
|
}
|
||||||
os << svalue;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
os << svalue;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CodeGenC::VisitStmt_(const Store* op) {
|
void CodeGenC::VisitStmt_(const Store* op) {
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include "./codegen_source_base.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace codegen {
|
namespace codegen {
|
||||||
|
@ -25,7 +26,8 @@ using namespace ir;
|
||||||
*/
|
*/
|
||||||
class CodeGenC :
|
class CodeGenC :
|
||||||
public ExprFunctor<void(const Expr&, std::ostream&)>,
|
public ExprFunctor<void(const Expr&, std::ostream&)>,
|
||||||
public StmtFunctor<void(const Stmt&)> {
|
public StmtFunctor<void(const Stmt&)>,
|
||||||
|
public CodeGenSourceBase {
|
||||||
public:
|
public:
|
||||||
/*!
|
/*!
|
||||||
* \brief Initialize the code generator.
|
* \brief Initialize the code generator.
|
||||||
|
@ -64,26 +66,6 @@ class CodeGenC :
|
||||||
PrintExpr(n, os);
|
PrintExpr(n, os);
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
/*! \brief print the current indented value */
|
|
||||||
void PrintIndent();
|
|
||||||
/*!
|
|
||||||
* \brief Register constant value appeared in expresion tree
|
|
||||||
* This avoid generated a ssa id for each appearance of the value
|
|
||||||
* \param value The constant value.
|
|
||||||
*/
|
|
||||||
void MarkConst(std::string value);
|
|
||||||
/*!
|
|
||||||
* \brief Allocate a variable name for a newly defined var.
|
|
||||||
* \param v The variable.
|
|
||||||
* \return the variable name.
|
|
||||||
*/
|
|
||||||
std::string AllocVarID(const Variable* v);
|
|
||||||
/*!
|
|
||||||
* \brief Get a variable name.
|
|
||||||
* \param v The variable.
|
|
||||||
* \return the variable name.
|
|
||||||
*/
|
|
||||||
std::string GetVarID(const Variable* v) const;
|
|
||||||
// The following parts are overloadable print operations.
|
// The following parts are overloadable print operations.
|
||||||
/*!
|
/*!
|
||||||
* \brief Initialize codegen state for generating f.
|
* \brief Initialize codegen state for generating f.
|
||||||
|
@ -164,43 +146,12 @@ class CodeGenC :
|
||||||
virtual void PrintVecElemStore(
|
virtual void PrintVecElemStore(
|
||||||
const std::string& vec, Type t, int i, const std::string& value);
|
const std::string& vec, Type t, int i, const std::string& value);
|
||||||
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/*! \brief the stream to be printed */
|
|
||||||
std::ostringstream stream;
|
|
||||||
/*! \brief entry in ssa assign map */
|
|
||||||
struct SSAEntry {
|
|
||||||
/*! \brief The value id */
|
|
||||||
std::string vid;
|
|
||||||
/*! \brief The scope id */
|
|
||||||
int scope_id;
|
|
||||||
};
|
|
||||||
// print reference to a buffer as type t in index.
|
// print reference to a buffer as type t in index.
|
||||||
void PrintBufferRef(const Variable* buffer,
|
void PrintBufferRef(const Variable* buffer,
|
||||||
Type t, Expr index,
|
Type t, Expr index,
|
||||||
std::ostream& os); // NOLINT(*)
|
std::ostream& os); // NOLINT(*)
|
||||||
/*!
|
|
||||||
* \brief Get the SSA ID corresponds to src
|
|
||||||
* If necessary, generate new assignment
|
|
||||||
* \param src The source expression
|
|
||||||
* \param t The type of the expression.
|
|
||||||
*/
|
|
||||||
std::string SSAGetID(std::string src, Type t);
|
|
||||||
/*!
|
|
||||||
* \brief get a unique name with the corresponding prefix
|
|
||||||
* \param prefix The prefix of the name
|
|
||||||
* \return The returned name.
|
|
||||||
*/
|
|
||||||
std::string GetUniqueName(std::string prefix);
|
|
||||||
/*!
|
|
||||||
* \brief mark the beginning of a new scope
|
|
||||||
* \return The scope id.
|
|
||||||
*/
|
|
||||||
int BeginScope();
|
|
||||||
/*!
|
|
||||||
* \brief mark the end of an old scope.
|
|
||||||
* \param scope_id The scope id to be ended.
|
|
||||||
*/
|
|
||||||
void EndScope(int scope_id);
|
|
||||||
/*!
|
/*!
|
||||||
* \brief If buffer is allocated as type t.
|
* \brief If buffer is allocated as type t.
|
||||||
* \param buf_var The buffer variable.
|
* \param buf_var The buffer variable.
|
||||||
|
@ -213,30 +164,17 @@ class CodeGenC :
|
||||||
* \param t The type to be checked.
|
* \param t The type to be checked.
|
||||||
*/
|
*/
|
||||||
void RegisterHandleType(const Variable* buf_var, Type t);
|
void RegisterHandleType(const Variable* buf_var, Type t);
|
||||||
/*!
|
// override
|
||||||
* \brief Get the storage scope of buf_var.
|
void PrintSSAAssign(
|
||||||
* \param buf_var The buf_var to be queryed.
|
const std::string& target, const std::string& src, Type t) final;
|
||||||
* \return The storage scope.
|
|
||||||
*/
|
|
||||||
std::string GetStorageScope(const Variable* buf_var) const;
|
|
||||||
/*! \brief the storage scope of allocation */
|
/*! \brief the storage scope of allocation */
|
||||||
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
|
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*! \brief whether to print in SSA form */
|
/*! \brief whether to print in SSA form */
|
||||||
bool print_ssa_form_{true};
|
bool print_ssa_form_{true};
|
||||||
/*! \brief name allocation map */
|
|
||||||
std::unordered_map<std::string, int> name_alloc_map_;
|
|
||||||
/*! \brief assignment map of ssa */
|
|
||||||
std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
|
|
||||||
/*! \brief name of each variable */
|
|
||||||
std::unordered_map<const Variable*, std::string> var_idmap_;
|
|
||||||
/*! \brief the data type of allocated buffers */
|
/*! \brief the data type of allocated buffers */
|
||||||
std::unordered_map<const Variable*, Type> handle_data_type_;
|
std::unordered_map<const Variable*, Type> handle_data_type_;
|
||||||
/*! \brief array to check whether we are inside certain scope */
|
|
||||||
std::vector<bool> scope_mark_;
|
|
||||||
/*! \brief The current indentation value */
|
|
||||||
int indent{0};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace codegen
|
} // namespace codegen
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file codegen_source_base.cc
|
||||||
|
*/
|
||||||
|
#include "./codegen_source_base.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
void CodeGenSourceBase::ClearFuncState() {
|
||||||
|
name_alloc_map_.clear();
|
||||||
|
ssa_assign_map_.clear();
|
||||||
|
var_idmap_.clear();
|
||||||
|
scope_mark_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
|
||||||
|
for (size_t i = 0; i < prefix.size(); ++i) {
|
||||||
|
if (prefix[i] == '.') prefix[i] = '_';
|
||||||
|
}
|
||||||
|
auto it = name_alloc_map_.find(prefix);
|
||||||
|
if (it != name_alloc_map_.end()) {
|
||||||
|
while (true) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << prefix << (++it->second);
|
||||||
|
std::string name = os.str();
|
||||||
|
if (name_alloc_map_.count(name) == 0) {
|
||||||
|
prefix = name;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name_alloc_map_[prefix] = 0;
|
||||||
|
return prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) {
|
||||||
|
if (name_alloc_map_.count(src)) return src;
|
||||||
|
auto it = ssa_assign_map_.find(src);
|
||||||
|
if (it != ssa_assign_map_.end()) {
|
||||||
|
if (scope_mark_.at(it->second.scope_id)) {
|
||||||
|
return it->second.vid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SSAEntry e;
|
||||||
|
e.vid = GetUniqueName("_");
|
||||||
|
e.scope_id = static_cast<int>(scope_mark_.size() - 1);
|
||||||
|
ssa_assign_map_[src] = e;
|
||||||
|
this->PrintIndent();
|
||||||
|
PrintSSAAssign(e.vid, src, t);
|
||||||
|
return e.vid;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CodeGenSourceBase::AllocVarID(const Variable* v) {
|
||||||
|
CHECK(!var_idmap_.count(v))
|
||||||
|
<< "Need input to be in SSA form dup " << v->name_hint;
|
||||||
|
std::string key = v->name_hint;
|
||||||
|
std::string vid = GetUniqueName(key);
|
||||||
|
var_idmap_[v] = vid;
|
||||||
|
return vid;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CodeGenSourceBase::GetVarID(const Variable* v) const {
|
||||||
|
auto it = var_idmap_.find(v);
|
||||||
|
CHECK(it != var_idmap_.end())
|
||||||
|
<< "Find undefined Variable " << v->name_hint;
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenSourceBase::PrintIndent() {
|
||||||
|
for (int i = 0; i < indent_; ++i) {
|
||||||
|
this->stream << ' ';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenSourceBase::MarkConst(std::string vid) {
|
||||||
|
auto it = ssa_assign_map_.find(vid);
|
||||||
|
if (it == ssa_assign_map_.end()) {
|
||||||
|
SSAEntry e;
|
||||||
|
e.vid = vid;
|
||||||
|
e.scope_id = 0;
|
||||||
|
ssa_assign_map_[vid] = e;
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(it->second.vid, vid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int CodeGenSourceBase::BeginScope() {
|
||||||
|
int sid = static_cast<int>(scope_mark_.size());
|
||||||
|
scope_mark_.push_back(true);
|
||||||
|
indent_ += 2;
|
||||||
|
return sid;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenSourceBase::EndScope(int scope_id) {
|
||||||
|
scope_mark_[scope_id] = false;
|
||||||
|
indent_ -= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace tvm
|
|
@ -0,0 +1,105 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2018 by Contributors
|
||||||
|
* \file codegen_source_base.h
|
||||||
|
* \brief Common utilities to source code in text form.
|
||||||
|
*/
|
||||||
|
#ifndef TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
|
||||||
|
#define TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
|
||||||
|
|
||||||
|
#include <tvm/ir.h>
|
||||||
|
#include <tvm/codegen.h>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief A base class to generate source code.
|
||||||
|
* Contains helper utilities to generate nest and ssa form.
|
||||||
|
*/
|
||||||
|
class CodeGenSourceBase {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief Register constant value appeared in expresion tree
|
||||||
|
* This avoid generated a ssa id for each appearance of the value
|
||||||
|
* \param value The constant value.
|
||||||
|
*/
|
||||||
|
void MarkConst(std::string value);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/*! \brief entry in ssa assign map */
|
||||||
|
struct SSAEntry {
|
||||||
|
/*! \brief The value id */
|
||||||
|
std::string vid;
|
||||||
|
/*! \brief The scope id, used to check if this entry is invalid. */
|
||||||
|
int scope_id;
|
||||||
|
};
|
||||||
|
/*! \brief Clear the states that might relates to function generation */
|
||||||
|
void ClearFuncState();
|
||||||
|
/*! \brief print the current indented value */
|
||||||
|
void PrintIndent();
|
||||||
|
/*!
|
||||||
|
* \brief Allocate a variable name for a newly defined var.
|
||||||
|
* \param v The variable.
|
||||||
|
* \return the variable name.
|
||||||
|
*/
|
||||||
|
std::string AllocVarID(const Variable* v);
|
||||||
|
/*!
|
||||||
|
* \brief Get a variable name.
|
||||||
|
* \param v The variable.
|
||||||
|
* \return the variable name.
|
||||||
|
*/
|
||||||
|
std::string GetVarID(const Variable* v) const;
|
||||||
|
/*!
|
||||||
|
* \brief Get the SSA ID corresponds to src
|
||||||
|
* If necessary, generate new assignment
|
||||||
|
* \param src The source expression
|
||||||
|
* \param t The type of the expression.
|
||||||
|
*/
|
||||||
|
std::string SSAGetID(std::string src, Type t);
|
||||||
|
/*!
|
||||||
|
* \brief get a unique name with the corresponding prefix
|
||||||
|
* \param prefix The prefix of the name
|
||||||
|
* \return The returned name.
|
||||||
|
*/
|
||||||
|
std::string GetUniqueName(std::string prefix);
|
||||||
|
/*!
|
||||||
|
* \brief mark the beginning of a new scope
|
||||||
|
* \return The scope id.
|
||||||
|
*/
|
||||||
|
int BeginScope();
|
||||||
|
/*!
|
||||||
|
* \brief mark the end of an old scope.
|
||||||
|
* \param scope_id The scope id to be ended.
|
||||||
|
*/
|
||||||
|
void EndScope(int scope_id);
|
||||||
|
/*!
|
||||||
|
* \brief Print assignment of src to the id in ssa entry.
|
||||||
|
* \param target id of target variable.
|
||||||
|
* \param src The source expression.
|
||||||
|
* \param t The type of target.
|
||||||
|
*/
|
||||||
|
virtual void PrintSSAAssign(
|
||||||
|
const std::string& target, const std::string& src, Type t) = 0;
|
||||||
|
|
||||||
|
/*! \brief the stream to be printed */
|
||||||
|
std::ostringstream stream;
|
||||||
|
/*! \brief name of each variable */
|
||||||
|
std::unordered_map<const Variable*, std::string> var_idmap_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief assignment map of ssa */
|
||||||
|
std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
|
||||||
|
/*! \brief name allocation map */
|
||||||
|
std::unordered_map<std::string, int> name_alloc_map_;
|
||||||
|
/*! \brief array to check whether we are inside certain scope */
|
||||||
|
std::vector<bool> scope_mark_;
|
||||||
|
/*! \brief The current indentation value */
|
||||||
|
int indent_{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace tvm
|
||||||
|
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
|
|
@ -0,0 +1,724 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file codegen_verilog.cc
|
||||||
|
*/
|
||||||
|
#include <tvm/ir_pass.h>
|
||||||
|
#include <cctype>
|
||||||
|
#include <sstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include "./codegen_verilog.h"
|
||||||
|
#include "../../arithmetic/compute_expr.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace codegen {
|
||||||
|
namespace verilog {
|
||||||
|
|
||||||
|
using namespace ir;
|
||||||
|
|
||||||
|
void CodeGenVerilog::Init() {
|
||||||
|
stream << "`include \"tvm_marcos.v\"\n\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::InitFuncState(LoweredFunc f) {
|
||||||
|
CodeGenSourceBase::ClearFuncState();
|
||||||
|
cmap_.clear();
|
||||||
|
tvm_vpi_modules_.clear();
|
||||||
|
done_sigs_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::AddFunction(LoweredFunc f) {
|
||||||
|
// clear previous generated state.
|
||||||
|
this->InitFuncState(f);
|
||||||
|
// skip the first underscore, so SSA variable starts from _1
|
||||||
|
GetUniqueName("_");
|
||||||
|
GetUniqueName("rst");
|
||||||
|
GetUniqueName("clk");
|
||||||
|
GetUniqueName("done");
|
||||||
|
GetUniqueName("enable");
|
||||||
|
GetUniqueName("all_input_valid");
|
||||||
|
// print out function body.
|
||||||
|
int func_scope = this->BeginScope();
|
||||||
|
|
||||||
|
// Stich things up.
|
||||||
|
stream << "module " << f->name << "(\n";
|
||||||
|
PrintDecl("clk", kInput, Bool(1), "");
|
||||||
|
stream << ",\n";
|
||||||
|
PrintDecl("rst", kInput, Bool(1), "");
|
||||||
|
VerilogFuncEntry entry;
|
||||||
|
for (size_t i = 0; i < f->args.size(); ++i) {
|
||||||
|
stream << ",\n";
|
||||||
|
Var v = f->args[i];
|
||||||
|
std::string vid = AllocVarID(v.get());
|
||||||
|
entry.arg_ids.push_back(vid);
|
||||||
|
entry.arg_types.push_back(v.type());
|
||||||
|
PrintDecl(vid, kInput, v.type(), "");
|
||||||
|
}
|
||||||
|
stream << ",\n";
|
||||||
|
PrintDecl("done", kOutput, Bool(1), "");
|
||||||
|
stream << "\n);\n";
|
||||||
|
this->CodeGen(MakePipeline(f));
|
||||||
|
PrintAssignAnd("done", done_sigs_);
|
||||||
|
this->EndScope(func_scope);
|
||||||
|
this->PrintIndent();
|
||||||
|
stream << "endmodule\n";
|
||||||
|
entry.vpi_modules = std::move(tvm_vpi_modules_);
|
||||||
|
functions_[f->name] = entry;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string VerilogCodeGenModule::AppendSimMain(
|
||||||
|
const std::string& func_name) const {
|
||||||
|
// Add main function for simulator hook
|
||||||
|
const VerilogFuncEntry& entry = fmap.at(func_name);
|
||||||
|
std::ostringstream stream;
|
||||||
|
stream << code;
|
||||||
|
stream << "\n"
|
||||||
|
<< "module main();\n"
|
||||||
|
<< " `TVM_DEFINE_TEST_SIGNAL(clk, rst)\n";
|
||||||
|
// print out function body.
|
||||||
|
std::vector<std::string> sargs;
|
||||||
|
for (size_t i = 0; i < entry.arg_types.size(); ++i) {
|
||||||
|
Type t = entry.arg_types[i];
|
||||||
|
std::ostringstream sarg;
|
||||||
|
sarg << "tvm_arg" << i;
|
||||||
|
std::string vid = sarg.str();
|
||||||
|
stream << " reg";
|
||||||
|
if (t.bits() > 1) {
|
||||||
|
stream << "[" << t.bits() - 1 << ":0]";
|
||||||
|
}
|
||||||
|
stream << " " << vid << ";\n";
|
||||||
|
sargs.push_back(vid);
|
||||||
|
}
|
||||||
|
stream << " wire done;\n";
|
||||||
|
stream << "\n " << func_name << " dut(\n"
|
||||||
|
<< " .clk(clk),\n"
|
||||||
|
<< " .rst(rst),\n";
|
||||||
|
|
||||||
|
for (size_t i = 0; i < entry.arg_ids.size(); ++i) {
|
||||||
|
stream << " ." << entry.arg_ids[i] << '('
|
||||||
|
<< sargs[i] << "),\n";
|
||||||
|
}
|
||||||
|
stream << " .done(done)\n"
|
||||||
|
<< " );\n";
|
||||||
|
|
||||||
|
|
||||||
|
stream << " initial begin\n"
|
||||||
|
<< " $tvm_session(clk";
|
||||||
|
for (const std::string& mvpi : entry.vpi_modules) {
|
||||||
|
stream << ", dut." << mvpi;
|
||||||
|
}
|
||||||
|
stream << ");\n"
|
||||||
|
<< " end\n";
|
||||||
|
stream << "endmodule\n";
|
||||||
|
return stream.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
VerilogCodeGenModule CodeGenVerilog::Finish() {
|
||||||
|
VerilogCodeGenModule m;
|
||||||
|
m.code = stream.str();
|
||||||
|
m.fmap = std::move(functions_);
|
||||||
|
return m;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::PrintDecl(
|
||||||
|
const std::string& vid, VerilogVarType vtype, Type dtype,
|
||||||
|
const char* suffix, bool indent) {
|
||||||
|
if (indent) PrintIndent();
|
||||||
|
switch (vtype) {
|
||||||
|
case kReg: stream << "reg "; break;
|
||||||
|
case kWire: stream << "wire "; break;
|
||||||
|
case kInput: stream << "input "; break;
|
||||||
|
case kOutput: stream << "output "; break;
|
||||||
|
default: LOG(FATAL) << "unsupported vtype=" << vtype;
|
||||||
|
}
|
||||||
|
int bits = dtype.bits();
|
||||||
|
// bits for handle type.
|
||||||
|
if (dtype.is_handle()) {
|
||||||
|
bits = 64;
|
||||||
|
}
|
||||||
|
if (bits > 1) {
|
||||||
|
stream << "[" << bits - 1 << ":0] ";
|
||||||
|
}
|
||||||
|
stream << vid << suffix;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::PrintSSAAssign(
|
||||||
|
const std::string& target, const std::string& src, Type t) {
|
||||||
|
// add target to list of declaration.
|
||||||
|
PrintDecl(target, kWire, t, ";\n", false);
|
||||||
|
PrintAssign(target, src);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::PrintAssign(
|
||||||
|
const std::string& target, const std::string& src) {
|
||||||
|
PrintIndent();
|
||||||
|
stream << "assign " << target << " = ";
|
||||||
|
if (src.length() > 3 &&
|
||||||
|
src[0] == '(' && src[src.length() - 1] == ')') {
|
||||||
|
stream << src.substr(1, src.length() - 2);
|
||||||
|
} else {
|
||||||
|
stream << src;
|
||||||
|
}
|
||||||
|
stream << ";\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::PrintAssignAnd(
|
||||||
|
const std::string& target, const std::vector<std::string>& conds) {
|
||||||
|
if (conds.size() != 0) {
|
||||||
|
std::ostringstream os_valid;
|
||||||
|
for (size_t i = 0; i < conds.size(); ++i) {
|
||||||
|
if (i != 0) os_valid << " && ";
|
||||||
|
os_valid << conds[i];
|
||||||
|
}
|
||||||
|
PrintAssign(target, os_valid.str());
|
||||||
|
} else {
|
||||||
|
PrintAssign(target, "1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::PrintLine(const std::string& line) {
|
||||||
|
PrintIndent();
|
||||||
|
stream << line << '\n';
|
||||||
|
}
|
||||||
|
|
||||||
|
VerilogValue CodeGenVerilog::MakeBinary(Type t,
|
||||||
|
VerilogValue a,
|
||||||
|
VerilogValue b,
|
||||||
|
const char *opstr) {
|
||||||
|
CHECK_EQ(t.lanes(), 1)
|
||||||
|
<< "Do not yet support vectorized op";
|
||||||
|
CHECK(t.is_int() || t.is_uint())
|
||||||
|
<< "Only support integer operations";
|
||||||
|
std::ostringstream os;
|
||||||
|
os << a.vid << ' ' << opstr << ' '<< b.vid;
|
||||||
|
return GetSSAValue(os.str(), t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline VerilogValue IntConst(const T* op, CodeGenVerilog* p) {
|
||||||
|
if (op->type.bits() <= 32 && op->type.lanes() == 1) {
|
||||||
|
std::ostringstream temp;
|
||||||
|
temp << op->value;
|
||||||
|
p->MarkConst(temp.str());
|
||||||
|
return VerilogValue(temp.str(), kConst, op->type);
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Do not support integer constant type " << op->type;
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const IntImm *op) {
|
||||||
|
return IntConst(op, this);
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const UIntImm *op) {
|
||||||
|
return IntConst(op, this);
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const FloatImm *op) {
|
||||||
|
LOG(FATAL) << "Donot support float constant in Verilog";
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const StringImm *op) {
|
||||||
|
LOG(FATAL) << "Donot support string constant in Verilog";
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Cast *op) {
|
||||||
|
LOG(FATAL) << "Type cast not supported";
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Variable *op) {
|
||||||
|
return VerilogValue(GetVarID(op), kReg, op->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Add *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "+");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Sub *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "-");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Mul *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "*");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Div *op) {
|
||||||
|
int shift;
|
||||||
|
if (is_const_power_of_two_integer(op->b, &shift) &&
|
||||||
|
(op->type.is_int() || op->type.is_uint())) {
|
||||||
|
return MakeValue(op->a >> make_const(op->b.type(), shift));
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "do not support synthesis division";
|
||||||
|
}
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Mod *op) {
|
||||||
|
LOG(FATAL) << "do not support synthesis Mod";
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Min *op) {
|
||||||
|
LOG(FATAL) << "not supported";
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Max *op) {
|
||||||
|
LOG(FATAL) << "not supported";
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const EQ *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "==");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const NE *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "!=");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const LT *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "<");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const LE *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "<=");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const GT *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), ">");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const GE *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), ">=");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const And *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "&&");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Or *op) {
|
||||||
|
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "||");
|
||||||
|
}
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Not *op) {
|
||||||
|
VerilogValue value = MakeValue(op->a);
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "(!" << value.vid << ")";
|
||||||
|
return GetSSAValue(os.str(), op->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Call *op) {
|
||||||
|
if (op->is_intrinsic(Call::bitwise_and)) {
|
||||||
|
return MakeBinary(
|
||||||
|
op->type, MakeValue(op->args[0]), MakeValue(op->args[1]), "&");
|
||||||
|
} else if (op->is_intrinsic(Call::bitwise_xor)) {
|
||||||
|
return MakeBinary(
|
||||||
|
op->type, MakeValue(op->args[0]), MakeValue(op->args[1]), "^");
|
||||||
|
} else if (op->is_intrinsic(Call::bitwise_or)) {
|
||||||
|
return MakeBinary(
|
||||||
|
op->type, MakeValue(op->args[0]), MakeValue(op->args[1]), "|");
|
||||||
|
} else if (op->is_intrinsic(Call::bitwise_not)) {
|
||||||
|
VerilogValue value = MakeValue(op->args[0]);
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "(~" << value.vid << ")";
|
||||||
|
return GetSSAValue(os.str(), op->type);
|
||||||
|
} else if (op->is_intrinsic(Call::shift_left)) {
|
||||||
|
return MakeBinary(
|
||||||
|
op->type, MakeValue(op->args[0]), MakeValue(op->args[1]), "<<");
|
||||||
|
} else if (op->is_intrinsic(Call::shift_right)) {
|
||||||
|
return MakeBinary(
|
||||||
|
op->type, MakeValue(op->args[0]), MakeValue(op->args[1]), ">>");
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Cannot generate call type " << op->name;
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Let* op) {
|
||||||
|
VerilogValue value = MakeValue(op->value);
|
||||||
|
CHECK(!var_idmap_.count(op->var.get()));
|
||||||
|
var_idmap_[op->var.get()] = value.vid;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Ramp* op) {
|
||||||
|
LOG(FATAL) << "Ramp: not supported ";
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Broadcast* op) {
|
||||||
|
LOG(FATAL) << "Broadcast: not supported ";
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
VerilogValue CodeGenVerilog::VisitExpr_(const Select* op) {
|
||||||
|
LOG(FATAL) << "Select: not supported ";
|
||||||
|
return VerilogValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::CodeGen(const Pipeline& pipeline) {
|
||||||
|
// setup channel map.
|
||||||
|
for (auto kv : pipeline->channels) {
|
||||||
|
ChannelEntry e; e.block = kv.second;
|
||||||
|
cmap_[kv.first.get()] = e;
|
||||||
|
}
|
||||||
|
for (ComputeBlock stage : pipeline->stages) {
|
||||||
|
const Store* store = stage->body.as<Store>();
|
||||||
|
CHECK(store);
|
||||||
|
const Load* load = store->value.as<Load>();
|
||||||
|
if (load) {
|
||||||
|
MakeLoadToFIFO(stage, store, load);
|
||||||
|
} else {
|
||||||
|
MakeStore(stage, store);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& kv : cmap_) {
|
||||||
|
MakeChannelUnit(kv.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CodeGenVerilog::SignalEntry
|
||||||
|
CodeGenVerilog::MakeLoop(const Array<Stmt>& loop) {
|
||||||
|
SignalEntry sig;
|
||||||
|
// do not use init signal for now.
|
||||||
|
std::string init = "0";
|
||||||
|
std::string lp_ready = GetUniqueName("lp_tmp_sig");
|
||||||
|
sig.ready = GetUniqueName("loop_ready");
|
||||||
|
sig.valid = GetUniqueName("loop_valid");
|
||||||
|
PrintLine("// loop logic");
|
||||||
|
PrintDecl(lp_ready, kWire, Bool(1));
|
||||||
|
PrintDecl(sig.ready, kWire, Bool(1));
|
||||||
|
|
||||||
|
std::string end_loop = lp_ready;
|
||||||
|
for (size_t i = loop.size(); i != 0; --i) {
|
||||||
|
const For* for_op = loop[i - 1].as<For>();
|
||||||
|
int bits = for_op->loop_var.type().bits();
|
||||||
|
VerilogValue min = MakeValue(for_op->min);
|
||||||
|
VerilogValue extent = MakeValue(for_op->extent);
|
||||||
|
CHECK(min.vtype == kConst && extent.vtype == kConst)
|
||||||
|
<< "Only support constant loop domain";
|
||||||
|
|
||||||
|
std::string vid = AllocVarID(for_op->loop_var.get());
|
||||||
|
std::string finish = GetUniqueName(vid + "_finish");
|
||||||
|
this->PrintIndent();
|
||||||
|
stream <<"`NONSTOP_LOOP(" << vid << ", " << bits << ", " << init
|
||||||
|
<< ", " << end_loop << ", " << finish
|
||||||
|
<< ", " << min.vid << ", " << extent.vid << ")\n";
|
||||||
|
end_loop = finish;
|
||||||
|
}
|
||||||
|
if (loop.size() != 0) {
|
||||||
|
std::string local_ready = GetUniqueName("lp_tmp_sig");
|
||||||
|
this->PrintIndent();
|
||||||
|
stream <<"`WRAP_LOOP_ONCE(" << init << ", " << sig.valid
|
||||||
|
<< ", " << sig.ready << ", " << end_loop << ", " << local_ready << ")\n";
|
||||||
|
PrintAssign(lp_ready, local_ready);
|
||||||
|
}
|
||||||
|
return sig;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::MakeStageInputs(
|
||||||
|
const ComputeBlock& block,
|
||||||
|
const std::string& enable,
|
||||||
|
std::string* out_all_input_valid) {
|
||||||
|
std::vector<SignalEntry> sigs;
|
||||||
|
sigs.push_back(MakeLoop(block->loop));
|
||||||
|
// Input data path.
|
||||||
|
PrintLine("// stage inputs");
|
||||||
|
for (auto kv : block->inputs) {
|
||||||
|
const Var& var = kv.first;
|
||||||
|
const StageInput& arg = kv.second;
|
||||||
|
std::string vid = AllocVarID(var.get());
|
||||||
|
this->PrintDecl(vid, kWire, var.type());
|
||||||
|
if (arg->input_type == kGlobalConst ||
|
||||||
|
arg->input_type == kLoopVar) {
|
||||||
|
PrintAssign(vid, GetVarID(arg->var.get()));
|
||||||
|
} else if (arg->input_type == kChannel) {
|
||||||
|
std::string vid_valid = GetUniqueName(vid + "_valid");
|
||||||
|
std::string vid_ready = GetUniqueName(vid + "_ready");
|
||||||
|
this->PrintDecl(vid_valid, kWire, Bool(1));
|
||||||
|
this->PrintDecl(vid_ready, kWire, Bool(1));
|
||||||
|
ChannelEntry* e = GetChannelInfo(arg->var.get());
|
||||||
|
// TODO(tqchen, thierry) add one cache here.
|
||||||
|
e->AssignPort("read_data", vid, var.type());
|
||||||
|
e->AssignPort("read_valid", vid_valid, Bool(1));
|
||||||
|
e->AssignPort("read_ready", vid_ready, Bool(1));
|
||||||
|
e->AssignPort("read_addr", "0", Int(1));
|
||||||
|
sigs.push_back(SignalEntry{vid_valid, vid_ready});
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Unknown input type";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PrintLine("// stage input stall");
|
||||||
|
std::string all_input_valid = GetUniqueName("all_input_valid");
|
||||||
|
this->PrintDecl(all_input_valid, kWire, Bool(1));
|
||||||
|
// forward all valid
|
||||||
|
std::vector<std::string> valid_conds;
|
||||||
|
for (const SignalEntry& e : sigs) {
|
||||||
|
if (e.valid.length() != 0) {
|
||||||
|
valid_conds.push_back(e.valid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PrintAssignAnd(all_input_valid, valid_conds);
|
||||||
|
// input ready signal
|
||||||
|
for (size_t i = 0; i < sigs.size(); ++i) {
|
||||||
|
if (sigs[i].ready.length() == 0) continue;
|
||||||
|
std::vector<std::string> conds = {enable};
|
||||||
|
for (size_t j = 0; j < sigs.size(); ++j) {
|
||||||
|
if (j != i && sigs[j].valid.length() != 0) {
|
||||||
|
conds.push_back(sigs[j].valid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PrintAssignAnd(sigs[i].ready, conds);
|
||||||
|
}
|
||||||
|
*out_all_input_valid = all_input_valid;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::MakeDelay(const std::string& dst,
|
||||||
|
const std::string& src,
|
||||||
|
Type dtype,
|
||||||
|
int delay,
|
||||||
|
const std::string& enable) {
|
||||||
|
PrintIndent();
|
||||||
|
stream << "`DELAY(" << dst << ", " << src << ", "
|
||||||
|
<< dtype.bits() << ", " << delay << ", " << enable << ")\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::MakeStore(const ComputeBlock& block,
|
||||||
|
const Store* store) {
|
||||||
|
std::string all_input_valid;
|
||||||
|
std::string enable = GetUniqueName("enable");
|
||||||
|
this->PrintDecl(enable, kWire, Bool(1));
|
||||||
|
MakeStageInputs(block, enable, &all_input_valid);
|
||||||
|
// Data path
|
||||||
|
PrintLine("// data path");
|
||||||
|
VerilogValue value = MakeValue(store->value);
|
||||||
|
VerilogValue index = MakeValue(store->index);
|
||||||
|
PrintLine("// control and retiming");
|
||||||
|
ChannelEntry* write_entry = GetChannelInfo(store->buffer_var.get());
|
||||||
|
// TODO(tqchen, thierry) add delay model from expression.a
|
||||||
|
int delay = 2;
|
||||||
|
std::string ch_name = write_entry->block->channel->handle_var->name_hint;
|
||||||
|
std::string write_addr = GetUniqueName(ch_name + ".write_addr");
|
||||||
|
std::string write_ready = GetUniqueName(ch_name + ".write_ready");
|
||||||
|
std::string write_valid = GetUniqueName(ch_name + ".write_valid");
|
||||||
|
std::string write_data = GetUniqueName(ch_name + ".write_data");
|
||||||
|
PrintDecl(write_addr, kWire, store->index.type());
|
||||||
|
PrintDecl(write_ready, kWire, Bool(1));
|
||||||
|
PrintDecl(write_valid, kWire, Bool(1));
|
||||||
|
PrintDecl(write_data, kWire, store->value.type());
|
||||||
|
|
||||||
|
MakeDelay(write_addr, index.vid, store->index.type(), delay, enable);
|
||||||
|
MakeDelay(write_data, value.vid, store->value.type(), delay, enable);
|
||||||
|
MakeDelay(write_valid, all_input_valid, Bool(1), delay, enable);
|
||||||
|
PrintAssign(enable, "!" + write_valid + " || " + write_ready);
|
||||||
|
write_entry->AssignPort("write_addr", write_addr, store->index.type());
|
||||||
|
write_entry->AssignPort("write_ready", write_ready, Bool(1));
|
||||||
|
write_entry->AssignPort("write_valid", write_valid, Bool(1));
|
||||||
|
write_entry->AssignPort("write_data", write_data, store->value.type());
|
||||||
|
// The triggers
|
||||||
|
for (size_t i = 0; i < block->triggers.size(); ++i) {
|
||||||
|
SignalTrigger trigger = block->triggers[i];
|
||||||
|
CHECK(trigger->predicate.type() == Bool(1));
|
||||||
|
ChannelEntry* trigger_ch = GetChannelInfo(trigger->channel_var.get());
|
||||||
|
std::string port = trigger_ch->SignalPortName(trigger->signal_index);
|
||||||
|
VerilogValue v = MakeValue(trigger->predicate);
|
||||||
|
// Assign constant trigger.
|
||||||
|
if (v.vtype == kConst) {
|
||||||
|
trigger_ch->AssignPort(port, v.vid, Bool(1));
|
||||||
|
} else {
|
||||||
|
// non-constant trigger
|
||||||
|
CHECK_EQ(trigger_ch, write_entry)
|
||||||
|
<< "Can only triggger conditional event at write channel";
|
||||||
|
std::string v_trigger = GetUniqueName(ch_name + "." + port);
|
||||||
|
MakeDelay(v_trigger, v.vid, Bool(1), delay, enable);
|
||||||
|
write_entry->AssignPort(port, v_trigger, Bool(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stream << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::MakeLoadToFIFO(const ComputeBlock& block,
|
||||||
|
const Store* store,
|
||||||
|
const Load* load) {
|
||||||
|
ChannelEntry* write_entry = GetChannelInfo(store->buffer_var.get());
|
||||||
|
ChannelEntry* load_entry = GetChannelInfo(load->buffer_var.get());
|
||||||
|
std::string all_input_valid;
|
||||||
|
std::string enable = GetUniqueName("enable");
|
||||||
|
this->PrintDecl(enable, kWire, Bool(1));
|
||||||
|
MakeStageInputs(block, enable, &all_input_valid);
|
||||||
|
// data path
|
||||||
|
PrintLine("// data path");
|
||||||
|
VerilogValue index = MakeValue(load->index);
|
||||||
|
// control and retiming
|
||||||
|
PrintLine("// control and retiming");
|
||||||
|
// TODO(tqchen, thierry) add delay model from expression
|
||||||
|
int delay = 1;
|
||||||
|
std::string read_ch_name = load_entry->block->channel->handle_var->name_hint;
|
||||||
|
std::string write_ch_name = write_entry->block->channel->handle_var->name_hint;
|
||||||
|
std::string read_addr = GetUniqueName(read_ch_name + ".read_addr");
|
||||||
|
std::string read_data = GetUniqueName(read_ch_name + ".read_data");
|
||||||
|
std::string read_valid = GetUniqueName(read_ch_name + ".read_valid");
|
||||||
|
std::string index_valid = GetUniqueName(read_ch_name + ".index_valid");
|
||||||
|
std::string write_ready = GetUniqueName(write_ch_name + ".write_ready");
|
||||||
|
std::string data_valid = GetUniqueName(read_ch_name + ".data_valid");
|
||||||
|
std::string valid_delay = GetUniqueName(read_ch_name + ".valid_delay");
|
||||||
|
PrintDecl(read_addr, kWire, load->index.type());
|
||||||
|
PrintDecl(read_data, kWire, load->type);
|
||||||
|
PrintDecl(read_valid, kWire, Bool(1));
|
||||||
|
PrintDecl(index_valid, kWire, Bool(1));
|
||||||
|
PrintDecl(data_valid, kWire, Bool(1));
|
||||||
|
MakeDelay(read_addr, index.vid, load->index.type(), delay, enable);
|
||||||
|
MakeDelay(index_valid, all_input_valid, Bool(1), delay, enable);
|
||||||
|
PrintAssignAnd(data_valid, {read_valid, index_valid});
|
||||||
|
// The read ports.
|
||||||
|
load_entry->AssignPort("read_addr", read_addr, load->index.type());
|
||||||
|
load_entry->AssignPort("read_data", read_data, load->type);
|
||||||
|
load_entry->AssignPort("read_valid", read_valid, Bool(1));
|
||||||
|
// The write ports.
|
||||||
|
write_entry->AssignPort("write_ready", write_ready, Bool(1));
|
||||||
|
write_entry->AssignPort("write_data", read_data, load->type);
|
||||||
|
write_entry->AssignPort("write_valid", valid_delay, Bool(1));
|
||||||
|
write_entry->AssignPort("write_addr", "0", Int(1));
|
||||||
|
// The not stall condition.
|
||||||
|
PrintAssignAnd(enable, {write_ready, read_valid});
|
||||||
|
// The ready signal
|
||||||
|
PrintIndent();
|
||||||
|
stream << "`BUFFER_READ_VALID_DELAY(" << valid_delay << ", " << data_valid
|
||||||
|
<< ", " << write_ready << ")\n";
|
||||||
|
// The triggers
|
||||||
|
for (size_t i = 0; i < block->triggers.size(); ++i) {
|
||||||
|
SignalTrigger trigger = block->triggers[i];
|
||||||
|
CHECK(trigger->predicate.type() == Bool(1));
|
||||||
|
ChannelEntry* trigger_ch = GetChannelInfo(trigger->channel_var.get());
|
||||||
|
std::string port = trigger_ch->SignalPortName(trigger->signal_index);
|
||||||
|
VerilogValue v = MakeValue(trigger->predicate);
|
||||||
|
// Assign constant trigger.
|
||||||
|
if (v.vtype == kConst) {
|
||||||
|
trigger_ch->AssignPort(port, v.vid, Bool(1));
|
||||||
|
} else {
|
||||||
|
// non-constant trigger
|
||||||
|
CHECK_EQ(trigger_ch, load_entry)
|
||||||
|
<< "Can only triggger conditional event at load channel";
|
||||||
|
std::string v_trigger = GetUniqueName(read_ch_name + "." + port);
|
||||||
|
MakeDelay(v_trigger, v.vid, Bool(1), delay, enable);
|
||||||
|
load_entry->AssignPort(port, v_trigger, Bool(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stream << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::MakeChannelUnit(const ChannelEntry& ch) {
|
||||||
|
if (ch.block->read_window == 0) {
|
||||||
|
// This is a memory map
|
||||||
|
MakeChannelMemMap(ch);
|
||||||
|
} else if (ch.block->read_window == 1 &&
|
||||||
|
ch.block->write_window == 1) {
|
||||||
|
MakeChannelFIFO(ch);
|
||||||
|
} else {
|
||||||
|
// general Buffer
|
||||||
|
MakeChannelBuffer(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::MakeChannelMemMap(const ChannelEntry& ch) {
|
||||||
|
Var ch_var = ch.block->channel->handle_var;
|
||||||
|
std::string dut = GetUniqueName(ch_var->name_hint + ".mmap");
|
||||||
|
std::string mmap_addr = GetVarID(ch_var.get());
|
||||||
|
|
||||||
|
tvm_vpi_modules_.push_back(dut);
|
||||||
|
if (ch.ports.count("read_addr")) {
|
||||||
|
CHECK(!ch.ports.count("write_addr"))
|
||||||
|
<< "Cannot read/write to same RAM";
|
||||||
|
const PortEntry& read_addr = ch.GetPort("read_addr");
|
||||||
|
const PortEntry& read_data = ch.GetPort("read_data");
|
||||||
|
const PortEntry& read_valid = ch.GetPort("read_valid");
|
||||||
|
stream << " // channel setup for " << ch_var << "\n"
|
||||||
|
<< " tvm_vpi_read_mmap # (\n"
|
||||||
|
<< " .DATA_WIDTH(" << read_data.dtype.bits() << "),\n"
|
||||||
|
<< " .ADDR_WIDTH(" << read_addr.dtype.bits() << "),\n"
|
||||||
|
<< " .BASE_ADDR_WIDTH(" << ch_var.type().bits() << ")\n"
|
||||||
|
<< " ) " << dut << " (\n"
|
||||||
|
<< " .clk(clk),\n"
|
||||||
|
<< " .rst(rst),\n"
|
||||||
|
<< " .addr(" << read_addr.value << "),\n"
|
||||||
|
<< " .data_out(" << read_data.value << "),\n"
|
||||||
|
<< " .mmap_addr(" << mmap_addr << ")\n"
|
||||||
|
<< " );\n";
|
||||||
|
PrintAssign(read_valid.value, "1");
|
||||||
|
} else if (ch.ports.count("write_addr")) {
|
||||||
|
const PortEntry& write_addr = ch.GetPort("write_addr");
|
||||||
|
const PortEntry& write_data = ch.GetPort("write_data");
|
||||||
|
const PortEntry& write_valid = ch.GetPort("write_valid");
|
||||||
|
const PortEntry& write_ready = ch.GetPort("write_ready");
|
||||||
|
stream << " // channel setup for " << ch_var << "\n"
|
||||||
|
<< " tvm_vpi_write_mmap # (\n"
|
||||||
|
<< " .DATA_WIDTH(" << write_data.dtype.bits() << "),\n"
|
||||||
|
<< " .ADDR_WIDTH(" << write_addr.dtype.bits() << "),\n"
|
||||||
|
<< " .BASE_ADDR_WIDTH(" << ch_var.type().bits() << ")\n"
|
||||||
|
<< " ) " << dut << " (\n"
|
||||||
|
<< " .clk(clk),\n"
|
||||||
|
<< " .rst(rst),\n"
|
||||||
|
<< " .addr(" << write_addr.value << "),\n"
|
||||||
|
<< " .data_in(" << write_data.value << "),\n"
|
||||||
|
<< " .en(" << write_valid.value << "),\n"
|
||||||
|
<< " .mmap_addr(" << mmap_addr << ")\n"
|
||||||
|
<< " );\n";
|
||||||
|
PrintAssign(write_ready.value, "1");
|
||||||
|
// additional control signals
|
||||||
|
for (size_t i = 0; i < ch.block->ctrl_signals.size(); ++i) {
|
||||||
|
ControlSignal sig = ch.block->ctrl_signals[i];
|
||||||
|
CHECK_EQ(sig->ctrl_type, kComputeFinish);
|
||||||
|
std::string port = ch.SignalPortName(i);
|
||||||
|
done_sigs_.push_back(ch.GetPort(port).value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::MakeChannelFIFO(const ChannelEntry& ch) {
|
||||||
|
Var ch_var = ch.block->channel->handle_var;
|
||||||
|
std::string dut = GetUniqueName(ch_var->name_hint + ".fifo_reg");
|
||||||
|
|
||||||
|
const PortEntry& write_data = ch.GetPort("write_data");
|
||||||
|
const PortEntry& write_valid = ch.GetPort("write_valid");
|
||||||
|
const PortEntry& write_ready = ch.GetPort("write_ready");
|
||||||
|
|
||||||
|
const PortEntry& read_data = ch.GetPort("read_data");
|
||||||
|
const PortEntry& read_valid = ch.GetPort("read_valid");
|
||||||
|
const PortEntry& read_ready = ch.GetPort("read_ready");
|
||||||
|
|
||||||
|
CHECK_EQ(write_data.dtype, read_data.dtype);
|
||||||
|
|
||||||
|
stream << " // channel setup for " << ch_var << "\n"
|
||||||
|
<< " `CACHE_REG(" << write_data.dtype.bits()
|
||||||
|
<< ", " << write_data.value
|
||||||
|
<< ", " << write_valid.value
|
||||||
|
<< ", " << write_ready.value
|
||||||
|
<< ", " << read_data.value
|
||||||
|
<< ", " << read_valid.value
|
||||||
|
<< ", " << read_ready.value
|
||||||
|
<< ")\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::MakeChannelBuffer(const ChannelEntry& ch) {
|
||||||
|
LOG(FATAL) << "not implemeneted";
|
||||||
|
}
|
||||||
|
|
||||||
|
CodeGenVerilog::ChannelEntry*
|
||||||
|
CodeGenVerilog::GetChannelInfo(const Variable* var) {
|
||||||
|
auto it = cmap_.find(var);
|
||||||
|
CHECK(it != cmap_.end())
|
||||||
|
<< "cannot find channel for var " << var->name_hint;
|
||||||
|
return &(it->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenVerilog::ChannelEntry::AssignPort(
|
||||||
|
std::string port, std::string value, Type dtype) {
|
||||||
|
CHECK(!ports.count(port))
|
||||||
|
<< "port " << port
|
||||||
|
<< " of channel " << block->channel << " has already been connected";
|
||||||
|
ports[port] = PortEntry{value, dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
const CodeGenVerilog::PortEntry&
|
||||||
|
CodeGenVerilog::ChannelEntry::GetPort(const std::string& port) const {
|
||||||
|
auto it = ports.find(port);
|
||||||
|
CHECK(it != ports.end())
|
||||||
|
<< "port " << port
|
||||||
|
<< " of channel " << block->channel << " has not been connected";
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CodeGenVerilog::ChannelEntry::SignalPortName(int index) const {
|
||||||
|
CHECK_LT(static_cast<size_t>(index), block->ctrl_signals.size());
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "ctrl_port" << index;
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
} // namespace verilog
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace tvm
|
|
@ -0,0 +1,223 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file codegen_verilog.h
|
||||||
|
* \brief Generate verilog code.
|
||||||
|
*/
|
||||||
|
#ifndef TVM_CODEGEN_VERILOG_CODEGEN_VERILOG_H_
|
||||||
|
#define TVM_CODEGEN_VERILOG_CODEGEN_VERILOG_H_
|
||||||
|
|
||||||
|
#include <tvm/base.h>
|
||||||
|
#include <tvm/ir.h>
|
||||||
|
#include <tvm/ir_functor_ext.h>
|
||||||
|
#include <tvm/codegen.h>
|
||||||
|
#include <tvm/lowered_func.h>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "./verilog_ir.h"
|
||||||
|
#include "../codegen_source_base.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace codegen {
|
||||||
|
namespace verilog {
|
||||||
|
using namespace ir;
|
||||||
|
|
||||||
|
/* \brief The variable type in register.*/
|
||||||
|
enum VerilogVarType {
|
||||||
|
kWire,
|
||||||
|
kInput,
|
||||||
|
kOutput,
|
||||||
|
kReg,
|
||||||
|
kConst
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief The verilog value */
|
||||||
|
struct VerilogValue {
|
||||||
|
/*! \brief The variable id */
|
||||||
|
std::string vid;
|
||||||
|
/*! \brief The variable type */
|
||||||
|
VerilogVarType vtype{kReg};
|
||||||
|
/*! \brief The data type it encodes */
|
||||||
|
Type dtype;
|
||||||
|
VerilogValue() {}
|
||||||
|
VerilogValue(std::string vid, VerilogVarType vtype, Type dtype)
|
||||||
|
: vid(vid), vtype(vtype), dtype(dtype) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief Information of each procedure function generated */
|
||||||
|
struct VerilogFuncEntry {
|
||||||
|
/*! \brief The original functions */
|
||||||
|
std::vector<Type> arg_types;
|
||||||
|
/*! \brief The real argument ids of the function */
|
||||||
|
std::vector<std::string> arg_ids;
|
||||||
|
/*! \brief The VPI Modules in the function */
|
||||||
|
std::vector<std::string> vpi_modules;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief The code module of generated verilog code.
|
||||||
|
*/
|
||||||
|
class VerilogCodeGenModule {
|
||||||
|
public:
|
||||||
|
/*! \brief the code of each modoules */
|
||||||
|
std::string code;
|
||||||
|
/*! \brief map of functions */
|
||||||
|
std::unordered_map<std::string, VerilogFuncEntry> fmap;
|
||||||
|
/*!
|
||||||
|
* \brief Generate a code that append simulator function to call func_name.
|
||||||
|
* \param func_name The function to be called.
|
||||||
|
* \return The generated code.
|
||||||
|
*/
|
||||||
|
std::string AppendSimMain(const std::string& func_name) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Verilog generator
|
||||||
|
*/
|
||||||
|
class CodeGenVerilog :
|
||||||
|
public ExprFunctor<VerilogValue(const Expr&)>,
|
||||||
|
public CodeGenSourceBase {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief Initialize the code generator.
|
||||||
|
* \param output_ssa Whether output SSA.
|
||||||
|
*/
|
||||||
|
void Init();
|
||||||
|
/*!
|
||||||
|
* \brief Add the function to the generated module.
|
||||||
|
* \param f The function to be compiled.
|
||||||
|
*/
|
||||||
|
void AddFunction(LoweredFunc f);
|
||||||
|
/*!
|
||||||
|
* \brief Finalize the compilation and return the code.
|
||||||
|
* \return The code.
|
||||||
|
*/
|
||||||
|
VerilogCodeGenModule Finish();
|
||||||
|
/*!
|
||||||
|
* \brief Transform expression to verilog value.
|
||||||
|
* \param n The expression to be printed.
|
||||||
|
*/
|
||||||
|
VerilogValue MakeValue(const Expr& n) {
|
||||||
|
return VisitExpr(n);
|
||||||
|
}
|
||||||
|
// The following parts are overloadable print operations.
|
||||||
|
// expression
|
||||||
|
VerilogValue VisitExpr_(const Variable* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Let* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Call* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Add* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Sub* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Mul* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Div* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Mod* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Min* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Max* op) final;
|
||||||
|
VerilogValue VisitExpr_(const EQ* op) final;
|
||||||
|
VerilogValue VisitExpr_(const NE* op) final;
|
||||||
|
VerilogValue VisitExpr_(const LT* op) final;
|
||||||
|
VerilogValue VisitExpr_(const LE* op) final;
|
||||||
|
VerilogValue VisitExpr_(const GT* op) final;
|
||||||
|
VerilogValue VisitExpr_(const GE* op) final;
|
||||||
|
VerilogValue VisitExpr_(const And* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Or* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Cast* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Not* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Select* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Ramp* op) final;
|
||||||
|
VerilogValue VisitExpr_(const Broadcast* op) final;
|
||||||
|
VerilogValue VisitExpr_(const IntImm* op) final;
|
||||||
|
VerilogValue VisitExpr_(const UIntImm* op) final;
|
||||||
|
VerilogValue VisitExpr_(const FloatImm* op) final;
|
||||||
|
VerilogValue VisitExpr_(const StringImm* op) final;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitFuncState(LoweredFunc f);
|
||||||
|
void PrintDecl(const std::string& vid, VerilogVarType vtype, Type dtype,
|
||||||
|
const char* suffix = ";\n", bool indent = true);
|
||||||
|
void PrintAssign(
|
||||||
|
const std::string& target, const std::string& src);
|
||||||
|
void PrintAssignAnd(
|
||||||
|
const std::string& target, const std::vector<std::string>& conds);
|
||||||
|
void PrintLine(const std::string& line);
|
||||||
|
void PrintSSAAssign(
|
||||||
|
const std::string& target, const std::string& src, Type t) final;
|
||||||
|
// make binary op
|
||||||
|
VerilogValue MakeBinary(Type t, VerilogValue a, VerilogValue b, const char* opstr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Hand shake signal name.
|
||||||
|
// These name can be empty.
|
||||||
|
// Indicate that the signal is always true
|
||||||
|
// or do not need to take these signals.
|
||||||
|
struct SignalEntry {
|
||||||
|
std::string valid;
|
||||||
|
std::string ready;
|
||||||
|
};
|
||||||
|
// Information about port
|
||||||
|
struct PortEntry {
|
||||||
|
// The port value
|
||||||
|
std::string value;
|
||||||
|
// The data type
|
||||||
|
Type dtype;
|
||||||
|
};
|
||||||
|
// Channel setup
|
||||||
|
struct ChannelEntry {
|
||||||
|
// The channel block
|
||||||
|
ChannelBlock block;
|
||||||
|
// The port map, on how port is assigned.
|
||||||
|
std::unordered_map<std::string, PortEntry> ports;
|
||||||
|
// Assign port to be valueo
|
||||||
|
void AssignPort(std::string port, std::string value, Type dtype);
|
||||||
|
// Assign port to be valueo
|
||||||
|
const PortEntry& GetPort(const std::string& port) const;
|
||||||
|
// Signal port name
|
||||||
|
std::string SignalPortName(int index) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get wire ssa value from s
|
||||||
|
VerilogValue GetSSAValue(std::string s, Type dtype) {
|
||||||
|
VerilogValue ret;
|
||||||
|
ret.vid = SSAGetID(s, dtype);
|
||||||
|
ret.vtype = kWire;
|
||||||
|
ret.dtype = dtype;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
void CodeGen(const Pipeline& pipeine);
|
||||||
|
// codegen the delays
|
||||||
|
void MakeDelay(const std::string& dst,
|
||||||
|
const std::string& src,
|
||||||
|
Type dtype,
|
||||||
|
int delay,
|
||||||
|
const std::string& not_stall);
|
||||||
|
// codegen the loop macros
|
||||||
|
SignalEntry MakeLoop(const Array<Stmt>& loop);
|
||||||
|
// codegen the loop macros
|
||||||
|
void MakeStageInputs(const ComputeBlock& block,
|
||||||
|
const std::string& not_stall,
|
||||||
|
std::string* out_all_input_valid);
|
||||||
|
// codegen compute block
|
||||||
|
void MakeStore(const ComputeBlock& block, const Store* store);
|
||||||
|
// Codegen of load statement into FIFO
|
||||||
|
void MakeLoadToFIFO(const ComputeBlock& block,
|
||||||
|
const Store* store,
|
||||||
|
const Load* load);
|
||||||
|
// Make channel unit.
|
||||||
|
void MakeChannelUnit(const ChannelEntry& ch);
|
||||||
|
void MakeChannelFIFO(const ChannelEntry& ch);
|
||||||
|
void MakeChannelBuffer(const ChannelEntry& ch);
|
||||||
|
void MakeChannelMemMap(const ChannelEntry& ch);
|
||||||
|
// Get channel information
|
||||||
|
ChannelEntry* GetChannelInfo(const Variable* var);
|
||||||
|
// channel setup map.
|
||||||
|
std::unordered_map<const Variable*, ChannelEntry> cmap_;
|
||||||
|
// list of vpi modules to be hooked.
|
||||||
|
std::vector<std::string> tvm_vpi_modules_;
|
||||||
|
// The signals for done.
|
||||||
|
std::vector<std::string> done_sigs_;
|
||||||
|
// The verilog function.
|
||||||
|
std::unordered_map<std::string, VerilogFuncEntry> functions_;
|
||||||
|
};
|
||||||
|
} // namespace verilog
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace tvm
|
||||||
|
#endif // TVM_CODEGEN_VERILOG_CODEGEN_VERILOG_H_
|
|
@ -0,0 +1,284 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file verilog_ir.cc
|
||||||
|
*/
|
||||||
|
#include <tvm/ir_pass.h>
|
||||||
|
#include <tvm/ir_visitor.h>
|
||||||
|
#include <tvm/ir_mutator.h>
|
||||||
|
#include "./verilog_ir.h"
|
||||||
|
#include "../../arithmetic/compute_expr.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace codegen {
|
||||||
|
namespace verilog {
|
||||||
|
|
||||||
|
using namespace ir;
|
||||||
|
|
||||||
|
ControlSignal ControlSignalNode::make(
|
||||||
|
ControlSignalType type, int advance_size) {
|
||||||
|
auto n = std::make_shared<ControlSignalNode>();
|
||||||
|
n->ctrl_type = type;
|
||||||
|
n->advance_size = advance_size;
|
||||||
|
return ControlSignal(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
StageInput StageInputNode::make(Var var, StageInputType input_type) {
|
||||||
|
std::shared_ptr<StageInputNode> n = std::make_shared<StageInputNode>();
|
||||||
|
n->var = var;
|
||||||
|
n->input_type = input_type;
|
||||||
|
return StageInput(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace stage inputs by placeholder, update the input map.
|
||||||
|
class StageInputReplacer : public IRMutator {
|
||||||
|
public:
|
||||||
|
explicit StageInputReplacer(
|
||||||
|
const std::unordered_map<const Variable*, StageInput>& var_info)
|
||||||
|
: var_info_(var_info) {}
|
||||||
|
|
||||||
|
Expr Mutate_(const Variable* op, const Expr& e) final {
|
||||||
|
if (replace_.count(op)) {
|
||||||
|
return replace_.at(op);
|
||||||
|
}
|
||||||
|
auto it = var_info_.find(op);
|
||||||
|
if (it == var_info_.end()) return e;
|
||||||
|
Var new_var(it->second->var->name_hint + ".sync", op->type);
|
||||||
|
inputs_.Set(new_var, it->second);
|
||||||
|
replace_[op] = new_var;
|
||||||
|
return new_var;
|
||||||
|
}
|
||||||
|
Expr Mutate_(const Load* op, const Expr& e) final {
|
||||||
|
CHECK(is_zero(op->index))
|
||||||
|
<< "Load should be in its own stage.";
|
||||||
|
if (replace_.count(op->buffer_var.get())) {
|
||||||
|
return replace_.at(op->buffer_var.get());
|
||||||
|
}
|
||||||
|
auto it = var_info_.find(op->buffer_var.get());
|
||||||
|
CHECK(it != var_info_.end())
|
||||||
|
<< "Load from unknown channel";
|
||||||
|
Var data(it->second->var->name_hint + ".load.sync", op->type);
|
||||||
|
inputs_.Set(data, it->second);
|
||||||
|
replace_[op->buffer_var.get()] = data;
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
// inputs that get replaced.
|
||||||
|
Map<Var, StageInput> inputs_;
|
||||||
|
// replacement map
|
||||||
|
std::unordered_map<const Variable*, Var> replace_;
|
||||||
|
// Variable replacement plan.
|
||||||
|
const std::unordered_map<const Variable*, StageInput>& var_info_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief Extract module block */
|
||||||
|
class PipelineExtractor: public IRVisitor {
|
||||||
|
public:
|
||||||
|
Pipeline Extract(LoweredFunc f) {
|
||||||
|
// Initialize the memory map channels
|
||||||
|
// TODO(tqchen) move the logic to explicit specification.
|
||||||
|
for (auto arg : f->args) {
|
||||||
|
if (arg.type().is_handle()) {
|
||||||
|
arg_handle_[arg.get()] = arg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pipeline_ = std::make_shared<PipelineNode>();
|
||||||
|
this->Visit(f->body);
|
||||||
|
// setup channels
|
||||||
|
for (const auto &kv : cmap_) {
|
||||||
|
pipeline_->channels.Set(
|
||||||
|
kv.second.node->channel->handle_var,
|
||||||
|
ChannelBlock(kv.second.node));
|
||||||
|
}
|
||||||
|
pipeline_->args = f->args;
|
||||||
|
return Pipeline(pipeline_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Visit_(const AttrStmt* op) final {
|
||||||
|
if (op->type_key == attr::pipeline_stage_scope) {
|
||||||
|
CHECK(!in_pipeline_stage_);
|
||||||
|
in_pipeline_stage_ = true;
|
||||||
|
trigger_.emplace_back(std::make_pair(loop_.size(), op));
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
trigger_.pop_back();
|
||||||
|
in_pipeline_stage_ = false;
|
||||||
|
} else if (op->type_key == attr::channel_read_advance ||
|
||||||
|
op->type_key == attr::channel_write_advance) {
|
||||||
|
trigger_.emplace_back(std::make_pair(loop_.size(), op));
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
trigger_.pop_back();
|
||||||
|
} else if (op->type_key == attr::channel_read_scope ||
|
||||||
|
op->type_key == attr::channel_write_scope) {
|
||||||
|
Channel ch(op->node.node_);
|
||||||
|
ChannelEntry& cb = cmap_[ch->handle_var.get()];
|
||||||
|
if (cb.node != nullptr) {
|
||||||
|
CHECK(cb.node->channel.same_as(ch));
|
||||||
|
} else {
|
||||||
|
cb.node = std::make_shared<ChannelBlockNode>();
|
||||||
|
cb.node->channel = ch;
|
||||||
|
}
|
||||||
|
if (op->type_key == attr::channel_read_scope) {
|
||||||
|
CHECK_EQ(cb.read_ref_count, 0)
|
||||||
|
<< "One channel can only be read from one consumer";
|
||||||
|
++cb.read_ref_count;
|
||||||
|
CHECK(arith::GetConstInt(op->value, &(cb.node->read_window)))
|
||||||
|
<< "Only supprt constant read window";
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(cb.write_ref_count, 0)
|
||||||
|
<< "One channel can only be write by one producer";
|
||||||
|
++cb.write_ref_count;
|
||||||
|
CHECK(arith::GetConstInt(op->value, &(cb.node->write_window)))
|
||||||
|
<< "Only supprt constant write window";
|
||||||
|
}
|
||||||
|
var_info_[ch->handle_var.get()] =
|
||||||
|
StageInputNode::make(ch->handle_var, kChannel);
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
var_info_.erase(ch->handle_var.get());
|
||||||
|
} else {
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Visit_(const Block* op) final {
|
||||||
|
CHECK(!in_pipeline_stage_)
|
||||||
|
<< "Do not support serial execution inside pipeline";
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
}
|
||||||
|
void Visit_(const IfThenElse* op) final {
|
||||||
|
LOG(FATAL) << "Not implemeneted";
|
||||||
|
}
|
||||||
|
void Visit_(const For* op) final {
|
||||||
|
if (in_pipeline_stage_) {
|
||||||
|
loop_.push_back(
|
||||||
|
For::make(op->loop_var, op->min, op->extent,
|
||||||
|
op->for_type, op->device_api, Evaluate::make(0)));
|
||||||
|
var_info_[op->loop_var.get()] =
|
||||||
|
StageInputNode::make(Var(op->loop_var.node_), kLoopVar);
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
var_info_.erase(op->loop_var.get());
|
||||||
|
loop_.pop_back();
|
||||||
|
} else {
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Visit_(const Store* op) final {
|
||||||
|
// Check the access pattern
|
||||||
|
Channel arg_write =
|
||||||
|
CheckArgHandleAccess(op->buffer_var.get(), op->value.type(), false);
|
||||||
|
this->Visit(op->value);
|
||||||
|
// The replace logic
|
||||||
|
StageInputReplacer repl(var_info_);
|
||||||
|
// Setup the compute block.
|
||||||
|
std::shared_ptr<ComputeBlockNode> compute =
|
||||||
|
std::make_shared<ComputeBlockNode>();
|
||||||
|
compute->loop = Array<Stmt>(loop_);
|
||||||
|
// setup the advance triggers
|
||||||
|
for (const auto& e : trigger_) {
|
||||||
|
const AttrStmt* attr = e.second;
|
||||||
|
Channel ch;
|
||||||
|
if (attr->type_key == attr::pipeline_stage_scope) {
|
||||||
|
ch = arg_write;
|
||||||
|
if (!ch.defined()) continue;
|
||||||
|
} else {
|
||||||
|
ch = Channel(attr->node.node_);
|
||||||
|
}
|
||||||
|
std::shared_ptr<SignalTriggerNode> trigger
|
||||||
|
= std::make_shared<SignalTriggerNode>();
|
||||||
|
trigger->channel_var = ch->handle_var;
|
||||||
|
// predicate for the trigger
|
||||||
|
Expr predicate = const_true();
|
||||||
|
for (size_t i = e.first; i < loop_.size(); ++i) {
|
||||||
|
const For* loop = loop_[i].as<For>();
|
||||||
|
predicate = predicate &&
|
||||||
|
(loop->loop_var == (loop->extent - 1));
|
||||||
|
}
|
||||||
|
trigger->predicate = ir::Simplify(predicate);
|
||||||
|
// Add the signal back to the channels.
|
||||||
|
ChannelEntry& cb = cmap_.at(ch->handle_var.get());
|
||||||
|
trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size());
|
||||||
|
// Grab the advance constant size.
|
||||||
|
int trigger_size;
|
||||||
|
if (attr->type_key == attr::pipeline_stage_scope) {
|
||||||
|
cb.node->ctrl_signals.push_back(
|
||||||
|
ControlSignalNode::make(kComputeFinish, 0));
|
||||||
|
} else if (attr->type_key == attr::channel_read_advance) {
|
||||||
|
CHECK(arith::GetConstInt(attr->value, &trigger_size))
|
||||||
|
<< "Only support constant advance size";
|
||||||
|
cb.node->ctrl_signals.push_back(
|
||||||
|
ControlSignalNode::make(kReadAdvance, trigger_size));
|
||||||
|
} else {
|
||||||
|
CHECK(arith::GetConstInt(attr->value, &trigger_size))
|
||||||
|
<< "Only support constant advance size";
|
||||||
|
cb.node->ctrl_signals.push_back(
|
||||||
|
ControlSignalNode::make(kWriteAdvance, trigger_size));
|
||||||
|
}
|
||||||
|
compute->triggers.push_back(SignalTrigger(trigger));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we are writing to FIFO.
|
||||||
|
const Load* load = op->value.as<Load>();
|
||||||
|
if (is_zero(op->index) && load) {
|
||||||
|
compute->body = Store::make(
|
||||||
|
op->buffer_var,
|
||||||
|
Load::make(load->type, load->buffer_var, repl.Mutate(load->index)),
|
||||||
|
op->index);
|
||||||
|
} else {
|
||||||
|
compute->body = Store::make(
|
||||||
|
op->buffer_var, repl.Mutate(op->value), repl.Mutate(op->index));
|
||||||
|
}
|
||||||
|
compute->inputs = repl.inputs_;
|
||||||
|
pipeline_->stages.push_back(ComputeBlock(compute));
|
||||||
|
}
|
||||||
|
void Visit_(const LetStmt* op) final {
|
||||||
|
LOG(FATAL) << "cannot pass through let";
|
||||||
|
}
|
||||||
|
void Visit_(const Evaluate* op) final {
|
||||||
|
LOG(FATAL) << "Not implemeneted";
|
||||||
|
}
|
||||||
|
void Visit_(const Allocate* op) final {
|
||||||
|
CHECK(!in_pipeline_stage_);
|
||||||
|
}
|
||||||
|
void Visit_(const AssertStmt* op) final {
|
||||||
|
LOG(FATAL) << "Not implemeneted";
|
||||||
|
}
|
||||||
|
void Visit_(const Load* op) final {
|
||||||
|
CheckArgHandleAccess(op->buffer_var.get(), op->type, true);
|
||||||
|
}
|
||||||
|
Channel CheckArgHandleAccess(const Variable* var, Type dtype, bool read_access) {
|
||||||
|
if (!arg_handle_.count(var)) return Channel();
|
||||||
|
CHECK(!cmap_.count(var))
|
||||||
|
<< "Multiple access to the same handle";
|
||||||
|
ChannelEntry& cb = cmap_[var];
|
||||||
|
cb.node = std::make_shared<ChannelBlockNode>();
|
||||||
|
cb.node->channel = ChannelNode::make(arg_handle_.at(var), dtype);
|
||||||
|
return cb.node->channel;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The channel information.
|
||||||
|
struct ChannelEntry {
|
||||||
|
std::shared_ptr<ChannelBlockNode> node;
|
||||||
|
int read_ref_count{0};
|
||||||
|
int write_ref_count{0};
|
||||||
|
};
|
||||||
|
// Whether we are inside the pipeline stage.
|
||||||
|
bool in_pipeline_stage_{false};
|
||||||
|
// The current loop nest
|
||||||
|
std::vector<Stmt> loop_;
|
||||||
|
// Advance signal trigger
|
||||||
|
std::vector<std::pair<size_t, const AttrStmt*> > trigger_;
|
||||||
|
// Read write scope
|
||||||
|
std::vector<const AttrStmt*> channel_scope_;
|
||||||
|
// The loop index.
|
||||||
|
std::unordered_map<const Variable*, StageInput> var_info_;
|
||||||
|
// The channel entry;
|
||||||
|
std::unordered_map<const Variable*, ChannelEntry> cmap_;
|
||||||
|
// The argument handle map
|
||||||
|
std::unordered_map<const Variable*, Var> arg_handle_;
|
||||||
|
// The result block.
|
||||||
|
std::shared_ptr<PipelineNode> pipeline_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Pipeline MakePipeline(LoweredFunc f) {
|
||||||
|
return PipelineExtractor().Extract(f);
|
||||||
|
}
|
||||||
|
} // namespace verilog
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace tvm
|
|
@ -0,0 +1,188 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file verilog_ir.h
|
||||||
|
* \brief A lowered IR that resembles verilog blocks,
|
||||||
|
* This is data structure before final codegen.
|
||||||
|
*/
|
||||||
|
#ifndef TVM_CODEGEN_VERILOG_VERILOG_IR_H_
|
||||||
|
#define TVM_CODEGEN_VERILOG_VERILOG_IR_H_
|
||||||
|
|
||||||
|
#include <tvm/ir.h>
|
||||||
|
#include <tvm/expr.h>
|
||||||
|
#include <tvm/channel.h>
|
||||||
|
#include <tvm/lowered_func.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace codegen {
|
||||||
|
namespace verilog {
|
||||||
|
|
||||||
|
/*! \brief The data argument type */
|
||||||
|
enum StageInputType : int {
|
||||||
|
/*! \brief Data channel input. */
|
||||||
|
kChannel,
|
||||||
|
/*! \brief Loop variable generated by compute block. */
|
||||||
|
kLoopVar,
|
||||||
|
/*! \brief Global constant. */
|
||||||
|
kGlobalConst
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief The data argument type */
|
||||||
|
enum ControlSignalType : int {
|
||||||
|
// Read advance signal
|
||||||
|
kReadAdvance,
|
||||||
|
// Write advance signal
|
||||||
|
kWriteAdvance,
|
||||||
|
// Pipeline stage finish signal
|
||||||
|
kComputeFinish
|
||||||
|
};
|
||||||
|
|
||||||
|
class ControlSignal;
|
||||||
|
class StageInput;
|
||||||
|
class SignalTrigger;
|
||||||
|
|
||||||
|
/*! \brief The control signal of a channel */
|
||||||
|
struct ControlSignalNode : public Node {
|
||||||
|
/*! \brief The control signal type */
|
||||||
|
ControlSignalType ctrl_type;
|
||||||
|
/*! \brief Advance size of the signal */
|
||||||
|
int advance_size{0};
|
||||||
|
// visit all attributes
|
||||||
|
void VisitAttrs(AttrVisitor* v) final {
|
||||||
|
v->Visit("ctrl_type", &ctrl_type);
|
||||||
|
v->Visit("advance_size", &advance_size);
|
||||||
|
}
|
||||||
|
static ControlSignal make(ControlSignalType ctrl_type, int advance_size);
|
||||||
|
static constexpr const char* _type_key = "VerilogControlSignal";
|
||||||
|
TVM_DECLARE_NODE_TYPE_INFO(ControlSignalNode, Node);
|
||||||
|
};
|
||||||
|
|
||||||
|
TVM_DEFINE_NODE_REF(ControlSignal, ControlSignalNode);
|
||||||
|
|
||||||
|
/*! \brief Information about channel. */
|
||||||
|
struct ChannelBlockNode : public Node {
|
||||||
|
/*! \brief The channel we are refer to */
|
||||||
|
Channel channel;
|
||||||
|
/*! \brief Read window */
|
||||||
|
int read_window{0};
|
||||||
|
/*! \brief Write window */
|
||||||
|
int write_window{0};
|
||||||
|
/*! \brief Control signals in the channel */
|
||||||
|
Array<ControlSignal> ctrl_signals;
|
||||||
|
// visit all attributes
|
||||||
|
void VisitAttrs(AttrVisitor* v) final {
|
||||||
|
v->Visit("channel", &channel);
|
||||||
|
v->Visit("read_window", &read_window);
|
||||||
|
v->Visit("write_window", &write_window);
|
||||||
|
v->Visit("ctrl_signals", &ctrl_signals);
|
||||||
|
}
|
||||||
|
static constexpr const char* _type_key = "VerilogChannelBlock";
|
||||||
|
TVM_DECLARE_NODE_TYPE_INFO(ChannelBlockNode, Node);
|
||||||
|
};
|
||||||
|
|
||||||
|
TVM_DEFINE_NODE_REF(ChannelBlock, ChannelBlockNode);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Input to the compute block.
|
||||||
|
* These represents the data values that need to be shared;
|
||||||
|
*/
|
||||||
|
struct StageInputNode : public Node {
|
||||||
|
/*!
|
||||||
|
* \brief The corresponding var of the input
|
||||||
|
* For loop and global const it is the var.
|
||||||
|
* For channel this corresponds to the channel handle.
|
||||||
|
*/
|
||||||
|
Var var;
|
||||||
|
/*! \brief The type of the input. */
|
||||||
|
StageInputType input_type;
|
||||||
|
// visit all attributes
|
||||||
|
void VisitAttrs(AttrVisitor* v) final {
|
||||||
|
v->Visit("var", &var);
|
||||||
|
v->Visit("input_type", &input_type);
|
||||||
|
}
|
||||||
|
// constructor
|
||||||
|
static StageInput make(Var var, StageInputType input_type);
|
||||||
|
static constexpr const char* _type_key = "VerilogStageInput";
|
||||||
|
TVM_DECLARE_NODE_TYPE_INFO(StageInputNode, Node);
|
||||||
|
};
|
||||||
|
|
||||||
|
TVM_DEFINE_NODE_REF(StageInput, StageInputNode);
|
||||||
|
|
||||||
|
/*! \brief The trigger signal for certain channel */
|
||||||
|
struct SignalTriggerNode : public Node {
|
||||||
|
/*! \brief The channel handle variable */
|
||||||
|
Var channel_var;
|
||||||
|
/*! \brief Boolean predicate to trigger the signal */
|
||||||
|
Expr predicate;
|
||||||
|
/*! \brief siginal index of the channel */
|
||||||
|
int signal_index;
|
||||||
|
// visit all attributes
|
||||||
|
void VisitAttrs(AttrVisitor* v) final {
|
||||||
|
v->Visit("channel_var", &channel_var);
|
||||||
|
v->Visit("predicate", &predicate);
|
||||||
|
v->Visit("signal_index", &signal_index);
|
||||||
|
}
|
||||||
|
// constructor
|
||||||
|
static constexpr const char* _type_key = "VerilogSignalTrigger";
|
||||||
|
TVM_DECLARE_NODE_TYPE_INFO(SignalTriggerNode, Node);
|
||||||
|
};
|
||||||
|
|
||||||
|
TVM_DEFINE_NODE_REF(SignalTrigger, SignalTriggerNode);
|
||||||
|
|
||||||
|
/*! \brief compute block for verilog */
|
||||||
|
struct ComputeBlockNode : public Node {
|
||||||
|
/*! \brief The body of the block. */
|
||||||
|
Stmt body;
|
||||||
|
/*! \brief The loop nest around the body, each is a For with no_op as body */
|
||||||
|
Array<Stmt> loop;
|
||||||
|
/*! \brief The channel advance trigger */
|
||||||
|
Array<SignalTrigger> triggers;
|
||||||
|
/*! \brief The input variables that need to be synced. */
|
||||||
|
Map<Var, StageInput> inputs;
|
||||||
|
// visit all attributes
|
||||||
|
void VisitAttrs(AttrVisitor* v) final {
|
||||||
|
v->Visit("body", &body);
|
||||||
|
v->Visit("loop", &loop);
|
||||||
|
v->Visit("triggers", &triggers);
|
||||||
|
v->Visit("inputs", &inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr const char* _type_key = "VerilogComputeBlock";
|
||||||
|
TVM_DECLARE_NODE_TYPE_INFO(ComputeBlockNode, Node);
|
||||||
|
};
|
||||||
|
|
||||||
|
TVM_DEFINE_NODE_REF(ComputeBlock, ComputeBlockNode);
|
||||||
|
|
||||||
|
/*! \brief Codeblock for verilog module. */
|
||||||
|
struct PipelineNode : public Node {
|
||||||
|
/*! \brief arguments to the module */
|
||||||
|
Array<Var> args;
|
||||||
|
/*! \brief Computation stages */
|
||||||
|
Array<ComputeBlock> stages;
|
||||||
|
/*! \brief The data channels */
|
||||||
|
Map<Var, ChannelBlock> channels;
|
||||||
|
|
||||||
|
// visit all attributes
|
||||||
|
void VisitAttrs(AttrVisitor* v) final {
|
||||||
|
v->Visit("args", &args);
|
||||||
|
v->Visit("stages", &stages);
|
||||||
|
v->Visit("channels", &channels);
|
||||||
|
}
|
||||||
|
static constexpr const char* _type_key = "VerilogPipeline";
|
||||||
|
TVM_DECLARE_NODE_TYPE_INFO(PipelineNode, Node);
|
||||||
|
};
|
||||||
|
|
||||||
|
TVM_DEFINE_NODE_REF(Pipeline, PipelineNode);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Build a lowered verilog pipeline given function.
|
||||||
|
* \param f The function to be transformed.
|
||||||
|
* \param The created verilog pipeline.
|
||||||
|
*/
|
||||||
|
Pipeline MakePipeline(LoweredFunc f);
|
||||||
|
} // namespace verilog
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace tvm
|
||||||
|
#endif // TVM_CODEGEN_VERILOG_VERILOG_IR_H_
|
|
@ -0,0 +1,98 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file verilog_module.cc
|
||||||
|
* \brief Build verilog source code.
|
||||||
|
*/
|
||||||
|
#include <tvm/runtime/packed_func.h>
|
||||||
|
#include <tvm/codegen.h>
|
||||||
|
#include <mutex>
|
||||||
|
#include "./codegen_verilog.h"
|
||||||
|
#include "../../runtime/file_util.h"
|
||||||
|
#include "../../runtime/meta_data.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace codegen {
|
||||||
|
namespace verilog {
|
||||||
|
using runtime::TVMArgs;
|
||||||
|
using runtime::TVMRetValue;
|
||||||
|
using runtime::PackedFunc;
|
||||||
|
|
||||||
|
// Simulator function
|
||||||
|
class VerilogModuleNode : public runtime::ModuleNode {
|
||||||
|
public:
|
||||||
|
VerilogModuleNode() : fmt_("v") {}
|
||||||
|
const char* type_key() const {
|
||||||
|
return "verilog";
|
||||||
|
}
|
||||||
|
void PreCompile(const std::string& name, TVMContext ctx) final {
|
||||||
|
}
|
||||||
|
|
||||||
|
PackedFunc GetFunction(
|
||||||
|
const std::string& name,
|
||||||
|
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
|
||||||
|
CHECK(sptr_to_self.get() == this);
|
||||||
|
if (name == runtime::symbol::tvm_entry_setdevice) {
|
||||||
|
return PackedFunc([](const TVMArgs& args, TVMRetValue* rv){});
|
||||||
|
}
|
||||||
|
CHECK(m_.fmap.count(name)) << "Cannot find function " << name << " in the module";
|
||||||
|
|
||||||
|
auto f = [sptr_to_self, name, this](const runtime::TVMArgs& args, TVMRetValue* rv) {
|
||||||
|
auto* fsim = runtime::Registry::Get("tvm_callback_verilog_simulator");
|
||||||
|
CHECK(fsim != nullptr)
|
||||||
|
<< "tvm_callback_verilog_simulator is not registered,"
|
||||||
|
<<" did you import tvm.addon.verilog?";
|
||||||
|
std::string code = m_.AppendSimMain(name);
|
||||||
|
|
||||||
|
if (const auto* f = runtime::Registry::Get("tvm_callback_verilog_postproc")) {
|
||||||
|
code = (*f)(code).operator std::string();
|
||||||
|
}
|
||||||
|
std::vector<TVMValue> values;
|
||||||
|
std::vector<int> codes;
|
||||||
|
TVMValue v;
|
||||||
|
v.v_str = code.c_str();
|
||||||
|
values.push_back(v);
|
||||||
|
codes.push_back(kStr);
|
||||||
|
for (int i = 0; i < args.num_args; ++i) {
|
||||||
|
values.push_back(args.values[i]);
|
||||||
|
codes.push_back(args.type_codes[i]);
|
||||||
|
}
|
||||||
|
fsim->CallPacked(TVMArgs(&values[0], &codes[0], args.num_args + 1), rv);
|
||||||
|
};
|
||||||
|
return PackedFunc(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SaveToFile(const std::string& file_name,
|
||||||
|
const std::string& format) final {
|
||||||
|
LOG(FATAL) << "not implemented";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetSource(const std::string& format) final {
|
||||||
|
return m_.code;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Init(const Array<LoweredFunc>& funcs) {
|
||||||
|
CodeGenVerilog cg;
|
||||||
|
cg.Init();
|
||||||
|
for (LoweredFunc f : funcs) {
|
||||||
|
cg.AddFunction(f);
|
||||||
|
}
|
||||||
|
m_ = cg.Finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// the verilog code. data
|
||||||
|
VerilogCodeGenModule m_;
|
||||||
|
// format;
|
||||||
|
std::string fmt_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_codegen_build_verilog)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
std::shared_ptr<VerilogModuleNode> n =
|
||||||
|
std::make_shared<VerilogModuleNode>();
|
||||||
|
n->Init(args[0]);
|
||||||
|
*rv = runtime::Module(n);
|
||||||
|
});
|
||||||
|
} // namespace verilog
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace tvm
|
|
@ -186,8 +186,8 @@ class VPIMemoryInterface {
|
||||||
read_unit_bytes_ = read_bits / 8U;
|
read_unit_bytes_ = read_bits / 8U;
|
||||||
write_unit_bytes_ = write_bits / 8U;
|
write_unit_bytes_ = write_bits / 8U;
|
||||||
}
|
}
|
||||||
// Callback at post-edge.
|
// Callback at neg-edge.
|
||||||
void AtPosEedge() {
|
void AtNegEdge() {
|
||||||
// reset
|
// reset
|
||||||
if (in_rst_.get_int()) {
|
if (in_rst_.get_int()) {
|
||||||
CHECK_EQ(pending_read_.size, 0U);
|
CHECK_EQ(pending_read_.size, 0U);
|
||||||
|
@ -358,7 +358,7 @@ class VPIReadMemMap : public VPIMemMapBase {
|
||||||
void Init(VPIHandle module) {
|
void Init(VPIHandle module) {
|
||||||
VPIMemMapBase::Init(module, "reg_data");
|
VPIMemMapBase::Init(module, "reg_data");
|
||||||
}
|
}
|
||||||
void AtPosEedge() {
|
void AtNegEdge() {
|
||||||
void* ptr = RealAddr();
|
void* ptr = RealAddr();
|
||||||
if (ptr == nullptr) return;
|
if (ptr == nullptr) return;
|
||||||
size_t nwords = (unit_bytes_ + 3) / 4;
|
size_t nwords = (unit_bytes_ + 3) / 4;
|
||||||
|
@ -373,7 +373,7 @@ class VPIWriteMemMap : public VPIMemMapBase {
|
||||||
VPIMemMapBase::Init(module, "data_in");
|
VPIMemMapBase::Init(module, "data_in");
|
||||||
enable_ = module["en"];
|
enable_ = module["en"];
|
||||||
}
|
}
|
||||||
void AtPosEedge() {
|
void AtNegEdge() {
|
||||||
if (!enable_.get_int() || rst_.get_int()) return;
|
if (!enable_.get_int() || rst_.get_int()) return;
|
||||||
void* ptr = RealAddr();
|
void* ptr = RealAddr();
|
||||||
CHECK(ptr != nullptr)
|
CHECK(ptr != nullptr)
|
||||||
|
@ -398,7 +398,7 @@ void TVMVPIHook(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
|
||||||
p->Init(m);
|
p->Init(m);
|
||||||
LOG(INFO) << "Hook " << m.name() << " to tvm vpi simulation...";
|
LOG(INFO) << "Hook " << m.name() << " to tvm vpi simulation...";
|
||||||
PackedFunc pf([p](const runtime::TVMArgs&, runtime::TVMRetValue*) {
|
PackedFunc pf([p](const runtime::TVMArgs&, runtime::TVMRetValue*) {
|
||||||
p->AtPosEedge();
|
p->AtNegEdge();
|
||||||
});
|
});
|
||||||
*rv = pf;
|
*rv = pf;
|
||||||
}
|
}
|
||||||
|
|
|
@ -139,13 +139,25 @@ MakeLoopNest(const Stage& stage,
|
||||||
nest[i + 1].emplace_back(
|
nest[i + 1].emplace_back(
|
||||||
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op));
|
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op));
|
||||||
value_map[iv] = var;
|
value_map[iv] = var;
|
||||||
|
} else if (iv->thread_tag == "pipeline") {
|
||||||
|
// pipeline marker.
|
||||||
|
CHECK(is_zero(dom->min));
|
||||||
|
CHECK(is_one(dom->extent));
|
||||||
|
// annotate the extent of the IterVar
|
||||||
|
nest[i + 1].emplace_back(
|
||||||
|
AttrStmt::make(iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
|
||||||
|
value_map[iv] = dom->min;
|
||||||
} else {
|
} else {
|
||||||
// Always restrict threaded IterVar to starts from 0.
|
// Always restrict threaded IterVar to starts from 0.
|
||||||
CHECK(is_zero(dom->min));
|
CHECK(is_zero(dom->min));
|
||||||
// annotate the extent of the IterVar
|
// annotate the extent of the IterVar
|
||||||
nest[i + 1].emplace_back(
|
nest[i + 1].emplace_back(
|
||||||
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op));
|
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op));
|
||||||
value_map[iv] = var;
|
if (is_one(dom->extent)) {
|
||||||
|
value_map[iv] = dom->min;
|
||||||
|
} else {
|
||||||
|
value_map[iv] = var;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// annotate the extent of the IterVar
|
// annotate the extent of the IterVar
|
||||||
if (!new_loop_var) {
|
if (!new_loop_var) {
|
||||||
|
|
|
@ -0,0 +1,223 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file narrow_channel_access.cc
|
||||||
|
* \brief Narrow channel access to a smaller range
|
||||||
|
* when possible by bringing it to the internal loop.
|
||||||
|
*/
|
||||||
|
#include <tvm/ir.h>
|
||||||
|
#include <tvm/expr.h>
|
||||||
|
#include <tvm/ir_pass.h>
|
||||||
|
#include <tvm/ir_visitor.h>
|
||||||
|
#include <tvm/ir_mutator.h>
|
||||||
|
#include <tvm/arithmetic.h>
|
||||||
|
#include <tvm/channel.h>
|
||||||
|
#include "./ir_util.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace ir {
|
||||||
|
using namespace arith;
|
||||||
|
|
||||||
|
// Bound deducer for channel access.
|
||||||
|
class ChannelAccessBound : public IRVisitor {
|
||||||
|
public:
|
||||||
|
ChannelAccessBound(const Variable* buf_var, bool read_access)
|
||||||
|
: buf_var_(buf_var), read_access_(read_access) {}
|
||||||
|
|
||||||
|
void Visit_(const Store* op) final {
|
||||||
|
if (!read_access_ && buf_var_ == op->buffer_var.get()) {
|
||||||
|
ret_.emplace_back(EvalSet(op->index, dom_map_));
|
||||||
|
}
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
}
|
||||||
|
void Visit_(const For* op) final {
|
||||||
|
CHECK(is_zero(op->min));
|
||||||
|
// We know that the extent of the loop won't depend on relaxed scope.
|
||||||
|
// TODO(tqchen) have a verification pass.
|
||||||
|
dom_map_[op->loop_var.get()] = IntSet::interval(op->min, op->extent - 1);
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
}
|
||||||
|
void Visit_(const Load* op) final {
|
||||||
|
if (read_access_ && buf_var_ == op->buffer_var.get()) {
|
||||||
|
ret_.emplace_back(EvalSet(op->index, dom_map_));
|
||||||
|
}
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
}
|
||||||
|
void Visit_(const Let* op) final {
|
||||||
|
LOG(FATAL) << "cannot pass through let";
|
||||||
|
}
|
||||||
|
void Visit_(const LetStmt* op) final {
|
||||||
|
LOG(FATAL) << "cannot pass through let";
|
||||||
|
}
|
||||||
|
IntSet Eval(const Stmt& stmt) {
|
||||||
|
Visit(stmt);
|
||||||
|
return Union(ret_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The buffer variable.
|
||||||
|
const Variable* buf_var_;
|
||||||
|
// read or write
|
||||||
|
bool read_access_{true};
|
||||||
|
// Box
|
||||||
|
std::vector<IntSet> ret_;
|
||||||
|
// Domain map.
|
||||||
|
std::unordered_map<const Variable*, IntSet> dom_map_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ChannelAccessIndexRewriter : public IRMutator {
|
||||||
|
public:
|
||||||
|
ChannelAccessIndexRewriter(const Variable* buf_var,
|
||||||
|
Expr min,
|
||||||
|
bool read_access)
|
||||||
|
: buf_var_(buf_var), min_(min), read_access_(read_access) {}
|
||||||
|
Expr Mutate_(const Load* op, const Expr& e) final {
|
||||||
|
Expr expr = IRMutator::Mutate_(op, e);
|
||||||
|
op = expr.as<Load>();
|
||||||
|
if (read_access_ && buf_var_ == op->buffer_var.get()) {
|
||||||
|
return Load::make(
|
||||||
|
op->type, op->buffer_var, ir::Simplify(op->index - min_));
|
||||||
|
} else {
|
||||||
|
return expr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Stmt Mutate_(const Store* op, const Stmt& s) final {
|
||||||
|
Stmt stmt = IRMutator::Mutate_(op, s);
|
||||||
|
op = stmt.as<Store>();
|
||||||
|
if (!read_access_ && buf_var_ == op->buffer_var.get()) {
|
||||||
|
return Store::make(
|
||||||
|
op->buffer_var, op->value, ir::Simplify(op->index - min_));
|
||||||
|
} else {
|
||||||
|
return stmt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The buffer variable.
|
||||||
|
const Variable* buf_var_;
|
||||||
|
// The min bound.
|
||||||
|
Expr min_;
|
||||||
|
// read or write
|
||||||
|
bool read_access_{true};
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Rewrite channel access pattern.
|
||||||
|
class ChannelAccessRewriter : public IRMutator {
|
||||||
|
public:
|
||||||
|
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
|
||||||
|
Stmt ret;
|
||||||
|
const AttrStmt* adv = op->body.as<AttrStmt>();
|
||||||
|
if ((op->type_key == ir::attr::channel_read_scope &&
|
||||||
|
adv && adv->type_key == ir::attr::channel_read_advance) ||
|
||||||
|
(op->type_key == ir::attr::channel_write_scope &&
|
||||||
|
adv && adv->type_key == ir::attr::channel_write_advance)) {
|
||||||
|
RewriteEntry e;
|
||||||
|
e.window = op;
|
||||||
|
e.advance = adv;
|
||||||
|
e.read_access = op->type_key == ir::attr::channel_read_scope;
|
||||||
|
tasks_.push_back(e);
|
||||||
|
ret = IRMutator::Mutate_(op, s);
|
||||||
|
if (tasks_.back().rewrite_success) {
|
||||||
|
ret = ret.as<AttrStmt>()->body.as<AttrStmt>()->body;
|
||||||
|
}
|
||||||
|
tasks_.pop_back();
|
||||||
|
return ret;
|
||||||
|
} else {
|
||||||
|
return IRMutator::Mutate_(op, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Stmt Mutate_(const For* op, const Stmt& s) final {
|
||||||
|
std::vector<RewriteEntry> tasks;
|
||||||
|
std::swap(tasks_, tasks);
|
||||||
|
Stmt body = op->body;
|
||||||
|
std::vector<Stmt> nest;
|
||||||
|
for (RewriteEntry& e : tasks) {
|
||||||
|
body = RewriteAccess(op, body, &e, &nest);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!body.same_as(op->body)) {
|
||||||
|
body = Mutate(body);
|
||||||
|
body = For::make(
|
||||||
|
op->loop_var, op->min, op->extent,
|
||||||
|
op->for_type, op->device_api, body);
|
||||||
|
body = MergeNest(nest, body);
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(nest.size(), 0U);
|
||||||
|
body = IRMutator::Mutate_(op, s);
|
||||||
|
}
|
||||||
|
std::swap(tasks_, tasks);
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct RewriteEntry {
|
||||||
|
bool read_access;
|
||||||
|
const AttrStmt* window;
|
||||||
|
const AttrStmt* advance;
|
||||||
|
bool rewrite_success{false};
|
||||||
|
};
|
||||||
|
|
||||||
|
Stmt RewriteAccess(const For* for_op,
|
||||||
|
Stmt body,
|
||||||
|
RewriteEntry* e,
|
||||||
|
std::vector<Stmt>* outer_nest) {
|
||||||
|
const AttrStmt* adv_op = e->advance;
|
||||||
|
const Expr& window = e->window->value;
|
||||||
|
bool read_access = e->read_access;
|
||||||
|
Var var(for_op->loop_var);
|
||||||
|
Channel ch(adv_op->node.node_);
|
||||||
|
ChannelAccessBound acc(ch->handle_var.get(), read_access);
|
||||||
|
IntSet iset = acc.Eval(for_op->body);
|
||||||
|
Range r = iset.cover_range(Range::make_with_min_extent(0, window));
|
||||||
|
r = Range::make_with_min_extent(
|
||||||
|
ir::Simplify(r->min), ir::Simplify(r->extent));
|
||||||
|
if (ExprUseVar(r->extent, var)) return body;
|
||||||
|
Array<Expr> linear_eq = DetectLinearEquation(r->min, var);
|
||||||
|
if (linear_eq.size() == 0) return body;
|
||||||
|
Expr base = linear_eq[0];
|
||||||
|
Expr coeff = linear_eq[1];
|
||||||
|
if (!is_zero(base)) return body;
|
||||||
|
Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
|
||||||
|
if (!can_prove(left >= 0)) return body;
|
||||||
|
// rewrite access index.
|
||||||
|
ChannelAccessIndexRewriter rw(
|
||||||
|
ch->handle_var.get(), var * coeff, read_access);
|
||||||
|
body = rw.Mutate(body);
|
||||||
|
|
||||||
|
if (read_access) {
|
||||||
|
body = AttrStmt::make(
|
||||||
|
ch, ir::attr::channel_read_scope, r->extent,
|
||||||
|
AttrStmt::make(ch, ir::attr::channel_read_advance, coeff,
|
||||||
|
body));
|
||||||
|
} else {
|
||||||
|
body = AttrStmt::make(
|
||||||
|
ch, ir::attr::channel_write_scope, r->extent,
|
||||||
|
AttrStmt::make(ch, ir::attr::channel_write_advance, coeff,
|
||||||
|
body));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_zero(left)) {
|
||||||
|
Stmt no_op = Evaluate::make(0);
|
||||||
|
if (read_access) {
|
||||||
|
outer_nest->emplace_back(
|
||||||
|
AttrStmt::make(ch, ir::attr::channel_read_advance, left, no_op));
|
||||||
|
} else {
|
||||||
|
outer_nest->emplace_back(
|
||||||
|
AttrStmt::make(ch, ir::attr::channel_write_advance, left, no_op));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
e->rewrite_success = true;
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<RewriteEntry> tasks_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Stmt NarrowChannelAccess(Stmt stmt) {
|
||||||
|
return ChannelAccessRewriter().Mutate(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace tvm
|
|
@ -15,6 +15,7 @@ class IRSideEffect : public IRVisitor {
|
||||||
public:
|
public:
|
||||||
void Visit(const NodeRef& e) final {
|
void Visit(const NodeRef& e) final {
|
||||||
if (has_side_effect_) return;
|
if (has_side_effect_) return;
|
||||||
|
IRVisitor::Visit(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Visit_(const Call* op) final {
|
void Visit_(const Call* op) final {
|
||||||
|
@ -55,5 +56,39 @@ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
|
||||||
}
|
}
|
||||||
return m.Mutate(stmt);
|
return m.Mutate(stmt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ExprUseVarVisitor : public IRVisitor {
|
||||||
|
public:
|
||||||
|
explicit ExprUseVarVisitor(const Variable* var)
|
||||||
|
: var_(var) {}
|
||||||
|
|
||||||
|
void Visit(const NodeRef& e) final {
|
||||||
|
if (use_var_) return;
|
||||||
|
IRVisitor::Visit(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Visit_(const Variable* op) final {
|
||||||
|
if (op == var_) {
|
||||||
|
use_var_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Visit_(const Load* op) final {
|
||||||
|
if (op->buffer_var.get() == var_) {
|
||||||
|
use_var_ = true;
|
||||||
|
}
|
||||||
|
IRVisitor::Visit_(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
const Variable* var_;
|
||||||
|
bool use_var_{false};
|
||||||
|
};
|
||||||
|
|
||||||
|
bool ExprUseVar(const Expr& e, const Var& v) {
|
||||||
|
ExprUseVarVisitor visitor(v.get());
|
||||||
|
visitor.Visit(e);
|
||||||
|
return visitor.use_var_;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ir
|
} // namespace ir
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
|
@ -147,8 +147,8 @@ class IRUseDefAnalysis : public IRMutator {
|
||||||
class HostDeviceSplitter : public IRMutator {
|
class HostDeviceSplitter : public IRMutator {
|
||||||
public:
|
public:
|
||||||
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
|
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
|
||||||
if (op->type_key == "thread_extent") {
|
if (op->type_key == attr::thread_extent ||
|
||||||
IterVar iv(op->node.node_);
|
op->type_key == attr::pipeline_exec_scope) {
|
||||||
return SplitDeviceFunc(s);
|
return SplitDeviceFunc(s);
|
||||||
}
|
}
|
||||||
return IRMutator::Mutate_(op, s);
|
return IRMutator::Mutate_(op, s);
|
||||||
|
@ -195,7 +195,6 @@ class HostDeviceSplitter : public IRMutator {
|
||||||
n->name = os.str();
|
n->name = os.str();
|
||||||
n->args = m.undefined_;
|
n->args = m.undefined_;
|
||||||
n->thread_axis = m.thread_axis_;
|
n->thread_axis = m.thread_axis_;
|
||||||
CHECK_NE(m.thread_extent_.size(), 0U);
|
|
||||||
|
|
||||||
// improve the handle data type
|
// improve the handle data type
|
||||||
for (Var arg : n->args) {
|
for (Var arg : n->args) {
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include <tvm/ir_mutator.h>
|
#include <tvm/ir_mutator.h>
|
||||||
#include <tvm/channel.h>
|
#include <tvm/channel.h>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
#include "./ir_util.h"
|
#include "./ir_util.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
@ -18,14 +19,38 @@ namespace ir {
|
||||||
class MarkChannelAccess : public IRMutator {
|
class MarkChannelAccess : public IRMutator {
|
||||||
public:
|
public:
|
||||||
MarkChannelAccess(
|
MarkChannelAccess(
|
||||||
const std::unordered_map<const Variable*, Channel>& cmap)
|
const std::unordered_map<const Variable*, Channel>& cmap,
|
||||||
: cmap_(cmap) {}
|
const std::unordered_map<const Variable*, Channel>& fifo_map)
|
||||||
|
: cmap_(cmap), fifo_map_(fifo_map) {}
|
||||||
|
using IRMutator::Mutate;
|
||||||
|
Stmt Mutate(Stmt stmt) final {
|
||||||
|
Stmt ret = IRMutator::Mutate(stmt);
|
||||||
|
if (read_fifos_.size() != 0) {
|
||||||
|
for (const Variable* v : read_fifos_) {
|
||||||
|
Channel ch = fifo_map_.at(v);
|
||||||
|
ret = ReadChannel(ch, 1, ret);
|
||||||
|
}
|
||||||
|
read_fifos_.clear();
|
||||||
|
}
|
||||||
|
if (write_fifos_.size() != 0) {
|
||||||
|
for (const Variable* v : write_fifos_) {
|
||||||
|
Channel ch = fifo_map_.at(v);
|
||||||
|
ret = WriteChannel(ch, 1, ret);
|
||||||
|
}
|
||||||
|
write_fifos_.clear();
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
Expr Mutate_(const Load *op, const Expr& e) final {
|
Expr Mutate_(const Load *op, const Expr& e) final {
|
||||||
auto it = rmap_.find(op->buffer_var.get());
|
auto it = rmap_.find(op->buffer_var.get());
|
||||||
if (it != rmap_.end()) {
|
if (it != rmap_.end()) {
|
||||||
++it->second.read_count;
|
++it->second.read_count;
|
||||||
}
|
}
|
||||||
|
if (fifo_map_.count(op->buffer_var.get())) {
|
||||||
|
read_fifos_.insert(op->buffer_var.get());
|
||||||
|
CHECK(!write_fifos_.count(op->buffer_var.get()));
|
||||||
|
}
|
||||||
return IRMutator::Mutate_(op, e);
|
return IRMutator::Mutate_(op, e);
|
||||||
}
|
}
|
||||||
Stmt Mutate_(const Store *op, const Stmt& s) final {
|
Stmt Mutate_(const Store *op, const Stmt& s) final {
|
||||||
|
@ -33,6 +58,10 @@ class MarkChannelAccess : public IRMutator {
|
||||||
if (it != rmap_.end()) {
|
if (it != rmap_.end()) {
|
||||||
++it->second.write_count;
|
++it->second.write_count;
|
||||||
}
|
}
|
||||||
|
if (fifo_map_.count(op->buffer_var.get())) {
|
||||||
|
write_fifos_.insert(op->buffer_var.get());
|
||||||
|
CHECK(!read_fifos_.count(op->buffer_var.get()));
|
||||||
|
}
|
||||||
return IRMutator::Mutate_(op, s);
|
return IRMutator::Mutate_(op, s);
|
||||||
}
|
}
|
||||||
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
|
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
|
||||||
|
@ -79,51 +108,90 @@ class MarkChannelAccess : public IRMutator {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (rw.write_count) {
|
if (rw.write_count) {
|
||||||
return AttrStmt::make(
|
return WriteChannel(ch, alloc_size, body);
|
||||||
ch, ir::attr::channel_write_scope, alloc_size, body);
|
|
||||||
} else {
|
} else {
|
||||||
CHECK(rw.read_count);
|
CHECK(rw.read_count);
|
||||||
return AttrStmt::make(
|
return ReadChannel(ch, alloc_size, body);
|
||||||
ch, ir::attr::channel_read_scope, alloc_size, body);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Stmt ReadChannel(Channel ch, Expr size, Stmt body) {
|
||||||
|
return AttrStmt::make(
|
||||||
|
ch, ir::attr::channel_read_scope, size,
|
||||||
|
AttrStmt::make(ch, ir::attr::channel_read_advance, size,
|
||||||
|
body));
|
||||||
|
}
|
||||||
|
Stmt WriteChannel(Channel ch, Expr size, Stmt body) {
|
||||||
|
return AttrStmt::make(
|
||||||
|
ch, ir::attr::channel_write_scope, size,
|
||||||
|
AttrStmt::make(ch, ir::attr::channel_write_advance, size,
|
||||||
|
body));
|
||||||
|
}
|
||||||
struct Entry {
|
struct Entry {
|
||||||
int read_count{0};
|
int read_count{0};
|
||||||
int write_count{0};
|
int write_count{0};
|
||||||
};
|
};
|
||||||
// The channels of each allocation.
|
// The channels of each allocation.
|
||||||
const std::unordered_map<const Variable*, Channel>& cmap_;
|
const std::unordered_map<const Variable*, Channel>& cmap_;
|
||||||
|
// FIFO map.
|
||||||
|
const std::unordered_map<const Variable*, Channel>& fifo_map_;
|
||||||
// the result.
|
// the result.
|
||||||
std::unordered_map<const Variable*, Entry> rmap_;
|
std::unordered_map<const Variable*, Entry> rmap_;
|
||||||
|
// Accessed FIFOs
|
||||||
|
std::unordered_set<const Variable*> read_fifos_, write_fifos_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Mark the statment of each stage.
|
// Mark the statment of each stage.
|
||||||
class StageSplitter : public IRMutator {
|
class StageSplitter : public IRMutator {
|
||||||
public:
|
public:
|
||||||
|
using IRMutator::Mutate;
|
||||||
|
explicit StageSplitter(bool split_load)
|
||||||
|
: split_load_(split_load) {}
|
||||||
|
|
||||||
Stmt Mutate(Stmt stmt) final {
|
Stmt Mutate(Stmt stmt) final {
|
||||||
nest_.push_back(stmt);
|
nest_.push_back(stmt);
|
||||||
Stmt ret = IRMutator::Mutate(stmt);
|
Stmt ret = IRMutator::Mutate(stmt);
|
||||||
nest_.pop_back();
|
nest_.pop_back();
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) {
|
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
|
||||||
if (!op->is_producer) return IRMutator::Mutate_(op, s);
|
if (!op->is_producer) {
|
||||||
|
return Mutate(op->body);
|
||||||
|
}
|
||||||
Stmt body = Mutate(op->body);
|
Stmt body = Mutate(op->body);
|
||||||
stages_.emplace_back(BuildStage(body, op->func));
|
stages_.emplace_back(BuildStage(body, op->func));
|
||||||
return Evaluate::make(0);
|
return Evaluate::make(0);
|
||||||
}
|
}
|
||||||
|
Expr Mutate_(const Load* op, const Expr& e) final {
|
||||||
|
if (!split_load_) return IRMutator::Mutate_(op, e);
|
||||||
|
std::ostringstream cname;
|
||||||
|
cname << "fifo." << temp_fifo_count_++;
|
||||||
|
// Create FIFO channel for load.
|
||||||
|
Channel ch = ChannelNode::make(Var(cname.str(), Handle()), op->type);
|
||||||
|
Expr index = Mutate(op->index);
|
||||||
|
Stmt provide = Store::make(
|
||||||
|
ch->handle_var,
|
||||||
|
Load::make(op->type, op->buffer_var, index), 0);
|
||||||
|
Stmt temp = nest_.back(); nest_.pop_back();
|
||||||
|
stages_.emplace_back(BuildStage(provide, ch));
|
||||||
|
nest_.push_back(temp);
|
||||||
|
fifo_map_[ch->handle_var.get()] = ch;
|
||||||
|
return Load::make(op->type, ch->handle_var, 0);
|
||||||
|
}
|
||||||
|
|
||||||
Stmt Split(Stmt stmt) {
|
Stmt Split(Stmt stmt, const ProducerConsumer* env) {
|
||||||
stmt = Mutate(stmt);
|
stmt = Mutate(stmt);
|
||||||
stmt = RemoveNoOp(stmt);
|
if (env) {
|
||||||
CHECK(is_no_op(stmt));
|
stages_.emplace_back(BuildStage(stmt, env->func));
|
||||||
|
} else {
|
||||||
|
stmt = RemoveNoOp(stmt);
|
||||||
|
CHECK(is_no_op(stmt));
|
||||||
|
}
|
||||||
CHECK_NE(stages_.size(), 0);
|
CHECK_NE(stages_.size(), 0);
|
||||||
stmt = stages_.back();
|
stmt = stages_.back();
|
||||||
for (size_t i = stages_.size() - 1; i != 0; --i) {
|
for (size_t i = stages_.size() - 1; i != 0; --i) {
|
||||||
stmt = Block::make(stages_[i - 1], stmt);
|
stmt = Block::make(stages_[i - 1], stmt);
|
||||||
}
|
}
|
||||||
stmt = MarkChannelAccess(cmap_).Mutate(stmt);
|
stmt = MarkChannelAccess(cmap_, fifo_map_).Mutate(stmt);
|
||||||
return RemoveNoOp(stmt);
|
return RemoveNoOp(stmt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -184,10 +252,52 @@ class StageSplitter : public IRMutator {
|
||||||
std::vector<Stmt> stages_;
|
std::vector<Stmt> stages_;
|
||||||
// channel map
|
// channel map
|
||||||
std::unordered_map<const Variable*, Channel> cmap_;
|
std::unordered_map<const Variable*, Channel> cmap_;
|
||||||
|
// Whether split load into a temp fifo.
|
||||||
|
bool split_load_{true};
|
||||||
|
// Counter for temp FIFOs.
|
||||||
|
size_t temp_fifo_count_{0};
|
||||||
|
// fifo map
|
||||||
|
std::unordered_map<const Variable*, Channel> fifo_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Stmt SplitPipeline(Stmt stmt) {
|
class PipelineSplitter : public IRMutator {
|
||||||
return StageSplitter().Split(stmt);
|
public:
|
||||||
|
explicit PipelineSplitter(bool split_load)
|
||||||
|
: split_load_(split_load) {}
|
||||||
|
|
||||||
|
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
|
||||||
|
if (op->type_key == ir::attr::pipeline_exec_scope) {
|
||||||
|
CHECK_LE(env_.size(), 1U);
|
||||||
|
const ProducerConsumer* env = nullptr;
|
||||||
|
if (env_.size() == 1) {
|
||||||
|
std::swap(env_[0], env);
|
||||||
|
}
|
||||||
|
Stmt body = StageSplitter(split_load_).Split(
|
||||||
|
op->body, env);
|
||||||
|
if (body.same_as(op->body)) return s;
|
||||||
|
return AttrStmt::make(
|
||||||
|
op->node, op->type_key, op->value, body);
|
||||||
|
} else {
|
||||||
|
return IRMutator::Mutate_(op, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) {
|
||||||
|
env_.push_back(op);
|
||||||
|
Stmt ret = IRMutator::Mutate_(op, s);
|
||||||
|
if (env_.back() == nullptr) {
|
||||||
|
ret = ret.as<ProducerConsumer>()->body;
|
||||||
|
}
|
||||||
|
env_.pop_back();
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool split_load_;
|
||||||
|
std::vector<const ProducerConsumer *> env_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Stmt SplitPipeline(Stmt stmt, bool split_load) {
|
||||||
|
return PipelineSplitter(split_load).Mutate(stmt);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ir
|
} // namespace ir
|
||||||
|
|
|
@ -283,8 +283,9 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
|
||||||
if (fin == nullptr) {
|
if (fin == nullptr) {
|
||||||
*out = new PackedFunc(
|
*out = new PackedFunc(
|
||||||
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
|
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
|
||||||
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
|
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
|
||||||
args.num_args, rv, resource_handle);
|
args.num_args, rv, resource_handle);
|
||||||
|
CHECK_EQ(ret, 0) << "TVMCall CFunc Error:\n" << TVMGetLastError();
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// wrap it in a shared_ptr, with fin as deleter.
|
// wrap it in a shared_ptr, with fin as deleter.
|
||||||
|
@ -292,8 +293,9 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
|
||||||
std::shared_ptr<void> rpack(resource_handle, fin);
|
std::shared_ptr<void> rpack(resource_handle, fin);
|
||||||
*out = new PackedFunc(
|
*out = new PackedFunc(
|
||||||
[func, rpack](TVMArgs args, TVMRetValue* rv) {
|
[func, rpack](TVMArgs args, TVMRetValue* rv) {
|
||||||
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
|
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
|
||||||
args.num_args, rv, rpack.get());
|
args.num_args, rv, rpack.get());
|
||||||
|
CHECK_EQ(ret, 0) << "TVMCall CFunc Error:\n" << TVMGetLastError();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
API_END();
|
API_END();
|
||||||
|
|
|
@ -275,8 +275,7 @@ void InferRootBound(const Stage& stage,
|
||||||
// special optimization to remove trivial loop
|
// special optimization to remove trivial loop
|
||||||
if (is_one(vrange->extent)) {
|
if (is_one(vrange->extent)) {
|
||||||
up_state[iv] = IntSet::single_point(vrange->min);
|
up_state[iv] = IntSet::single_point(vrange->min);
|
||||||
}
|
} else if (fix_value && !ScopeRelax(iv, stage->scope)) {
|
||||||
if (fix_value && !ScopeRelax(iv, stage->scope)) {
|
|
||||||
up_state[iv] = IntSet::single_point(iv->var);
|
up_state[iv] = IntSet::single_point(iv->var);
|
||||||
} else {
|
} else {
|
||||||
up_state[iv] = IntSet::range(vrange);
|
up_state[iv] = IntSet::range(vrange);
|
||||||
|
|
|
@ -26,7 +26,7 @@ def mybuild(fapi, target="llvm"):
|
||||||
|
|
||||||
def test_dot():
|
def test_dot():
|
||||||
nn = 12
|
nn = 12
|
||||||
n = tvm.Var('n')
|
n = tvm.convert(nn)
|
||||||
A = tvm.placeholder((n,), name='A')
|
A = tvm.placeholder((n,), name='A')
|
||||||
B = tvm.placeholder((n,), name='B')
|
B = tvm.placeholder((n,), name='B')
|
||||||
k = tvm.reduce_axis((0, n), 'k')
|
k = tvm.reduce_axis((0, n), 'k')
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
import tvm
|
||||||
|
|
||||||
|
def test_basic():
|
||||||
|
a = tvm.Var("a")
|
||||||
|
b = tvm.Var("b")
|
||||||
|
m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, a)
|
||||||
|
assert m[1].value == 4
|
||||||
|
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7)).value == 0
|
||||||
|
|
||||||
|
m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, a)
|
||||||
|
assert len(m) == 0
|
||||||
|
|
||||||
|
m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, a)
|
||||||
|
assert m[1].value == 5
|
||||||
|
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7 + 1)).value == 0
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_basic()
|
|
@ -29,3 +29,13 @@ def test_convert_ssa():
|
||||||
assert(not tvm.ir_pass.VerifySSA(z))
|
assert(not tvm.ir_pass.VerifySSA(z))
|
||||||
z_ssa = tvm.ir_pass.ConvertSSA(z)
|
z_ssa = tvm.ir_pass.ConvertSSA(z)
|
||||||
assert(tvm.ir_pass.VerifySSA(z_ssa))
|
assert(tvm.ir_pass.VerifySSA(z_ssa))
|
||||||
|
|
||||||
|
|
||||||
|
def test_expr_use_var():
|
||||||
|
x = tvm.Var('x')
|
||||||
|
assert(tvm.ir_pass.ExprUseVar(x+1, x))
|
||||||
|
assert(not tvm.ir_pass.ExprUseVar(1+10, x))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_expr_use_var()
|
||||||
|
|
|
@ -1,5 +1,22 @@
|
||||||
import tvm
|
import tvm
|
||||||
|
|
||||||
|
def lower(s, args):
|
||||||
|
binds = {}
|
||||||
|
arg_list = []
|
||||||
|
|
||||||
|
for x in args:
|
||||||
|
assert isinstance(x, tvm.tensor.Tensor)
|
||||||
|
buf = tvm.Buffer(x.shape, dtype=x.dtype, name=x.op.name)
|
||||||
|
binds[x] = buf
|
||||||
|
arg_list.append(buf)
|
||||||
|
s.normalize()
|
||||||
|
bounds = tvm.schedule.InferBound(s)
|
||||||
|
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||||
|
stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
|
||||||
|
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
|
||||||
|
stmt = tvm.ir_pass.Simplify(stmt)
|
||||||
|
return stmt
|
||||||
|
|
||||||
def test_basic_pipeline():
|
def test_basic_pipeline():
|
||||||
n = tvm.convert(128)
|
n = tvm.convert(128)
|
||||||
A = tvm.placeholder((n,), name='A')
|
A = tvm.placeholder((n,), name='A')
|
||||||
|
@ -12,20 +29,37 @@ def test_basic_pipeline():
|
||||||
B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k)
|
B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k)
|
||||||
|
|
||||||
s = tvm.Schedule(B.op)
|
s = tvm.Schedule(B.op)
|
||||||
xo, xi = s[B].split(B.op.axis[0], factor=4)
|
px = tvm.thread_axis((0, 1), "pipeline")
|
||||||
|
xo, xi = s[B].split(B.op.axis[0], outer=px)
|
||||||
|
xo, xi = s[B].split(xi, factor=4)
|
||||||
for S in stages:
|
for S in stages:
|
||||||
s[S].compute_at(s[B], xo)
|
s[S].compute_at(s[B], xo)
|
||||||
|
|
||||||
# Lowering
|
stmt = lower(s, [A, B])
|
||||||
bounds = tvm.schedule.InferBound(s)
|
stmt = tvm.ir_pass.SplitPipeline(stmt, False)
|
||||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
print(stmt)
|
||||||
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
|
stmt = tvm.ir_pass.NarrowChannelAccess(stmt)
|
||||||
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
|
|
||||||
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb})
|
|
||||||
stmt = tvm.ir_pass.Simplify(stmt)
|
|
||||||
stmt = tvm.ir_pass.SplitPipeline(stmt)
|
|
||||||
print(stmt)
|
print(stmt)
|
||||||
assert(tvm.ir_pass.VerifySSA(stmt))
|
assert(tvm.ir_pass.VerifySSA(stmt))
|
||||||
|
|
||||||
|
def test_conv1d():
|
||||||
|
n = tvm.Var('n')
|
||||||
|
A = tvm.compute((n+2), lambda i: 1, name='A')
|
||||||
|
def computeB(ii):
|
||||||
|
i = ii + 1
|
||||||
|
return A[i-1] + A[i] + A[i+1]
|
||||||
|
B = tvm.compute(n, computeB, name='B')
|
||||||
|
s = tvm.Schedule(B.op)
|
||||||
|
px = tvm.thread_axis((0, 1), "pipeline")
|
||||||
|
xo, xi = s[B].split(B.op.axis[0], outer=px)
|
||||||
|
s[A].compute_at(s[B], px)
|
||||||
|
stmt = lower(s, [B])
|
||||||
|
stmt = tvm.ir_pass.SplitPipeline(stmt, False)
|
||||||
|
print(stmt)
|
||||||
|
stmt = tvm.ir_pass.NarrowChannelAccess(stmt)
|
||||||
|
print(stmt)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_basic_pipeline()
|
test_basic_pipeline()
|
||||||
|
test_conv1d()
|
||||||
|
|
|
@ -36,7 +36,8 @@ if [ ${TASK} == "verilog_test" ] || [ ${TASK} == "all_test" ]; then
|
||||||
make -f tests/travis/packages.mk iverilog
|
make -f tests/travis/packages.mk iverilog
|
||||||
make verilog || exit -1
|
make verilog || exit -1
|
||||||
make all || exit -1
|
make all || exit -1
|
||||||
nosetests -v tests/verilog || exit -1
|
nosetests -v tests/verilog/unittest || exit -1
|
||||||
|
nosetests -v tests/verilog/integration || exit -1
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
import tvm
|
||||||
|
from tvm.addon import testing, verilog
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def lower(s, args, name):
|
||||||
|
binds = {}
|
||||||
|
arg_list = []
|
||||||
|
|
||||||
|
for x in args:
|
||||||
|
assert isinstance(x, tvm.tensor.Tensor)
|
||||||
|
buf = tvm.Buffer(x.shape, dtype=x.dtype, name=x.op.name)
|
||||||
|
binds[x] = buf
|
||||||
|
arg_list.append(buf)
|
||||||
|
s.normalize()
|
||||||
|
bounds = tvm.schedule.InferBound(s)
|
||||||
|
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||||
|
stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
|
||||||
|
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
|
||||||
|
stmt = tvm.ir_pass.Simplify(stmt)
|
||||||
|
stmt = tvm.ir_pass.SplitPipeline(stmt, True)
|
||||||
|
fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0)
|
||||||
|
return fapi
|
||||||
|
|
||||||
|
@tvm.register_func
|
||||||
|
def tvm_callback_verilog_postproc(code):
|
||||||
|
"""Hook to inspect the verilog code before actually run it"""
|
||||||
|
print(code)
|
||||||
|
return code
|
||||||
|
|
||||||
|
def test_add_pipeline():
|
||||||
|
nn = 128
|
||||||
|
n = tvm.convert(nn)
|
||||||
|
A = tvm.placeholder((n,), name='A', dtype='int32')
|
||||||
|
B = tvm.placeholder((n,), name='B', dtype='int32')
|
||||||
|
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
|
||||||
|
s = tvm.Schedule(C.op)
|
||||||
|
|
||||||
|
grid_x = tvm.thread_axis((0, 1), "pipeline")
|
||||||
|
_, x = s[C].split(C.op.axis[0], outer=grid_x)
|
||||||
|
fapi = lower(s, [A, B, C], "myadd")
|
||||||
|
fsplits = tvm.ir_pass.SplitHostDevice(fapi)
|
||||||
|
print(fsplits[1].body)
|
||||||
|
print("------")
|
||||||
|
|
||||||
|
def check_target(device, host="stackvm"):
|
||||||
|
if not tvm.codegen.enabled(host):
|
||||||
|
return
|
||||||
|
if not tvm.codegen.enabled(device):
|
||||||
|
return
|
||||||
|
ctx = tvm.vpi(0)
|
||||||
|
mhost = tvm.codegen.build(fsplits[0], host)
|
||||||
|
mdev = tvm.codegen.build(fsplits[1:], device)
|
||||||
|
mhost.import_module(mdev)
|
||||||
|
code = mdev.get_source()
|
||||||
|
f = mhost.entry_func
|
||||||
|
# launch the kernel.
|
||||||
|
n = nn
|
||||||
|
a = tvm.nd.array((np.random.uniform(size=n) * 128).astype(A.dtype), ctx)
|
||||||
|
b = tvm.nd.array((np.random.uniform(size=n) * 128).astype(A.dtype), ctx)
|
||||||
|
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
|
||||||
|
f(a, b, c)
|
||||||
|
print("Check correctness...")
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
c.asnumpy(), a.asnumpy() + b.asnumpy())
|
||||||
|
check_target("verilog")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_add_pipeline()
|
|
@ -1,29 +0,0 @@
|
||||||
import tvm
|
|
||||||
from tvm.addon import verilog
|
|
||||||
|
|
||||||
def test_counter():
|
|
||||||
# Start a new session by run simulation on test_counter.v
|
|
||||||
# Find file will search root/verilog and root/tests/verilog
|
|
||||||
sess = verilog.session([
|
|
||||||
verilog.find_file("test_counter.v"),
|
|
||||||
verilog.find_file("example_counter.v")
|
|
||||||
])
|
|
||||||
# Get the handles by their names
|
|
||||||
rst = sess.main.rst
|
|
||||||
counter = sess.main.counter
|
|
||||||
cnt = sess.main["counter_unit1"]
|
|
||||||
assert(counter.name == "main.counter")
|
|
||||||
assert(counter.size == 4)
|
|
||||||
rst.put_int(1)
|
|
||||||
# This will advance the cycle to next pos-edge of clk.
|
|
||||||
sess.yield_until_posedge()
|
|
||||||
rst.put_int(0)
|
|
||||||
|
|
||||||
for i in range(10):
|
|
||||||
# get value of counter.
|
|
||||||
assert(counter.get_int() == i)
|
|
||||||
sess.yield_until_posedge()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_counter()
|
|
|
@ -1,35 +0,0 @@
|
||||||
import tvm
|
|
||||||
from tvm.addon import verilog
|
|
||||||
|
|
||||||
def test_loop():
|
|
||||||
sess = verilog.session([
|
|
||||||
verilog.find_file("test_loop.v")
|
|
||||||
])
|
|
||||||
# Get the handles by their names
|
|
||||||
rst = sess.main.rst
|
|
||||||
init = sess.main.init
|
|
||||||
iter0 = sess.main.iter0
|
|
||||||
iter1 = sess.main.iter1
|
|
||||||
enable = sess.main.enable
|
|
||||||
invalid = sess.main.done
|
|
||||||
|
|
||||||
rst.put_int(1)
|
|
||||||
# This will advance the cycle to next pos-edge of clk.
|
|
||||||
sess.yield_until_posedge()
|
|
||||||
rst.put_int(0)
|
|
||||||
init.put_int(1)
|
|
||||||
sess.yield_until_posedge()
|
|
||||||
enable.put_int(1)
|
|
||||||
init.put_int(0)
|
|
||||||
|
|
||||||
for i in range(0, 3):
|
|
||||||
for j in range(0, 4):
|
|
||||||
while invalid.get_int():
|
|
||||||
sess.yield_until_posedge()
|
|
||||||
assert(iter1.get_int() == i)
|
|
||||||
assert(iter0.get_int() == j)
|
|
||||||
sess.yield_until_posedge()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_loop()
|
|
|
@ -1,24 +0,0 @@
|
||||||
`include "tvm_marcos.v"
|
|
||||||
|
|
||||||
module main();
|
|
||||||
parameter PER = 10;
|
|
||||||
reg clk;
|
|
||||||
reg rst;
|
|
||||||
wire init;
|
|
||||||
wire done;
|
|
||||||
wire enable;
|
|
||||||
|
|
||||||
`NORMAL_LOOP_LEAF(iter0, 4, init0, enable, iter0_done, 0, 4, 1)
|
|
||||||
`NORMAL_LOOP_NEST(iter1, 4, init, iter0_done, iter1_done, 0, 3, 1, init0)
|
|
||||||
|
|
||||||
assign done = iter0_done;
|
|
||||||
|
|
||||||
always begin
|
|
||||||
#(PER/2) clk =~ clk;
|
|
||||||
end
|
|
||||||
|
|
||||||
initial begin
|
|
||||||
// This will allow tvm session to be called every cycle.
|
|
||||||
$tvm_session(clk);
|
|
||||||
end
|
|
||||||
endmodule
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
import tvm
|
||||||
|
from tvm.addon import verilog
|
||||||
|
from testing_util import FIFODelayedWriter, FIFODelayedReader
|
||||||
|
|
||||||
|
def run_with_lag(n, read_lag, write_lag):
|
||||||
|
data = list(range(n))
|
||||||
|
# head ptr of a
|
||||||
|
sess = verilog.session([
|
||||||
|
verilog.find_file("test_cache_reg.v")
|
||||||
|
])
|
||||||
|
rst = sess.main.rst
|
||||||
|
in_data = sess.main.in_data
|
||||||
|
in_valid = sess.main.in_valid
|
||||||
|
in_ready = sess.main.in_ready
|
||||||
|
|
||||||
|
out_data = sess.main.out_data
|
||||||
|
out_valid = sess.main.out_valid
|
||||||
|
out_ready = sess.main.out_ready
|
||||||
|
# hook up reader
|
||||||
|
reader = FIFODelayedReader(out_data, out_valid, out_ready, read_lag)
|
||||||
|
writer = FIFODelayedWriter(in_data, in_valid, in_ready, data, write_lag)
|
||||||
|
rst.put_int(1)
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
rst.put_int(0)
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
sess.yield_callbacks.append(reader)
|
||||||
|
sess.yield_callbacks.append(writer)
|
||||||
|
timeout = sum(read_lag) + sum(write_lag) + n + 10
|
||||||
|
for t in range(timeout):
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
if len(reader.data) == n:
|
||||||
|
break
|
||||||
|
assert tuple(reader.data) == tuple(range(n))
|
||||||
|
assert len(writer.data) == 0
|
||||||
|
sess.shutdown()
|
||||||
|
|
||||||
|
def test_fifo():
|
||||||
|
n = 20
|
||||||
|
# slow reader
|
||||||
|
run_with_lag(n, read_lag=[3,4,8], write_lag=[])
|
||||||
|
# slow writer
|
||||||
|
run_with_lag(n, read_lag=[0], write_lag=[0, 2, 10])
|
||||||
|
# mix
|
||||||
|
run_with_lag(n, read_lag=[3, 4, 8], write_lag=[0, 2, 10])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_fifo()
|
|
@ -0,0 +1,20 @@
|
||||||
|
`include "tvm_marcos.v"
|
||||||
|
|
||||||
|
module main();
|
||||||
|
`TVM_DEFINE_TEST_SIGNAL(clk, rst)
|
||||||
|
|
||||||
|
reg[31:0] in_data;
|
||||||
|
wire[31:0] out_data;
|
||||||
|
wire in_ready;
|
||||||
|
reg in_valid;
|
||||||
|
reg out_ready;
|
||||||
|
wire out_valid;
|
||||||
|
|
||||||
|
`CACHE_REG(32, in_data, in_valid, in_ready,
|
||||||
|
out_data, out_valid, out_ready)
|
||||||
|
|
||||||
|
initial begin
|
||||||
|
// This will allow tvm session to be called every cycle.
|
||||||
|
$tvm_session(clk);
|
||||||
|
end
|
||||||
|
endmodule
|
|
@ -0,0 +1,53 @@
|
||||||
|
import tvm
|
||||||
|
from tvm.addon import verilog
|
||||||
|
|
||||||
|
def test_counter():
|
||||||
|
# Start a new session by run simulation on test_counter.v
|
||||||
|
# Find file will search root/verilog and root/tests/verilog
|
||||||
|
sess = verilog.session([
|
||||||
|
verilog.find_file("test_counter.v"),
|
||||||
|
verilog.find_file("example_counter.v")
|
||||||
|
])
|
||||||
|
# Get the handles by their names
|
||||||
|
rst = sess.main.rst
|
||||||
|
counter = sess.main.counter
|
||||||
|
cnt = sess.main["counter_unit1"]
|
||||||
|
assert(counter.name == "main.counter")
|
||||||
|
assert(counter.size == 4)
|
||||||
|
rst.put_int(1)
|
||||||
|
# This will advance the cycle to next pos-edge of clk.
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
rst.put_int(0)
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
# get value of counter.
|
||||||
|
assert(counter.get_int() == i)
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
|
||||||
|
|
||||||
|
def test_scratch():
|
||||||
|
sess = verilog.session([
|
||||||
|
verilog.find_file("test_counter.v"),
|
||||||
|
verilog.find_file("example_counter.v")
|
||||||
|
])
|
||||||
|
# Get the handles by their names
|
||||||
|
rst = sess.main.rst
|
||||||
|
counter = sess.main.counter
|
||||||
|
rst.put_int(1)
|
||||||
|
# This will advance the cycle to next pos-edge of clk.
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
rst.put_int(0)
|
||||||
|
temp = 0
|
||||||
|
for i in range(10):
|
||||||
|
if rst.get_int():
|
||||||
|
rst.put_int(0)
|
||||||
|
temp = counter.get_int()
|
||||||
|
elif counter.get_int() == 3:
|
||||||
|
rst.put_int(1)
|
||||||
|
print("counter=%d, temp=%d" % (counter.get_int(), temp))
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_scratch()
|
||||||
|
test_counter()
|
|
@ -1,13 +1,11 @@
|
||||||
module main();
|
`include "tvm_marcos.v"
|
||||||
parameter PER = 10;
|
|
||||||
reg clk;
|
|
||||||
reg rst;
|
|
||||||
wire [3:0] counter;
|
|
||||||
|
|
||||||
|
module main();
|
||||||
|
`TVM_DEFINE_TEST_SIGNAL(clk, rst)
|
||||||
|
|
||||||
|
wire[3:0] counter;
|
||||||
counter counter_unit1(.clk(clk), .rst(rst), .out(counter));
|
counter counter_unit1(.clk(clk), .rst(rst), .out(counter));
|
||||||
always begin
|
|
||||||
#(PER/2) clk =~ clk;
|
|
||||||
end
|
|
||||||
initial begin
|
initial begin
|
||||||
// This will allow tvm session to be called every cycle.
|
// This will allow tvm session to be called every cycle.
|
||||||
$tvm_session(clk);
|
$tvm_session(clk);
|
|
@ -0,0 +1,30 @@
|
||||||
|
import tvm
|
||||||
|
from tvm.addon import verilog
|
||||||
|
|
||||||
|
def test_loop():
|
||||||
|
sess = verilog.session([
|
||||||
|
verilog.find_file("test_loop.v")
|
||||||
|
])
|
||||||
|
# Get the handles by their names
|
||||||
|
rst = sess.main.rst
|
||||||
|
iter0 = sess.main.iter0
|
||||||
|
iter1 = sess.main.iter1
|
||||||
|
ready = sess.main.ready
|
||||||
|
|
||||||
|
rst.put_int(1)
|
||||||
|
ready.put_int(1)
|
||||||
|
# This will advance the cycle to next pos-edge of clk.
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
rst.put_int(0)
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
|
||||||
|
for k in range(0, 1):
|
||||||
|
for i in range(0, 3):
|
||||||
|
for j in range(0, 4):
|
||||||
|
assert(iter1.get_int() == i)
|
||||||
|
assert(iter0.get_int() == j)
|
||||||
|
sess.yield_until_next_cycle()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_loop()
|
|
@ -0,0 +1,19 @@
|
||||||
|
`include "tvm_marcos.v"
|
||||||
|
|
||||||
|
module main();
|
||||||
|
`TVM_DEFINE_TEST_SIGNAL(clk, rst)
|
||||||
|
|
||||||
|
reg ready;
|
||||||
|
wire lp_ready;
|
||||||
|
|
||||||
|
`NONSTOP_LOOP(iter0, 4, 0, lp_ready, iter0_finish, 0, 4)
|
||||||
|
`NONSTOP_LOOP(iter1, 4, 0, iter0_finish, iter1_finish, 0, 3)
|
||||||
|
`WRAP_LOOP_ONCE(0, valid, ready, iter1_finish, loop_ready)
|
||||||
|
assign lp_ready = loop_ready;
|
||||||
|
|
||||||
|
|
||||||
|
initial begin
|
||||||
|
// This will allow tvm session to be called every cycle.
|
||||||
|
$tvm_session(clk);
|
||||||
|
end
|
||||||
|
endmodule
|
|
@ -51,7 +51,7 @@ def test_ram_read():
|
||||||
host_read_addr = sess.main.read_addr
|
host_read_addr = sess.main.read_addr
|
||||||
host_read_size = sess.main.read_size
|
host_read_size = sess.main.read_size
|
||||||
rst.put_int(1)
|
rst.put_int(1)
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
rst.put_int(0)
|
rst.put_int(0)
|
||||||
# hook up reader
|
# hook up reader
|
||||||
reader = FIFOReader(read_data, read_valid)
|
reader = FIFOReader(read_data, read_valid)
|
||||||
|
@ -61,18 +61,18 @@ def test_ram_read():
|
||||||
host_read_addr.put_int(a_ptr)
|
host_read_addr.put_int(a_ptr)
|
||||||
host_read_size.put_int(a.shape[0])
|
host_read_size.put_int(a.shape[0])
|
||||||
|
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
# second read request
|
# second read request
|
||||||
host_read_addr.put_int(a_ptr + 2)
|
host_read_addr.put_int(a_ptr + 2)
|
||||||
host_read_size.put_int(a.shape[0] - 2)
|
host_read_size.put_int(a.shape[0] - 2)
|
||||||
|
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
host_read_req.put_int(0)
|
host_read_req.put_int(0)
|
||||||
read_en.put_int(1)
|
read_en.put_int(1)
|
||||||
|
|
||||||
# yield until read is done
|
# yield until read is done
|
||||||
for i in range(a.shape[0] * 3):
|
for i in range(a.shape[0] * 3):
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
# check if result matches
|
# check if result matches
|
||||||
r = np.concatenate((a_np, a_np[2:]))
|
r = np.concatenate((a_np, a_np[2:]))
|
||||||
np.testing.assert_equal(np.array(reader.data), r)
|
np.testing.assert_equal(np.array(reader.data), r)
|
||||||
|
@ -105,7 +105,7 @@ def test_ram_write():
|
||||||
host_write_size = sess.main.write_size
|
host_write_size = sess.main.write_size
|
||||||
|
|
||||||
rst.put_int(1)
|
rst.put_int(1)
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
rst.put_int(0)
|
rst.put_int(0)
|
||||||
# hook up writeer
|
# hook up writeer
|
||||||
writer = FIFOWriter(write_data, write_en, write_ready, w_data)
|
writer = FIFOWriter(write_data, write_en, write_ready, w_data)
|
||||||
|
@ -116,12 +116,12 @@ def test_ram_write():
|
||||||
host_write_addr.put_int(a_ptr + offset)
|
host_write_addr.put_int(a_ptr + offset)
|
||||||
host_write_size.put_int(a.shape[0] - offset)
|
host_write_size.put_int(a.shape[0] - offset)
|
||||||
|
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
host_write_req.put_int(0)
|
host_write_req.put_int(0)
|
||||||
|
|
||||||
# yield until write is done
|
# yield until write is done
|
||||||
for i in range(a.shape[0]+2):
|
for i in range(a.shape[0]+2):
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
|
|
||||||
# check if result matches
|
# check if result matches
|
||||||
np.testing.assert_equal(a.asnumpy()[2:], r_data)
|
np.testing.assert_equal(a.asnumpy()[2:], r_data)
|
|
@ -25,18 +25,18 @@ def test_mmap():
|
||||||
|
|
||||||
# setup memory map.
|
# setup memory map.
|
||||||
rst.put_int(1)
|
rst.put_int(1)
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
rst.put_int(0)
|
rst.put_int(0)
|
||||||
write_en.put_int(0)
|
write_en.put_int(0)
|
||||||
mmap_addr.put_int(a_ptr)
|
mmap_addr.put_int(a_ptr)
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
|
|
||||||
# read test
|
# read test
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
read_addr.put_int(i)
|
read_addr.put_int(i)
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
# read addr get set this cycle
|
# read addr get set this cycle
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
# get the data out
|
# get the data out
|
||||||
assert(read_data.get_int() == i)
|
assert(read_data.get_int() == i)
|
||||||
|
|
||||||
|
@ -45,9 +45,9 @@ def test_mmap():
|
||||||
write_addr.put_int(i)
|
write_addr.put_int(i)
|
||||||
write_en.put_int(1)
|
write_en.put_int(1)
|
||||||
write_data.put_int(i + 1)
|
write_data.put_int(i + 1)
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
write_en.put_int(0)
|
write_en.put_int(0)
|
||||||
sess.yield_until_posedge()
|
sess.yield_until_next_cycle()
|
||||||
|
|
||||||
np.testing.assert_equal(a.asnumpy(), a_np + 1)
|
np.testing.assert_equal(a.asnumpy(), a_np + 1)
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
"""Common utilities for test"""
|
||||||
|
|
||||||
|
class FIFODelayedReader(object):
|
||||||
|
"""Reader that have specified ready lag."""
|
||||||
|
def __init__(self, read_data, read_valid, read_ready, lag):
|
||||||
|
self.read_data = read_data
|
||||||
|
self.read_valid = read_valid
|
||||||
|
self.read_ready = read_ready
|
||||||
|
self.read_ready.put_int(1)
|
||||||
|
self.lag = list(reversed(lag))
|
||||||
|
self.data = []
|
||||||
|
self.wait_counter = 0
|
||||||
|
self.wait_state = False
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
"""Logic as if always at pos-edge"""
|
||||||
|
if not self.wait_state:
|
||||||
|
if (self.read_ready.get_int() and
|
||||||
|
self.read_valid.get_int()):
|
||||||
|
self.data.append(self.read_data.get_int())
|
||||||
|
self.wait_counter = self.lag.pop() if self.lag else 0
|
||||||
|
self.wait_state = True
|
||||||
|
|
||||||
|
if self.wait_state:
|
||||||
|
if self.wait_counter == 0:
|
||||||
|
self.read_ready.put_int(1)
|
||||||
|
self.wait_state = False
|
||||||
|
else:
|
||||||
|
self.wait_counter -= 1
|
||||||
|
self.read_ready.put_int(0)
|
||||||
|
|
||||||
|
|
||||||
|
class FIFODelayedWriter(object):
|
||||||
|
"""Auxiliary class to write to FIFO """
|
||||||
|
def __init__(self, write_data, write_valid, write_ready, data, lag):
|
||||||
|
self.write_data = write_data
|
||||||
|
self.write_valid = write_valid
|
||||||
|
self.write_ready = write_ready
|
||||||
|
self.write_valid.put_int(0)
|
||||||
|
self.lag = list(reversed(lag))
|
||||||
|
self.data = list(reversed(data))
|
||||||
|
self.wait_counter = 0
|
||||||
|
self.wait_state = True
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
"""Logic as if always at pos-edge"""
|
||||||
|
if not self.wait_state:
|
||||||
|
if self.write_ready.get_int():
|
||||||
|
self.wait_counter = self.lag.pop() if self.lag else 0
|
||||||
|
self.wait_state = True
|
||||||
|
|
||||||
|
if self.wait_state:
|
||||||
|
if self.wait_counter == 0:
|
||||||
|
if self.data:
|
||||||
|
self.write_valid.put_int(1)
|
||||||
|
self.write_data.put_int(self.data.pop())
|
||||||
|
self.wait_state = False
|
||||||
|
else:
|
||||||
|
self.write_valid.put_int(0)
|
||||||
|
else:
|
||||||
|
self.write_valid.put_int(0)
|
||||||
|
self.wait_counter -= 1
|
|
@ -1,70 +1,120 @@
|
||||||
// Leaf of a normal loop nest
|
// Nonstop version of loop
|
||||||
// Starts at done = 1
|
// Always keeps looping when increase == true
|
||||||
// Need init to reset to done = 0
|
// At end is a signal to indicate the next cycle is end
|
||||||
// increases when enabled = 1
|
// Use that to signal parent loop to advance.
|
||||||
`define NORMAL_LOOP_LEAF(iter, width, init, enable, done, min, max, incr)\
|
`define NONSTOP_LOOP(iter, width, init, ready, finish, min, extent)\
|
||||||
reg [width-1:0] iter;\
|
reg [width-1:0] iter;\
|
||||||
reg valid;\
|
wire finish;\
|
||||||
reg done;\
|
|
||||||
always@(posedge clk) begin\
|
always@(posedge clk) begin\
|
||||||
if(rst) begin\
|
if (rst || init) begin\
|
||||||
iter <= 0;\
|
|
||||||
done <= 1;\
|
|
||||||
end else if(init) begin\
|
|
||||||
iter <= (min);\
|
iter <= (min);\
|
||||||
done <= 0;\
|
end else if(ready) begin\
|
||||||
end else if(done) begin\
|
if (iter != ((extent)-1)) begin\
|
||||||
iter <= 0;\
|
iter <= iter + 1;\
|
||||||
done <= 1;\
|
|
||||||
end else if(enable) begin\
|
|
||||||
if (iter < ((max)-(incr))) begin\
|
|
||||||
iter <= iter + (incr);\
|
|
||||||
done <= 0;\
|
|
||||||
end else begin\
|
end else begin\
|
||||||
iter <= 0;\
|
iter <= (min);\
|
||||||
done <= 1;\
|
|
||||||
end\
|
end\
|
||||||
end else begin\
|
end else begin\
|
||||||
iter <= iter;\
|
iter <= iter;\
|
||||||
done <= done;\
|
|
||||||
end\
|
end\
|
||||||
end
|
end\
|
||||||
|
assign finish = (ready && (iter == (extent) - 1));
|
||||||
|
|
||||||
// Normal loop nest that can connect to a child which is a normal loop
|
|
||||||
`define NORMAL_LOOP_NEST(iter, width, init, body_done, done, min, max, incr, body_init)\
|
// Wrap a nonstop loop to normal loop that loop only once.
|
||||||
reg [width-1:0] iter;\
|
// Use done signal to control the non-stop body to stop.
|
||||||
reg done;\
|
// The init and done behaves like normal loop
|
||||||
reg body_init;\
|
`define WRAP_LOOP_ONCE(init, valid, ready, body_finish, body_ready)\
|
||||||
|
reg valid;\
|
||||||
|
wire body_ready;\
|
||||||
|
always@(posedge clk) begin\
|
||||||
|
if (rst || init) begin\
|
||||||
|
valid <= 1;\
|
||||||
|
end else if(body_finish) begin\
|
||||||
|
valid <= 0;\
|
||||||
|
end else begin\
|
||||||
|
valid <= valid;\
|
||||||
|
end\
|
||||||
|
end\
|
||||||
|
assign body_ready = (valid && ready);
|
||||||
|
|
||||||
|
// Assign dst as src delayed by specific cycles.
|
||||||
|
`define DELAY(dst, src, width, delay, not_stall)\
|
||||||
|
reg [(width)*(delay)-1:0] src``_dly_chain;\
|
||||||
always@(posedge clk) begin\
|
always@(posedge clk) begin\
|
||||||
if(rst) begin\
|
if(rst) begin\
|
||||||
iter <= 0;\
|
src``_dly_chain <= 0;\
|
||||||
done <= 1;\
|
end else if (not_stall) begin\
|
||||||
body_init <= 0;\
|
src``_dly_chain[(width)-1:0] <= src;\
|
||||||
end else if(init) begin\
|
if((delay) != 1) begin\
|
||||||
iter <= (min);\
|
src``_dly_chain[(delay)*(width)-1:(width)] <= src``_dly_chain[((delay)-1)*(width)-1:0];\
|
||||||
done <= 0;\
|
|
||||||
body_init <= 1;\
|
|
||||||
end else if(done) begin\
|
|
||||||
iter <= 0;\
|
|
||||||
done <= 1;\
|
|
||||||
body_init <= 0;\
|
|
||||||
end else if (body_init) begin\
|
|
||||||
iter <= iter;\
|
|
||||||
done <= done;\
|
|
||||||
body_init <= 0;\
|
|
||||||
end else if (body_done) begin\
|
|
||||||
if (iter < ((max)-(incr))) begin\
|
|
||||||
iter <= iter + (incr);\
|
|
||||||
done <= 0;\
|
|
||||||
body_init <= 1;\
|
|
||||||
end else begin\
|
|
||||||
iter <= 0;\
|
|
||||||
done <= 1;\
|
|
||||||
body_init <= 0;\
|
|
||||||
end\
|
end\
|
||||||
end else begin\
|
end else begin\
|
||||||
iter <= iter;\
|
src``_dly_chain <= src``_dly_chain;\
|
||||||
done <= done;\
|
|
||||||
body_init <= 0;\
|
|
||||||
end\
|
end\
|
||||||
end
|
end\
|
||||||
|
assign dst = src``_dly_chain[(delay)*(width)-1:((delay)-1)*(width)];
|
||||||
|
|
||||||
|
// TVM generate clock signal
|
||||||
|
`define TVM_DEFINE_TEST_SIGNAL(clk, rst)\
|
||||||
|
parameter PER = 10;\
|
||||||
|
reg clk;\
|
||||||
|
reg rst;\
|
||||||
|
always begin\
|
||||||
|
#(PER/2) clk =~ clk;\
|
||||||
|
end
|
||||||
|
|
||||||
|
// Control logic on buffer/RAM read valid.
|
||||||
|
// This delays the valid signal by one cycle and retain it when write_ready == 0
|
||||||
|
`define BUFFER_READ_VALID_DELAY(dst, data_valid, write_ready)\
|
||||||
|
reg dst;\
|
||||||
|
always@(posedge clk) begin\
|
||||||
|
if(rst) begin\
|
||||||
|
dst <= 0;\
|
||||||
|
end else if (write_ready) begin\
|
||||||
|
dst <= (data_valid);\
|
||||||
|
end else begin\
|
||||||
|
dst <= dst;\
|
||||||
|
end\
|
||||||
|
end\
|
||||||
|
|
||||||
|
// A cache register that add one cycle lag to the ready signal
|
||||||
|
// This allows the signal to flow more smoothly
|
||||||
|
`define CACHE_REG(width, in_data, in_valid, in_ready, out_data, out_valid, out_ready)\
|
||||||
|
reg [width-1:0] out_data``_state_;\
|
||||||
|
reg [width-1:0] out_data``_overflow_;\
|
||||||
|
reg out_valid``_state_;\
|
||||||
|
reg out_valid``_overflow_;\
|
||||||
|
always@(posedge clk) begin\
|
||||||
|
if(rst) begin\
|
||||||
|
out_valid``_overflow_ <= 0;\
|
||||||
|
out_valid``_state_ <= 0;\
|
||||||
|
end else if (out_valid``_overflow_) begin\
|
||||||
|
if (out_ready) begin\
|
||||||
|
out_valid``_state_ <= 1;\
|
||||||
|
out_data``_state_ <= out_data``_overflow_;\
|
||||||
|
out_valid``_overflow_ <= 0;\
|
||||||
|
out_data``_overflow_ <= 0;\
|
||||||
|
end else begin\
|
||||||
|
out_valid``_state_ <= 1;\
|
||||||
|
out_data``_state_ <= out_data``_state_;\
|
||||||
|
out_valid``_overflow_ <= out_valid``_overflow_;\
|
||||||
|
out_data``_overflow_ <= out_data``_overflow_;\
|
||||||
|
end\
|
||||||
|
end else begin\
|
||||||
|
if (!out_ready && out_valid``_state_) begin\
|
||||||
|
out_valid``_state_ <= 1;\
|
||||||
|
out_data``_state_ <= out_data``_state_;\
|
||||||
|
out_valid``_overflow_ <= in_valid;\
|
||||||
|
out_data``_overflow_ <= in_data;\
|
||||||
|
end else begin\
|
||||||
|
out_valid``_state_ <= in_valid;\
|
||||||
|
out_data``_state_ <= in_data;\
|
||||||
|
out_valid``_overflow_ <= out_valid``_overflow_;\
|
||||||
|
out_data``_overflow_ <= out_data``_overflow_;\
|
||||||
|
end\
|
||||||
|
end\
|
||||||
|
end\ // always@ (posedge clk)
|
||||||
|
assign in_ready = !out_valid``_overflow_;\
|
||||||
|
assign out_data = out_data``_state_;\
|
||||||
|
assign out_valid = out_valid``_state_;
|
||||||
|
|
|
@ -43,9 +43,9 @@ class IPCClient {
|
||||||
PutInt(clock_, 0);
|
PutInt(clock_, 0);
|
||||||
}
|
}
|
||||||
int Callback() {
|
int Callback() {
|
||||||
if (GetInt(clock_)) {
|
if (!GetInt(clock_)) {
|
||||||
try {
|
try {
|
||||||
return AtPosEedge();
|
return AtNegEdge();
|
||||||
} catch (const std::runtime_error& e) {
|
} catch (const std::runtime_error& e) {
|
||||||
reader_.Close();
|
reader_.Close();
|
||||||
writer_.Close();
|
writer_.Close();
|
||||||
|
@ -57,8 +57,11 @@ class IPCClient {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// called at positive edge.
|
// called at neg edge.
|
||||||
int AtPosEedge() {
|
int AtNegEdge() {
|
||||||
|
// This is actually called at neg-edge
|
||||||
|
// The put values won't take effect until next neg-edge.
|
||||||
|
// This allow us to see the registers before snc
|
||||||
writer_.Write(kPosEdgeTrigger);
|
writer_.Write(kPosEdgeTrigger);
|
||||||
VPICallCode rcode;
|
VPICallCode rcode;
|
||||||
VPIRawHandle handle;
|
VPIRawHandle handle;
|
||||||
|
@ -149,10 +152,10 @@ class IPCClient {
|
||||||
s_vpi_time time_s;
|
s_vpi_time time_s;
|
||||||
time_s.type = vpiSimTime;
|
time_s.type = vpiSimTime;
|
||||||
time_s.high = 0;
|
time_s.high = 0;
|
||||||
time_s.low = 0;
|
time_s.low = 10;
|
||||||
value_s.format = vpiVectorVal;
|
value_s.format = vpiVectorVal;
|
||||||
value_s.value.vector = &svec_buf_[0];
|
value_s.value.vector = &svec_buf_[0];
|
||||||
vpi_put_value(h, &value_s, &time_s, vpiInertialDelay);
|
vpi_put_value(h, &value_s, &time_s, vpiTransportDelay);
|
||||||
writer_.Write(kSuccess);
|
writer_.Write(kSuccess);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -202,10 +205,10 @@ class IPCClient {
|
||||||
s_vpi_time time_s;
|
s_vpi_time time_s;
|
||||||
time_s.type = vpiSimTime;
|
time_s.type = vpiSimTime;
|
||||||
time_s.high = 0;
|
time_s.high = 0;
|
||||||
time_s.low = 0;
|
time_s.low = 10;
|
||||||
value_s.format = vpiIntVal;
|
value_s.format = vpiIntVal;
|
||||||
value_s.value.integer = value;
|
value_s.value.integer = value;
|
||||||
vpi_put_value(h, &value_s, &time_s, vpiInertialDelay);
|
vpi_put_value(h, &value_s, &time_s, vpiTransportDelay);
|
||||||
}
|
}
|
||||||
// Handles
|
// Handles
|
||||||
vpiHandle clock_;
|
vpiHandle clock_;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
VPI_CFLAGS=`iverilog-vpi --cflags`
|
VPI_CFLAGS=`iverilog-vpi --cflags`
|
||||||
VPI_LDLAGS=`iverilog-vpi --ldlags`
|
VPI_LDFLAGS=`iverilog-vpi --ldflags`
|
||||||
|
|
||||||
VER_SRCS = $(wildcard verilog/*.v)
|
VER_SRCS = $(wildcard verilog/*.v)
|
||||||
|
|
||||||
|
@ -7,4 +7,4 @@ VER_LIBS=lib/tvm_vpi.vpi
|
||||||
|
|
||||||
lib/tvm_vpi.vpi: verilog/tvm_vpi.cc verilog/tvm_vpi.h
|
lib/tvm_vpi.vpi: verilog/tvm_vpi.cc verilog/tvm_vpi.h
|
||||||
@mkdir -p $(@D)
|
@mkdir -p $(@D)
|
||||||
$(CXX) $(VPI_CFLAGS) $(CFLAGS) -shared -o $@ $(filter %.cc, $^) $(LDFLAGS) $(VPI_LDFLAGS)
|
$(CXX) $(VPI_CFLAGS) $(CFLAGS) -o $@ $(filter %.cc, $^) $(LDFLAGS) $(VPI_LDFLAGS)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче