Move dense compute back to python (#364)
This commit is contained in:
Родитель
5884cd01f3
Коммит
f1e0a55abc
|
@ -51,6 +51,13 @@ reg.register_pattern("log_softmax", OpPattern.OPAQUE)
|
|||
|
||||
|
||||
# dense
|
||||
@reg.register_compute("dense")
|
||||
def compute_dense(attrs, inputs, _):
|
||||
"""Compute definition of dense"""
|
||||
if attrs.get_bool("use_bias"):
|
||||
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
|
||||
return topi.nn.dense(inputs[0], inputs[1])
|
||||
|
||||
@reg.register_schedule("dense")
|
||||
def schedule_dense(_, outs, target):
|
||||
"""Schedule definition of dense"""
|
||||
|
|
|
@ -82,21 +82,6 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored.
|
|||
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>)
|
||||
.set_attr<FInferShape>("FInferShape", DenseInferShape)
|
||||
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info) {
|
||||
Tensor bias_val;
|
||||
Tensor* bias;
|
||||
const DenseParam& param = nnvm::get<DenseParam>(attrs.parsed);
|
||||
if (param.use_bias) {
|
||||
bias_val = inputs[2];
|
||||
bias = &bias_val;
|
||||
} else {
|
||||
bias = nullptr;
|
||||
}
|
||||
return Array<Tensor>{ topi::nn::dense(inputs[0], inputs[1], bias) };
|
||||
})
|
||||
.set_attr<FGradient>(
|
||||
"FGradient", [](const NodePtr& n,
|
||||
const std::vector<NodeEntry>& ograds) {
|
||||
|
|
Загрузка…
Ссылка в новой задаче