[UTILS, DOC] Use TVM file downloading utility, conv2d tutorial (#48)

This commit is contained in:
Thierry Moreau 2018-06-22 13:40:22 -07:00 коммит произвёл Tianqi Chen
Родитель d1128cedfb
Коммит 4ba6bd50dd
9 изменённых файлов: 156 добавлений и 108 удалений

Просмотреть файл

@ -2,12 +2,20 @@
Follow the first two parts of the [Installation Guide](../../../docs/how_to/install.md) to make sure that the VTA python libraries are installed, and that the RPC server is running on the Pynq FPGA dev board.
Simply run the following python script:
We recommend leaving the `config.json` to its default parameterization (of course you can change the target between "sim" and "pynq").
Simply run the example program. We rely on pickle to store parameters which now only works with python2.
```bash
python imagenet_predict.py
python2 imagenet_predict.py
```
This will run imagenet classification using the ResNet18 architecture on a VTA design that performs 8-bit integer inference, to perform classification on a cat image `cat.jpg`.
The script will first download the following files into `_data/` directory:
* `cat.jpg` which provides a test sample for the ImageNet classifier
* `quantize_graph.json` which describes the NNVM graph of the 8-bit ResNet-18
* `quantize_params.plk` which contains the network parameters
* `synset.txt` which contains the ImageNet categories
Next, it will run imagenet classification using the ResNet18 architecture on a VTA design that performs 8-bit integer inference, to perform classification on a cat image `cat.jpg`.
The script reports runtime measured on the Pynq board (in seconds), and the top-1 result category:
```

Просмотреть файл

@ -1,17 +1,18 @@
# some standard imports
import nnvm
import tvm
from nnvm.compiler import graph_attr
import vta
import vta.testing
import os
import numpy as np
from PIL import Image
import pickle
import json
import logging
import wget
from PIL import Image
from nnvm.compiler import graph_attr
from tvm.contrib import graph_runtime, rpc, util
from tvm.contrib.download import download
bfactor = 1
cfactor = 16
@ -20,15 +21,20 @@ verbose = False
debug_fpga_only = False
# Obtain model and hardware files (they're too large to check-in)
# Download them into _data dir
data_dir = "_data/"
url = "https://homes.cs.washington.edu/~moreau/media/vta/"
TEST_FILE = 'cat.jpg'
CATEG_FILE = 'synset.txt'
RESNET_GRAPH_FILE = 'quantize_graph.json'
RESNET_PARAMS_FILE = 'quantize_params.pkl'
# Create data dir
if not os.path.exists(data_dir):
os.makedirs(data_dir)
# Download files
for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE]:
if not os.path.isfile(file):
print ("Downloading {}".format(file))
wget.download(url+file)
download(os.path.join(url, file), os.path.join(data_dir, file))
if verbose:
logging.basicConfig(level=logging.DEBUG)
@ -40,8 +46,8 @@ target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+
if vta.get_env().TARGET == "sim":
target_host = "llvm"
synset = eval(open(os.path.join(CATEG_FILE)).read())
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
synset = eval(open(os.path.join(data_dir, CATEG_FILE)).read())
image = Image.open(os.path.join(data_dir, TEST_FILE)).resize((224, 224))
def transform_image(image):
image = np.array(image) - np.array([123., 117., 104.])
@ -88,9 +94,9 @@ print('x', x.shape)
import nnvm.compiler
np.random.seed(0)
sym = nnvm.graph.load_json(
open(os.path.join(RESNET_GRAPH_FILE)).read())
open(os.path.join(data_dir, RESNET_GRAPH_FILE)).read())
params = pickle.load(
open(os.path.join(RESNET_PARAMS_FILE)))
open(os.path.join(data_dir, RESNET_PARAMS_FILE), 'rb'))
shape_dict = {"data": x.shape}
dtype_dict = {"data": 'float32'}

Просмотреть файл

@ -2,9 +2,16 @@
from __future__ import absolute_import as _abs
import os
import urllib
import sys
from tvm.contrib.download import download
from .environment import get_env
if sys.version_info >= (3,):
import urllib.error as urllib2
else:
import urllib2
# bitstream repo
BITSTREAM_URL = "https://github.com/uwsaml/vta-distro/raw/master/bitstreams/"
@ -41,15 +48,25 @@ def download_bitstream():
url = os.path.join(BITSTREAM_URL, env.TARGET)
url = os.path.join(url, env.HW_VER)
url = os.path.join(url, env.BITSTREAM)
# Check that the bitstream is accessible from the server
if urllib.urlopen(url).getcode() == 404:
# Raise error - the solution when this happens it to build your own bitstream and add it
# to your VTA_CACHE_PATH
raise RuntimeError(
"Error: {} is not available. It appears that this configuration has not been built."
.format(url))
else:
urllib.urlretrieve(url, bit)
success = True
try:
download(url, bit)
except urllib2.HTTPError as err:
if err.code == 404:
raise RuntimeError(
# Raise error - the solution when this happens it to build your
# own bitstream and add it to your $VTA_CACHE_PATH
"{} is not available. It appears that this configuration \
bistream has not been cached. Please compile your own bitstream (see hardware \
compilation guide to get Xilinx toolchains setup) and add it to your \
$VTA_CACHE_PATH. Alternatively edit your config.json back to its default \
settings. You can see the list of available bitstreams under {}"
.format(url, BITSTREAM_URL))
else:
raise RuntimeError(
# This could happen when trying to access the URL behind a proxy
"Something went wrong when trying to access {}. Check your \
internet connection or proxy settings."
.format(url))
return success

Просмотреть файл

@ -15,23 +15,34 @@ def run(run_func):
"""
env = get_env()
# Run on local sim rpc if necessary
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
env.TARGET = "sim"
remote = rpc.connect("localhost", local_rpc)
run_func(env, remote)
else:
# run on simulator
if simulator.enabled():
env.TARGET = "sim"
if env.TARGET == "sim":
# Talk to local RPC if necessary to debug RPC server.
# Compile vta on your host with make at the root.
# Make sure TARGET is set to "sim" in the config.json file.
# Then launch the RPC server on the host machine
# with ./apps/pynq_rpc/start_rpc_server.sh
# Set your VTA_LOCAL_SIM_RPC environment variable to
# the port it's listening to, e.g. 9090
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
remote = rpc.connect("localhost", local_rpc)
run_func(env, remote)
else:
# Make sure simulation library exists
# If this fails, build vta on host (make)
# with TARGET="sim" in the json.config file.
assert simulator.enabled()
run_func(env, rpc.LocalSession())
# Run on PYNQ if env variable exists
host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
if host:
env.TARGET = "pynq"
port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
port = int(port)
remote = rpc.connect(host, port)
run_func(env, remote)
elif env.TARGET == "pynq":
# Run on PYNQ if env variable exists
host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None))
if host and port:
remote = rpc.connect(host, port)
run_func(env, remote)
else:
raise RuntimeError(
"Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables")

Просмотреть файл

@ -18,7 +18,8 @@ def test_gemm():
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
num_ops = channel * channel * batch_size
# To compute number of ops, use a x2 factor for FMA
num_ops = 2 * channel * channel * batch_size
ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki')
@ -157,14 +158,14 @@ def test_gemm():
def gemm_normal(print_ir):
mock = env.mock
print("----- GEMM GFLOPS End-to-End Test-------")
print("----- GEMM GOPS End-to-End Test-------")
def run_test(header, print_ir, check_correctness):
cost = run_schedule(
env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy,
print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir, True)
@ -177,7 +178,7 @@ def test_gemm():
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir)
@ -190,7 +191,7 @@ def test_gemm():
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
@ -204,7 +205,7 @@ def test_gemm():
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with vta.build_config():
run_test("NORMAL", print_ir)
@ -219,7 +220,7 @@ def test_gemm():
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with vta.build_config():
run_test("NORMAL", print_ir)
@ -235,7 +236,7 @@ def test_gemm():
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with vta.build_config():
run_test("NORMAL", print_ir)

Просмотреть файл

@ -42,6 +42,7 @@ def test_vta_conv2d():
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
# To compute number of ops, use a x2 factor for FMA
num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
@ -118,7 +119,7 @@ def test_vta_conv2d():
print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
conv_normal(False)

Просмотреть файл

@ -46,10 +46,10 @@ def get_insn_count(layer, sched):
env = vta.get_env()
b, h, w, ci, co = sched
b_factor = b
h_factor = layer.height / h
w_factor = layer.width / w
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * env.BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * env.BLOCK_OUT)))
h_factor = layer.height // h
w_factor = layer.width // w
ci_factor = layer.in_filter // (ci * env.BLOCK_IN)
co_factor = layer.out_filter // (co * env.BLOCK_OUT)
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
output_xfers = b_factor * h_factor * w_factor * co_factor
@ -69,11 +69,11 @@ def find_schedules(layer, mtOnly=False, bestOnly=False):
factors.append(i)
return factors
# Scheduling exploration
batch_factors = find_factors(int(np.ceil(float(layer.batch) / env.BATCH)))
height_factors = find_factors(layer.height / layer.hstride)
width_factors = find_factors(layer.width / layer.wstride)
cin_factors = find_factors(int(np.ceil(float(layer.in_filter) / env.BLOCK_IN)))
cout_factors = find_factors(int(np.ceil(float(layer.out_filter) / env.BLOCK_OUT)))
batch_factors = find_factors(layer.batch // env.BATCH)
height_factors = find_factors(layer.height // layer.hstride)
width_factors = find_factors(layer.width // layer.wstride)
cin_factors = find_factors(layer.in_filter // env.BLOCK_IN)
cout_factors = find_factors(layer.out_filter // env.BLOCK_OUT)
ht_factors = [1, 2]
cot_factors = [1, 2]
# Explore schedules
@ -124,7 +124,7 @@ def find_schedules(layer, mtOnly=False, bestOnly=False):
if input_tile_elems*input_elem_size_b <= input_brams_capacity_b/(cot*ht) and \
weight_tile_elems*weight_elem_size_b <= weight_brams_capacity_b and \
output_tile_elems*output_elem_size_b <= output_brams_capacity_b/(cot*ht) and \
insn_count <= env.MAX_XFER / 16 and \
insn_count <= env.MAX_XFER // 16 and \
h > 2 and w > 2:
schedule = Schedule(oc_factor=co, ko_factor=ci, h_factor=h,
w_factor=w, oc_nthread=cot, h_nthread=ht)
@ -154,19 +154,19 @@ def get_data_movementB(sched, layer):
weight_tile_elems = layer.hkernel * layer.wkernel * ci
output_tile_elems = b * h * w * co
# Derive factors
b_factor = int(np.ceil(float(layer.batch) / (b * env.BATCH)))
h_factor = (layer.height / layer.hstride) / h
w_factor = (layer.width / layer.wstride) / w
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * env.BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * env.BLOCK_OUT)))
b_factor = layer.batch // (b * env.BATCH)
h_factor = (layer.height // layer.hstride) // h
w_factor = (layer.width // layer.wstride) // w
ci_factor = int(np.ceil(float(layer.in_filter) // (ci * env.BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) // (co * env.BLOCK_OUT)))
# Derive transfers
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
output_xfers = b_factor * h_factor * w_factor * co_factor
# Compute total transfer sizes
input_xfer_B = input_tile_elems * input_xfers * input_elem_size_b / 8
weight_xfer_B = weight_tile_elems * weight_xfers * weight_elem_size_b / 8
output_xfer_B = output_tile_elems * output_xfers * output_elem_size_b / 8
input_xfer_B = input_tile_elems * input_xfers * input_elem_size_b // 8
weight_xfer_B = weight_tile_elems * weight_xfers * weight_elem_size_b // 8
output_xfer_B = output_tile_elems * output_xfers * output_elem_size_b // 8
total_xfer_B = input_xfer_B + weight_xfer_B + output_xfer_B
return total_xfer_B
@ -175,13 +175,13 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
assert batch_size % env.BATCH == 0
assert wl.in_filter % env.BLOCK_IN == 0
assert wl.out_filter % env.BLOCK_OUT == 0
data_shape = (batch_size//env.BATCH, wl.in_filter//env.BLOCK_IN,
data_shape = (batch_size // env.BATCH, wl.in_filter // env.BLOCK_IN,
wl.height, wl.width, env.BATCH, env.BLOCK_IN)
kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN,
kernel_shape = (wl.out_filter // env.BLOCK_OUT, wl.in_filter // env.BLOCK_IN,
wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
res_shape = (batch_size//env.BATCH, wl.out_filter//env.BLOCK_OUT,
res_shape = (batch_size // env.BATCH, wl.out_filter // env.BLOCK_OUT,
fout_height, fout_width, env.BATCH, env.BLOCK_OUT)
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
@ -201,7 +201,8 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
kernel_buf[co, ko, di, dj, ci, ki].astype(env.acc_dtype),
axis=[ko, di, dj, ki]),
name="res_cnv")
res_shf = tvm.compute(res_shape, lambda *i: res_cnv(*i) >> 8, name="res_shf")
# res_shf = tvm.compute(res_shape, lambda *i: res_cnv(*i) >> 8, name="res_shf")
res_shf = topi.right_shift(res_cnv, 8)
res = tvm.compute(res_shape, lambda *i: res_shf(*i).astype(env.inp_dtype), name="res")
num_ops = batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
total_xfer_B = get_data_movementB(sched, wl)
@ -310,7 +311,7 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
log_frame["key"].append(key)
log_frame["layer"].append(layer)
log_frame["total-data"].append(total_xfer_B)
@ -347,7 +348,7 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
log_frame["skip-alu-gops"].append(gops)
log_frame["skip-alu-cost"].append(cost.mean)
@ -365,7 +366,7 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
log_frame["gemm-gops"].append(gops)
log_frame["gemm-cost"].append(cost.mean)
with vta.build_config():
@ -382,7 +383,7 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
log_frame["alu-gops"].append(gops)
log_frame["alu-cost"].append(cost.mean)
with vta.build_config():
@ -401,7 +402,7 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
bandwith = (batch_size * wl.in_filter * wl.height *
wl.width * env.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
print("\tTime cost = %g sec/op, %g GOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith))
log_frame["ld-inp-gbits"].append(bandwith)
log_frame["ld-inp-cost"].append(cost.mean)
@ -421,7 +422,7 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
bandwith = (wl.out_filter * wl.in_filter * wl.hkernel *
wl.wkernel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
print("\tTime cost = %g sec/op, %g GOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith))
log_frame["ld-wgt-gbits"].append(bandwith)
log_frame["ld-wgt-cost"].append(cost.mean)
@ -441,7 +442,7 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
bandwith = (batch_size * wl.out_filter * fout_height *
fout_width * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
print("\tTime cost = %g sec/op, %g GOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith))
log_frame["st-out-gbits"].append(bandwith)
log_frame["st-out-cost"].append(cost.mean)
@ -460,7 +461,7 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (
print("\tTime cost = %g sec/op, %g GOPS" % (
cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir)
@ -532,10 +533,11 @@ for x in resnet_schedules:
key = "resnet-cfg[{}-{}]".format(l, plan)
test_conv2d_chwv(l, key, batch_size, resnet[l], plan, log_frame, profile)
pd.set_option('expand_frame_repr', False)
log_df = pd.DataFrame()
for k in keys:
log_df[k] = log_frame[k]
print(log_df)
log_df.to_csv("conv2d.csv")
if profile:
pd.set_option('expand_frame_repr', False)
log_df = pd.DataFrame()
for k in keys:
log_df[k] = log_frame[k]
print(log_df)
log_df.to_csv("conv2d.csv")

Просмотреть файл

@ -91,6 +91,7 @@ elif env.TARGET == "sim":
#
# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/tensor_core.png
# :align: center
# :width: 480px
#
# The dimensions of that matrix-matrix multiplication are specified in
# the :code:`config.json` configuration file.
@ -109,6 +110,7 @@ elif env.TARGET == "sim":
#
# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/data_tiling.png
# :align: center
# :width: 480px
#
# We first define the variables :code:`m`, :code:`n`, :code:`o` to represent
# the shape of the matrix multiplication. These variables are multiplicative
@ -119,7 +121,6 @@ elif env.TARGET == "sim":
# 1 implies that our compute building block is vector-matrix multiply).
#
######################################################################
# .. note::
#

Просмотреть файл

@ -66,7 +66,7 @@ elif env.TARGET == "sim":
# :code:`BATCH`, :code:`BLOCK_IN`, and :code:`BLOCK_OUT` respectively.
#
# We've added extra operators to the matrix multiplication that apply
# shifting and clipping to the output in order to mimic the a fixed-point
# shifting and clipping to the output in order to mimic a fixed-point
# matrix multiplication followed by a rectified linear activation.
# We describe the TVM dataflow graph of the fully connected layer below:
#
@ -152,7 +152,7 @@ res = tvm.compute(output_shape,
# Those include:
#
# - Computation blocking
# - Computation lowering to VTA hardware intrinsics
# - Lowering to VTA hardware intrinsics
# Create TVM schedule
@ -161,8 +161,8 @@ s = tvm.create_schedule(res.op)
print(tvm.lower(s, [data, weight, res], simple_mode=True))
######################################################################
# Tiling the Computation
# ~~~~~~~~~~~~~~~~~~~~~~
# Blocking the Computation
# ~~~~~~~~~~~~~~~~~~~~~~~~
# The matrix multiplication is by default too large for activations or weights
# to fit on VTA's on-chip buffers all at once.
# We block the (1, 1024) by (1024, 1024) matrix multiplication into
@ -180,8 +180,7 @@ print(tvm.lower(s, [data, weight, res], simple_mode=True))
#
# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/blocking.png
# :align: center
# :height: 367px
# :width: 387px
# :width: 480px
#
# .. note::
#
@ -236,7 +235,7 @@ s[res_shr].compute_at(s[res], oc_out)
s[res_max].compute_at(s[res], oc_out)
s[res_min].compute_at(s[res], oc_out)
# Apply additional loop split along input channel axis
# Apply additional loop split along reduction axis (input channel)
b_inn, oc_inn, b_tns, oc_tns = s[res_gemm].op.axis
ic_out, ic_inn = s[res_gemm].split(ic, i_block)
@ -273,6 +272,8 @@ s[data_buf].pragma(s[data_buf].op.axis[0], env.dma_copy)
s[weight_buf].pragma(s[weight_buf].op.axis[0], env.dma_copy)
# Use DMA copy pragma on SRAM->DRAM operation
# (this implies that these copies should be performed along b_inn,
# or result axis 2)
s[res].pragma(s[res].op.axis[2], env.dma_copy)
######################################################################
@ -313,21 +314,21 @@ f = remote.load_module("gemm.o")
# Get the remote device context
ctx = remote.ext_dev(0)
# Initialize the A and B arrays randomly in the int range of (-128, 128]
data = np.random.randint(
# Initialize the data and weight arrays randomly in the int range of (-128, 128]
data_np = np.random.randint(
-128, 128, size=(batch_size, in_channels)).astype(data.dtype)
weight = np.random.randint(
weight_np = np.random.randint(
-128, 128, size=(out_channels, in_channels)).astype(weight.dtype)
# Apply packing to the A and B arrays from a 2D to a 4D packed layout
data_packed = data.reshape(batch_size // env.BATCH,
env.BATCH,
in_channels // env.BLOCK_IN,
env.BLOCK_IN).transpose((0, 2, 1, 3))
weight_packed = weight.reshape(out_channels // env.BLOCK_OUT,
env.BLOCK_OUT,
in_channels // env.BLOCK_IN,
env.BLOCK_IN).transpose((0, 2, 1, 3))
# Apply packing to the data and weight arrays from a 2D to a 4D packed layout
data_packed = data_np.reshape(batch_size // env.BATCH,
env.BATCH,
in_channels // env.BLOCK_IN,
env.BLOCK_IN).transpose((0, 2, 1, 3))
weight_packed = weight_np.reshape(out_channels // env.BLOCK_OUT,
env.BLOCK_OUT,
in_channels // env.BLOCK_IN,
env.BLOCK_IN).transpose((0, 2, 1, 3))
# Format the input/output arrays with tvm.nd.array to the DLPack standard
data_nd = tvm.nd.array(data_packed, ctx)
@ -338,8 +339,8 @@ res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)
f(data_nd, weight_nd, res_nd)
# Verify against numpy implementation
res_ref = np.dot(data.astype(env.acc_dtype),
weight.T.astype(env.acc_dtype))
res_ref = np.dot(data_np.astype(env.acc_dtype),
weight_np.T.astype(env.acc_dtype))
res_ref = res_ref >> env.INP_WIDTH
res_ref = np.clip(res_ref, 0, inp_max)
res_ref = res_ref.astype(res.dtype)