[Relay] Bitserial ops (#3844)
* Added arm_cpu NHWC schedules. * Fixed kernel shape legalization. * Added bitserial ops to relay. * Snapshot and more missing files. * Added dense testing. * Added tests * Added ASF header to new files. * cc lint * Pylint change. * pylint fixes. * Change arm legalize test. * Added assert check to arm legalize. * Added better documentation, fixed some bad style * Reverted arm conv2d nhwc changes.
This commit is contained in:
Родитель
73dc5ac379
Коммит
d08c74caf6
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tvm/relay/attrs/bitserial.h
|
||||
* \brief Auxiliary attributes for bitserial operators.
|
||||
*/
|
||||
|
||||
#ifndef TVM_RELAY_ATTRS_BITSERIAL_H_
|
||||
#define TVM_RELAY_ATTRS_BITSERIAL_H_
|
||||
|
||||
#include <tvm/attrs.h>
|
||||
#include <tvm/relay/base.h>
|
||||
#include <string>
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
/*! \brief Attributes used in bitpack operators */
|
||||
struct BitPackAttrs : public tvm::AttrsNode<BitPackAttrs> {
|
||||
int bits;
|
||||
int pack_axis;
|
||||
int bit_axis;
|
||||
DataType pack_type;
|
||||
std::string name;
|
||||
|
||||
TVM_DECLARE_ATTRS(BitPackAttrs, "relay.attrs.BitPackAttrs") {
|
||||
TVM_ATTR_FIELD(bits).set_default(1).describe("Number of bits to quantize with.");
|
||||
TVM_ATTR_FIELD(pack_axis).set_default(1).describe(
|
||||
"Axis that should be compressed, typically channels.");
|
||||
TVM_ATTR_FIELD(bit_axis).set_default(-1).describe("New axis for packed bits.");
|
||||
TVM_ATTR_FIELD(pack_type)
|
||||
.set_default(NullValue<DataType>())
|
||||
.describe("Type of int to pack bits into.");
|
||||
TVM_ATTR_FIELD(name).set_default("BitPack").describe("Name of operation.");
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Attribues used in bitserial convolution operators */
|
||||
struct BinaryConv2DAttrs : public tvm::AttrsNode<BinaryConv2DAttrs> {
|
||||
Array<IndexExpr> strides;
|
||||
Array<IndexExpr> padding;
|
||||
IndexExpr channels;
|
||||
Array<IndexExpr> kernel_size;
|
||||
int activation_bits;
|
||||
int weight_bits;
|
||||
std::string data_layout;
|
||||
std::string kernel_layout;
|
||||
DataType pack_dtype;
|
||||
DataType out_dtype;
|
||||
bool unipolar;
|
||||
|
||||
TVM_DECLARE_ATTRS(BinaryConv2DAttrs, "relay.attrs.BinaryConv2DAttrs") {
|
||||
TVM_ATTR_FIELD(strides)
|
||||
.set_default(Array<IndexExpr>({1, 1}))
|
||||
.describe("Specifies the strides of the convolution.");
|
||||
TVM_ATTR_FIELD(padding)
|
||||
.set_default(Array<IndexExpr>({0, 0}))
|
||||
.describe(
|
||||
"If padding is non-zero the input is implicitly zero-padded"
|
||||
"on both sides for padding number of points.");
|
||||
TVM_ATTR_FIELD(kernel_size)
|
||||
.set_default(Array<IndexExpr>({3, 3}))
|
||||
.describe("Specifies the dimensions of the convolution window.");
|
||||
TVM_ATTR_FIELD(channels)
|
||||
.set_default(NullValue<IndexExpr>())
|
||||
.describe("Number of output channels, needed for shape inference.");
|
||||
TVM_ATTR_FIELD(activation_bits)
|
||||
.set_default(1)
|
||||
.describe("Number of bits activation should be packed with.");
|
||||
TVM_ATTR_FIELD(weight_bits)
|
||||
.set_default(1)
|
||||
.describe("Number of bits kernel should be packed with.");
|
||||
TVM_ATTR_FIELD(data_layout)
|
||||
.set_default("NCHW")
|
||||
.describe("Dimension ordering of input data, can be 'NCHW' or NHWC'.");
|
||||
TVM_ATTR_FIELD(kernel_layout)
|
||||
.set_default("OIHW")
|
||||
.describe("Dimension ordering of kernel data, can be 'OIHW' or HWIO'.");
|
||||
TVM_ATTR_FIELD(pack_dtype)
|
||||
.set_default(NullValue<DataType>())
|
||||
.describe("Datatype to pack bits into.");
|
||||
TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output datatype.");
|
||||
TVM_ATTR_FIELD(unipolar).set_default(true).describe(
|
||||
"Whether to use unipolar or bipolar quantization.");
|
||||
}
|
||||
};
|
||||
|
||||
/*~ \brief Attributes for bitserial dense operator */
|
||||
struct BinaryDenseAttrs : public tvm::AttrsNode<BinaryDenseAttrs> {
|
||||
IndexExpr units;
|
||||
int data_bits;
|
||||
int weight_bits;
|
||||
DataType pack_dtype;
|
||||
DataType out_dtype;
|
||||
bool unipolar;
|
||||
|
||||
TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs") {
|
||||
TVM_ATTR_FIELD(units)
|
||||
.describe("Number of hidden units of the dense transformation.");
|
||||
TVM_ATTR_FIELD(data_bits)
|
||||
.set_default(1)
|
||||
.describe("Number of bits to pack for incoming tensor.");
|
||||
TVM_ATTR_FIELD(weight_bits)
|
||||
.set_default(1)
|
||||
.describe("Number of bits to pack for weight tensor.");
|
||||
TVM_ATTR_FIELD(pack_dtype)
|
||||
.set_default(NullValue<DataType>())
|
||||
.describe("Datatype to pack bits into before computation.");
|
||||
TVM_ATTR_FIELD(out_dtype)
|
||||
.set_default(NullValue<DataType>())
|
||||
.describe("Output data type.");
|
||||
TVM_ATTR_FIELD(unipolar)
|
||||
.set_default(true)
|
||||
.describe("Whether to use unipolar or bipolar quantization for inputs.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
#endif // TVM_RELAY_ATTRS_BITSERIAL_H_
|
|
@ -600,3 +600,120 @@ def schedule_deformable_conv2d(attrs, outs, target):
|
|||
|
||||
|
||||
reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
|
||||
|
||||
|
||||
@reg.register_compute("nn.bitpack")
|
||||
def compute_bitpack(attrs, inputs, out_dtype, target):
|
||||
"""Compute definition for bitpack"""
|
||||
bits = attrs.bits
|
||||
pack_axis = attrs.pack_axis
|
||||
bit_axis = attrs.bit_axis
|
||||
pack_type = attrs.pack_type
|
||||
name = attrs.name
|
||||
with target:
|
||||
out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type,
|
||||
name)
|
||||
return [out]
|
||||
|
||||
@reg.register_schedule("nn.bitpack")
|
||||
def schedule_bitpack(attrs, outs, target):
|
||||
with target:
|
||||
return topi.generic.schedule_bitpack(outs)
|
||||
|
||||
reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE)
|
||||
|
||||
|
||||
@reg.register_compute("nn.bitserial_conv2d")
|
||||
def compute_bitserial_conv2d(attrs, inputs, out_dtype, target):
|
||||
"""Compute definition for bitserial conv2d."""
|
||||
padding = get_const_tuple(attrs.padding)
|
||||
strides = get_const_tuple(attrs.strides)
|
||||
activation_bits = attrs.activation_bits
|
||||
weight_bits = attrs.weight_bits
|
||||
layout = attrs.data_layout
|
||||
pack_dtype = attrs.pack_dtype
|
||||
out_dtype = attrs.out_dtype
|
||||
unipolar = attrs.unipolar
|
||||
if layout == 'NCHW':
|
||||
with target:
|
||||
out = topi.nn.bitserial_conv2d_nchw(
|
||||
inputs[0], inputs[1], strides, padding, activation_bits,
|
||||
weight_bits, pack_dtype, out_dtype, unipolar)
|
||||
elif layout == 'NHWC':
|
||||
with target:
|
||||
out = topi.nn.bitserial_conv2d_nhwc(
|
||||
inputs[0], inputs[1], strides, padding, activation_bits,
|
||||
weight_bits, pack_dtype, out_dtype, unipolar)
|
||||
else:
|
||||
raise ValueError("Data layout not supported.")
|
||||
|
||||
return [out]
|
||||
|
||||
|
||||
@reg.register_schedule("nn.bitserial_conv2d")
|
||||
def schedule_bitserial_conv2d(attrs, outs, target):
|
||||
"""Schedule definition for bitserial conv2d."""
|
||||
layout = attrs.data_layout
|
||||
if layout == 'NCHW':
|
||||
with target:
|
||||
return topi.generic.schedule_bitserial_conv2d_nchw(outs)
|
||||
elif layout == 'NHWC':
|
||||
with target:
|
||||
return topi.generic.schedule_bitserial_conv2d_nhwc(outs)
|
||||
else:
|
||||
raise ValueError("Data layout not supported.")
|
||||
|
||||
@reg.register_legalize("nn.bitserial_conv2d")
|
||||
def legalize_bitserial_conv2d(attrs, inputs, types):
|
||||
"""Legalize bitserial_conv2d op.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
attrs : tvm.attrs.Attrs
|
||||
Attributes of current convolution
|
||||
inputs : list of tvm.relay.Expr
|
||||
The args of the Relay expr to be legalized
|
||||
types : list of types
|
||||
List of input and output types
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.relay.Expr
|
||||
The legalized expr
|
||||
"""
|
||||
return topi.nn.bitserial_conv2d_legalize(attrs, inputs, types)
|
||||
|
||||
|
||||
reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
|
||||
|
||||
|
||||
# bitserial_dense
|
||||
@reg.register_compute("nn.bitserial_dense")
|
||||
def compute_bitserial_dense(attrs, inputs, out_type, target):
|
||||
"""Compute definition of bitserial_dense"""
|
||||
data_bits = attrs.data_bits
|
||||
weight_bits = attrs.weight_bits
|
||||
pack_dtype = attrs.pack_dtype
|
||||
out_dtype = attrs.out_dtype
|
||||
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
|
||||
unipolar = attrs.unipolar
|
||||
return [
|
||||
topi.nn.bitserial_dense(
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
data_bits,
|
||||
weight_bits,
|
||||
pack_dtype,
|
||||
out_dtype,
|
||||
unipolar)
|
||||
]
|
||||
|
||||
|
||||
@reg.register_schedule("nn.bitserial_dense")
|
||||
def schedule_bitserial_dense(attrs, outputs, target):
|
||||
"""Schedule definition of bitserial_dense"""
|
||||
with target:
|
||||
return topi.generic.schedule_bitserial_dense(outputs)
|
||||
|
||||
|
||||
reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
|
||||
|
|
|
@ -1459,3 +1459,165 @@ def deformable_conv2d(data,
|
|||
return _make.deformable_conv2d(data, offset, weight, strides, padding, dilation,
|
||||
deformable_groups, groups, channels, kernel_size, data_layout,
|
||||
kernel_layout, out_layout, out_dtype)
|
||||
|
||||
|
||||
def bitpack(data,
|
||||
bits=1,
|
||||
pack_axis=1,
|
||||
bit_axis=2,
|
||||
pack_type="uint32",
|
||||
name="BitPack"):
|
||||
r"""Tensor packing for bitserial operations.
|
||||
The values along the input tensor's pack_axis are quantized
|
||||
and packed together into the specified pack_type in a new
|
||||
bit axis.
|
||||
|
||||
For example, consider bitpacking with data to be a tensor with shape [1, 64, 128, 128],
|
||||
pack_axis=1, bit_axis=4, pack_type=uint8, and bits=2. The output in this case will
|
||||
be of shape [1, 8, 128, 128, 2]. The dimension of axis 1 has been reduced by a factor
|
||||
of 8 since each value is packed into an 8-bit uint8. Axis 4 is now two bitplanes
|
||||
representing the quantized value of the incoming data. The output tensor is now
|
||||
ready to be used in a bitserial operation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : tvm.relay.expr
|
||||
The incoming tensor to be packed.
|
||||
|
||||
bits : int
|
||||
Number of bits that should be packed.
|
||||
|
||||
pack_axis : int
|
||||
Axis that should be decomposed and packed.
|
||||
|
||||
bit_axis : int
|
||||
New axis containing bitplane.
|
||||
|
||||
pack_type : str
|
||||
Datatype to pack bits into.
|
||||
|
||||
name : str, optional
|
||||
Name of the operation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.relay.Expr
|
||||
The packed tensor.
|
||||
"""
|
||||
return _make.bitpack(data, bits, pack_axis, bit_axis, pack_type, name)
|
||||
|
||||
|
||||
def bitserial_conv2d(data,
|
||||
weight,
|
||||
strides=(1, 1),
|
||||
padding=(0, 0),
|
||||
channels=None,
|
||||
kernel_size=(3, 3),
|
||||
activation_bits=1,
|
||||
weight_bits=1,
|
||||
data_layout='NCHW',
|
||||
kernel_layout='OIHW',
|
||||
pack_dtype='uint32',
|
||||
out_dtype='int16',
|
||||
unipolar=True):
|
||||
r"""2D convolution using bitserial computation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : tvm.relay.Expr
|
||||
The input data to the operator.
|
||||
|
||||
weight : tvm.relay.Expr
|
||||
The weight expressions.
|
||||
|
||||
strides : tuple of int, optional
|
||||
The strides of convolution.
|
||||
|
||||
padding : tuple of int, optional
|
||||
The padding of convolution on both sides of inputs before convolution.
|
||||
|
||||
channels : int, optional
|
||||
Number of output channels of this convolution.
|
||||
|
||||
kernel_size : tuple of int, optional
|
||||
The spatial of the convolution kernel.
|
||||
|
||||
activation_bits : int
|
||||
Number of bits to pack for activations.
|
||||
|
||||
weight_bits : int
|
||||
Number of bits to pack for weights.
|
||||
|
||||
data_layout : str, optional
|
||||
Layout of the input.
|
||||
|
||||
kernel_layout : str, optional
|
||||
Layout of the kernel
|
||||
|
||||
pack_dtype: str, optional
|
||||
Datatype to pack bits into.
|
||||
|
||||
out_dtype : str, optional
|
||||
Specifies the output data type for mixed precision conv2d.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.relay.Expr
|
||||
The computed result.
|
||||
"""
|
||||
|
||||
return _make.bitserial_conv2d(data, weight, strides, padding, channels,
|
||||
kernel_size, activation_bits, weight_bits,
|
||||
data_layout, kernel_layout, pack_dtype,
|
||||
out_dtype, unipolar)
|
||||
|
||||
|
||||
def bitserial_dense(data,
|
||||
weight,
|
||||
units=None,
|
||||
data_bits=1,
|
||||
weight_bits=1,
|
||||
pack_dtype='uint32',
|
||||
out_dtype='int16',
|
||||
unipolar=True):
|
||||
"""Bitserial Dense operator.
|
||||
Applies matrix multiplication of two quantized matrices
|
||||
using a fast bitserial algorithm.
|
||||
|
||||
.. math::
|
||||
|
||||
`Y = X * W`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : tvm.relay.Expr
|
||||
The input data to the operator.
|
||||
|
||||
weight : tvm.relay.Expr
|
||||
The weight expressions.
|
||||
|
||||
units : int, optional
|
||||
Number of hidden units of the dense transformation.
|
||||
|
||||
data_bits : int
|
||||
Number of bits incoming tensor should be packed with.
|
||||
|
||||
weight_bits : int
|
||||
Number of bits weight tensor should be packed with.
|
||||
|
||||
pack_dtype : str, optional
|
||||
Datatype to pack individual bits into before computation.
|
||||
|
||||
out_dtype : str, optional
|
||||
Specifies the output data type for mixed precision dense.
|
||||
|
||||
unipolar : bool, optional
|
||||
Whether to use unipolar or bipolar quantization for inputs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.relay.Expr
|
||||
The computed result.
|
||||
"""
|
||||
return _make.bitserial_dense(data, weight, units, data_bits, weight_bits,
|
||||
pack_dtype, out_dtype, unipolar)
|
||||
|
|
|
@ -264,3 +264,18 @@ class MaxPool2DAttrs(Attrs):
|
|||
@register_relay_attr_node
|
||||
class AvgPool2DAttrs(Attrs):
|
||||
"""Attributes used in avg_pool2d operators"""
|
||||
|
||||
|
||||
@register_relay_attr_node
|
||||
class BitPackAttrs(Attrs):
|
||||
"""Attributes used in bitpack operator"""
|
||||
|
||||
|
||||
@register_relay_attr_node
|
||||
class BinaryConv2DAttrs(Attrs):
|
||||
"""Attributes used in bitserial conv2d operators"""
|
||||
|
||||
|
||||
@register_relay_attr_node
|
||||
class BinaryDenseAttrs(Attrs):
|
||||
"""Attributes used in bitserial dense operators"""
|
||||
|
|
|
@ -0,0 +1,257 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file bitserial.cc
|
||||
* \brief Property def of bitserial operators.
|
||||
*/
|
||||
|
||||
#include <tvm/data_layout.h>
|
||||
#include <tvm/relay/attrs/bitserial.h>
|
||||
#include <tvm/relay/op.h>
|
||||
|
||||
#include "../../pass/alter_op_layout.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
// relay.nn.bitpack
|
||||
TVM_REGISTER_NODE_TYPE(BitPackAttrs);
|
||||
|
||||
template <typename T>
|
||||
Array<Array<Layout>> BinaryConv2DInferCorrectLayout(const Attrs& attrs,
|
||||
const Array<Layout>& new_in_layouts,
|
||||
const Array<Layout>& old_in_layouts,
|
||||
const Array<Array<IndexExpr>>& old_in_shapes) {
|
||||
const T* params = attrs.as<T>();
|
||||
|
||||
// We always make other operators to fit the layouts of convolution layers
|
||||
// So this inference ignores all inputs
|
||||
return Array<Array<Layout>>{{params->data_layout, params->kernel_layout}, {params->data_layout}};
|
||||
}
|
||||
|
||||
bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
|
||||
const TypeReporter& reporter) {
|
||||
const BitPackAttrs* param = attrs.as<BitPackAttrs>();
|
||||
CHECK_EQ(types.size(), 2);
|
||||
const auto* data = types[0].as<TensorTypeNode>();
|
||||
CHECK(data);
|
||||
int ndim = data->shape.size();
|
||||
int bits = param->bits;
|
||||
int pack_axis = param->pack_axis;
|
||||
int bit_axis = param->bit_axis;
|
||||
DataType pack_type = param->pack_type;
|
||||
|
||||
int pack_bits = pack_type.bits();
|
||||
|
||||
Array<IndexExpr> out_shape;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
if (i == bit_axis) {
|
||||
out_shape.push_back(bits);
|
||||
if (i == pack_axis) {
|
||||
out_shape.push_back(data->shape[i] / pack_bits);
|
||||
} else {
|
||||
out_shape.push_back(data->shape[i]);
|
||||
}
|
||||
} else if (i == pack_axis) {
|
||||
out_shape.push_back(data->shape[i] / pack_bits);
|
||||
} else {
|
||||
out_shape.push_back(data->shape[i]);
|
||||
}
|
||||
}
|
||||
// Add extra check for last axis expansion.
|
||||
if (bit_axis == ndim) {
|
||||
out_shape.push_back(bits);
|
||||
}
|
||||
|
||||
reporter->Assign(types[1], TensorTypeNode::make(out_shape, pack_type));
|
||||
return true;
|
||||
}
|
||||
|
||||
Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack_type,
|
||||
std::string name) {
|
||||
auto attrs = make_node<BitPackAttrs>();
|
||||
attrs->bits = bits;
|
||||
attrs->pack_axis = pack_axis;
|
||||
attrs->bit_axis = bit_axis;
|
||||
attrs->pack_type = pack_type;
|
||||
attrs->name = name;
|
||||
static const Op& op = Op::Get("nn.bitpack");
|
||||
return CallNode::make(op, {data}, Attrs(attrs), {});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack);
|
||||
|
||||
RELAY_REGISTER_OP("nn.bitpack")
|
||||
.describe(R"code(Bitpack layer that prepares data for bitserial operations.
|
||||
|
||||
This layer backs the bits of an input into a single datatype, allowing
|
||||
efficient implementation of bitserial operations.
|
||||
|
||||
- **data**: Input tensor of any shape, dimension that is to be
|
||||
packed must be divisible by number of bits.
|
||||
- **out**: Packed tensor with shape appropriately compressed.
|
||||
)code" TVM_ADD_FILELINE)
|
||||
.set_num_inputs(1)
|
||||
.set_attrs_type_key("relay.attrs.BitPackAttrs")
|
||||
.add_argument("data", "Tensor", "Input data.")
|
||||
.set_support_level(2)
|
||||
.add_type_rel("BitPack", BitPackRel);
|
||||
|
||||
// relay.nn.bitserial_conv2d
|
||||
TVM_REGISTER_NODE_TYPE(BinaryConv2DAttrs);
|
||||
|
||||
bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
|
||||
const TypeReporter& reporter) {
|
||||
CHECK_EQ(types.size(), 3);
|
||||
const auto* data = types[0].as<TensorTypeNode>();
|
||||
if (data == nullptr) return false;
|
||||
|
||||
const BinaryConv2DAttrs* param = attrs.as<BinaryConv2DAttrs>();
|
||||
CHECK(param != nullptr);
|
||||
|
||||
static const Layout kNCHW("NCHW");
|
||||
|
||||
const Layout in_layout(param->data_layout);
|
||||
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
|
||||
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
|
||||
CHECK(param->channels.defined());
|
||||
CHECK(param->kernel_size.defined());
|
||||
Array<IndexExpr> oshape({dshape_nchw[0], param->channels, 0, 0});
|
||||
oshape.Set(
|
||||
2, (dshape_nchw[2] + param->padding[0] * 2 - param->kernel_size[0]) / param->strides[0] + 1);
|
||||
oshape.Set(
|
||||
3, (dshape_nchw[3] + param->padding[1] * 2 - param->kernel_size[1]) / param->strides[1] + 1);
|
||||
DataType out_dtype = param->out_dtype;
|
||||
oshape = trans_in_layout.BackwardShape(oshape);
|
||||
// assign output type
|
||||
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
|
||||
return true;
|
||||
}
|
||||
|
||||
// Positional relay function to create binaryconv2d operator
|
||||
// used by frontend FFI.
|
||||
Expr MakeBinaryConv2D(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
|
||||
IndexExpr channels, Array<IndexExpr> kernel_size, int activation_bits,
|
||||
int weight_bits, std::string data_layout, std::string kernel_layout,
|
||||
DataType pack_dtype, DataType out_dtype, bool unipolar) {
|
||||
auto attrs = make_node<BinaryConv2DAttrs>();
|
||||
attrs->strides = std::move(strides);
|
||||
attrs->padding = std::move(padding);
|
||||
attrs->channels = std::move(channels);
|
||||
attrs->kernel_size = std::move(kernel_size);
|
||||
attrs->activation_bits = activation_bits;
|
||||
attrs->weight_bits = weight_bits;
|
||||
attrs->data_layout = std::move(data_layout);
|
||||
attrs->kernel_layout = std::move(kernel_layout);
|
||||
attrs->pack_dtype = std::move(pack_dtype);
|
||||
attrs->out_dtype = std::move(out_dtype);
|
||||
attrs->unipolar = unipolar;
|
||||
static const Op& op = Op::Get("nn.bitserial_conv2d");
|
||||
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay.op.nn._make.bitserial_conv2d").set_body_typed(MakeBinaryConv2D);
|
||||
|
||||
RELAY_REGISTER_OP("nn.bitserial_conv2d")
|
||||
.describe(R"code(2D convolution using packed binary computation.
|
||||
|
||||
This layer creates a convolution kernel that is convolved with the
|
||||
layer input using bitserial computation. This enables faster processing
|
||||
on some platforms.
|
||||
|
||||
- **data**: 4D input tensor that can be either `NCHW` or `NHWC` layout.
|
||||
|
||||
- **weight**: Weight tensor that can either be prepacked (5D) or unpacked (4D).
|
||||
When data is NCHW, weight is expected to be OIHW or OIHWi.
|
||||
When data is NHWC weight is expected to be HWIO or HWIOi.
|
||||
|
||||
- **out**: Output with same layout as input.
|
||||
)code" TVM_ADD_FILELINE)
|
||||
.set_attrs_type_key("relay.attrs.BinaryConv2DAttrs")
|
||||
.set_num_inputs(2)
|
||||
.add_argument("data", "Tensor", "The input tensor.")
|
||||
.add_argument("weight", "Tensor", "The weight tensor.")
|
||||
.set_support_level(2)
|
||||
.add_type_rel("BinaryConv2D", BinaryConv2DRel)
|
||||
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
|
||||
BinaryConv2DInferCorrectLayout<BinaryConv2DAttrs>);
|
||||
|
||||
// relay.nn.bitserial_dense
|
||||
TVM_REGISTER_NODE_TYPE(BinaryDenseAttrs);
|
||||
|
||||
bool BinaryDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
|
||||
const TypeReporter& reporter) {
|
||||
CHECK_EQ(types.size(), 3);
|
||||
const auto* data = types[0].as<TensorTypeNode>();
|
||||
if (data == nullptr) return false;
|
||||
|
||||
const BinaryDenseAttrs* param = attrs.as<BinaryDenseAttrs>();
|
||||
CHECK(param != nullptr);
|
||||
|
||||
CHECK(static_cast<int>(data->shape.size()) != 0);
|
||||
CHECK(param->units.defined());
|
||||
|
||||
Array<tvm::Expr> oshape = data->shape;
|
||||
oshape.Set((oshape.size() - 1), param->units);
|
||||
|
||||
DataType out_dtype = param->out_dtype;
|
||||
if (out_dtype.bits() == 0) {
|
||||
out_dtype = data->dtype;
|
||||
}
|
||||
|
||||
// Assign output type.
|
||||
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
|
||||
return true;
|
||||
}
|
||||
|
||||
// Positional relay function to create bitserial dense operator used by frontend FFI.
|
||||
Expr MakeBinaryDense(Expr data, Expr weight, IndexExpr units, int data_bits, int weight_bits,
|
||||
DataType pack_dtype, DataType out_dtype, bool unipolar) {
|
||||
auto attrs = make_node<BinaryDenseAttrs>();
|
||||
attrs->units = units;
|
||||
attrs->data_bits = data_bits;
|
||||
attrs->weight_bits = weight_bits;
|
||||
attrs->pack_dtype = pack_dtype;
|
||||
attrs->out_dtype = out_dtype;
|
||||
attrs->unipolar = unipolar;
|
||||
static const Op& op = Op::Get("nn.bitserial_dense");
|
||||
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay.op.nn._make.bitserial_dense").set_body_typed(MakeBinaryDense);
|
||||
|
||||
RELAY_REGISTER_OP("nn.bitserial_dense")
|
||||
.describe(R"code(Applies a quantized linear transformation: :math:`Y = XW^T`.
|
||||
|
||||
- **data**: `(x1, x2, ..., xn, input_dim)`
|
||||
- **weight**: `(units, input_dim)`
|
||||
- **out**: `(x1, x2, ..., xn, units)`.
|
||||
|
||||
)code" TVM_ADD_FILELINE)
|
||||
.set_attrs_type_key("relay.attrs.BinaryDenseAttrs")
|
||||
.set_num_inputs(2)
|
||||
.add_argument("data", "2D Tensor", "Input data.")
|
||||
.add_argument("weight", "2D Tensor", "Weight matrix.")
|
||||
.set_support_level(1)
|
||||
.add_type_rel("BinaryDense", BinaryDenseRel);
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
|
@ -337,6 +337,16 @@ def test_dense():
|
|||
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
|
||||
|
||||
|
||||
def test_bitserial_dense():
|
||||
m, k = tvm.var("m"), tvm.var("k")
|
||||
x = relay.var("x", relay.TensorType((m, k), "int16"))
|
||||
w = relay.var("w", relay.TensorType((k, 32), "int16"))
|
||||
y = relay.nn.bitserial_dense(x, w, units=32)
|
||||
"units=8" in y.astext()
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((m, 32), "int16")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_concatenate()
|
||||
test_bias_add()
|
||||
|
@ -349,3 +359,4 @@ if __name__ == "__main__":
|
|||
test_dropout()
|
||||
test_batch_norm()
|
||||
test_dense()
|
||||
test_bitserial_dense()
|
||||
|
|
|
@ -105,8 +105,8 @@ def test_conv2d_run():
|
|||
except_targets=None,
|
||||
**attrs):
|
||||
if except_targets is None:
|
||||
except_targets = []
|
||||
|
||||
except_targets = []
|
||||
|
||||
x = relay.var("x", shape=dshape, dtype=dtype)
|
||||
w = relay.var("w", dtype=dtype)
|
||||
y = relay.nn.conv2d(x, w,
|
||||
|
@ -599,12 +599,35 @@ def test_conv2d_int8_intrinsics():
|
|||
assert "vpmulld" in asm and "vpadd" in asm
|
||||
|
||||
|
||||
def test_bitserial_conv2d_infer_type():
|
||||
# Basic shape test with ambiguous batch.
|
||||
n, c, h, w = tvm.var("n"), 32, 224, 224
|
||||
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "int16"))
|
||||
w = relay.var("w", relay.ty.TensorType((32, 32, 3, 3), "int16"))
|
||||
y = relay.nn.bitserial_conv2d(
|
||||
x, w, kernel_size=(3, 3), padding=(0, 0), channels=32)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(n, 32, 222, 222), "int16")
|
||||
|
||||
|
||||
def test_bitpack_infer_type():
|
||||
# Test axis packing shape inference.
|
||||
o, i, h, w = 32, 32, 128, 128
|
||||
x = relay.var("x", relay.ty.TensorType((o, i, h, w), "int16"))
|
||||
y = relay.nn.bitpack(x, bit_axis=4, pack_axis=1, pack_type='uint16', bits=1)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(32, 2, 128, 128, 1), "uint16")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pool2d()
|
||||
test_avg_pool2d_no_count_pad()
|
||||
test_lrn()
|
||||
test_l2_normalize()
|
||||
test_conv2d_infer_type()
|
||||
test_bitpack_infer_type()
|
||||
test_upsampling_infer_type()
|
||||
test_flatten_infer_type()
|
||||
test_pad_infer_type()
|
||||
|
@ -612,6 +635,7 @@ if __name__ == "__main__":
|
|||
test_conv2d_transpose_infer_type()
|
||||
test_conv2d_transpose_run()
|
||||
test_conv2d_run()
|
||||
test_bitserial_conv2d_infer_type()
|
||||
test_batch_flatten()
|
||||
test_upsampling()
|
||||
test_conv2d_int8_intrinsics()
|
||||
|
|
|
@ -14,14 +14,15 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=invalid-name,unused-variable,invalid-name
|
||||
# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument
|
||||
"""Bitserial conv2d schedule on arm cpu"""
|
||||
from __future__ import absolute_import as _abs
|
||||
import tvm
|
||||
from tvm import autotvm
|
||||
from tvm import relay
|
||||
from .. import tag
|
||||
from ..nn.pad import pad
|
||||
from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc
|
||||
from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc, bitserial_conv2d_legalize
|
||||
from ..nn.bitserial_util import bitpack, binary_op_multiplier
|
||||
from ..nn.util import get_pad_tuple
|
||||
from ..util import get_const_int, get_const_tuple
|
||||
|
@ -350,3 +351,40 @@ def schedule_bitserial_conv2d_nhwc(cfg, outs):
|
|||
|
||||
traverse(outs[0].op)
|
||||
return s
|
||||
|
||||
@bitserial_conv2d_legalize.register("arm_cpu")
|
||||
def _bitserial_conv2d_legalize(attrs, inputs, arg_types):
|
||||
"""Legalizes Bitserial Conv2D op.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
attrs : tvm.attrs.Attrs
|
||||
Attributes of current convolution
|
||||
inputs : list of tvm.relay.Expr
|
||||
The args of the Relay expr to be legalized
|
||||
types : list of types
|
||||
List of input and output types
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.relay.Expr
|
||||
The legalized expr
|
||||
"""
|
||||
|
||||
# Fix different kernel layouts where possible.
|
||||
if attrs['data_layout'] == 'NHWC':
|
||||
data, kernel = inputs
|
||||
if len(kernel.data.shape) == 4:
|
||||
# HWIO layout is expected for NHWC input.
|
||||
if attrs['kernel_layout'] == 'HWOI':
|
||||
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
|
||||
kernel = relay.transpose(kernel, axes=(0, 1, 3, 2))
|
||||
elif attrs['kernel_layout'] == 'OIHW':
|
||||
kernel = relay.transpose(kernel, axes=(2, 3, 1, 0))
|
||||
## Set new attrs for the tranposed conv.
|
||||
new_attrs = {k: attrs[k] for k in attrs.keys()}
|
||||
new_attrs['kernel_layout'] = 'HWIO'
|
||||
|
||||
conv = relay.nn.bitserial_conv2d(data, kernel, **new_attrs)
|
||||
return conv
|
||||
return None
|
||||
|
|
|
@ -470,6 +470,23 @@ def schedule_binarize_pack(outs):
|
|||
return _default_schedule(outs, False)
|
||||
|
||||
|
||||
@tvm.target.override_native_generic_func("schedule_bitpack")
|
||||
def schedule_bitpack(outs):
|
||||
"""Schedule for bitpack
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of bitpack
|
||||
in the format of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
||||
|
||||
|
||||
@tvm.target.override_native_generic_func("schedule_binary_dense")
|
||||
def schedule_binary_dense(outs):
|
||||
"""Schedule for binary_dense
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
|
||||
# pylint: disable=unused-argument, redefined-builtin
|
||||
"""Bitserial Conv2D operators"""
|
||||
from __future__ import absolute_import as _abs
|
||||
import tvm
|
||||
|
@ -65,7 +66,10 @@ def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight
|
|||
"""
|
||||
assert isinstance(stride, int) or len(stride) == 2
|
||||
Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype)
|
||||
Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype)
|
||||
if len(filter.shape) == 4:
|
||||
Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype)
|
||||
else:
|
||||
Filter_q = filter
|
||||
batch, in_channel, activation_bits, in_height, in_width = Input_q.shape
|
||||
num_filter, _, kernel_h, kernel_w, weight_bits = Filter_q.shape
|
||||
|
||||
|
@ -414,3 +418,24 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
|
|||
return tvm.compute(oshape, lambda n, h, w, co:
|
||||
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC],
|
||||
name='output_unpack', tag='spatial_bitserial_conv_nhwc')
|
||||
|
||||
@tvm.target.generic_func
|
||||
def bitserial_conv2d_legalize(attrs, inputs, types):
|
||||
"""Legalizes Bitserial Conv2D op.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
attrs : tvm.attrs.Attrs
|
||||
Attributes of current convolution
|
||||
inputs : list of tvm.relay.Expr
|
||||
The args of the Relay expr to be legalized
|
||||
types : list of types
|
||||
List of input and output types
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.relay.Expr
|
||||
The legalized expr
|
||||
"""
|
||||
# not to change by default
|
||||
return None
|
||||
|
|
Загрузка…
Ссылка в новой задаче