[OP/LANG] Support Extern Call, more regression tests (#69)
* [OP/LANG] Support Extern Call, more regression tests * [TEST] Include pylintrc
This commit is contained in:
Родитель
b19e01bf27
Коммит
2548cedcb8
13
Makefile
13
Makefile
|
@ -1,3 +1,5 @@
|
|||
ROOTDIR = $(CURDIR)
|
||||
|
||||
ifndef config
|
||||
ifneq ("$(wildcard ./config.mk)","")
|
||||
config ?= config.mk
|
||||
|
@ -9,7 +11,7 @@ endif
|
|||
include $(config)
|
||||
|
||||
# specify tensor path
|
||||
.PHONY: clean all test doc
|
||||
.PHONY: clean all test doc pylint cpplint lint
|
||||
|
||||
all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a
|
||||
|
||||
|
@ -99,8 +101,13 @@ $(LIB_HALIDE_IR): LIBHALIDEIR
|
|||
LIBHALIDEIR:
|
||||
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)
|
||||
|
||||
lint:
|
||||
python2 dmlc-core/scripts/lint.py tvm all include src python
|
||||
cpplint:
|
||||
python2 dmlc-core/scripts/lint.py tvm cpp include src
|
||||
|
||||
pylint:
|
||||
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
|
||||
|
||||
lint: cpplint pylint
|
||||
|
||||
doc:
|
||||
doxygen docs/Doxyfile
|
||||
|
|
|
@ -98,6 +98,8 @@ constexpr const char* loop_scope = "loop_scope";
|
|||
constexpr const char* scan_update_scope = "scan_update_scope";
|
||||
/*! \brief Mark of scan init scope */
|
||||
constexpr const char* scan_init_scope = "scan_init_scope";
|
||||
/*! \brief extern operator scope */
|
||||
constexpr const char* extern_op_scope = "extern_op_scope";
|
||||
// Pipeline related attributes
|
||||
/*! \brief channel read scope */
|
||||
constexpr const char* channel_read_scope = "channel_read_scope";
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "./tensor.h"
|
||||
#include "./schedule.h"
|
||||
#include "./arithmetic.h"
|
||||
#include "./buffer.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
|
@ -307,6 +308,62 @@ class ScanOpNode : public OperationNode {
|
|||
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief External computation that cannot be splitted.
|
||||
*/
|
||||
class ExternOpNode : public OperationNode {
|
||||
public:
|
||||
/*! \brief The input tensors */
|
||||
Array<Tensor> inputs;
|
||||
/*! \brief Symbolic placeholder representationinputs */
|
||||
Array<Buffer> input_placeholders;
|
||||
/*! \brief Symbolic placeholder representation of outputs */
|
||||
Array<Buffer> output_placeholders;
|
||||
/*! \brief the statement that generates the computation. */
|
||||
Stmt body;
|
||||
|
||||
/*! \brief constructor */
|
||||
ExternOpNode() {}
|
||||
// override functions
|
||||
int num_outputs() const final;
|
||||
Array<IterVar> root_iter_vars() const final;
|
||||
Type output_dtype(size_t i) const final;
|
||||
Array<Expr> output_shape(size_t i) const final;
|
||||
Array<Tensor> InputTensors() const final;
|
||||
Operation ReplaceInputs(
|
||||
const Operation& self,
|
||||
const std::unordered_map<Tensor, Tensor>& rmap) const final;
|
||||
void PropBoundToInputs(
|
||||
const Operation& self,
|
||||
const std::unordered_map<const Variable*, IntSet>& dom_map,
|
||||
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
|
||||
void GatherBound(
|
||||
const Operation& self,
|
||||
const GraphContext& graph_ctx,
|
||||
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
|
||||
std::unordered_map<IterVar, Range>* out_dom_map) const final;
|
||||
Stmt BuildRealize(
|
||||
const Operation& self,
|
||||
const std::unordered_map<IterVar, Range>& realize_map,
|
||||
const Stmt& body) const final;
|
||||
Stmt BuildProvide(
|
||||
const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& dom_map) const final;
|
||||
|
||||
void VisitAttrs(AttrVisitor* v) final {
|
||||
v->Visit("name", &name);
|
||||
v->Visit("inputs", &inputs);
|
||||
v->Visit("body", &body);
|
||||
}
|
||||
static Operation make(std::string name,
|
||||
Array<Tensor> inputs,
|
||||
Array<Buffer> input_placeholders,
|
||||
Array<Buffer> output_placeholders,
|
||||
Stmt body);
|
||||
|
||||
static constexpr const char* _type_key = "ExternOp";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode);
|
||||
};
|
||||
|
||||
/*! \brief The compute function to specify the input source of a Tensor */
|
||||
using FCompute = std::function<Expr (const Array<Var>& i)>;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding: utf-8
|
||||
# pylint: disable=invalid-name, no-member
|
||||
# pylint: disable=invalid-name
|
||||
""" ctypes library of nnvm and helper functions """
|
||||
from __future__ import absolute_import
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# pylint: disable=invalid-name
|
||||
"""Util to compile with C++ code"""
|
||||
# pylint: disable=invalid-name
|
||||
from __future__ import absolute_import as _abs
|
||||
import sys
|
||||
import subprocess
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# pylint: disable=invalid-name, too-many-locals
|
||||
# pylint: disable=invalid-name
|
||||
"""Util to compile with NVCC"""
|
||||
from __future__ import absolute_import as _abs
|
||||
import os
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# pylint: disable=protected-access, no-member, invalid-name
|
||||
# pylint: disable=redefined-builtin, undefined-variable, unused-import
|
||||
"""Functions defined in TVM."""
|
||||
# pylint: disable=invalid-name,unused-import,redefined-builtin
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
from numbers import Integral as _Integral
|
||||
|
@ -162,8 +161,8 @@ def scan(init, update, state_placeholder, name="scan"):
|
|||
|
||||
Returns
|
||||
-------
|
||||
tensor: tensor.Tensor
|
||||
The created tensor
|
||||
tensor: Tensor or list of Tensors
|
||||
The created tensor or tuple of tensors it it contains multiple outputs.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
@ -187,7 +186,77 @@ def scan(init, update, state_placeholder, name="scan"):
|
|||
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
|
||||
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
|
||||
res = [op.output(i) for i in range(len(update))]
|
||||
return (res[0] if len(res) == 1 else res)
|
||||
return res[0] if len(res) == 1 else res
|
||||
|
||||
|
||||
def extern(shape, inputs, fcompute,
|
||||
name="extern", dtype=None):
|
||||
"""Compute several tensor via extern function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shape: Shape tuple or list of shapes.
|
||||
The shape of the outputs.
|
||||
|
||||
inputs: list of Tensor
|
||||
The inputs
|
||||
|
||||
fcompute: lambda function of inputs, outputs-> stmt
|
||||
Specifies the IR statement to do the computation.
|
||||
|
||||
name: str, optional
|
||||
The name hint of the tensor
|
||||
|
||||
dtype: str or list of str, optional
|
||||
The data types of outputs,
|
||||
by default dtype will be same as inputs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tensor: Tensor or list of Tensors
|
||||
The created tensor or tuple of tensors it it contains multiple outputs.
|
||||
"""
|
||||
if isinstance(shape[0], _expr.Expr):
|
||||
shape = [shape]
|
||||
input_placeholders = []
|
||||
output_placeholders = []
|
||||
types = set()
|
||||
for t in inputs:
|
||||
if not isinstance(t, _tensor.Tensor):
|
||||
raise ValueError("expect inputs to be tensor")
|
||||
input_placeholders.append(
|
||||
Buffer(t.shape, t.dtype, t.op.name))
|
||||
types.add(t.dtype)
|
||||
|
||||
if dtype is None:
|
||||
if len(types) != 1:
|
||||
raise ValueError("Cannot infer output type, please provide dtype argument")
|
||||
infered_type = types.pop()
|
||||
dtype = [infered_type for _ in shape]
|
||||
|
||||
for shp, dt in zip(shape, dtype):
|
||||
output_placeholders.append(Buffer(shp, dt, name))
|
||||
body = fcompute(input_placeholders, output_placeholders)
|
||||
if isinstance(body, _expr.Expr):
|
||||
body = _make.Evaluate(body)
|
||||
|
||||
op = _api_internal._ExternOp(
|
||||
name, inputs, input_placeholders, output_placeholders, body)
|
||||
res = [op.output(i) for i in range(len(output_placeholders))]
|
||||
return res[0] if len(res) == 1 else res
|
||||
|
||||
|
||||
def call_packed(*args):
|
||||
"""Build expression by call an external packed function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args : list
|
||||
Positional arguments.
|
||||
"""
|
||||
args = convert(args)
|
||||
return _make.Call(
|
||||
int32, "tvm_call_packed", args, 4, None, 0)
|
||||
|
||||
|
||||
def Buffer(shape, dtype=None,
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# pylint: disable=protected-access, no-member
|
||||
"""Arithmetic data structure and utility"""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
Eventually some of these pipelines will be moved to C++.
|
||||
But the first pipeline will be kept in python for ease of change and evolving.
|
||||
"""
|
||||
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
|
||||
|
||||
from . import api
|
||||
from . import tensor
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# pylint: disable=protected-access, no-member
|
||||
"""Collection structure in the high level DSL."""
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._node import NodeBase, register_node
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# pylint: disable=protected-access, no-member, missing-docstring
|
||||
"""Expression class"""
|
||||
# pylint: disable=missing-docstring
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._node import NodeBase, register_node
|
||||
from . import make as _make
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
"""Runtime module related stuffs"""
|
||||
# pylint: disable=unused-import, invalid-name, undefined-variable
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._function import ModuleBase, _init_module_module
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
This is a simplified runtime API for quick testing and proptyping.
|
||||
"""
|
||||
# pylint: disable=unused-import, invalid-name
|
||||
# pylint: disable=invalid-name,unused-import
|
||||
from __future__ import absolute_import as _abs
|
||||
import numpy as _np
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# pylint: disable=protected-access, no-member
|
||||
"""Collection structure in the high level DSL."""
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._node import NodeBase, register_node
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# pylint: disable=protected-access, no-member, missing-docstring
|
||||
"""Statement classes"""
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._node import NodeBase, register_node
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# pylint: disable=protected-access, no-member, invalid-name
|
||||
"""Tensor related abstractions"""
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._node import NodeBase, SliceBase, register_node, convert_to_node
|
||||
|
@ -90,3 +89,8 @@ class ComputeOp(Operation):
|
|||
class ScanOp(Operation):
|
||||
"""Scan operation."""
|
||||
pass
|
||||
|
||||
@register_node
|
||||
class ExternOp(Operation):
|
||||
"""Extern operation."""
|
||||
pass
|
||||
|
|
|
@ -183,6 +183,15 @@ TVM_REGISTER_API(_ScanOp)
|
|||
args[4]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_ExternOp)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = ExternOpNode::make(args[0],
|
||||
args[1],
|
||||
args[2],
|
||||
args[3],
|
||||
args[4]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_OpGetOutput)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = args[0].operator Operation().output(
|
||||
|
|
|
@ -1236,6 +1236,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
|
|||
buf = builder_->CreatePointerCast(buf, LLVMType(op->type)->getPointerTo());
|
||||
CHECK(!var_map_.count(op->buffer_var.get()));
|
||||
var_map_[op->buffer_var.get()] = buf;
|
||||
this->VisitStmt(op->body);
|
||||
}
|
||||
|
||||
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
|
||||
|
|
|
@ -8,9 +8,8 @@
|
|||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_visitor.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <tvm/ir_mutator.h>
|
||||
#include <unordered_set>
|
||||
#include "./make_loop.h"
|
||||
#include "./op_util.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
|
@ -101,40 +100,12 @@ Array<Tensor> ComputeOpNode::InputTensors() const {
|
|||
return ret;
|
||||
}
|
||||
|
||||
// replacer to replace tensors
|
||||
class TensorReplacer : public ir::IRMutator {
|
||||
public:
|
||||
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
|
||||
: vmap_(vmap) {}
|
||||
Expr Mutate_(const ir::Call* op, const Expr& e) {
|
||||
if (op->call_type == ir::Call::Halide) {
|
||||
Tensor t = Operation(op->func.node_).output(op->value_index);
|
||||
auto it = vmap_.find(t);
|
||||
if (it != vmap_.end()) {
|
||||
Expr ret = ir::Call::make(
|
||||
op->type, it->second->op->name, op->args,
|
||||
op->call_type, it->second->op, it->second->value_index);
|
||||
found = true;
|
||||
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
|
||||
}
|
||||
}
|
||||
return IRMutator::Mutate_(op, e);
|
||||
}
|
||||
|
||||
// whether it is found.
|
||||
bool found{false};
|
||||
|
||||
private:
|
||||
const std::unordered_map<Tensor, Tensor>& vmap_;
|
||||
};
|
||||
|
||||
Operation ComputeOpNode::ReplaceInputs(
|
||||
const Operation& self,
|
||||
const std::unordered_map<Tensor, Tensor>& rmap) const {
|
||||
CHECK_EQ(self.operator->(), this);
|
||||
TensorReplacer repl(rmap);
|
||||
Expr new_body = repl.Mutate(this->body);
|
||||
if (repl.found) {
|
||||
Expr new_body = op::ReplaceTensor(this->body, rmap);
|
||||
if (!new_body.same_as(this->body)) {
|
||||
return ComputeOpNode::make(name, axis, new_body);
|
||||
} else {
|
||||
return self;
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \brief External computation rule.
|
||||
* \file extern_op.cc
|
||||
*/
|
||||
#include <tvm/operation.h>
|
||||
#include <tvm/arithmetic.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <unordered_set>
|
||||
#include "./op_util.h"
|
||||
|
||||
namespace tvm {
|
||||
using namespace ir;
|
||||
// ExternOpNode
|
||||
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
||||
.set_dispatch<ExternOpNode>([](const ExternOpNode *op, IRPrinter *p) {
|
||||
p->stream << "extern(" << op->name << ", " << op << ")";
|
||||
});
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(ExternOpNode);
|
||||
|
||||
int ExternOpNode::num_outputs() const {
|
||||
return static_cast<int>(output_placeholders.size());
|
||||
}
|
||||
|
||||
Array<IterVar> ExternOpNode::root_iter_vars() const {
|
||||
return {};
|
||||
}
|
||||
|
||||
Type ExternOpNode::output_dtype(size_t i) const {
|
||||
return output_placeholders[i]->dtype;
|
||||
}
|
||||
|
||||
Array<Expr> ExternOpNode::output_shape(size_t i) const {
|
||||
return output_placeholders[i]->shape;
|
||||
}
|
||||
|
||||
|
||||
Operation ExternOpNode::make(std::string name,
|
||||
Array<Tensor> inputs,
|
||||
Array<Buffer> input_placeholders,
|
||||
Array<Buffer> output_placeholders,
|
||||
Stmt body) {
|
||||
auto n = std::make_shared<ExternOpNode>();
|
||||
n->name = name;
|
||||
CHECK_EQ(inputs.size(), input_placeholders.size());
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
|
||||
CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape));
|
||||
CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
|
||||
}
|
||||
n->inputs = inputs;
|
||||
n->input_placeholders = input_placeholders;
|
||||
n->output_placeholders = output_placeholders;
|
||||
n->body = body;
|
||||
return Operation(n);
|
||||
}
|
||||
|
||||
Array<Tensor> ExternOpNode::InputTensors() const {
|
||||
return inputs;
|
||||
}
|
||||
|
||||
Operation ExternOpNode::ReplaceInputs(
|
||||
const Operation& self,
|
||||
const std::unordered_map<Tensor, Tensor>& rmap) const {
|
||||
CHECK_EQ(self.operator->(), this);
|
||||
auto n = std::make_shared<ExternOpNode>(*this);
|
||||
n->body = op::ReplaceTensor(this->body, rmap);
|
||||
for (size_t i = 0; i < n->inputs.size(); ++i) {
|
||||
Tensor t = n->inputs[i];
|
||||
if (rmap.count(t)) {
|
||||
n->inputs.Set(i, rmap.at(t));
|
||||
}
|
||||
}
|
||||
|
||||
if (body.same_as(n->body) &&
|
||||
inputs.same_as(n->inputs)) {
|
||||
return self;
|
||||
} else {
|
||||
return Operation(n);
|
||||
}
|
||||
}
|
||||
|
||||
void ExternOpNode::PropBoundToInputs(
|
||||
const Operation& self,
|
||||
const std::unordered_map<const Variable*, IntSet>& dom_map,
|
||||
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
|
||||
for (Tensor t : this->inputs) {
|
||||
auto it = out_dom_map->find(t);
|
||||
if (it == out_dom_map->end()) continue;
|
||||
TensorDom& dom = it->second;
|
||||
for (size_t i = 0; i < t->shape.size(); ++i) {
|
||||
dom.data[i].emplace_back(IntSet::range(
|
||||
Range::make_with_min_extent(
|
||||
make_const(t->shape[i].type(), 0), t->shape[i])));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ExternOpNode::GatherBound(
|
||||
const Operation& self,
|
||||
const GraphContext& graph_ctx,
|
||||
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
|
||||
std::unordered_map<IterVar, Range>* out_dom_map) const {
|
||||
}
|
||||
|
||||
Stmt ExternOpNode::BuildRealize(
|
||||
const Operation& self,
|
||||
const std::unordered_map<IterVar, Range>& realize_map,
|
||||
const Stmt& body) const {
|
||||
CHECK_EQ(self.operator->(), this);
|
||||
Stmt realize_body = body;
|
||||
for (int k = 0; k < num_outputs(); ++k) {
|
||||
Tensor t = self.output(k);
|
||||
Halide::Internal::Region bounds;
|
||||
for (size_t i = 0; i < t->shape.size(); ++i) {
|
||||
bounds.push_back(
|
||||
Range::make_with_min_extent(
|
||||
make_const(t->shape[i].type(), 0), t->shape[i]));
|
||||
}
|
||||
realize_body = ir::Realize::make(
|
||||
t->op, t->value_index, t->dtype,
|
||||
bounds, const_true(), realize_body);
|
||||
}
|
||||
return realize_body;
|
||||
}
|
||||
|
||||
Stmt ExternOpNode::BuildProvide(
|
||||
const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& dom_map) const {
|
||||
CHECK_EQ(stage->op.operator->(), this);
|
||||
return AttrStmt::make(
|
||||
stage->op, ir::attr::extern_op_scope,
|
||||
StringImm::make(name), body);
|
||||
}
|
||||
} // namespace tvm
|
|
@ -1,12 +1,13 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \brief Utility to make loop nest.
|
||||
* \file make_loop.cc
|
||||
* \file op_util.cc
|
||||
*/
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <tvm/operation.h>
|
||||
#include "./make_loop.h"
|
||||
#include <tvm/ir_mutator.h>
|
||||
#include "./op_util.h"
|
||||
#include "../arithmetic/compute_expr.h"
|
||||
|
||||
namespace tvm {
|
||||
|
@ -231,5 +232,45 @@ std::vector<Stmt> MakeBoundCheck(
|
|||
return nest;
|
||||
}
|
||||
|
||||
|
||||
// replacer to replace tensors
|
||||
class TensorReplacer : public ir::IRMutator {
|
||||
public:
|
||||
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
|
||||
: vmap_(vmap) {}
|
||||
Expr Mutate_(const ir::Call* op, const Expr& e) {
|
||||
if (op->call_type == ir::Call::Halide) {
|
||||
Tensor t = Operation(op->func.node_).output(op->value_index);
|
||||
auto it = vmap_.find(t);
|
||||
if (it != vmap_.end()) {
|
||||
Expr ret = ir::Call::make(
|
||||
op->type, it->second->op->name, op->args,
|
||||
op->call_type, it->second->op, it->second->value_index);
|
||||
found = true;
|
||||
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
|
||||
}
|
||||
}
|
||||
return IRMutator::Mutate_(op, e);
|
||||
}
|
||||
|
||||
// whether it is found.
|
||||
bool found{false};
|
||||
|
||||
private:
|
||||
const std::unordered_map<Tensor, Tensor>& vmap_;
|
||||
};
|
||||
|
||||
Stmt ReplaceTensor(Stmt stmt,
|
||||
const std::unordered_map<Tensor, Tensor>& replace) {
|
||||
TensorReplacer repl(replace);
|
||||
Stmt ret = repl.Mutate(stmt);
|
||||
return repl.found ? ret : stmt;
|
||||
}
|
||||
Expr ReplaceTensor(Expr expr,
|
||||
const std::unordered_map<Tensor, Tensor>& replace) {
|
||||
TensorReplacer repl(replace);
|
||||
Expr ret = repl.Mutate(expr);
|
||||
return repl.found ? ret : expr;
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tvm
|
|
@ -1,10 +1,10 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file make_loop.h
|
||||
* \brief Utility to make loop nest from schedule stage info.
|
||||
* \file op_util.h
|
||||
* \brief Common utility used in operator construction.
|
||||
*/
|
||||
#ifndef TVM_OP_MAKE_LOOP_H_
|
||||
#define TVM_OP_MAKE_LOOP_H_
|
||||
#ifndef TVM_OP_OP_UTIL_H_
|
||||
#define TVM_OP_OP_UTIL_H_
|
||||
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/schedule.h>
|
||||
|
@ -50,6 +50,22 @@ MakeBoundCheck(const Stage& stage,
|
|||
bool skip_ivar_domain,
|
||||
const std::unordered_set<IterVar>& skip_iter,
|
||||
const std::unordered_map<IterVar, Expr>& value_map);
|
||||
|
||||
/*!
|
||||
* \brief Replace the tensor reference in stmt by the replace map.
|
||||
* \param stmt The statement to be processed.
|
||||
* \param replace The replacement rule.
|
||||
*/
|
||||
Stmt ReplaceTensor(Stmt stmt,
|
||||
const std::unordered_map<Tensor, Tensor>& replace);
|
||||
/*!
|
||||
* \brief Replace the tensor reference in expr by the replace map.
|
||||
* \param expr The expression to be processed.
|
||||
* \param replace The replacement rule.
|
||||
*/
|
||||
Expr ReplaceTensor(Expr expr,
|
||||
const std::unordered_map<Tensor, Tensor>& replace);
|
||||
|
||||
} // namespace op
|
||||
} // namespace tvm
|
||||
#endif // TVM_OP_MAKE_LOOP_H_
|
||||
#endif // TVM_OP_OP_UTIL_H_
|
|
@ -6,7 +6,7 @@
|
|||
#include <tvm/operation.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include "./make_loop.h"
|
||||
#include "./op_util.h"
|
||||
#include "../schedule/graph.h"
|
||||
|
||||
namespace tvm {
|
||||
|
|
|
@ -89,7 +89,7 @@ class AllocateLifter : public IRMutator {
|
|||
};
|
||||
|
||||
Stmt LiftAllocate(Stmt stmt) {
|
||||
return AllocateLifter().Mutate(stmt);
|
||||
return AllocateLifter().Lift(stmt);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
|
|
@ -3,8 +3,11 @@
|
|||
* \file storage_flatten.cc
|
||||
*/
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/ir_mutator.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <tvm/buffer.h>
|
||||
#include <tvm/operation.h>
|
||||
#include <unordered_map>
|
||||
#include "../runtime/thread_storage_scope.h"
|
||||
|
||||
|
@ -25,6 +28,16 @@ class StorageFlattener : public IRMutator {
|
|||
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
|
||||
}
|
||||
}
|
||||
Stmt Mutate_(const Store* op, const Stmt& s) final {
|
||||
Stmt stmt = IRMutator::Mutate_(op, s);
|
||||
op = stmt.as<Store>();
|
||||
auto it = extern_buf_remap_.find(op->buffer_var.get());
|
||||
if (it != extern_buf_remap_.end()) {
|
||||
return Store::make(it->second, op->value, op->index);
|
||||
} else {
|
||||
return stmt;
|
||||
}
|
||||
}
|
||||
|
||||
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
|
||||
if (op->type_key == attr::realize_scope) {
|
||||
|
@ -37,6 +50,8 @@ class StorageFlattener : public IRMutator {
|
|||
Stmt stmt = IRMutator::Mutate_(op, s);
|
||||
curr_thread_scope_.pop_back();
|
||||
return stmt;
|
||||
} else if (op->type_key == attr::extern_op_scope) {
|
||||
return HandleExternOp(op);
|
||||
}
|
||||
return IRMutator::Mutate_(op, s);
|
||||
}
|
||||
|
@ -95,6 +110,26 @@ class StorageFlattener : public IRMutator {
|
|||
}
|
||||
}
|
||||
|
||||
Expr Mutate_(const Load* op, const Expr& e) final {
|
||||
Expr expr = IRMutator::Mutate_(op, e);
|
||||
op = expr.as<Load>();
|
||||
auto it = extern_buf_remap_.find(op->buffer_var.get());
|
||||
if (it != extern_buf_remap_.end()) {
|
||||
return Load::make(op->type, it->second, op->index);
|
||||
} else {
|
||||
return expr;
|
||||
}
|
||||
}
|
||||
|
||||
Expr Mutate_(const Variable* op, const Expr& e) final {
|
||||
auto it = extern_buf_remap_.find(op);
|
||||
if (it != extern_buf_remap_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
Expr Mutate_(const Call* op, const Expr& olde) final {
|
||||
Expr expr = IRMutator::Mutate_(op, olde);
|
||||
op = expr.as<Call>();
|
||||
|
@ -113,6 +148,28 @@ class StorageFlattener : public IRMutator {
|
|||
}
|
||||
|
||||
private:
|
||||
Stmt HandleExternOp(const AttrStmt* op) {
|
||||
const ExternOpNode* ext_op = op->node.as<ExternOpNode>();
|
||||
CHECK(ext_op);
|
||||
Operation func(op->node.node_);
|
||||
CHECK_EQ(extern_buf_remap_.size(), 0U);
|
||||
for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) {
|
||||
TensorKey key{func, static_cast<int>(i)};
|
||||
CHECK(buf_map_.count(key));
|
||||
extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] =
|
||||
buf_map_.at(key).buffer->data;
|
||||
}
|
||||
for (size_t i = 0; i < ext_op->inputs.size(); ++i) {
|
||||
TensorKey key{ext_op->inputs[i]->op, ext_op->inputs[i]->value_index};
|
||||
CHECK(buf_map_.count(key));
|
||||
extern_buf_remap_[ext_op->input_placeholders[i]->data.get()] =
|
||||
buf_map_.at(key).buffer->data;
|
||||
}
|
||||
Stmt ret = Mutate(op->body);
|
||||
extern_buf_remap_.clear();
|
||||
return ret;
|
||||
}
|
||||
|
||||
// The buffer entry in the flatten map
|
||||
struct BufferEntry {
|
||||
// the buffer of storage
|
||||
|
@ -139,6 +196,7 @@ class StorageFlattener : public IRMutator {
|
|||
}
|
||||
};
|
||||
// The buffer assignment map
|
||||
std::unordered_map<const Variable*, Var> extern_buf_remap_;
|
||||
std::unordered_map<TensorKey, BufferEntry> buf_map_;
|
||||
std::unordered_map<const Node*, std::string> storage_scope_;
|
||||
// The current thread scope.
|
||||
|
|
|
@ -0,0 +1,407 @@
|
|||
[MASTER]
|
||||
|
||||
# Specify a configuration file.
|
||||
#rcfile=
|
||||
|
||||
# Python code to execute, usually for sys.path manipulation such as
|
||||
# pygtk.require().
|
||||
#init-hook=
|
||||
|
||||
# Add files or directories to the blacklist. They should be base names, not
|
||||
# paths.
|
||||
ignore=CVS
|
||||
|
||||
# Add files or directories matching the regex patterns to the blacklist. The
|
||||
# regex matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=yes
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=8
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code
|
||||
extension-pkg-whitelist=numpy,opencv
|
||||
|
||||
# Allow optimization of some AST trees. This will activate a peephole AST
|
||||
# optimizer, which will apply various small optimizations. For instance, it can
|
||||
# be used to obtain the result of joining multiple strings with the addition
|
||||
# operator. Joining a lot of strings can lead to a maximum recursion error in
|
||||
# Pylint and this flag can prevent that. It has one side effect, the resulting
|
||||
# AST will be different than the one from reality. This option is deprecated
|
||||
# and it will be removed in Pylint 2.0.
|
||||
optimize-ast=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=indexing-exception,old-raise-syntax
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,protected-access
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||
# (visual studio) and html. You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Put messages in a separate file for each module / package specified on the
|
||||
# command line instead of printing them on stdout. Reports (if any) will be
|
||||
# written in a file name "pylint_global.[txt|html]". This option is deprecated
|
||||
# and it will be removed in Pylint 2.0.
|
||||
files-output=no
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=100
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=no
|
||||
|
||||
# List of optional constructs for which whitespace checking is disabled. `dict-
|
||||
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
|
||||
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
|
||||
# `empty-line` allows space-only lines.
|
||||
no-space-check=trailing-comma,dict-separator
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=1000
|
||||
|
||||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||
# tab).
|
||||
indent-string=' '
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=FIXME,XXX,TODO
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,_cb
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six.moves,future.builtins
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=i,j,_,a,b,op,x,y,wd,lr,kv,k,v,s,p,h,c,m,n,X,t,g,f
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty
|
||||
|
||||
# Regular expression matching correct module names
|
||||
module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
|
||||
|
||||
# Naming hint for module names
|
||||
module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
|
||||
|
||||
# Regular expression matching correct constant names
|
||||
const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
|
||||
|
||||
# Naming hint for constant names
|
||||
const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
|
||||
|
||||
# Regular expression matching correct inline iteration names
|
||||
inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
|
||||
|
||||
# Naming hint for inline iteration names
|
||||
inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct method names
|
||||
method-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Naming hint for method names
|
||||
method-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression matching correct class attribute names
|
||||
class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
|
||||
|
||||
# Naming hint for class attribute names
|
||||
class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Naming hint for argument names
|
||||
argument-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression matching correct attribute names
|
||||
attr-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Naming hint for attribute names
|
||||
attr-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression matching correct variable names
|
||||
variable-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Naming hint for variable names
|
||||
variable-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression matching correct function names
|
||||
function-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Naming hint for function names
|
||||
function-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression matching correct class names
|
||||
class-rgx=[A-Z_][a-zA-Z0-9]+$
|
||||
|
||||
# Naming hint for class names
|
||||
class-name-hint=[A-Z_][a-zA-Z0-9]+$
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=^_
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=10
|
||||
|
||||
|
||||
[ELIF]
|
||||
|
||||
# Maximum number of nested blocks for function / method body
|
||||
max-nested-blocks=5
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,__new__,setUp
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,_fields,_replace,_source,_make
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=optparse
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
|
||||
[DESIGN]
|
||||
|
||||
# Maximum number of arguments for function / method
|
||||
max-args=5
|
||||
|
||||
# Argument names that match this expression will be ignored. Default to name
|
||||
# with leading underscore
|
||||
ignored-argument-names=_.*
|
||||
|
||||
# Maximum number of locals for function / method body
|
||||
max-locals=15
|
||||
|
||||
# Maximum number of return / yield for function / method body
|
||||
max-returns=6
|
||||
|
||||
# Maximum number of branch for function / method body
|
||||
max-branches=12
|
||||
|
||||
# Maximum number of statements in function / method body
|
||||
max-statements=50
|
||||
|
||||
# Maximum number of parents for a class (see R0901).
|
||||
max-parents=7
|
||||
|
||||
# Maximum number of attributes for a class (see R0902).
|
||||
max-attributes=7
|
||||
|
||||
# Minimum number of public methods for a class (see R0903).
|
||||
min-public-methods=0
|
||||
|
||||
# Maximum number of public methods for a class (see R0904).
|
||||
max-public-methods=20
|
||||
|
||||
# Maximum number of boolean expressions in a if statement
|
||||
max-bool-expr=5
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=Exception
|
|
@ -0,0 +1,37 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
|
||||
def test_add_pipeline():
|
||||
nn = 1024
|
||||
n = tvm.convert(nn)
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
def extern_generator(ins, outs):
|
||||
"""Manually write the IR for the extern function, add pipeline"""
|
||||
i = tvm.Var('i')
|
||||
stmt = tvm.make.For(
|
||||
i, 0, n, 0, 0,
|
||||
tvm.make.Store(outs[0].data,
|
||||
tvm.make.Load(A.dtype, ins[0].data, i) +
|
||||
1, i))
|
||||
return stmt
|
||||
C = tvm.extern(A.shape, [A], extern_generator, name='C')
|
||||
s = tvm.Schedule(C.op)
|
||||
|
||||
def check_llvm():
|
||||
if not tvm.codegen.enabled("llvm"):
|
||||
return
|
||||
# build and invoke the kernel.
|
||||
f = tvm.build(s, [A, C], "llvm")
|
||||
ctx = tvm.cpu(0)
|
||||
# launch the kernel.
|
||||
n = nn
|
||||
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
|
||||
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
|
||||
f(a, c)
|
||||
np.testing.assert_allclose(
|
||||
c.asnumpy(), a.asnumpy() + 1)
|
||||
check_llvm()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_add_pipeline()
|
|
@ -8,10 +8,7 @@ def test_llvm_add_pipeline():
|
|||
B = tvm.placeholder((n,), name='B')
|
||||
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
|
||||
s = tvm.Schedule(C.op)
|
||||
print(s[C])
|
||||
print("a?")
|
||||
xo, xi = s[C].split(C.op.axis[0], factor=4)
|
||||
print("a?")
|
||||
s[C].parallel(xo)
|
||||
s[C].vectorize(xi)
|
||||
def check_llvm():
|
||||
|
@ -83,12 +80,31 @@ def test_llvm_madd_pipeline():
|
|||
check_llvm(4, 0, 1)
|
||||
check_llvm(4, 0, 3)
|
||||
|
||||
def test_llvm_temp_space():
|
||||
nn = 1024
|
||||
n = tvm.convert(nn)
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
B = tvm.compute(A.shape, lambda i: A(i) + 1, name='B')
|
||||
C = tvm.compute(A.shape, lambda i: B(i) + 1, name='C')
|
||||
s = tvm.Schedule(C.op)
|
||||
|
||||
def check_llvm():
|
||||
if not tvm.codegen.enabled("llvm"):
|
||||
return
|
||||
# build and invoke the kernel.
|
||||
f = tvm.build(s, [A, C], "llvm")
|
||||
ctx = tvm.cpu(0)
|
||||
# launch the kernel.
|
||||
n = nn
|
||||
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
|
||||
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
|
||||
f(a, c)
|
||||
np.testing.assert_allclose(
|
||||
c.asnumpy(), a.asnumpy() + 1 + 1)
|
||||
check_llvm()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("a")
|
||||
test_llvm_add_pipeline()
|
||||
print("a")
|
||||
test_llvm_flip_pipeline()
|
||||
print("a")
|
||||
test_llvm_madd_pipeline()
|
||||
test_llvm_temp_space()
|
||||
|
|
|
@ -1,11 +1,6 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
|
||||
def tvm_call_packed(*args):
|
||||
args = tvm.convert(args)
|
||||
return tvm.make.Call("int32", "tvm_call_packed", args, 4, None, 0)
|
||||
|
||||
|
||||
def run_jit(fapi, check):
|
||||
for target in ["llvm", "stackvm"]:
|
||||
if not tvm.codegen.enabled(target):
|
||||
|
@ -24,7 +19,7 @@ def test_stack_vm_basic():
|
|||
|
||||
n = tvm.Var('n')
|
||||
Ab = tvm.Buffer((n, ), tvm.float32)
|
||||
stmt = tvm.make.Evaluate(tvm_call_packed("tvm_call_back_get_shape", Ab.shape[0]))
|
||||
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
|
||||
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0)
|
||||
run_jit(fapi, lambda f: f(a))
|
||||
|
||||
|
@ -46,7 +41,7 @@ def test_stack_vm_loop():
|
|||
tvm.make.Store(Ab.data,
|
||||
tvm.make.Load(dtype, Ab.data, i) + 1,
|
||||
i + 1),
|
||||
tvm.make.Evaluate(tvm_call_packed("tvm_stack_vm_print", i))))
|
||||
tvm.make.Evaluate(tvm.call_packed("tvm_stack_vm_print", i))))
|
||||
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
|
||||
a = tvm.nd.array(np.zeros(10, dtype=dtype))
|
||||
def check(f):
|
||||
|
|
|
@ -80,6 +80,30 @@ def test_scan_multi_out():
|
|||
zz = tvm.load_json(json_str)
|
||||
assert isinstance(zz, tvm.tensor.ScanOp)
|
||||
|
||||
def test_extern():
|
||||
m = tvm.Var('m')
|
||||
A = tvm.placeholder((m,), name='A')
|
||||
|
||||
def extern_func(ins, outs):
|
||||
assert(isinstance(ins[0], tvm.schedule.Buffer))
|
||||
return tvm.call_packed("myadd", ins[0].data, outs[0].data, m)
|
||||
B = tvm.extern((m,), [A], extern_func)
|
||||
assert(tuple(B.shape) == (m,))
|
||||
|
||||
|
||||
def test_extern_multi_out():
|
||||
m = tvm.Var('m')
|
||||
A = tvm.placeholder((m,), name='A')
|
||||
B = tvm.compute((m,), lambda i: A[i] * 10)
|
||||
|
||||
def extern_func(ins, outs):
|
||||
assert(isinstance(ins[0], tvm.schedule.Buffer))
|
||||
return tvm.call_packed(
|
||||
"myadd", ins[0].data, outs[0].data, outs[1].data, m)
|
||||
res = tvm.extern([A.shape, A.shape], [A, B], extern_func)
|
||||
assert(len(res) == 2)
|
||||
assert(res[1].value_index == 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv1d()
|
||||
|
@ -88,3 +112,5 @@ if __name__ == "__main__":
|
|||
test_tensor_reduce()
|
||||
test_tensor_scan()
|
||||
test_scan_multi_out()
|
||||
test_extern()
|
||||
test_extern_multi_out()
|
||||
|
|
|
@ -2,7 +2,10 @@
|
|||
|
||||
if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
|
||||
if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then
|
||||
make lint || exit -1
|
||||
echo "Check codestyle of c++ code..."
|
||||
make cpplint || exit -1
|
||||
echo "Check codestyle of python code..."
|
||||
make pylint || exit -1
|
||||
echo "Check documentations of c++ code..."
|
||||
make doc 2>log.txt
|
||||
(cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt
|
||||
|
|
Загрузка…
Ссылка в новой задаче